Coverage for python / lsst / dax / apdb / sql / modelToSql.py: 12%

118 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-30 08:48 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["GUID", "ModelToSql"] 

25 

26import uuid 

27from collections.abc import Iterable, Mapping 

28from typing import Any 

29 

30import felis.datamodel 

31import sqlalchemy 

32from sqlalchemy.dialects.postgresql import UUID 

33 

34from .. import schema_model 

35 

36 

37# 

38# Copied from daf_butler. 

39# 

40class GUID(sqlalchemy.TypeDecorator): 

41 """Platform-independent GUID type. 

42 

43 Uses PostgreSQL's UUID type, otherwise uses CHAR(32), storing as 

44 stringified hex values. 

45 """ 

46 

47 impl = sqlalchemy.CHAR 

48 

49 cache_ok = True 

50 

51 def load_dialect_impl(self, dialect: sqlalchemy.engine.Dialect) -> sqlalchemy.types.TypeEngine: 

52 if dialect.name == "postgresql": 

53 return dialect.type_descriptor(UUID()) 

54 else: 

55 return dialect.type_descriptor(sqlalchemy.CHAR(32)) 

56 

57 def process_bind_param(self, value: Any, dialect: sqlalchemy.engine.Dialect) -> str | None: 

58 if value is None: 

59 return value 

60 

61 # Coerce input to UUID type, in general having UUID on input is the 

62 # only thing that we want but there is code right now that uses ints. 

63 if isinstance(value, int): 

64 value = uuid.UUID(int=value) 

65 elif isinstance(value, bytes): 

66 value = uuid.UUID(bytes=value) 

67 elif isinstance(value, str): 

68 # hexstring 

69 value = uuid.UUID(hex=value) 

70 elif not isinstance(value, uuid.UUID): 

71 raise TypeError(f"Unexpected type of a bind value: {type(value)}") 

72 

73 if dialect.name == "postgresql": 

74 return str(value) 

75 else: 

76 return f"{value.int:032x}" 

77 

78 def process_result_value( 

79 self, value: str | uuid.UUID | None, dialect: sqlalchemy.engine.Dialect 

80 ) -> uuid.UUID | None: 

81 if value is None: 

82 return value 

83 elif isinstance(value, uuid.UUID): 

84 # sqlalchemy 2 converts to UUID internally 

85 return value 

86 else: 

87 return uuid.UUID(hex=value) 

88 

89 

90class ModelToSql: 

91 """Class which implements schema model conversion to SQLAlchemy format. 

92 

93 Parameters 

94 ---------- 

95 metadata : `sqlalchemy.schema.MetaData` 

96 Metadata object for created tables. 

97 prefix : `str`, optional 

98 Prefix to add to all schema elements. 

99 """ 

100 

101 def __init__( 

102 self, 

103 metadata: sqlalchemy.schema.MetaData, 

104 prefix: str = "", 

105 ): 

106 self._metadata = metadata 

107 self._prefix = prefix 

108 

109 # Map model column types to SQLAlchemy. 

110 self._type_map: dict[felis.datamodel.DataType | schema_model.ExtraDataTypes, type] = { 

111 felis.datamodel.DataType.double: sqlalchemy.types.Double, 

112 felis.datamodel.DataType.float: sqlalchemy.types.REAL, 

113 felis.datamodel.DataType.timestamp: sqlalchemy.types.TIMESTAMP, 

114 felis.datamodel.DataType.long: sqlalchemy.types.BigInteger, 

115 felis.datamodel.DataType.int: sqlalchemy.types.Integer, 

116 felis.datamodel.DataType.short: sqlalchemy.types.SmallInteger, 

117 felis.datamodel.DataType.byte: sqlalchemy.types.SmallInteger, # Byte types are not very portable 

118 felis.datamodel.DataType.binary: sqlalchemy.types.LargeBinary, 

119 felis.datamodel.DataType.text: sqlalchemy.types.Text, 

120 felis.datamodel.DataType.string: sqlalchemy.types.VARCHAR, 

121 felis.datamodel.DataType.char: sqlalchemy.types.CHAR, 

122 felis.datamodel.DataType.unicode: sqlalchemy.types.NVARCHAR, 

123 felis.datamodel.DataType.boolean: sqlalchemy.types.Boolean, 

124 schema_model.ExtraDataTypes.UUID: GUID, 

125 } 

126 

127 def make_tables(self, tables: Iterable[schema_model.Table]) -> Mapping[str, sqlalchemy.schema.Table]: 

128 """Generate sqlalchemy table schema from the list of modedls. 

129 

130 Parameters 

131 ---------- 

132 tables : `~collections.abc.Iterable` [`schema_model.Table`] 

133 List of table models. 

134 

135 Returns 

136 ------- 

137 tables : `~collections.abc.Mapping` [`str`, `sqlalchemy.schema.Table`] 

138 SQLAlchemy table definitions indexed by identifier of the table 

139 model. 

140 """ 

141 # Order tables based on their FK dependencies. 

142 tables = self._topo_sort(tables) 

143 

144 table_map: dict[str, sqlalchemy.schema.Table] = {} 

145 for table in tables: 

146 columns = self._table_columns(table) 

147 constraints = self._table_constraints(table, table_map) 

148 sa_table = sqlalchemy.schema.Table( 

149 self._prefix + table.name, 

150 self._metadata, 

151 *columns, 

152 *constraints, 

153 schema=self._metadata.schema, 

154 ) 

155 table_map[table.id] = sa_table 

156 

157 return table_map 

158 

