Coverage for python/felis/datamodel.py: 48%

330 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-16 02:49 -0700

1# This file is part of felis. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24import logging 

25import re 

26from collections.abc import Mapping, Sequence 

27from enum import StrEnum, auto 

28from typing import Annotated, Any, Literal, TypeAlias 

29 

30from astropy import units as units # type: ignore 

31from astropy.io.votable import ucd # type: ignore 

32from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator, model_validator 

33from sqlalchemy import dialects 

34from sqlalchemy import types as sqa_types 

35from sqlalchemy.engine import create_mock_engine 

36from sqlalchemy.engine.interfaces import Dialect 

37from sqlalchemy.types import TypeEngine 

38 

39from .db.sqltypes import get_type_func 

40from .types import Boolean, Byte, Char, Double, FelisType, Float, Int, Long, Short, String, Text, Unicode 

41 

42logger = logging.getLogger(__name__) 

43 

44__all__ = ( 

45 "BaseObject", 

46 "Column", 

47 "CheckConstraint", 

48 "Constraint", 

49 "DescriptionStr", 

50 "ForeignKeyConstraint", 

51 "Index", 

52 "Schema", 

53 "SchemaVersion", 

54 "Table", 

55 "UniqueConstraint", 

56) 

57 

58CONFIG = ConfigDict( 

59 populate_by_name=True, # Populate attributes by name. 

60 extra="forbid", # Do not allow extra fields. 

61 str_strip_whitespace=True, # Strip whitespace from string fields. 

62) 

63"""Pydantic model configuration as described in: 

64https://docs.pydantic.dev/2.0/api/config/#pydantic.config.ConfigDict 

65""" 

66 

67DESCR_MIN_LENGTH = 3 

68"""Minimum length for a description field.""" 

69 

70DescriptionStr: TypeAlias = Annotated[str, Field(min_length=DESCR_MIN_LENGTH)] 

71"""Define a type for a description string, which must be three or more 

72characters long. Stripping of whitespace is done globally on all str fields.""" 

73 

74 

75class BaseObject(BaseModel): 

76 """Base class for all Felis objects.""" 

77 

78 model_config = CONFIG 

79 """Pydantic model configuration.""" 

80 

81 name: str 

82 """The name of the database object. 

83 

84 All Felis database objects must have a name. 

85 """ 

86 

87 id: str = Field(alias="@id") 

88 """The unique identifier of the database object. 

89 

90 All Felis database objects must have a unique identifier. 

91 """ 

92 

93 description: DescriptionStr | None = None 

94 """A description of the database object.""" 

95 

96 votable_utype: str | None = Field(None, alias="votable:utype") 

97 """The VOTable utype (usage-specific or unique type) of the object.""" 

98 

99 @model_validator(mode="after") 

100 def check_description(self, info: ValidationInfo) -> BaseObject: 

101 """Check that the description is present if required.""" 

102 context = info.context 

103 if not context or not context.get("require_description", False): 

104 return self 

105 if self.description is None or self.description == "": 

106 raise ValueError("Description is required and must be non-empty") 

107 if len(self.description) < DESCR_MIN_LENGTH: 

108 raise ValueError(f"Description must be at least {DESCR_MIN_LENGTH} characters long") 

109 return self 

110 

111 

112class DataType(StrEnum): 

113 """`Enum` representing the data types supported by Felis.""" 

114 

115 boolean = auto() 

116 byte = auto() 

117 short = auto() 

118 int = auto() 

119 long = auto() 

120 float = auto() 

121 double = auto() 

122 char = auto() 

123 string = auto() 

124 unicode = auto() 

125 text = auto() 

126 binary = auto() 

127 timestamp = auto() 

128 

129 

130_DIALECTS = { 

131 "mysql": create_mock_engine("mysql://", executor=None).dialect, 

132 "postgresql": create_mock_engine("postgresql://", executor=None).dialect, 

133} 

134"""Dictionary of dialect names to SQLAlchemy dialects.""" 

135 

136_DIALECT_MODULES = {"mysql": getattr(dialects, "mysql"), "postgresql": getattr(dialects, "postgresql")} 

137"""Dictionary of dialect names to SQLAlchemy dialect modules.""" 

138 

139_DATATYPE_REGEXP = re.compile(r"(\w+)(\((.*)\))?") 

