Coverage for python/felis/datamodel.py: 48%

319 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 02:49 -0700

1# This file is part of felis. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24import logging 

25import re 

26from collections.abc import Mapping, Sequence 

27from enum import StrEnum, auto 

28from typing import Annotated, Any, Literal, TypeAlias 

29 

30from astropy import units as units # type: ignore 

31from astropy.io.votable import ucd # type: ignore 

32from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator, model_validator 

33from sqlalchemy import dialects 

34from sqlalchemy import types as sqa_types 

35from sqlalchemy.engine import create_mock_engine 

36from sqlalchemy.engine.interfaces import Dialect 

37from sqlalchemy.types import TypeEngine 

38 

39from .db.sqltypes import get_type_func 

40from .types import Boolean, Byte, Char, Double, FelisType, Float, Int, Long, Short, String, Text, Unicode 

41 

42logger = logging.getLogger(__name__) 

43 

44__all__ = ( 

45 "BaseObject", 

46 "Column", 

47 "CheckConstraint", 

48 "Constraint", 

49 "DescriptionStr", 

50 "ForeignKeyConstraint", 

51 "Index", 

52 "Schema", 

53 "SchemaVersion", 

54 "Table", 

55 "UniqueConstraint", 

56) 

57 

58CONFIG = ConfigDict( 

59 populate_by_name=True, # Populate attributes by name. 

60 extra="forbid", # Do not allow extra fields. 

61 str_strip_whitespace=True, # Strip whitespace from string fields. 

62) 

63"""Pydantic model configuration as described in: 

64https://docs.pydantic.dev/2.0/api/config/#pydantic.config.ConfigDict 

65""" 

66 

67DESCR_MIN_LENGTH = 3 

68"""Minimum length for a description field.""" 

69 

70DescriptionStr: TypeAlias = Annotated[str, Field(min_length=DESCR_MIN_LENGTH)] 

71"""Define a type for a description string, which must be three or more 

72characters long. Stripping of whitespace is done globally on all str fields.""" 

73 

74 

75class BaseObject(BaseModel): 

76 """Base class for all Felis objects.""" 

77 

78 model_config = CONFIG 

79 """Pydantic model configuration.""" 

80 

81 name: str 

82 """The name of the database object. 

83 

84 All Felis database objects must have a name. 

85 """ 

86 

87 id: str = Field(alias="@id") 

88 """The unique identifier of the database object. 

89 

90 All Felis database objects must have a unique identifier. 

91 """ 

92 

93 description: DescriptionStr | None = None 

94 """A description of the database object.""" 

95 

96 votable_utype: str | None = Field(None, alias="votable:utype") 

97 """The VOTable utype (usage-specific or unique type) of the object.""" 

98 

99 @model_validator(mode="after") 

100 def check_description(self, info: ValidationInfo) -> BaseObject: 

101 """Check that the description is present if required.""" 

102 context = info.context 

103 if not context or not context.get("require_description", False): 

104 return self 

105 if self.description is None or self.description == "": 

106 raise ValueError("Description is required and must be non-empty") 

107 if len(self.description) < DESCR_MIN_LENGTH: 

108 raise ValueError(f"Description must be at least {DESCR_MIN_LENGTH} characters long") 

109 return self 

110 

111 

112class DataType(StrEnum): 

113 """`Enum` representing the data types supported by Felis.""" 

114 

115 boolean = auto() 

116 byte = auto() 

117 short = auto() 

118 int = auto() 

119 long = auto() 

120 float = auto() 

121 double = auto() 

122 char = auto() 

123 string = auto() 

124 unicode = auto() 

125 text = auto() 

126 binary = auto() 

127 timestamp = auto() 

128 

129 

130_DIALECTS = { 

131 "mysql": create_mock_engine("mysql://", executor=None).dialect, 

132 "postgresql": create_mock_engine("postgresql://", executor=None).dialect, 

133} 

134"""Dictionary of dialect names to SQLAlchemy dialects.""" 

135 

136_DIALECT_MODULES = {"mysql": getattr(dialects, "mysql"), "postgresql": getattr(dialects, "postgresql")} 

137"""Dictionary of dialect names to SQLAlchemy dialect modules.""" 

