Coverage for python/lsst/dax/apdb/sql/modelToSql.py: 13%

116 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-27 03:01 -0700

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ModelToSql", "GUID"] 

25 

26import uuid 

27from collections.abc import Iterable, Mapping 

28from typing import Any 

29 

30import felis.datamodel 

31import sqlalchemy 

32from sqlalchemy.dialects.postgresql import UUID 

33 

34from .. import schema_model 

35 

36 

37# 

38# Copied from daf_butler. 

39# 

40class GUID(sqlalchemy.TypeDecorator): 

41 """Platform-independent GUID type. 

42 

43 Uses PostgreSQL's UUID type, otherwise uses CHAR(32), storing as 

44 stringified hex values. 

45 """ 

46 

47 impl = sqlalchemy.CHAR 

48 

49 cache_ok = True 

50 

51 def load_dialect_impl(self, dialect: sqlalchemy.engine.Dialect) -> sqlalchemy.types.TypeEngine: 

52 if dialect.name == "postgresql": 

53 return dialect.type_descriptor(UUID()) 

54 else: 

55 return dialect.type_descriptor(sqlalchemy.CHAR(32)) 

56 

57 def process_bind_param(self, value: Any, dialect: sqlalchemy.engine.Dialect) -> str | None: 

58 if value is None: 

59 return value 

60 

61 # Coerce input to UUID type, in general having UUID on input is the 

62 # only thing that we want but there is code right now that uses ints. 

63 if isinstance(value, int): 

64 value = uuid.UUID(int=value) 

65 elif isinstance(value, bytes): 

66 value = uuid.UUID(bytes=value) 

67 elif isinstance(value, str): 

68 # hexstring 

69 value = uuid.UUID(hex=value) 

70 elif not isinstance(value, uuid.UUID): 

71 raise TypeError(f"Unexpected type of a bind value: {type(value)}") 

72 

73 if dialect.name == "postgresql": 

74 return str(value) 

75 else: 

76 return "%.32x" % value.int 

77 

78 def process_result_value( 

79 self, value: str | uuid.UUID | None, dialect: sqlalchemy.engine.Dialect 

80 ) -> uuid.UUID | None: 

81 if value is None: 

82 return value 

83 elif isinstance(value, uuid.UUID): 

84 # sqlalchemy 2 converts to UUID internally 

85 return value 

86 else: 

87 return uuid.UUID(hex=value) 

88 

89 

90class ModelToSql: 

91 """Class which implements schema model conversion to SQLAlchemy format. 

92 

93 Parameters 

94 ---------- 

95 metadata : `sqlalchemy.schema.MetaData` 

96 Metadata object for created tables. 

97 prefix : `str`, optional 

98 Prefix to add to all schema elements. 

99 """ 

100 

101 def __init__( 

102 self, 

103 metadata: sqlalchemy.schema.MetaData, 

104 prefix: str = "", 

105 ): 

106 self._metadata = metadata 

107 self._prefix = prefix 

108 

109 # Map model column types to SQLAlchemy. 

110 self._type_map: dict[felis.datamodel.DataType | schema_model.ExtraDataTypes, type] = { 

111 felis.datamodel.DataType.double: sqlalchemy.types.Double, 

112 felis.datamodel.DataType.float: sqlalchemy.types.Float, 

113 felis.datamodel.DataType.timestamp: sqlalchemy.types.TIMESTAMP, 

114 felis.datamodel.DataType.long: sqlalchemy.types.BigInteger, 

115 felis.datamodel.DataType.int: sqlalchemy.types.Integer, 

116 felis.datamodel.DataType.short: sqlalchemy.types.Integer, 

117 felis.datamodel.DataType.byte: sqlalchemy.types.Integer, 

118 felis.datamodel.DataType.binary: sqlalchemy.types.LargeBinary, 

119 felis.datamodel.DataType.text: sqlalchemy.types.Text, 

120 felis.datamodel.DataType.string: sqlalchemy.types.CHAR, 

121 felis.datamodel.DataType.char: sqlalchemy.types.CHAR, 

122 felis.datamodel.DataType.unicode: sqlalchemy.types.CHAR, 

123 felis.datamodel.DataType.boolean: sqlalchemy.types.Boolean, 

124 schema_model.ExtraDataTypes.UUID: GUID, 

125 } 

126 

127 def make_tables(self, tables: Iterable[schema_model.Table]) -> Mapping[str, sqlalchemy.schema.Table]: 

128 """Generate sqlalchemy table schema from the list of modedls. 

129 

130 Parameters 

131 ---------- 

132 tables : `~collections.abc.Iterable` [`schema_model.Table`] 

133 List of table models. 

134 

135 Returns 

136 ------- 

137 tables : `~collections.abc.Mapping` [`str`, `sqlalchemy.schema.Table`] 

138 SQLAlchemy table definitions indexed by identifier of the table 

139 model. 

140 """ 

141 # Order tables based on their FK dependencies. 

142 tables = self._topo_sort(tables) 

143 

144 table_map: dict[str, sqlalchemy.schema.Table] = {} 

145 for table in tables: 

146 columns = self._table_columns(table) 

147 constraints = self._table_constraints(table, table_map) 

148 sa_table = sqlalchemy.schema.Table( 

149 self._prefix + table.name, 

150 self._metadata, 

151 *columns, 

152 *constraints, 

153 schema=self._metadata.schema, 

154 ) 

