Coverage for python/felis/datamodel.py: 52%

301 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-24 02:46 -0700

1# This file is part of felis. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24import logging 

25import re 

26from collections.abc import Mapping, Sequence 

27from enum import StrEnum, auto 

28from typing import Annotated, Any, Literal, TypeAlias 

29 

30from astropy import units as units # type: ignore 

31from astropy.io.votable import ucd # type: ignore 

32from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator, model_validator 

33from sqlalchemy import dialects 

34from sqlalchemy import types as sqa_types 

35from sqlalchemy.engine import create_mock_engine 

36from sqlalchemy.engine.interfaces import Dialect 

37from sqlalchemy.types import TypeEngine 

38 

39from .db.sqltypes import get_type_func 

40from .types import FelisType 

41 

42logger = logging.getLogger(__name__) 

43 

44__all__ = ( 

45 "BaseObject", 

46 "Column", 

47 "CheckConstraint", 

48 "Constraint", 

49 "DescriptionStr", 

50 "ForeignKeyConstraint", 

51 "Index", 

52 "Schema", 

53 "SchemaVersion", 

54 "Table", 

55 "UniqueConstraint", 

56) 

57 

58CONFIG = ConfigDict( 

59 populate_by_name=True, # Populate attributes by name. 

60 extra="forbid", # Do not allow extra fields. 

61 str_strip_whitespace=True, # Strip whitespace from string fields. 

62) 

63"""Pydantic model configuration as described in: 

64https://docs.pydantic.dev/2.0/api/config/#pydantic.config.ConfigDict 

65""" 

66 

67DESCR_MIN_LENGTH = 3 

68"""Minimum length for a description field.""" 

69 

70DescriptionStr: TypeAlias = Annotated[str, Field(min_length=DESCR_MIN_LENGTH)] 

71"""Define a type for a description string, which must be three or more 

72characters long. Stripping of whitespace is done globally on all str fields.""" 

73 

74 

75class BaseObject(BaseModel): 

76 """Base class for all Felis objects.""" 

77 

78 model_config = CONFIG 

79 """Pydantic model configuration.""" 

80 

81 name: str 

82 """The name of the database object. 

83 

84 All Felis database objects must have a name. 

85 """ 

86 

87 id: str = Field(alias="@id") 

88 """The unique identifier of the database object. 

89 

90 All Felis database objects must have a unique identifier. 

91 """ 

92 

93 description: DescriptionStr | None = None 

94 """A description of the database object.""" 

95 

96 votable_utype: str | None = Field(None, alias="votable:utype") 

97 """The VOTable utype (usage-specific or unique type) of the object.""" 

98 

99 @model_validator(mode="after") 

100 def check_description(self, info: ValidationInfo) -> BaseObject: 

101 """Check that the description is present if required.""" 

102 context = info.context 

103 if not context or not context.get("require_description", False): 

104 return self 

105 if self.description is None or self.description == "": 

106 raise ValueError("Description is required and must be non-empty") 

107 if len(self.description) < DESCR_MIN_LENGTH: 

108 raise ValueError(f"Description must be at least {DESCR_MIN_LENGTH} characters long") 

109 return self 

110 

111 

112class DataType(StrEnum): 

113 """`Enum` representing the data types supported by Felis.""" 

114 

115 boolean = auto() 

116 byte = auto() 

117 short = auto() 

118 int = auto() 

119 long = auto() 

120 float = auto() 

121 double = auto() 

122 char = auto() 

123 string = auto() 

124 unicode = auto() 

125 text = auto() 

126 binary = auto() 

127 timestamp = auto() 

128 

129 

130_DIALECTS = { 

131 "mysql": create_mock_engine("mysql://", executor=None).dialect, 

132 "postgresql": create_mock_engine("postgresql://", executor=None).dialect, 

133} 

134"""Dictionary of dialect names to SQLAlchemy dialects.""" 

135 

136_DIALECT_MODULES = {"mysql": getattr(dialects, "mysql"), "postgresql": getattr(dialects, "postgresql")} 