138 

139_DATATYPE_REGEXP = re.compile(r"(\w+)(\((.*)\))?") 

140"""Regular expression to match data types in the form "type(length)""" 

141 

142 

143def string_to_typeengine( 

144 type_string: str, dialect: Dialect | None = None, length: int | None = None 

145) -> TypeEngine: 

146 match = _DATATYPE_REGEXP.search(type_string) 

147 if not match: 

148 raise ValueError(f"Invalid type string: {type_string}") 

149 

150 type_name, _, params = match.groups() 

151 if dialect is None: 

152 type_class = getattr(sqa_types, type_name.upper(), None) 

153 else: 

154 try: 

155 dialect_module = _DIALECT_MODULES[dialect.name] 

156 except KeyError: 

157 raise ValueError(f"Unsupported dialect: {dialect}") 

158 type_class = getattr(dialect_module, type_name.upper(), None) 

159 

160 if not type_class: 

161 raise ValueError(f"Unsupported type: {type_class}") 

162 

163 if params: 

164 params = [int(param) if param.isdigit() else param for param in params.split(",")] 

165 type_obj = type_class(*params) 

166 else: 

167 type_obj = type_class() 

168 

169 if hasattr(type_obj, "length") and getattr(type_obj, "length") is None and length is not None: 

170 type_obj.length = length 

171 

172 return type_obj 

173 

174 

175class Column(BaseObject): 

176 """A column in a table.""" 

177 

178 datatype: DataType 

179 """The datatype of the column.""" 

180 

181 length: int | None = None 

182 """The length of the column.""" 

183 

184 nullable: bool = True 

185 """Whether the column can be ``NULL``.""" 

186 

187 value: str | int | float | bool | None = None 

188 """The default value of the column.""" 

189 

190 autoincrement: bool | None = None 

191 """Whether the column is autoincremented.""" 

192 

193 mysql_datatype: str | None = Field(None, alias="mysql:datatype") 

194 """The MySQL datatype of the column.""" 

195 

196 postgresql_datatype: str | None = Field(None, alias="postgresql:datatype") 

197 """The PostgreSQL datatype of the column.""" 

198 

199 ivoa_ucd: str | None = Field(None, alias="ivoa:ucd") 

200 """The IVOA UCD of the column.""" 

201 

202 fits_tunit: str | None = Field(None, alias="fits:tunit") 

203 """The FITS TUNIT of the column.""" 

204 

205 ivoa_unit: str | None = Field(None, alias="ivoa:unit") 

206 """The IVOA unit of the column.""" 

207 

208 tap_column_index: int | None = Field(None, alias="tap:column_index") 

209 """The TAP_SCHEMA column index of the column.""" 

210 

211 tap_principal: int | None = Field(0, alias="tap:principal", ge=0, le=1) 

212 """Whether this is a TAP_SCHEMA principal column; can be either 0 or 1. 

213 """ 

214 

215 votable_arraysize: int | Literal["*"] | None = Field(None, alias="votable:arraysize") 

216 """The VOTable arraysize of the column.""" 

217 

218 tap_std: int | None = Field(0, alias="tap:std", ge=0, le=1) 

219 """TAP_SCHEMA indication that this column is defined by an IVOA standard. 

220 """ 

221 

222 votable_xtype: str | None = Field(None, alias="votable:xtype") 

223 """The VOTable xtype (extended type) of the column.""" 

224 

225 votable_datatype: str | None = Field(None, alias="votable:datatype") 

226 """The VOTable datatype of the column.""" 

227 

228 @model_validator(mode="after") 

229 def check_value(self) -> Column: 

230 """Check that the default value is valid.""" 

231 if (value := self.value) is not None: 

232 if value is not None and self.autoincrement is True: 

233 raise ValueError("Column cannot have both a default value and be autoincremented") 

234 felis_type = FelisType.felis_type(self.datatype) 

235 if felis_type.is_numeric: 

236 if felis_type in (Byte, Short, Int, Long) and not isinstance(value, int): 

237 raise ValueError("Default value must be an int for integer type columns") 

238 elif felis_type in (Float, Double) and not isinstance(value, float): 

