Coverage for python/felis/datamodel.py: 62%

214 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-23 10:44 +0000

1# This file is part of felis. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import logging 

23from collections.abc import Mapping 

24from enum import Enum 

25from typing import Any, Literal 

26 

27from astropy import units as units # type: ignore 

28from astropy.io.votable import ucd # type: ignore 

29from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator 

30 

31logger = logging.getLogger(__name__) 

32# logger.setLevel(logging.DEBUG) 

33 

34__all__ = ( 

35 "BaseObject", 

36 "Column", 

37 "Constraint", 

38 "CheckConstraint", 

39 "UniqueConstraint", 

40 "Index", 

41 "ForeignKeyConstraint", 

42 "Table", 

43 "SchemaVersion", 

44 "Schema", 

45) 

46 

47 

48class BaseObject(BaseModel): 

49 """Base class for all Felis objects.""" 

50 

51 model_config = ConfigDict(populate_by_name=True, extra="forbid", use_enum_values=True) 

52 """Configuration for the `BaseModel` class. 

53 

54 Allow attributes to be populated by name and forbid extra attributes. 

55 """ 

56 

57 name: str 

58 """The name of the database object. 

59 

60 All Felis database objects must have a name. 

61 """ 

62 

63 id: str = Field(alias="@id") 

64 """The unique identifier of the database object. 

65 

66 All Felis database objects must have a unique identifier. 

67 """ 

68 

69 description: str | None = None 

70 """A description of the database object. 

71 

72 The description is optional. 

73 """ 

74 

75 

76class DataType(Enum): 

77 """`Enum` representing the data types supported by Felis.""" 

78 

79 BOOLEAN = "boolean" 

80 BYTE = "byte" 

81 SHORT = "short" 

82 INT = "int" 

83 LONG = "long" 

84 FLOAT = "float" 

85 DOUBLE = "double" 

86 CHAR = "char" 

87 STRING = "string" 

88 UNICODE = "unicode" 

89 TEXT = "text" 

90 BINARY = "binary" 

91 TIMESTAMP = "timestamp" 

92 

93 

94class Column(BaseObject): 

95 """A column in a table.""" 

96 

97 datatype: DataType 

98 """The datatype of the column.""" 

99 

100 length: int | None = None 

101 """The length of the column.""" 

102 

103 nullable: bool = True 

104 """Whether the column can be `NULL`.""" 

105 

106 value: Any = None 

107 """The default value of the column.""" 

108 

109 autoincrement: bool | None = None 

110 """Whether the column is autoincremented.""" 

111 

112 mysql_datatype: str | None = Field(None, alias="mysql:datatype") 

113 """The MySQL datatype of the column.""" 

114 

115 ivoa_ucd: str | None = Field(None, alias="ivoa:ucd") 

116 """The IVOA UCD of the column.""" 

117 

118 fits_tunit: str | None = Field(None, alias="fits:tunit") 

119 """The FITS TUNIT of the column.""" 

120 

121 ivoa_unit: str | None = Field(None, alias="ivoa:unit") 

122 """The IVOA unit of the column.""" 

123 

124 tap_column_index: int | None = Field(None, alias="tap:column_index") 

125 """The TAP_SCHEMA column index of the column.""" 

126 

127 tap_principal: int | None = Field(0, alias="tap:principal", ge=0, le=1) 

128 """Whether this is a TAP_SCHEMA principal column; can be either 0 or 1. 

129 

130 This could be a boolean instead of 0 or 1. 

131 """ 

132 

133 votable_arraysize: int | Literal["*"] | None = Field(None, alias="votable:arraysize") 

134 """The VOTable arraysize of the column.""" 

135 

136 tap_std: int | None = Field(0, alias="tap:std", ge=0, le=1) 

137 """TAP_SCHEMA indication that this column is defined by an IVOA standard. 

138 """ 

139 

140 votable_utype: str | None = Field(None, alias="votable:utype") 

141 """The VOTable utype (usage-specific or unique type) of the column.""" 

142 

143 votable_xtype: str | None = Field(None, alias="votable:xtype") 

144 """The VOTable xtype (extended type) of the column.""" 

145 

146 @field_validator("ivoa_ucd") 

147 @classmethod 

148 def check_ivoa_ucd(cls, ivoa_ucd: str) -> str: 

149 """Check that IVOA UCD values are valid.""" 

150 if ivoa_ucd is not None: 

151 try: 

152 ucd.parse_ucd(ivoa_ucd, check_controlled_vocabulary=True, has_colon=";" in ivoa_ucd) 