137"""Dictionary of dialect names to SQLAlchemy dialect modules.""" 

138 

139_DATATYPE_REGEXP = re.compile(r"(\w+)(\((.*)\))?") 

140"""Regular expression to match data types in the form "type(length)""" 

141 

142 

143def string_to_typeengine( 

144 type_string: str, dialect: Dialect | None = None, length: int | None = None 

145) -> TypeEngine: 

146 match = _DATATYPE_REGEXP.search(type_string) 

147 if not match: 

148 raise ValueError(f"Invalid type string: {type_string}") 

149 

150 type_name, _, params = match.groups() 

151 if dialect is None: 

152 type_class = getattr(sqa_types, type_name.upper(), None) 

153 else: 

154 try: 

155 dialect_module = _DIALECT_MODULES[dialect.name] 

156 except KeyError: 

157 raise ValueError(f"Unsupported dialect: {dialect}") 

158 type_class = getattr(dialect_module, type_name.upper(), None) 

159 

160 if not type_class: 

161 raise ValueError(f"Unsupported type: {type_class}") 

162 

163 if params: 

164 params = [int(param) if param.isdigit() else param for param in params.split(",")] 

165 type_obj = type_class(*params) 

166 else: 

167 type_obj = type_class() 

168 

169 if hasattr(type_obj, "length") and getattr(type_obj, "length") is None and length is not None: 

170 type_obj.length = length 

171 

172 return type_obj 

173 

174 

175class Column(BaseObject): 

176 """A column in a table.""" 

177 

178 datatype: DataType 

179 """The datatype of the column.""" 

180 

181 length: int | None = None 

182 """The length of the column.""" 

183 

184 nullable: bool | None = None 

185 """Whether the column can be ``NULL``. 

186 

187 If `None`, this value was not set explicitly in the YAML data. In this 

188 case, it will be set to `False` for columns with numeric types and `True` 

189 otherwise. 

190 """ 

191 

192 value: Any = None 

193 """The default value of the column.""" 

194 

195 autoincrement: bool | None = None 

196 """Whether the column is autoincremented.""" 

197 

198 mysql_datatype: str | None = Field(None, alias="mysql:datatype") 

199 """The MySQL datatype of the column.""" 

200 

201 postgresql_datatype: str | None = Field(None, alias="postgresql:datatype") 

202 """The PostgreSQL datatype of the column.""" 

203 

204 ivoa_ucd: str | None = Field(None, alias="ivoa:ucd") 

205 """The IVOA UCD of the column.""" 

206 

207 fits_tunit: str | None = Field(None, alias="fits:tunit") 

208 """The FITS TUNIT of the column.""" 

209 

210 ivoa_unit: str | None = Field(None, alias="ivoa:unit") 

211 """The IVOA unit of the column.""" 

212 

213 tap_column_index: int | None = Field(None, alias="tap:column_index") 

214 """The TAP_SCHEMA column index of the column.""" 

215 

216 tap_principal: int | None = Field(0, alias="tap:principal", ge=0, le=1) 

217 """Whether this is a TAP_SCHEMA principal column; can be either 0 or 1. 

218 """ 

219 

220 votable_arraysize: int | Literal["*"] | None = Field(None, alias="votable:arraysize") 

221 """The VOTable arraysize of the column.""" 

222 

223 tap_std: int | None = Field(0, alias="tap:std", ge=0, le=1) 

224 """TAP_SCHEMA indication that this column is defined by an IVOA standard. 

225 """ 

226 

227 votable_xtype: str | None = Field(None, alias="votable:xtype") 

228 """The VOTable xtype (extended type) of the column.""" 

229 

230 votable_datatype: str | None = Field(None, alias="votable:datatype") 

231 """The VOTable datatype of the column.""" 

232 

233 @field_validator("ivoa_ucd") 

234 @classmethod 

235 def check_ivoa_ucd(cls, ivoa_ucd: str) -> str: 