155 table_map[table.id] = sa_table 

156 

157 return table_map 

158 

159 def _table_columns(self, table: schema_model.Table) -> list[sqlalchemy.schema.Column]: 

160 """Return set of columns in a table 

161 

162 Parameters 

163 ---------- 

164 table : `schema_model.Table` 

165 Table model. 

166 

167 Returns 

168 ------- 

169 column_defs : `list` [`sqlalchemy.schema.Column`] 

170 List of columns. 

171 """ 

172 column_defs: list[sqlalchemy.schema.Column] = [] 

173 for column in table.columns: 

174 kwargs: dict[str, Any] = dict(nullable=column.nullable) 

175 if column.value is not None: 

176 kwargs.update(server_default=str(column.value)) 

177 if column in table.primary_key and column.autoincrement is None: 

178 kwargs.update(autoincrement=False) 

179 else: 

180 kwargs.update(autoincrement=column.autoincrement) 

181 ctype = self._type_map[column.datatype] 

182 if column.length is not None: 

183 if ctype not in (sqlalchemy.types.Text, sqlalchemy.types.TIMESTAMP): 

184 ctype = ctype(length=column.length) 

185 column_defs.append(sqlalchemy.schema.Column(column.name, ctype, **kwargs)) 

186 

187 return column_defs 

188 

189 def _table_constraints( 

190 self, 

191 table: schema_model.Table, 

192 table_map: Mapping[str, sqlalchemy.schema.Table], 

193 ) -> list[sqlalchemy.schema.SchemaItem]: 

194 """Return set of constraints/indices in a table. 

195 

196 Parameters 

197 ---------- 

198 table : `schema_model.Table` 

199 Table model. 

200 table_map : `~collections.abc.Mapping` 

201 MApping of table ID to sqlalchemy table definition for tables 

202 that already exist, this must include all tables referenced by 

203 foreign keys in ``table``. 

204 

205 Returns 

206 ------- 

207 constraints : `list` [`sqlalchemy.schema.SchemaItem`] 

208 List of SQLAlchemy index/constraint objects. 

209 """ 

210 constraints: list[sqlalchemy.schema.SchemaItem] = [] 

211 if table.primary_key: 

212 # It is very useful to have named PK. 

213 name = self._prefix + table.name + "_pk" 

214 constraints.append( 

215 sqlalchemy.schema.PrimaryKeyConstraint(*[column.name for column in table.primary_key]) 

216 ) 

217 for index in table.indexes: 

218 if index.expressions: 

219 raise TypeError(f"Expression indices are not supported: {table}") 

220 name = self._prefix + index.name if index.name else "" 

221 constraints.append(sqlalchemy.schema.Index(name, *[column.name for column in index.columns])) 

222 for constraint in table.constraints: 

223 constr_name: str | None = None 

224 if constraint.name: 

225 constr_name = self._prefix + constraint.name 

226 if isinstance(constraint, schema_model.UniqueConstraint): 

227 constraints.append( 

228 sqlalchemy.schema.UniqueConstraint( 

229 *[column.name for column in constraint.columns], name=constr_name 

230 ) 

231 ) 

232 elif isinstance(constraint, schema_model.ForeignKeyConstraint): 

233 column_names = [col.name for col in constraint.columns] 

234 foreign_table = table_map[constraint.referenced_table.id] 

235 refcolumns = [foreign_table.columns[col.name] for col in constraint.referenced_columns] 

236 constraints.append( 

237 sqlalchemy.schema.ForeignKeyConstraint( 

238 columns=column_names, 

239 refcolumns=refcolumns, 

240 name=constr_name, 

241 deferrable=constraint.deferrable, 

242 initially=constraint.initially, 

243 onupdate=constraint.onupdate, 

244 ondelete=constraint.ondelete, 

245 ) 

246 ) 

247 elif isinstance(constraint, schema_model.CheckConstraint): 

248 constraints.append( 

249 sqlalchemy.schema.CheckConstraint( 

250 constraint.expression, 

251 name=constr_name, 

252 deferrable=constraint.deferrable, 

253 initially=constraint.initially, 

254 ) 

255 ) 

256 else: 

257 raise TypeError(f"Unknown constraint type: {constraint}") 

258 

259 return constraints 

260 

261 @staticmethod 

262 def _topo_sort(table_iter: Iterable[schema_model.Table]) -> list[schema_model.Table]: 

263 """Toplogical sorting of tables.""" 

264 result: list[schema_model.Table] = [] 

265 result_ids: set[str] = set() 

266 tables = list(table_iter) 

267 

268 # Map of table ID to foreign table IDs. 

269 referenced_tables: dict[str, set[str]] = {} 

270 for table in tables: 

271 referenced_tables[table.id] = set() 

272 for constraint in table.constraints: 

273 if isinstance(constraint, schema_model.ForeignKeyConstraint): 

274 referenced_tables[table.id].add(constraint.referenced_table.id) 

275 

276 while True: 

277 keep = [] 

278 changed = False 

279 for table in tables: 

280 if referenced_tables[table.id].issubset(result_ids): 

281 changed = True 

282 result.append(table) 

283 result_ids.add(table.id) 

284 else: 

285 keep.append(table) 

286 tables = keep 

287 if not changed: 

288 break 

289 

290 # If nothing can be removed it means cycle. 

291 if tables: 

292 raise ValueError(f"Dependency cycle in foreign keys: {tables}") 

293 

294 return result