153 except ValueError as e: 

154 raise ValueError(f"Invalid IVOA UCD: {e}") 

155 return ivoa_ucd 

156 

157 @model_validator(mode="before") 

158 @classmethod 

159 def check_units(cls, values: dict[str, Any]) -> dict[str, Any]: 

160 """Check that units are valid.""" 

161 fits_unit = values.get("fits:tunit") 

162 ivoa_unit = values.get("ivoa:unit") 

163 

164 if fits_unit and ivoa_unit: 

165 raise ValueError("Column cannot have both FITS and IVOA units") 

166 unit = fits_unit or ivoa_unit 

167 

168 if unit is not None: 

169 try: 

170 units.Unit(unit) 

171 except ValueError as e: 

172 raise ValueError(f"Invalid unit: {e}") 

173 

174 return values 

175 

176 

177class Constraint(BaseObject): 

178 """A database table constraint.""" 

179 

180 deferrable: bool = False 

181 """If `True` then this constraint will be declared as deferrable.""" 

182 

183 initially: str | None = None 

184 """Value for ``INITIALLY`` clause, only used if ``deferrable`` is True.""" 

185 

186 annotations: Mapping[str, Any] = Field(default_factory=dict) 

187 """Additional annotations for this constraint.""" 

188 

189 type: str | None = Field(None, alias="@type") 

190 """The type of the constraint.""" 

191 

192 

193class CheckConstraint(Constraint): 

194 """A check constraint on a table.""" 

195 

196 expression: str 

197 """The expression for the check constraint.""" 

198 

199 

200class UniqueConstraint(Constraint): 

201 """A unique constraint on a table.""" 

202 

203 columns: list[str] 

204 """The columns in the unique constraint.""" 

205 

206 

207class Index(BaseObject): 

208 """A database table index. 

209 

210 An index can be defined on either columns or expressions, but not both. 

211 """ 

212 

213 columns: list[str] | None = None 

214 """The columns in the index.""" 

215 

216 expressions: list[str] | None = None 

217 """The expressions in the index.""" 

218 

219 @model_validator(mode="before") 

220 @classmethod 

221 def check_columns_or_expressions(cls, values: dict[str, Any]) -> dict[str, Any]: 

222 """Check that columns or expressions are specified, but not both.""" 

223 if "columns" in values and "expressions" in values: 

224 raise ValueError("Defining columns and expressions is not valid") 

225 elif "columns" not in values and "expressions" not in values: 

226 raise ValueError("Must define columns or expressions") 

227 return values 

228 

229 

230class ForeignKeyConstraint(Constraint): 

231 """A foreign key constraint on a table. 

232 

233 These will be reflected in the TAP_SCHEMA keys and key_columns data. 

234 """ 

235 

236 columns: list[str] 

237 """The columns comprising the foreign key.""" 

238 

239 referenced_columns: list[str] = Field(alias="referencedColumns") 

240 """The columns referenced by the foreign key.""" 

241 

242 

243class Table(BaseObject): 

244 """A database table.""" 

245 

246 columns: list[Column] 

247 """The columns in the table.""" 

248 

249 constraints: list[Constraint] = Field(default_factory=list) 

250 """The constraints on the table.""" 

251 

252 indexes: list[Index] = Field(default_factory=list) 

253 """The indexes on the table.""" 

254 

255 primaryKey: str | list[str] | None = None 

256 """The primary key of the table.""" 

257 

258 tap_table_index: int | None = Field(None, alias="tap:table_index") 

259 """The IVOA TAP_SCHEMA table index of the table.""" 

260 

261 mysql_engine: str | None = Field(None, alias="mysql:engine") 

262 """The mysql engine to use for the table. 

263 

264 For now this is a freeform string but it could be constrained to a list of 

265 known engines in the future. 

266 """ 

267 

268 mysql_charset: str | None = Field(None, alias="mysql:charset") 

269 """The mysql charset to use for the table. 

270 

271 For now this is a freeform string but it could be constrained to a list of 

272 known charsets in the future. 

273 """ 

274 

275 @model_validator(mode="before") 

276 @classmethod 

277 def create_constraints(cls, values: dict[str, Any]) -> dict[str, Any]: 

278 """Create constraints from the ``constraints`` field.""" 

279 if "constraints" in values: 

280 new_constraints: list[Constraint] = [] 

281 for item in values["constraints"]: 

282 if item["@type"] == "ForeignKey": 

283 new_constraints.append(ForeignKeyConstraint(**item)) 

284 elif item["@type"] == "Unique": 

