Coverage for python/felis/datamodel.py: 62%

247 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-02 02:10 -0700

1# This file is part of felis. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24import logging 

25from collections.abc import Mapping, Sequence 

26from enum import Enum 

27from typing import Annotated, Any, Literal, TypeAlias 

28 

29from astropy import units as units # type: ignore 

30from astropy.io.votable import ucd # type: ignore 

31from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator 

32 

33logger = logging.getLogger(__name__) 

34 

35__all__ = ( 

36 "BaseObject", 

37 "Column", 

38 "CheckConstraint", 

39 "Constraint", 

40 "DescriptionStr", 

41 "ForeignKeyConstraint", 

42 "Index", 

43 "Schema", 

44 "SchemaVersion", 

45 "Table", 

46 "UniqueConstraint", 

47) 

48 

49CONFIG = ConfigDict( 

50 populate_by_name=True, # Populate attributes by name. 

51 extra="forbid", # Do not allow extra fields. 

52 validate_assignment=True, # Validate assignments after model is created. 

53 str_strip_whitespace=True, # Strip whitespace from string fields. 

54) 

55"""Pydantic model configuration as described in: 

56https://docs.pydantic.dev/2.0/api/config/#pydantic.config.ConfigDict 

57""" 

58 

59DESCR_MIN_LENGTH = 3 

60"""Minimum length for a description field.""" 

61 

62DescriptionStr: TypeAlias = Annotated[str, Field(min_length=DESCR_MIN_LENGTH)] 

63"""Define a type for a description string, which must be three or more 

64characters long. Stripping of whitespace is done globally on all str fields.""" 

65 

66 

67class BaseObject(BaseModel): 

68 """Base class for all Felis objects.""" 

69 

70 model_config = CONFIG 

71 """Pydantic model configuration.""" 

72 

73 name: str 

74 """The name of the database object. 

75 

76 All Felis database objects must have a name. 

77 """ 

78 

79 id: str = Field(alias="@id") 

80 """The unique identifier of the database object. 

81 

82 All Felis database objects must have a unique identifier. 

83 """ 

84 

85 description: DescriptionStr | None = None 

86 """A description of the database object. 

87 

88 By default, the description is optional but will be required if 

89 `BaseObject.Config.require_description` is set to `True` by the user. 

90 """ 

91 

92 @model_validator(mode="before") 

93 @classmethod 

94 def check_description(cls, values: dict[str, Any]) -> dict[str, Any]: 

95 """Check that the description is present if required.""" 

96 if Schema.is_description_required(): 

97 if "description" not in values or not values["description"]: 

98 raise ValueError("Description is required and must be non-empty") 

99 if len(values["description"].strip()) < DESCR_MIN_LENGTH: 

100 raise ValueError(f"Description must be at least {DESCR_MIN_LENGTH} characters long") 

101 return values 

102 

103 

104class DataType(Enum): 

105 """`Enum` representing the data types supported by Felis.""" 

106 

107 BOOLEAN = "boolean" 

108 BYTE = "byte" 

109 SHORT = "short" 

110 INT = "int" 

111 LONG = "long" 

112 FLOAT = "float" 

113 DOUBLE = "double" 

114 CHAR = "char" 

115 STRING = "string" 

116 UNICODE = "unicode" 

117 TEXT = "text" 

118 BINARY = "binary" 

119 TIMESTAMP = "timestamp" 

120 

121 

122class Column(BaseObject): 

123 """A column in a table.""" 

124 

125 datatype: DataType 

126 """The datatype of the column.""" 

127 

128 length: int | None = None 

129 """The length of the column.""" 

130 

131 nullable: bool | None = None 

132 """Whether the column can be ``NULL``. 

133 

134 If `None`, this value was not set explicitly in the YAML data. In this 

135 case, it will be set to `False` for columns with numeric types and `True` 

136 otherwise. 

137 """ 

138 

139 value: Any = None 

140 """The default value of the column.""" 

141 

142 autoincrement: bool | None = None 

143 """Whether the column is autoincremented.""" 

144 

145 mysql_datatype: str | None = Field(None, alias="mysql:datatype") 