140"""Regular expression to match data types in the form "type(length)""" 

141 

142 

143def string_to_typeengine( 

144 type_string: str, dialect: Dialect | None = None, length: int | None = None 

145) -> TypeEngine: 

146 match = _DATATYPE_REGEXP.search(type_string) 

147 if not match: 

148 raise ValueError(f"Invalid type string: {type_string}") 

149 

150 type_name, _, params = match.groups() 

151 if dialect is None: 

152 type_class = getattr(sqa_types, type_name.upper(), None) 

153 else: 

154 try: 

155 dialect_module = _DIALECT_MODULES[dialect.name] 

156 except KeyError: 

157 raise ValueError(f"Unsupported dialect: {dialect}") 

158 type_class = getattr(dialect_module, type_name.upper(), None) 

159 

160 if not type_class: 

161 raise ValueError(f"Unsupported type: {type_class}") 

162 

163 if params: 

164 params = [int(param) if param.isdigit() else param for param in params.split(",")] 

165 type_obj = type_class(*params) 

166 else: 

167 type_obj = type_class() 

168 

169 if hasattr(type_obj, "length") and getattr(type_obj, "length") is None and length is not None: 

170 type_obj.length = length 

171 

172 return type_obj 

173 

174 

175class Column(BaseObject): 

176 """A column in a table.""" 

177 

178 datatype: DataType 

179 """The datatype of the column.""" 

180 

181 length: int | None = Field(None, gt=0) 

182 """The length of the column.""" 

183 

184 nullable: bool = True 

185 """Whether the column can be ``NULL``.""" 

186 

187 value: str | int | float | bool | None = None 

188 """The default value of the column.""" 

189 

190 autoincrement: bool | None = None 

191 """Whether the column is autoincremented.""" 

192 

193 mysql_datatype: str | None = Field(None, alias="mysql:datatype") 

194 """The MySQL datatype of the column.""" 

195 

196 postgresql_datatype: str | None = Field(None, alias="postgresql:datatype") 

197 """The PostgreSQL datatype of the column.""" 

198 

199 ivoa_ucd: str | None = Field(None, alias="ivoa:ucd") 

200 """The IVOA UCD of the column.""" 

201 

202 fits_tunit: str | None = Field(None, alias="fits:tunit") 

203 """The FITS TUNIT of the column.""" 

204 

205 ivoa_unit: str | None = Field(None, alias="ivoa:unit") 

206 """The IVOA unit of the column.""" 

207 

208 tap_column_index: int | None = Field(None, alias="tap:column_index") 

209 """The TAP_SCHEMA column index of the column.""" 

210 

211 tap_principal: int | None = Field(0, alias="tap:principal", ge=0, le=1) 

212 """Whether this is a TAP_SCHEMA principal column; can be either 0 or 1. 

213 """ 

214 

215 votable_arraysize: int | Literal["*"] | None = Field(None, alias="votable:arraysize") 

216 """The VOTable arraysize of the column.""" 

217 

218 tap_std: int | None = Field(0, alias="tap:std", ge=0, le=1) 

219 """TAP_SCHEMA indication that this column is defined by an IVOA standard. 

220 """ 

221 

222 votable_xtype: str | None = Field(None, alias="votable:xtype") 

223 """The VOTable xtype (extended type) of the column.""" 

224 

225 votable_datatype: str | None = Field(None, alias="votable:datatype") 

226 """The VOTable datatype of the column.""" 

227 

228 @model_validator(mode="after") 

229 def check_value(self) -> Column: 

230 """Check that the default value is valid.""" 

231 if (value := self.value) is not None: 

232 if value is not None and self.autoincrement is True: 

233 raise ValueError("Column cannot have both a default value and be autoincremented") 

234 felis_type = FelisType.felis_type(self.datatype) 

235 if felis_type.is_numeric: 

236 if felis_type in (Byte, Short, Int, Long) and not isinstance(value, int): 

237 raise ValueError("Default value must be an int for integer type columns") 

238 elif felis_type in (Float, Double) and not isinstance(value, float): 

239 raise ValueError("Default value must be a decimal number for float and double columns") 

240 elif felis_type in (String, Char, Unicode, Text): 

241 if not isinstance(value, str): 

242 raise ValueError("Default value must be a string for string columns") 

243 if not len(value): 