236 """Check that IVOA UCD values are valid.""" 

237 if ivoa_ucd is not None: 

238 try: 

239 ucd.parse_ucd(ivoa_ucd, check_controlled_vocabulary=True, has_colon=";" in ivoa_ucd) 

240 except ValueError as e: 

241 raise ValueError(f"Invalid IVOA UCD: {e}") 

242 return ivoa_ucd 

243 

244 @model_validator(mode="before") 

245 @classmethod 

246 def check_units(cls, values: dict[str, Any]) -> dict[str, Any]: 

247 """Check that units are valid.""" 

248 fits_unit = values.get("fits:tunit") 

249 ivoa_unit = values.get("ivoa:unit") 

250 

251 if fits_unit and ivoa_unit: 

252 raise ValueError("Column cannot have both FITS and IVOA units") 

253 unit = fits_unit or ivoa_unit 

254 

255 if unit is not None: 

256 try: 

257 units.Unit(unit) 

258 except ValueError as e: 

259 raise ValueError(f"Invalid unit: {e}") 

260 

261 return values 

262 

263 @model_validator(mode="after") # type: ignore[arg-type] 

264 @classmethod 

265 def validate_datatypes(cls, col: Column, info: ValidationInfo) -> Column: 

266 """Check for redundant datatypes on columns.""" 

267 context = info.context 

268 if not context or not context.get("check_redundant_datatypes", False): 

269 return col 

270 if all(getattr(col, f"{dialect}:datatype", None) is not None for dialect in _DIALECTS.keys()): 

271 return col 

272 

273 datatype = col.datatype 

274 length: int | None = col.length or None 

275 

276 datatype_func = get_type_func(datatype) 

277 felis_type = FelisType.felis_type(datatype) 

278 if felis_type.is_sized: 

279 if length is not None: 

280 datatype_obj = datatype_func(length) 

281 else: 

282 raise ValueError(f"Length must be provided for sized type '{datatype}' in column '{col.id}'") 

283 else: 

284 datatype_obj = datatype_func() 

285 

286 for dialect_name, dialect in _DIALECTS.items(): 

287 db_annotation = f"{dialect_name}_datatype" 

288 if datatype_string := col.model_dump().get(db_annotation): 

289 db_datatype_obj = string_to_typeengine(datatype_string, dialect, length) 

290 if datatype_obj.compile(dialect) == db_datatype_obj.compile(dialect): 

291 raise ValueError( 

292 "'{}: {}' is a redundant override of 'datatype: {}' in column '{}'{}".format( 

293 db_annotation, 

294 datatype_string, 

295 col.datatype, 

296 col.id, 

297 "" if length is None else f" with length {length}", 

298 ) 

299 ) 

300 else: 

301 logger.debug( 

302 "Type override of 'datatype: {}' with '{}: {}' in column '{}' " 

303 "compiled to '{}' and '{}'".format( 

304 col.datatype, 

305 db_annotation, 

306 datatype_string, 

307 col.id, 

308 datatype_obj.compile(dialect), 

309 db_datatype_obj.compile(dialect), 

310 ) 

311 ) 

312 return col 

313 

314 

315class Constraint(BaseObject): 

316 """A database table constraint.""" 

317 

318 deferrable: bool = False 

319 """If `True` then this constraint will be declared as deferrable.""" 

320 

321 initially: str | None = None 

322 """Value for ``INITIALLY`` clause, only used if ``deferrable`` is True.""" 

323 

324 annotations: Mapping[str, Any] = Field(default_factory=dict) 

325 """Additional annotations for this constraint.""" 

326 

327 type: str | None = Field(None, alias="@type") 

328 """The type of the constraint.""" 

329 

330 

331class CheckConstraint(Constraint): 

332 """A check constraint on a table.""" 

333 

334 expression: str 

335 """The expression for the check constraint.""" 

336 

337 

338class UniqueConstraint(Constraint): 

339 """A unique constraint on a table.""" 

340 

341 columns: list[str] 

342 """The columns in the unique constraint.""" 