239 raise ValueError("Default value must be a decimal number for float and double columns") 

240 elif felis_type in (String, Char, Unicode, Text): 

241 if not isinstance(value, str): 

242 raise ValueError("Default value must be a string for string columns") 

243 if not len(value): 

244 raise ValueError("Default value must be a non-empty string for string columns") 

245 elif felis_type is Boolean and not isinstance(value, bool): 

246 raise ValueError("Default value must be a boolean for boolean columns") 

247 return self 

248 

249 @field_validator("ivoa_ucd") 

250 @classmethod 

251 def check_ivoa_ucd(cls, ivoa_ucd: str) -> str: 

252 """Check that IVOA UCD values are valid.""" 

253 if ivoa_ucd is not None: 

254 try: 

255 ucd.parse_ucd(ivoa_ucd, check_controlled_vocabulary=True, has_colon=";" in ivoa_ucd) 

256 except ValueError as e: 

257 raise ValueError(f"Invalid IVOA UCD: {e}") 

258 return ivoa_ucd 

259 

260 @model_validator(mode="before") 

261 @classmethod 

262 def check_units(cls, values: dict[str, Any]) -> dict[str, Any]: 

263 """Check that units are valid.""" 

264 fits_unit = values.get("fits:tunit") 

265 ivoa_unit = values.get("ivoa:unit") 

266 

267 if fits_unit and ivoa_unit: 

268 raise ValueError("Column cannot have both FITS and IVOA units") 

269 unit = fits_unit or ivoa_unit 

270 

271 if unit is not None: 

272 try: 

273 units.Unit(unit) 

274 except ValueError as e: 

275 raise ValueError(f"Invalid unit: {e}") 

276 

277 return values 

278 

279 @model_validator(mode="after") 

280 def check_datatypes(self, info: ValidationInfo) -> Column: 

281 """Check for redundant datatypes on columns.""" 

282 context = info.context 

283 if not context or not context.get("check_redundant_datatypes", False): 

284 return self 

285 if all(getattr(self, f"{dialect}:datatype", None) is not None for dialect in _DIALECTS.keys()): 

286 return self 

287 

288 datatype = self.datatype 

289 length: int | None = self.length or None 

290 

291 datatype_func = get_type_func(datatype) 

292 felis_type = FelisType.felis_type(datatype) 

293 if felis_type.is_sized: 

294 if length is not None: 

295 datatype_obj = datatype_func(length) 

296 else: 

297 raise ValueError(f"Length must be provided for sized type '{datatype}' in column '{self.id}'") 

298 else: 

299 datatype_obj = datatype_func() 

300 

301 for dialect_name, dialect in _DIALECTS.items(): 

302 db_annotation = f"{dialect_name}_datatype" 

303 if datatype_string := self.model_dump().get(db_annotation): 

304 db_datatype_obj = string_to_typeengine(datatype_string, dialect, length) 

305 if datatype_obj.compile(dialect) == db_datatype_obj.compile(dialect): 

306 raise ValueError( 

307 "'{}: {}' is a redundant override of 'datatype: {}' in column '{}'{}".format( 

308 db_annotation, 

309 datatype_string, 

310 self.datatype, 

311 self.id, 

312 "" if length is None else f" with length {length}", 

313 ) 

314 ) 

315 else: 

316 logger.debug( 

317 f"Type override of 'datatype: {self.datatype}' " 

318 f"with '{db_annotation}: {datatype_string}' in column '{self.id}' " 

319 f"compiled to '{datatype_obj.compile(dialect)}' and " 

320 f"'{db_datatype_obj.compile(dialect)}'" 

321 ) 

322 return self 

323 

324 

325class Constraint(BaseObject): 

326 """A database table constraint.""" 

327 

328 deferrable: bool = False 

329 """If `True` then this constraint will be declared as deferrable.""" 

330 

331 initially: str | None = None 

332 """Value for ``INITIALLY`` clause, only used if ``deferrable`` is True.""" 

333 

334 annotations: Mapping[str, Any] = Field(default_factory=dict) 

335 """Additional annotations for this constraint.""" 

336 

337 type: str | None = Field(None, alias="@type") 