146 """The MySQL datatype of the column.""" 

147 

148 postgresql_datatype: str | None = Field(None, alias="postgresql:datatype") 

149 """The PostgreSQL datatype of the column.""" 

150 

151 ivoa_ucd: str | None = Field(None, alias="ivoa:ucd") 

152 """The IVOA UCD of the column.""" 

153 

154 fits_tunit: str | None = Field(None, alias="fits:tunit") 

155 """The FITS TUNIT of the column.""" 

156 

157 ivoa_unit: str | None = Field(None, alias="ivoa:unit") 

158 """The IVOA unit of the column.""" 

159 

160 tap_column_index: int | None = Field(None, alias="tap:column_index") 

161 """The TAP_SCHEMA column index of the column.""" 

162 

163 tap_principal: int | None = Field(0, alias="tap:principal", ge=0, le=1) 

164 """Whether this is a TAP_SCHEMA principal column; can be either 0 or 1. 

165 """ 

166 

167 votable_arraysize: int | Literal["*"] | None = Field(None, alias="votable:arraysize") 

168 """The VOTable arraysize of the column.""" 

169 

170 tap_std: int | None = Field(0, alias="tap:std", ge=0, le=1) 

171 """TAP_SCHEMA indication that this column is defined by an IVOA standard. 

172 """ 

173 

174 votable_utype: str | None = Field(None, alias="votable:utype") 

175 """The VOTable utype (usage-specific or unique type) of the column.""" 

176 

177 votable_xtype: str | None = Field(None, alias="votable:xtype") 

178 """The VOTable xtype (extended type) of the column.""" 

179 

180 @field_validator("ivoa_ucd") 

181 @classmethod 

182 def check_ivoa_ucd(cls, ivoa_ucd: str) -> str: 

183 """Check that IVOA UCD values are valid.""" 

184 if ivoa_ucd is not None: 

185 try: 

186 ucd.parse_ucd(ivoa_ucd, check_controlled_vocabulary=True, has_colon=";" in ivoa_ucd) 

187 except ValueError as e: 

188 raise ValueError(f"Invalid IVOA UCD: {e}") 

189 return ivoa_ucd 

190 

191 @model_validator(mode="before") 

192 @classmethod 

193 def check_units(cls, values: dict[str, Any]) -> dict[str, Any]: 

194 """Check that units are valid.""" 

195 fits_unit = values.get("fits:tunit") 

196 ivoa_unit = values.get("ivoa:unit") 

197 

198 if fits_unit and ivoa_unit: 

199 raise ValueError("Column cannot have both FITS and IVOA units") 

200 unit = fits_unit or ivoa_unit 

201 

202 if unit is not None: 

203 try: 

204 units.Unit(unit) 

205 except ValueError as e: 

206 raise ValueError(f"Invalid unit: {e}") 

207 

208 return values 

209 

210 

211class Constraint(BaseObject): 

212 """A database table constraint.""" 

213 

214 deferrable: bool = False 

215 """If `True` then this constraint will be declared as deferrable.""" 

216 

217 initially: str | None = None 

218 """Value for ``INITIALLY`` clause, only used if ``deferrable`` is True.""" 

219 

220 annotations: Mapping[str, Any] = Field(default_factory=dict) 

221 """Additional annotations for this constraint.""" 

222 

223 type: str | None = Field(None, alias="@type") 

224 """The type of the constraint.""" 

225 

226 

227class CheckConstraint(Constraint): 

228 """A check constraint on a table.""" 

229 

230 expression: str 

231 """The expression for the check constraint.""" 

232 

233 

234class UniqueConstraint(Constraint): 

235 """A unique constraint on a table.""" 

236 

237 columns: list[str] 

238 """The columns in the unique constraint.""" 

239 

240 

241class Index(BaseObject): 

242 """A database table index. 

243 

244 An index can be defined on either columns or expressions, but not both. 

245 """ 

246 

247 columns: list[str] | None = None 

248 """The columns in the index.""" 

249 

250 expressions: list[str] | None = None 

251 """The expressions in the index.""" 