244 raise ValueError("Default value must be a non-empty string for string columns") 

245 elif felis_type is Boolean and not isinstance(value, bool): 

246 raise ValueError("Default value must be a boolean for boolean columns") 

247 return self 

248 

249 @field_validator("ivoa_ucd") 

250 @classmethod 

251 def check_ivoa_ucd(cls, ivoa_ucd: str) -> str: 

252 """Check that IVOA UCD values are valid.""" 

253 if ivoa_ucd is not None: 

254 try: 

255 ucd.parse_ucd(ivoa_ucd, check_controlled_vocabulary=True, has_colon=";" in ivoa_ucd) 

256 except ValueError as e: 

257 raise ValueError(f"Invalid IVOA UCD: {e}") 

258 return ivoa_ucd 

259 

260 @model_validator(mode="before") 

261 @classmethod 

262 def check_units(cls, values: dict[str, Any]) -> dict[str, Any]: 

263 """Check that units are valid.""" 

264 fits_unit = values.get("fits:tunit") 

265 ivoa_unit = values.get("ivoa:unit") 

266 

267 if fits_unit and ivoa_unit: 

268 raise ValueError("Column cannot have both FITS and IVOA units") 

269 unit = fits_unit or ivoa_unit 

270 

271 if unit is not None: 

272 try: 

273 units.Unit(unit) 

274 except ValueError as e: 

275 raise ValueError(f"Invalid unit: {e}") 

276 

277 return values 

278 

279 @model_validator(mode="before") 

280 @classmethod 

281 def check_length(cls, values: dict[str, Any]) -> dict[str, Any]: 

282 """Check that a valid length is provided for sized types.""" 

283 datatype = values.get("datatype") 

284 if datatype is None: 

285 # Skip this validation if datatype is not provided 

286 return values 

287 length = values.get("length") 

288 felis_type = FelisType.felis_type(datatype) 

289 if felis_type.is_sized and length is None: 

290 raise ValueError( 

291 f"Length must be provided for type '{datatype}'" 

292 + (f" in column '{values['@id']}'" if "@id" in values else "") 

293 ) 

294 elif not felis_type.is_sized and length is not None: 

295 logger.warning( 

296 f"The datatype '{datatype}' does not support a specified length" 

297 + (f" in column '{values['@id']}'" if "@id" in values else "") 

298 ) 

299 return values 

300 

301 @model_validator(mode="after") 

302 def check_datatypes(self, info: ValidationInfo) -> Column: 

303 """Check for redundant datatypes on columns.""" 

304 context = info.context 

305 if not context or not context.get("check_redundant_datatypes", False): 

306 return self 

307 if all(getattr(self, f"{dialect}:datatype", None) is not None for dialect in _DIALECTS.keys()): 

308 return self 

309 

310 datatype = self.datatype 

311 length: int | None = self.length or None 

312 

313 datatype_func = get_type_func(datatype) 

314 felis_type = FelisType.felis_type(datatype) 

315 if felis_type.is_sized: 

316 datatype_obj = datatype_func(length) 

317 else: 

318 datatype_obj = datatype_func() 

319 

320 for dialect_name, dialect in _DIALECTS.items(): 

321 db_annotation = f"{dialect_name}_datatype" 

322 if datatype_string := self.model_dump().get(db_annotation): 

323 db_datatype_obj = string_to_typeengine(datatype_string, dialect, length) 

324 if datatype_obj.compile(dialect) == db_datatype_obj.compile(dialect): 

325 raise ValueError( 

326 "'{}: {}' is a redundant override of 'datatype: {}' in column '{}'{}".format( 

327 db_annotation, 

328 datatype_string, 

329 self.datatype, 

330 self.id, 

331 "" if length is None else f" with length {length}", 

332 ) 

333 ) 

334 else: 

335 logger.debug( 

336 f"Type override of 'datatype: {self.datatype}' " 

337 f"with '{db_annotation}: {datatype_string}' in column '{self.id}' " 

338 f"compiled to '{datatype_obj.compile(dialect)}' and " 

339 f"'{db_datatype_obj.compile(dialect)}'" 

340 ) 

341 return self 

342 

343 

344class Constraint(BaseObject): 

345 """A database table constraint.""" 

346 

347 deferrable: bool = False 

348 """If `True` then this constraint will be declared as deferrable.""" 

349 