343 

344 

345class Index(BaseObject): 

346 """A database table index. 

347 

348 An index can be defined on either columns or expressions, but not both. 

349 """ 

350 

351 columns: list[str] | None = None 

352 """The columns in the index.""" 

353 

354 expressions: list[str] | None = None 

355 """The expressions in the index.""" 

356 

357 @model_validator(mode="before") 

358 @classmethod 

359 def check_columns_or_expressions(cls, values: dict[str, Any]) -> dict[str, Any]: 

360 """Check that columns or expressions are specified, but not both.""" 

361 if "columns" in values and "expressions" in values: 

362 raise ValueError("Defining columns and expressions is not valid") 

363 elif "columns" not in values and "expressions" not in values: 

364 raise ValueError("Must define columns or expressions") 

365 return values 

366 

367 

368class ForeignKeyConstraint(Constraint): 

369 """A foreign key constraint on a table. 

370 

371 These will be reflected in the TAP_SCHEMA keys and key_columns data. 

372 """ 

373 

374 columns: list[str] 

375 """The columns comprising the foreign key.""" 

376 

377 referenced_columns: list[str] = Field(alias="referencedColumns") 

378 """The columns referenced by the foreign key.""" 

379 

380 

381class Table(BaseObject): 

382 """A database table.""" 

383 

384 columns: Sequence[Column] 

385 """The columns in the table.""" 

386 

387 constraints: list[Constraint] = Field(default_factory=list) 

388 """The constraints on the table.""" 

389 

390 indexes: list[Index] = Field(default_factory=list) 

391 """The indexes on the table.""" 

392 

393 primary_key: str | list[str] | None = Field(None, alias="primaryKey") 

394 """The primary key of the table.""" 

395 

396 tap_table_index: int | None = Field(None, alias="tap:table_index") 

397 """The IVOA TAP_SCHEMA table index of the table.""" 

398 

399 mysql_engine: str | None = Field(None, alias="mysql:engine") 

400 """The mysql engine to use for the table. 

401 

402 For now this is a freeform string but it could be constrained to a list of 

403 known engines in the future. 

404 """ 

405 

406 mysql_charset: str | None = Field(None, alias="mysql:charset") 

407 """The mysql charset to use for the table. 

408 

409 For now this is a freeform string but it could be constrained to a list of 

410 known charsets in the future. 

411 """ 

412 

413 @model_validator(mode="before") 

414 @classmethod 

415 def create_constraints(cls, values: dict[str, Any]) -> dict[str, Any]: 

416 """Create constraints from the ``constraints`` field.""" 

417 if "constraints" in values: 

418 new_constraints: list[Constraint] = [] 

419 for item in values["constraints"]: 

420 if item["@type"] == "ForeignKey": 

421 new_constraints.append(ForeignKeyConstraint(**item)) 

422 elif item["@type"] == "Unique": 

423 new_constraints.append(UniqueConstraint(**item)) 

424 elif item["@type"] == "Check": 

425 new_constraints.append(CheckConstraint(**item)) 

426 else: 

427 raise ValueError(f"Unknown constraint type: {item['@type']}") 

428 values["constraints"] = new_constraints 

429 return values 

430 

431 @field_validator("columns", mode="after") 

432 @classmethod 

433 def check_unique_column_names(cls, columns: list[Column]) -> list[Column]: 

434 """Check that column names are unique.""" 

435 if len(columns) != len(set(column.name for column in columns)): 

436 raise ValueError("Column names must be unique") 

437 return columns 

438 

439 

440class SchemaVersion(BaseModel): 

441 """The version of the schema.""" 

442 

443 current: str 

444 """The current version of the schema.""" 

445 

446 compatible: list[str] = Field(default_factory=list) 

447 """The compatible versions of the schema.""" 

448 

449 read_compatible: list[str] = Field(default_factory=list) 

450 """The read compatible versions of the schema.""" 

451 

452 

453class SchemaIdVisitor: 