338 """The type of the constraint.""" 

339 

340 

341class CheckConstraint(Constraint): 

342 """A check constraint on a table.""" 

343 

344 expression: str 

345 """The expression for the check constraint.""" 

346 

347 

348class UniqueConstraint(Constraint): 

349 """A unique constraint on a table.""" 

350 

351 columns: list[str] 

352 """The columns in the unique constraint.""" 

353 

354 

355class Index(BaseObject): 

356 """A database table index. 

357 

358 An index can be defined on either columns or expressions, but not both. 

359 """ 

360 

361 columns: list[str] | None = None 

362 """The columns in the index.""" 

363 

364 expressions: list[str] | None = None 

365 """The expressions in the index.""" 

366 

367 @model_validator(mode="before") 

368 @classmethod 

369 def check_columns_or_expressions(cls, values: dict[str, Any]) -> dict[str, Any]: 

370 """Check that columns or expressions are specified, but not both.""" 

371 if "columns" in values and "expressions" in values: 

372 raise ValueError("Defining columns and expressions is not valid") 

373 elif "columns" not in values and "expressions" not in values: 

374 raise ValueError("Must define columns or expressions") 

375 return values 

376 

377 

378class ForeignKeyConstraint(Constraint): 

379 """A foreign key constraint on a table. 

380 

381 These will be reflected in the TAP_SCHEMA keys and key_columns data. 

382 """ 

383 

384 columns: list[str] 

385 """The columns comprising the foreign key.""" 

386 

387 referenced_columns: list[str] = Field(alias="referencedColumns") 

388 """The columns referenced by the foreign key.""" 

389 

390 

391class Table(BaseObject): 

392 """A database table.""" 

393 

394 columns: Sequence[Column] 

395 """The columns in the table.""" 

396 

397 constraints: list[Constraint] = Field(default_factory=list) 

398 """The constraints on the table.""" 

399 

400 indexes: list[Index] = Field(default_factory=list) 

401 """The indexes on the table.""" 

402 

403 primary_key: str | list[str] | None = Field(None, alias="primaryKey") 

404 """The primary key of the table.""" 

405 

406 tap_table_index: int | None = Field(None, alias="tap:table_index") 

407 """The IVOA TAP_SCHEMA table index of the table.""" 

408 

409 mysql_engine: str | None = Field(None, alias="mysql:engine") 

410 """The mysql engine to use for the table. 

411 

412 For now this is a freeform string but it could be constrained to a list of 

413 known engines in the future. 

414 """ 

415 

416 mysql_charset: str | None = Field(None, alias="mysql:charset") 

417 """The mysql charset to use for the table. 

418 

419 For now this is a freeform string but it could be constrained to a list of 

420 known charsets in the future. 

421 """ 

422 

423 @model_validator(mode="before") 

424 @classmethod 

425 def create_constraints(cls, values: dict[str, Any]) -> dict[str, Any]: 

426 """Create constraints from the ``constraints`` field.""" 

427 if "constraints" in values: 

428 new_constraints: list[Constraint] = [] 

429 for item in values["constraints"]: 

430 if item["@type"] == "ForeignKey": 

431 new_constraints.append(ForeignKeyConstraint(**item)) 

432 elif item["@type"] == "Unique": 

433 new_constraints.append(UniqueConstraint(**item)) 

434 elif item["@type"] == "Check": 

435 new_constraints.append(CheckConstraint(**item)) 

436 else: 

437 raise ValueError(f"Unknown constraint type: {item['@type']}") 

438 values["constraints"] = new_constraints 

439 return values 

440 

441 @field_validator("columns", mode="after") 

442 @classmethod 

443 def check_unique_column_names(cls, columns: list[Column]) -> list[Column]: 

444 """Check that column names are unique.""" 

445 if len(columns) != len(set(column.name for column in columns)): 

446 raise ValueError("Column names must be unique") 

447 return columns 

448 

449 

450class SchemaVersion(BaseModel): 

451 """The version of the schema.""" 

452 

453 current: str 

454 """The current version of the schema.""" 

455 

456 compatible: list[str] = Field(default_factory=list) 

457 """The compatible versions of the schema.""" 

458 