285 new_constraints.append(UniqueConstraint(**item)) 

286 elif item["@type"] == "Check": 

287 new_constraints.append(CheckConstraint(**item)) 

288 else: 

289 raise ValueError(f"Unknown constraint type: {item['@type']}") 

290 values["constraints"] = new_constraints 

291 return values 

292 

293 @field_validator("columns", mode="after") 

294 @classmethod 

295 def check_unique_column_names(cls, columns: list[Column]) -> list[Column]: 

296 """Check that column names are unique.""" 

297 if len(columns) != len(set(column.name for column in columns)): 

298 raise ValueError("Column names must be unique") 

299 return columns 

300 

301 

302class SchemaVersion(BaseModel): 

303 """The version of the schema.""" 

304 

305 current: str 

306 """The current version of the schema.""" 

307 

308 compatible: list[str] | None = None 

309 """The compatible versions of the schema.""" 

310 

311 read_compatible: list[str] | None = None 

312 """The read compatible versions of the schema.""" 

313 

314 

315class SchemaVisitor: 

316 """Visitor to build a Schema object's map of IDs to objects. 

317 

318 Duplicates are added to a set when they are encountered, which can be 

319 accessed via the `duplicates` attribute. The presence of duplicates will 

320 not throw an error. Only the first object with a given ID will be added to 

321 the map, but this should not matter, since a ValidationError will be thrown 

322 by the `model_validator` method if any duplicates are found in the schema. 

323 

324 This class is intended for internal use only. 

325 """ 

326 

327 def __init__(self) -> None: 

328 """Create a new SchemaVisitor.""" 

329 self.schema: "Schema" | None = None 

330 self.duplicates: set[str] = set() 

331 

332 def add(self, obj: BaseObject) -> None: 

333 """Add an object to the ID map.""" 

334 if hasattr(obj, "id"): 

335 obj_id = getattr(obj, "id") 

336 if self.schema is not None: 

337 if obj_id in self.schema.id_map: 

338 self.duplicates.add(obj_id) 

339 else: 

340 self.schema.id_map[obj_id] = obj 

341 

342 def visit_schema(self, schema: "Schema") -> None: 

343 """Visit the schema object that was added during initialization. 

344 

345 This will set an internal variable pointing to the schema object. 

346 """ 

347 self.schema = schema 

348 self.duplicates.clear() 

349 self.add(self.schema) 

350 for table in self.schema.tables: 

351 self.visit_table(table) 

352 

353 def visit_table(self, table: Table) -> None: 

354 """Visit a table object.""" 

355 self.add(table) 

356 for column in table.columns: 

357 self.visit_column(column) 

358 for constraint in table.constraints: 

359 self.visit_constraint(constraint) 

360 

361 def visit_column(self, column: Column) -> None: 

362 """Visit a column object.""" 

363 self.add(column) 

364 

365 def visit_constraint(self, constraint: Constraint) -> None: 

366 """Visit a constraint object.""" 

367 self.add(constraint) 

368 

369 

370class Schema(BaseObject): 

371 """The database schema.""" 

372 

373 version: SchemaVersion | None = None 

374 """The version of the schema.""" 

375 

376 tables: list[Table] 

377 """The tables in the schema.""" 

378 

379 id_map: dict[str, Any] = Field(default_factory=dict, exclude=True) 

380 """Map of IDs to objects.""" 

381 

382 @field_validator("tables", mode="after") 

383 @classmethod 

384 def check_unique_table_names(cls, tables: list[Table]) -> list[Table]: 

385 """Check that table names are unique.""" 

386 if len(tables) != len(set(table.name for table in tables)): 

387 raise ValueError("Table names must be unique") 

388 return tables 

389 

390 @model_validator(mode="after") 

391 def create_id_map(self) -> "Schema": 

392 """Create a map of IDs to objects.""" 

393 visitor: SchemaVisitor = SchemaVisitor() 

394 visitor.visit_schema(self) 

395 logger.debug(f"ID map contains {len(self.id_map.keys())} objects") 

396 if len(visitor.duplicates): 

397 raise ValueError( 

398 "Duplicate IDs found in schema:\n " + "\n ".join(visitor.duplicates) + "\n" 

399 ) 

400 return self 

401 

402 def get_object_by_id(self, id: str) -> BaseObject: 

403 """Get an object by its unique "@id" field value. 

404 

405 An error will be thrown if the object is not found. 

406 """ 

407 if id not in self.id_map: 

408 raise ValueError(f"Object with ID {id} not found in schema") 

409 return self.id_map[id]