454 """Visitor to build a Schema object's map of IDs to objects. 

455 

456 Duplicates are added to a set when they are encountered, which can be 

457 accessed via the `duplicates` attribute. The presence of duplicates will 

458 not throw an error. Only the first object with a given ID will be added to 

459 the map, but this should not matter, since a ValidationError will be thrown 

460 by the `model_validator` method if any duplicates are found in the schema. 

461 

462 This class is intended for internal use only. 

463 """ 

464 

465 def __init__(self) -> None: 

466 """Create a new SchemaVisitor.""" 

467 self.schema: "Schema" | None = None 

468 self.duplicates: set[str] = set() 

469 

470 def add(self, obj: BaseObject) -> None: 

471 """Add an object to the ID map.""" 

472 if hasattr(obj, "id"): 

473 obj_id = getattr(obj, "id") 

474 if self.schema is not None: 

475 if obj_id in self.schema.id_map: 

476 self.duplicates.add(obj_id) 

477 else: 

478 self.schema.id_map[obj_id] = obj 

479 

480 def visit_schema(self, schema: "Schema") -> None: 

481 """Visit the schema object that was added during initialization. 

482 

483 This will set an internal variable pointing to the schema object. 

484 """ 

485 self.schema = schema 

486 self.duplicates.clear() 

487 self.add(self.schema) 

488 for table in self.schema.tables: 

489 self.visit_table(table) 

490 

491 def visit_table(self, table: Table) -> None: 

492 """Visit a table object.""" 

493 self.add(table) 

494 for column in table.columns: 

495 self.visit_column(column) 

496 for constraint in table.constraints: 

497 self.visit_constraint(constraint) 

498 

499 def visit_column(self, column: Column) -> None: 

500 """Visit a column object.""" 

501 self.add(column) 

502 

503 def visit_constraint(self, constraint: Constraint) -> None: 

504 """Visit a constraint object.""" 

505 self.add(constraint) 

506 

507 

508class Schema(BaseObject): 

509 """The database schema containing the tables.""" 

510 

511 version: SchemaVersion | str | None = None 

512 """The version of the schema.""" 

513 

514 tables: Sequence[Table] 

515 """The tables in the schema.""" 

516 

517 id_map: dict[str, Any] = Field(default_factory=dict, exclude=True) 

518 """Map of IDs to objects.""" 

519 

520 @field_validator("tables", mode="after") 

521 @classmethod 

522 def check_unique_table_names(cls, tables: list[Table]) -> list[Table]: 

523 """Check that table names are unique.""" 

524 if len(tables) != len(set(table.name for table in tables)): 

525 raise ValueError("Table names must be unique") 

526 return tables 

527 

528 def _create_id_map(self: Schema) -> Schema: 

529 """Create a map of IDs to objects. 

530 

531 This method should not be called by users. It is called automatically 

532 by the ``model_post_init()`` method. If the ID map is already 

533 populated, this method will return immediately. 

534 """ 

535 if len(self.id_map): 

536 logger.debug("Ignoring call to create_id_map() - ID map was already populated") 

537 return self 

538 visitor: SchemaIdVisitor = SchemaIdVisitor() 

539 visitor.visit_schema(self) 

540 logger.debug(f"Created schema ID map with {len(self.id_map.keys())} objects") 

541 if len(visitor.duplicates): 

542 raise ValueError( 

543 "Duplicate IDs found in schema:\n " + "\n ".join(visitor.duplicates) + "\n" 

544 ) 

545 return self 

546 

547 def model_post_init(self, ctx: Any) -> None: 

548 """Post-initialization hook for the model.""" 

549 self._create_id_map() 

550 

551 def __getitem__(self, id: str) -> BaseObject: 

552 """Get an object by its ID.""" 

553 if id not in self: 

554 raise KeyError(f"Object with ID '{id}' not found in schema") 

555 return self.id_map[id] 

556 

557 def __contains__(self, id: str) -> bool: 

558 """Check if an object with the given ID is in the schema.""" 

559 return id in self.id_map