159 def _table_columns(self, table: schema_model.Table) -> list[sqlalchemy.schema.Column]: 

160 """Return set of columns in a table 

161 

162 Parameters 

163 ---------- 

164 table : `schema_model.Table` 

165 Table model. 

166 

167 Returns 

168 ------- 

169 column_defs : `list` [`sqlalchemy.schema.Column`] 

170 List of columns. 

171 """ 

172 column_defs: list[sqlalchemy.schema.Column] = [] 

173 for column in table.columns: 

174 kwargs: dict[str, Any] = {"nullable": column.nullable} 

175 if column.value is not None: 

176 kwargs.update(server_default=str(column.value)) 

177 if column in table.primary_key and column.autoincrement is None: 

178 kwargs.update(autoincrement=False) 

179 else: 

180 kwargs.update(autoincrement=column.autoincrement) 

181 ctype = self._type_map[column.datatype] 

182 if column.length is not None: 

183 if ctype not in (sqlalchemy.types.Text, sqlalchemy.types.TIMESTAMP): 

184 ctype = ctype(length=column.length) 

185 if ctype is sqlalchemy.types.TIMESTAMP: 

186 # Use TIMESTAMP WITH TIMEZONE. 

187 ctype = ctype(timezone=True) 

188 column_defs.append(sqlalchemy.schema.Column(column.name, ctype, **kwargs)) 

189 

190 return column_defs 

191 

192 def _table_constraints( 

193 self, 

194 table: schema_model.Table, 

195 table_map: Mapping[str, sqlalchemy.schema.Table], 

196 ) -> list[sqlalchemy.schema.SchemaItem]: 

197 """Return set of constraints/indices in a table. 

198 

199 Parameters 

200 ---------- 

201 table : `schema_model.Table` 

202 Table model. 

203 table_map : `~collections.abc.Mapping` 

204 MApping of table ID to sqlalchemy table definition for tables 

205 that already exist, this must include all tables referenced by 

206 foreign keys in ``table``. 

207 

208 Returns 

209 ------- 

210 constraints : `list` [`sqlalchemy.schema.SchemaItem`] 

211 List of SQLAlchemy index/constraint objects. 

212 """ 

213 constraints: list[sqlalchemy.schema.SchemaItem] = [] 

214 if table.primary_key: 

215 # It is very useful to have named PK. 

216 name = self._prefix + table.name + "_pk" 

217 constraints.append( 

218 sqlalchemy.schema.PrimaryKeyConstraint(*[column.name for column in table.primary_key]) 

219 ) 

220 for index in table.indexes: 

221 if index.expressions: 

222 raise TypeError(f"Expression indices are not supported: {table}") 

223 name = self._prefix + index.name if index.name else "" 

224 constraints.append(sqlalchemy.schema.Index(name, *[column.name for column in index.columns])) 

225 for constraint in table.constraints: 

226 constr_name: str | None = None 

227 if constraint.name: 

228 constr_name = self._prefix + constraint.name 

229 if isinstance(constraint, schema_model.UniqueConstraint): 

230 constraints.append( 

231 sqlalchemy.schema.UniqueConstraint( 

232 *[column.name for column in constraint.columns], name=constr_name 

233 ) 

234 ) 

235 elif isinstance(constraint, schema_model.ForeignKeyConstraint): 

236 column_names = [col.name for col in constraint.columns] 

237 foreign_table = table_map[constraint.referenced_table.id] 

238 refcolumns = [foreign_table.columns[col.name] for col in constraint.referenced_columns] 

239 constraints.append( 

240 sqlalchemy.schema.ForeignKeyConstraint( 

241 columns=column_names, 

242 refcolumns=refcolumns, 

243 name=constr_name, 

244 deferrable=constraint.deferrable, 

245 initially=constraint.initially, 

246 onupdate=constraint.onupdate, 

247 ondelete=constraint.ondelete, 

248 ) 

249 ) 

250 elif isinstance(constraint, schema_model.CheckConstraint): 

251 constraints.append( 

252 sqlalchemy.schema.CheckConstraint( 

253 constraint.expression, 

254 name=constr_name, 

255 deferrable=constraint.deferrable, 

256 initially=constraint.initially, 

257 ) 

258 ) 

259 else: 

260 raise TypeError(f"Unknown constraint type: {constraint}") 

261 

262 return constraints 

263 

264 @staticmethod 

265 def _topo_sort(table_iter: Iterable[schema_model.Table]) -> list[schema_model.Table]: 

266 """Toplogical sorting of tables.""" 

267 result: list[schema_model.Table] = [] 

268 result_ids: set[str] = set() 

269 tables = list(table_iter) 

270 

271 # Map of table ID to foreign table IDs. 

272 referenced_tables: dict[str, set[str]] = {} 

273 for table in tables: 

274 referenced_tables[table.id] = set() 

275 for constraint in table.constraints: 

276 if isinstance(constraint, schema_model.ForeignKeyConstraint): 

277 referenced_tables[table.id].add(constraint.referenced_table.id) 

278 

279 while True: 

280 keep = [] 

281 changed = False 

282 for table in tables: 

283 if referenced_tables[table.id].issubset(result_ids): 

284 changed = True 

285 result.append(table) 

286 result_ids.add(table.id) 

287 else: 

288 keep.append(table) 

289 tables = keep 

290 if not changed: 

291 break 

292 

293 # If nothing can be removed it means cycle. 

294 if tables: 

295 raise ValueError(f"Dependency cycle in foreign keys: {tables}") 

296 

297 return result