252 

253 @model_validator(mode="before") 

254 @classmethod 

255 def check_columns_or_expressions(cls, values: dict[str, Any]) -> dict[str, Any]: 

256 """Check that columns or expressions are specified, but not both.""" 

257 if "columns" in values and "expressions" in values: 

258 raise ValueError("Defining columns and expressions is not valid") 

259 elif "columns" not in values and "expressions" not in values: 

260 raise ValueError("Must define columns or expressions") 

261 return values 

262 

263 

264class ForeignKeyConstraint(Constraint): 

265 """A foreign key constraint on a table. 

266 

267 These will be reflected in the TAP_SCHEMA keys and key_columns data. 

268 """ 

269 

270 columns: list[str] 

271 """The columns comprising the foreign key.""" 

272 

273 referenced_columns: list[str] = Field(alias="referencedColumns") 

274 """The columns referenced by the foreign key.""" 

275 

276 

277class Table(BaseObject): 

278 """A database table.""" 

279 

280 columns: Sequence[Column] 

281 """The columns in the table.""" 

282 

283 constraints: list[Constraint] = Field(default_factory=list) 

284 """The constraints on the table.""" 

285 

286 indexes: list[Index] = Field(default_factory=list) 

287 """The indexes on the table.""" 

288 

289 primary_key: str | list[str] | None = Field(None, alias="primaryKey") 

290 """The primary key of the table.""" 

291 

292 tap_table_index: int | None = Field(None, alias="tap:table_index") 

293 """The IVOA TAP_SCHEMA table index of the table.""" 

294 

295 mysql_engine: str | None = Field(None, alias="mysql:engine") 

296 """The mysql engine to use for the table. 

297 

298 For now this is a freeform string but it could be constrained to a list of 

299 known engines in the future. 

300 """ 

301 

302 mysql_charset: str | None = Field(None, alias="mysql:charset") 

303 """The mysql charset to use for the table. 

304 

305 For now this is a freeform string but it could be constrained to a list of 

306 known charsets in the future. 

307 """ 

308 

309 @model_validator(mode="before") 

310 @classmethod 

311 def create_constraints(cls, values: dict[str, Any]) -> dict[str, Any]: 

312 """Create constraints from the ``constraints`` field.""" 

313 if "constraints" in values: 

314 new_constraints: list[Constraint] = [] 

315 for item in values["constraints"]: 

316 if item["@type"] == "ForeignKey": 

317 new_constraints.append(ForeignKeyConstraint(**item)) 

318 elif item["@type"] == "Unique": 

319 new_constraints.append(UniqueConstraint(**item)) 

320 elif item["@type"] == "Check": 

321 new_constraints.append(CheckConstraint(**item)) 

322 else: 

323 raise ValueError(f"Unknown constraint type: {item['@type']}") 

324 values["constraints"] = new_constraints 

325 return values 

326 

327 @field_validator("columns", mode="after") 

328 @classmethod 

329 def check_unique_column_names(cls, columns: list[Column]) -> list[Column]: 

330 """Check that column names are unique.""" 

331 if len(columns) != len(set(column.name for column in columns)): 

332 raise ValueError("Column names must be unique") 

333 return columns 

334 

335 

336class SchemaVersion(BaseModel): 

337 """The version of the schema.""" 

338 

339 current: str 

340 """The current version of the schema.""" 

341 

342 compatible: list[str] = Field(default_factory=list) 

343 """The compatible versions of the schema.""" 

344 

345 read_compatible: list[str] = Field(default_factory=list) 

346 """The read compatible versions of the schema.""" 

347 

348 

349class SchemaIdVisitor: 

350 """Visitor to build a Schema object's map of IDs to objects. 

351 

352 Duplicates are added to a set when they are encountered, which can be 

353 accessed via the `duplicates` attribute. The presence of duplicates will 

354 not throw an error. Only the first object with a given ID will be added to 

355 the map, but this should not matter, since a ValidationError will be thrown 

356 by the `model_validator` method if any duplicates are found in the schema. 

357 

358 This class is intended for internal use only. 

359 """ 

360 