350 initially: str | None = None 

351 """Value for ``INITIALLY`` clause, only used if ``deferrable`` is True.""" 

352 

353 annotations: Mapping[str, Any] = Field(default_factory=dict) 

354 """Additional annotations for this constraint.""" 

355 

356 type: str | None = Field(None, alias="@type") 

357 """The type of the constraint.""" 

358 

359 

360class CheckConstraint(Constraint): 

361 """A check constraint on a table.""" 

362 

363 expression: str 

364 """The expression for the check constraint.""" 

365 

366 

367class UniqueConstraint(Constraint): 

368 """A unique constraint on a table.""" 

369 

370 columns: list[str] 

371 """The columns in the unique constraint.""" 

372 

373 

374class Index(BaseObject): 

375 """A database table index. 

376 

377 An index can be defined on either columns or expressions, but not both. 

378 """ 

379 

380 columns: list[str] | None = None 

381 """The columns in the index.""" 

382 

383 expressions: list[str] | None = None 

384 """The expressions in the index.""" 

385 

386 @model_validator(mode="before") 

387 @classmethod 

388 def check_columns_or_expressions(cls, values: dict[str, Any]) -> dict[str, Any]: 

389 """Check that columns or expressions are specified, but not both.""" 

390 if "columns" in values and "expressions" in values: 

391 raise ValueError("Defining columns and expressions is not valid") 

392 elif "columns" not in values and "expressions" not in values: 

393 raise ValueError("Must define columns or expressions") 

394 return values 

395 

396 

397class ForeignKeyConstraint(Constraint): 

398 """A foreign key constraint on a table. 

399 

400 These will be reflected in the TAP_SCHEMA keys and key_columns data. 

401 """ 

402 

403 columns: list[str] 

404 """The columns comprising the foreign key.""" 

405 

406 referenced_columns: list[str] = Field(alias="referencedColumns") 

407 """The columns referenced by the foreign key.""" 

408 

409 

410class Table(BaseObject): 

411 """A database table.""" 

412 

413 columns: Sequence[Column] 

414 """The columns in the table.""" 

415 

416 constraints: list[Constraint] = Field(default_factory=list) 

417 """The constraints on the table.""" 

418 

419 indexes: list[Index] = Field(default_factory=list) 

420 """The indexes on the table.""" 

421 

422 primary_key: str | list[str] | None = Field(None, alias="primaryKey") 

423 """The primary key of the table.""" 

424 

425 tap_table_index: int | None = Field(None, alias="tap:table_index") 

426 """The IVOA TAP_SCHEMA table index of the table.""" 

427 

428 mysql_engine: str | None = Field(None, alias="mysql:engine") 

429 """The mysql engine to use for the table. 

430 

431 For now this is a freeform string but it could be constrained to a list of 

432 known engines in the future. 

433 """ 

434 

435 mysql_charset: str | None = Field(None, alias="mysql:charset") 

436 """The mysql charset to use for the table. 

437 

438 For now this is a freeform string but it could be constrained to a list of 

439 known charsets in the future. 

440 """ 

441 

442 @model_validator(mode="before") 

443 @classmethod 

444 def create_constraints(cls, values: dict[str, Any]) -> dict[str, Any]: 

445 """Create constraints from the ``constraints`` field.""" 

446 if "constraints" in values: 

447 new_constraints: list[Constraint] = [] 

448 for item in values["constraints"]: 

449 if item["@type"] == "ForeignKey": 

450 new_constraints.append(ForeignKeyConstraint(**item)) 

451 elif item["@type"] == "Unique": 

452 new_constraints.append(UniqueConstraint(**item)) 

453 elif item["@type"] == "Check": 

454 new_constraints.append(CheckConstraint(**item)) 

455 else: 

456 raise ValueError(f"Unknown constraint type: {item['@type']}") 

457 values["constraints"] = new_constraints 

458 return values 

459 

460 @field_validator("columns", mode="after") 

461 @classmethod 

462 def check_unique_column_names(cls, columns: list[Column]) -> list[Column]: 

463 """Check that column names are unique.""" 

464 if len(columns) != len(set(column.name for column in columns)): 

465 raise ValueError("Column names must be unique") 

466 return columns 

467 

468 

469class SchemaVersion(BaseModel): 

470 """The version of the schema.""" 

471 

472 current: str 