459 read_compatible: list[str] = Field(default_factory=list) 

460 """The read compatible versions of the schema.""" 

461 

462 

463class SchemaIdVisitor: 

464 """Visitor to build a Schema object's map of IDs to objects. 

465 

466 Duplicates are added to a set when they are encountered, which can be 

467 accessed via the `duplicates` attribute. The presence of duplicates will 

468 not throw an error. Only the first object with a given ID will be added to 

469 the map, but this should not matter, since a ValidationError will be thrown 

470 by the `model_validator` method if any duplicates are found in the schema. 

471 

472 This class is intended for internal use only. 

473 """ 

474 

475 def __init__(self) -> None: 

476 """Create a new SchemaVisitor.""" 

477 self.schema: Schema | None = None 

478 self.duplicates: set[str] = set() 

479 

480 def add(self, obj: BaseObject) -> None: 

481 """Add an object to the ID map.""" 

482 if hasattr(obj, "id"): 

483 obj_id = getattr(obj, "id") 

484 if self.schema is not None: 

485 if obj_id in self.schema.id_map: 

486 self.duplicates.add(obj_id) 

487 else: 

488 self.schema.id_map[obj_id] = obj 

489 

490 def visit_schema(self, schema: Schema) -> None: 

491 """Visit the schema object that was added during initialization. 

492 

493 This will set an internal variable pointing to the schema object. 

494 """ 

495 self.schema = schema 

496 self.duplicates.clear() 

497 self.add(self.schema) 

498 for table in self.schema.tables: 

499 self.visit_table(table) 

500 

501 def visit_table(self, table: Table) -> None: 

502 """Visit a table object.""" 

503 self.add(table) 

504 for column in table.columns: 

505 self.visit_column(column) 

506 for constraint in table.constraints: 

507 self.visit_constraint(constraint) 

508 

509 def visit_column(self, column: Column) -> None: 

510 """Visit a column object.""" 

511 self.add(column) 

512 

513 def visit_constraint(self, constraint: Constraint) -> None: 

514 """Visit a constraint object.""" 

515 self.add(constraint) 

516 

517 

518class Schema(BaseObject): 

519 """The database schema containing the tables.""" 

520 

521 version: SchemaVersion | str | None = None 

522 """The version of the schema.""" 

523 

524 tables: Sequence[Table] 

525 """The tables in the schema.""" 

526 

527 id_map: dict[str, Any] = Field(default_factory=dict, exclude=True) 

528 """Map of IDs to objects.""" 

529 

530 @field_validator("tables", mode="after") 

531 @classmethod 

532 def check_unique_table_names(cls, tables: list[Table]) -> list[Table]: 

533 """Check that table names are unique.""" 

534 if len(tables) != len(set(table.name for table in tables)): 

535 raise ValueError("Table names must be unique") 

536 return tables 

537 

538 def _create_id_map(self: Schema) -> Schema: 

539 """Create a map of IDs to objects. 

540 

541 This method should not be called by users. It is called automatically 

542 by the ``model_post_init()`` method. If the ID map is already 

543 populated, this method will return immediately. 

544 """ 

545 if len(self.id_map): 

546 logger.debug("Ignoring call to create_id_map() - ID map was already populated") 

547 return self 

548 visitor: SchemaIdVisitor = SchemaIdVisitor() 

549 visitor.visit_schema(self) 

550 logger.debug(f"Created schema ID map with {len(self.id_map.keys())} objects") 

551 if len(visitor.duplicates): 

552 raise ValueError( 

553 "Duplicate IDs found in schema:\n " + "\n ".join(visitor.duplicates) + "\n" 

554 ) 

555 return self 

556 

557 def model_post_init(self, ctx: Any) -> None: 

558 """Post-initialization hook for the model.""" 

559 self._create_id_map() 

560 

561 def __getitem__(self, id: str) -> BaseObject: 

562 """Get an object by its ID.""" 

563 if id not in self: 

564 raise KeyError(f"Object with ID '{id}' not found in schema") 

565 return self.id_map[id] 

566 

567 def __contains__(self, id: str) -> bool: 

568 """Check if an object with the given ID is in the schema.""" 

569 return id in self.id_map