361 def __init__(self) -> None: 

362 """Create a new SchemaVisitor.""" 

363 self.schema: "Schema" | None = None 

364 self.duplicates: set[str] = set() 

365 

366 def add(self, obj: BaseObject) -> None: 

367 """Add an object to the ID map.""" 

368 if hasattr(obj, "id"): 

369 obj_id = getattr(obj, "id") 

370 if self.schema is not None: 

371 if obj_id in self.schema.id_map: 

372 self.duplicates.add(obj_id) 

373 else: 

374 self.schema.id_map[obj_id] = obj 

375 

376 def visit_schema(self, schema: "Schema") -> None: 

377 """Visit the schema object that was added during initialization. 

378 

379 This will set an internal variable pointing to the schema object. 

380 """ 

381 self.schema = schema 

382 self.duplicates.clear() 

383 self.add(self.schema) 

384 for table in self.schema.tables: 

385 self.visit_table(table) 

386 

387 def visit_table(self, table: Table) -> None: 

388 """Visit a table object.""" 

389 self.add(table) 

390 for column in table.columns: 

391 self.visit_column(column) 

392 for constraint in table.constraints: 

393 self.visit_constraint(constraint) 

394 

395 def visit_column(self, column: Column) -> None: 

396 """Visit a column object.""" 

397 self.add(column) 

398 

399 def visit_constraint(self, constraint: Constraint) -> None: 

400 """Visit a constraint object.""" 

401 self.add(constraint) 

402 

403 

404class Schema(BaseObject): 

405 """The database schema containing the tables.""" 

406 

407 class ValidationConfig: 

408 """Validation configuration which is specific to Felis.""" 

409 

410 _require_description = False 

411 """Flag to require a description for all objects. 

412 

413 This is set by the `require_description` class method. 

414 """ 

415 

416 version: SchemaVersion | str | None = None 

417 """The version of the schema.""" 

418 

419 tables: Sequence[Table] 

420 """The tables in the schema.""" 

421 

422 id_map: dict[str, Any] = Field(default_factory=dict, exclude=True) 

423 """Map of IDs to objects.""" 

424 

425 @field_validator("tables", mode="after") 

426 @classmethod 

427 def check_unique_table_names(cls, tables: list[Table]) -> list[Table]: 

428 """Check that table names are unique.""" 

429 if len(tables) != len(set(table.name for table in tables)): 

430 raise ValueError("Table names must be unique") 

431 return tables 

432 

433 @model_validator(mode="after") 

434 def create_id_map(self: Schema) -> Schema: 

435 """Create a map of IDs to objects.""" 

436 if len(self.id_map): 

437 logger.debug("ID map was already populated") 

438 return self 

439 visitor: SchemaIdVisitor = SchemaIdVisitor() 

440 visitor.visit_schema(self) 

441 logger.debug(f"ID map contains {len(self.id_map.keys())} objects") 

442 if len(visitor.duplicates): 

443 raise ValueError( 

444 "Duplicate IDs found in schema:\n " + "\n ".join(visitor.duplicates) + "\n" 

445 ) 

446 return self 

447 

448 def __getitem__(self, id: str) -> BaseObject: 

449 """Get an object by its ID.""" 

450 if id not in self: 

451 raise KeyError(f"Object with ID '{id}' not found in schema") 

452 return self.id_map[id] 

453 

454 def __contains__(self, id: str) -> bool: 

455 """Check if an object with the given ID is in the schema.""" 

456 return id in self.id_map 

457 

458 @classmethod 

459 def require_description(cls, rd: bool = True) -> None: 

460 """Set whether a description is required for all objects. 

461 

462 This includes the schema, tables, columns, and constraints. 

463 

464 Users should call this method to set the requirement for a description 

465 when validating schemas, rather than change the flag value directly. 

466 """ 

467 logger.debug(f"Setting description requirement to '{rd}'") 

468 cls.ValidationConfig._require_description = rd 

469 

470 @classmethod 

471 def is_description_required(cls) -> bool: 

472 """Return whether a description is required for all objects.""" 

473 return cls.ValidationConfig._require_description