473 """The current version of the schema.""" 

474 

475 compatible: list[str] = Field(default_factory=list) 

476 """The compatible versions of the schema.""" 

477 

478 read_compatible: list[str] = Field(default_factory=list) 

479 """The read compatible versions of the schema.""" 

480 

481 

482class SchemaIdVisitor: 

483 """Visitor to build a Schema object's map of IDs to objects. 

484 

485 Duplicates are added to a set when they are encountered, which can be 

486 accessed via the `duplicates` attribute. The presence of duplicates will 

487 not throw an error. Only the first object with a given ID will be added to 

488 the map, but this should not matter, since a ValidationError will be thrown 

489 by the `model_validator` method if any duplicates are found in the schema. 

490 

491 This class is intended for internal use only. 

492 """ 

493 

494 def __init__(self) -> None: 

495 """Create a new SchemaVisitor.""" 

496 self.schema: Schema | None = None 

497 self.duplicates: set[str] = set() 

498 

499 def add(self, obj: BaseObject) -> None: 

500 """Add an object to the ID map.""" 

501 if hasattr(obj, "id"): 

502 obj_id = getattr(obj, "id") 

503 if self.schema is not None: 

504 if obj_id in self.schema.id_map: 

505 self.duplicates.add(obj_id) 

506 else: 

507 self.schema.id_map[obj_id] = obj 

508 

509 def visit_schema(self, schema: Schema) -> None: 

510 """Visit the schema object that was added during initialization. 

511 

512 This will set an internal variable pointing to the schema object. 

513 """ 

514 self.schema = schema 

515 self.duplicates.clear() 

516 self.add(self.schema) 

517 for table in self.schema.tables: 

518 self.visit_table(table) 

519 

520 def visit_table(self, table: Table) -> None: 

521 """Visit a table object.""" 

522 self.add(table) 

523 for column in table.columns: 

524 self.visit_column(column) 

525 for constraint in table.constraints: 

526 self.visit_constraint(constraint) 

527 

528 def visit_column(self, column: Column) -> None: 

529 """Visit a column object.""" 

530 self.add(column) 

531 

532 def visit_constraint(self, constraint: Constraint) -> None: 

533 """Visit a constraint object.""" 

534 self.add(constraint) 

535 

536 

537class Schema(BaseObject): 

538 """The database schema containing the tables.""" 

539 

540 version: SchemaVersion | str | None = None 

541 """The version of the schema.""" 

542 

543 tables: Sequence[Table] 

544 """The tables in the schema.""" 

545 

546 id_map: dict[str, Any] = Field(default_factory=dict, exclude=True) 

547 """Map of IDs to objects.""" 

548 

549 @field_validator("tables", mode="after") 

550 @classmethod 

551 def check_unique_table_names(cls, tables: list[Table]) -> list[Table]: 

552 """Check that table names are unique.""" 

553 if len(tables) != len(set(table.name for table in tables)): 

554 raise ValueError("Table names must be unique") 

555 return tables 

556 

557 def _create_id_map(self: Schema) -> Schema: 

558 """Create a map of IDs to objects. 

559 

560 This method should not be called by users. It is called automatically 

561 by the ``model_post_init()`` method. If the ID map is already 

562 populated, this method will return immediately. 

563 """ 

564 if len(self.id_map): 

565 logger.debug("Ignoring call to create_id_map() - ID map was already populated") 

566 return self 

567 visitor: SchemaIdVisitor = SchemaIdVisitor() 

568 visitor.visit_schema(self) 

569 logger.debug(f"Created schema ID map with {len(self.id_map.keys())} objects") 

570 if len(visitor.duplicates): 

571 raise ValueError( 

572 "Duplicate IDs found in schema:\n " + "\n ".join(visitor.duplicates) + "\n" 

573 ) 

574 return self 

575 

576 def model_post_init(self, ctx: Any) -> None: 

577 """Post-initialization hook for the model.""" 

578 self._create_id_map() 

579 

580 def __getitem__(self, id: str) -> BaseObject: 

581 """Get an object by its ID.""" 

582 if id not in self: 

583 raise KeyError(f"Object with ID '{id}' not found in schema") 

584 return self.id_map[id] 

585 

586 def __contains__(self, id: str) -> bool: 

587 """Check if an object with the given ID is in the schema.""" 

588 return id in self.id_map