Coverage for python/felis/datamodel.py: 52%

301 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-25 10:20 -0700

1# This file is part of felis. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24import logging 

25import re 

26from collections.abc import Mapping, Sequence 

27from enum import StrEnum, auto 

28from typing import Annotated, Any, Literal, TypeAlias 

29 

30from astropy import units as units # type: ignore 

31from astropy.io.votable import ucd # type: ignore 

32from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator, model_validator 

33from sqlalchemy import dialects 

34from sqlalchemy import types as sqa_types 

35from sqlalchemy.engine import create_mock_engine 

36from sqlalchemy.engine.interfaces import Dialect 

37from sqlalchemy.types import TypeEngine 

38 

39from .db.sqltypes import get_type_func 

40from .types import FelisType 

41 

42logger = logging.getLogger(__name__) 

43 

44__all__ = ( 

45 "BaseObject", 

46 "Column", 

47 "CheckConstraint", 

48 "Constraint", 

49 "DescriptionStr", 

50 "ForeignKeyConstraint", 

51 "Index", 

52 "Schema", 

53 "SchemaVersion", 

54 "Table", 

55 "UniqueConstraint", 

56) 

57 

58CONFIG = ConfigDict( 

59 populate_by_name=True, # Populate attributes by name. 

60 extra="forbid", # Do not allow extra fields. 

61 str_strip_whitespace=True, # Strip whitespace from string fields. 

62) 

63"""Pydantic model configuration as described in: 

64https://docs.pydantic.dev/2.0/api/config/#pydantic.config.ConfigDict 

65""" 

66 

67DESCR_MIN_LENGTH = 3 

68"""Minimum length for a description field.""" 

69 

70DescriptionStr: TypeAlias = Annotated[str, Field(min_length=DESCR_MIN_LENGTH)] 

71"""Define a type for a description string, which must be three or more 

72characters long. Stripping of whitespace is done globally on all str fields.""" 

73 

74 

75class BaseObject(BaseModel): 

76 """Base class for all Felis objects.""" 

77 

78 model_config = CONFIG 

79 """Pydantic model configuration.""" 

80 

81 name: str 

82 """The name of the database object. 

83 

84 All Felis database objects must have a name. 

85 """ 

86 

87 id: str = Field(alias="@id") 

88 """The unique identifier of the database object. 

89 

90 All Felis database objects must have a unique identifier. 

91 """ 

92 

93 description: DescriptionStr | None = None 

94 """A description of the database object.""" 

95 

96 votable_utype: str | None = Field(None, alias="votable:utype") 

97 """The VOTable utype (usage-specific or unique type) of the object.""" 

98 

99 @model_validator(mode="after") 

100 def check_description(self, info: ValidationInfo) -> BaseObject: 

101 """Check that the description is present if required.""" 

102 context = info.context 

103 if not context or not context.get("require_description", False): 

104 return self 

105 if self.description is None or self.description == "": 

106 raise ValueError("Description is required and must be non-empty") 

107 if len(self.description) < DESCR_MIN_LENGTH: 

108 raise ValueError(f"Description must be at least {DESCR_MIN_LENGTH} characters long") 

109 return self 

110 

111 

112class DataType(StrEnum): 

113 """`Enum` representing the data types supported by Felis.""" 

114 

115 boolean = auto() 

116 byte = auto() 

117 short = auto() 

118 int = auto() 

119 long = auto() 

120 float = auto() 

121 double = auto() 

122 char = auto() 

123 string = auto() 

124 unicode = auto() 

125 text = auto() 

126 binary = auto() 

127 timestamp = auto() 

128 

129 

130_DIALECTS = { 

131 "mysql": create_mock_engine("mysql://", executor=None).dialect, 

132 "postgresql": create_mock_engine("postgresql://", executor=None).dialect, 

133} 

134"""Dictionary of dialect names to SQLAlchemy dialects.""" 

135 

136_DIALECT_MODULES = {"mysql": getattr(dialects, "mysql"), "postgresql": getattr(dialects, "postgresql")} 

137"""Dictionary of dialect names to SQLAlchemy dialect modules.""" 

138 

139_DATATYPE_REGEXP = re.compile(r"(\w+)(\((.*)\))?") 

140"""Regular expression to match data types in the form "type(length)""" 

141 

142 

143def string_to_typeengine( 

144 type_string: str, dialect: Dialect | None = None, length: int | None = None 

145) -> TypeEngine: 

146 match = _DATATYPE_REGEXP.search(type_string) 

147 if not match: 

148 raise ValueError(f"Invalid type string: {type_string}") 

149 

150 type_name, _, params = match.groups() 

151 if dialect is None: 

152 type_class = getattr(sqa_types, type_name.upper(), None) 

153 else: 

154 try: 

155 dialect_module = _DIALECT_MODULES[dialect.name] 

156 except KeyError: 

157 raise ValueError(f"Unsupported dialect: {dialect}") 

158 type_class = getattr(dialect_module, type_name.upper(), None) 

159 

160 if not type_class: 

161 raise ValueError(f"Unsupported type: {type_class}") 

162 

163 if params: 

164 params = [int(param) if param.isdigit() else param for param in params.split(",")] 

165 type_obj = type_class(*params) 

166 else: 

167 type_obj = type_class() 

168 

169 if hasattr(type_obj, "length") and getattr(type_obj, "length") is None and length is not None: 

170 type_obj.length = length 

171 

172 return type_obj 

173 

174 

175class Column(BaseObject): 

176 """A column in a table.""" 

177 

178 datatype: DataType 

179 """The datatype of the column.""" 

180 

181 length: int | None = None 

182 """The length of the column.""" 

183 

184 nullable: bool = True 

185 """Whether the column can be ``NULL``.""" 

186 

187 value: Any = None 

188 """The default value of the column.""" 

189 

190 autoincrement: bool | None = None 

191 """Whether the column is autoincremented.""" 

192 

193 mysql_datatype: str | None = Field(None, alias="mysql:datatype") 

194 """The MySQL datatype of the column.""" 

195 

196 postgresql_datatype: str | None = Field(None, alias="postgresql:datatype") 

197 """The PostgreSQL datatype of the column.""" 

198 

199 ivoa_ucd: str | None = Field(None, alias="ivoa:ucd") 

200 """The IVOA UCD of the column.""" 

201 

202 fits_tunit: str | None = Field(None, alias="fits:tunit") 

203 """The FITS TUNIT of the column.""" 

204 

205 ivoa_unit: str | None = Field(None, alias="ivoa:unit") 

206 """The IVOA unit of the column.""" 

207 

208 tap_column_index: int | None = Field(None, alias="tap:column_index") 

209 """The TAP_SCHEMA column index of the column.""" 

210 

211 tap_principal: int | None = Field(0, alias="tap:principal", ge=0, le=1) 

212 """Whether this is a TAP_SCHEMA principal column; can be either 0 or 1. 

213 """ 

214 

215 votable_arraysize: int | Literal["*"] | None = Field(None, alias="votable:arraysize") 

216 """The VOTable arraysize of the column.""" 

217 

218 tap_std: int | None = Field(0, alias="tap:std", ge=0, le=1) 

219 """TAP_SCHEMA indication that this column is defined by an IVOA standard. 

220 """ 

221 

222 votable_xtype: str | None = Field(None, alias="votable:xtype") 

223 """The VOTable xtype (extended type) of the column.""" 

224 

225 votable_datatype: str | None = Field(None, alias="votable:datatype") 

226 """The VOTable datatype of the column.""" 

227 

228 @field_validator("ivoa_ucd") 

229 @classmethod 

230 def check_ivoa_ucd(cls, ivoa_ucd: str) -> str: 

231 """Check that IVOA UCD values are valid.""" 

232 if ivoa_ucd is not None: 

233 try: 

234 ucd.parse_ucd(ivoa_ucd, check_controlled_vocabulary=True, has_colon=";" in ivoa_ucd) 

235 except ValueError as e: 

236 raise ValueError(f"Invalid IVOA UCD: {e}") 

237 return ivoa_ucd 

238 

239 @model_validator(mode="before") 

240 @classmethod 

241 def check_units(cls, values: dict[str, Any]) -> dict[str, Any]: 

242 """Check that units are valid.""" 

243 fits_unit = values.get("fits:tunit") 

244 ivoa_unit = values.get("ivoa:unit") 

245 

246 if fits_unit and ivoa_unit: 

247 raise ValueError("Column cannot have both FITS and IVOA units") 

248 unit = fits_unit or ivoa_unit 

249 

250 if unit is not None: 

251 try: 

252 units.Unit(unit) 

253 except ValueError as e: 

254 raise ValueError(f"Invalid unit: {e}") 

255 

256 return values 

257 

258 @model_validator(mode="after") # type: ignore[arg-type] 

259 @classmethod 

260 def validate_datatypes(cls, col: Column, info: ValidationInfo) -> Column: 

261 """Check for redundant datatypes on columns.""" 

262 context = info.context 

263 if not context or not context.get("check_redundant_datatypes", False): 

264 return col 

265 if all(getattr(col, f"{dialect}:datatype", None) is not None for dialect in _DIALECTS.keys()): 

266 return col 

267 

268 datatype = col.datatype 

269 length: int | None = col.length or None 

270 

271 datatype_func = get_type_func(datatype) 

272 felis_type = FelisType.felis_type(datatype) 

273 if felis_type.is_sized: 

274 if length is not None: 

275 datatype_obj = datatype_func(length) 

276 else: 

277 raise ValueError(f"Length must be provided for sized type '{datatype}' in column '{col.id}'") 

278 else: 

279 datatype_obj = datatype_func() 

280 

281 for dialect_name, dialect in _DIALECTS.items(): 

282 db_annotation = f"{dialect_name}_datatype" 

283 if datatype_string := col.model_dump().get(db_annotation): 

284 db_datatype_obj = string_to_typeengine(datatype_string, dialect, length) 

285 if datatype_obj.compile(dialect) == db_datatype_obj.compile(dialect): 

286 raise ValueError( 

287 "'{}: {}' is a redundant override of 'datatype: {}' in column '{}'{}".format( 

288 db_annotation, 

289 datatype_string, 

290 col.datatype, 

291 col.id, 

292 "" if length is None else f" with length {length}", 

293 ) 

294 ) 

295 else: 

296 logger.debug( 

297 "Type override of 'datatype: {}' with '{}: {}' in column '{}' " 

298 "compiled to '{}' and '{}'".format( 

299 col.datatype, 

300 db_annotation, 

301 datatype_string, 

302 col.id, 

303 datatype_obj.compile(dialect), 

304 db_datatype_obj.compile(dialect), 

305 ) 

306 ) 

307 return col 

308 

309 

310class Constraint(BaseObject): 

311 """A database table constraint.""" 

312 

313 deferrable: bool = False 

314 """If `True` then this constraint will be declared as deferrable.""" 

315 

316 initially: str | None = None 

317 """Value for ``INITIALLY`` clause, only used if ``deferrable`` is True.""" 

318 

319 annotations: Mapping[str, Any] = Field(default_factory=dict) 

320 """Additional annotations for this constraint.""" 

321 

322 type: str | None = Field(None, alias="@type") 

323 """The type of the constraint.""" 

324 

325 

326class CheckConstraint(Constraint): 

327 """A check constraint on a table.""" 

328 

329 expression: str 

330 """The expression for the check constraint.""" 

331 

332 

333class UniqueConstraint(Constraint): 

334 """A unique constraint on a table.""" 

335 

336 columns: list[str] 

337 """The columns in the unique constraint.""" 

338 

339 

340class Index(BaseObject): 

341 """A database table index. 

342 

343 An index can be defined on either columns or expressions, but not both. 

344 """ 

345 

346 columns: list[str] | None = None 

347 """The columns in the index.""" 

348 

349 expressions: list[str] | None = None 

350 """The expressions in the index.""" 

351 

352 @model_validator(mode="before") 

353 @classmethod 

354 def check_columns_or_expressions(cls, values: dict[str, Any]) -> dict[str, Any]: 

355 """Check that columns or expressions are specified, but not both.""" 

356 if "columns" in values and "expressions" in values: 

357 raise ValueError("Defining columns and expressions is not valid") 

358 elif "columns" not in values and "expressions" not in values: 

359 raise ValueError("Must define columns or expressions") 

360 return values 

361 

362 

363class ForeignKeyConstraint(Constraint): 

364 """A foreign key constraint on a table. 

365 

366 These will be reflected in the TAP_SCHEMA keys and key_columns data. 

367 """ 

368 

369 columns: list[str] 

370 """The columns comprising the foreign key.""" 

371 

372 referenced_columns: list[str] = Field(alias="referencedColumns") 

373 """The columns referenced by the foreign key.""" 

374 

375 

376class Table(BaseObject): 

377 """A database table.""" 

378 

379 columns: Sequence[Column] 

380 """The columns in the table.""" 

381 

382 constraints: list[Constraint] = Field(default_factory=list) 

383 """The constraints on the table.""" 

384 

385 indexes: list[Index] = Field(default_factory=list) 

386 """The indexes on the table.""" 

387 

388 primary_key: str | list[str] | None = Field(None, alias="primaryKey") 

389 """The primary key of the table.""" 

390 

391 tap_table_index: int | None = Field(None, alias="tap:table_index") 

392 """The IVOA TAP_SCHEMA table index of the table.""" 

393 

394 mysql_engine: str | None = Field(None, alias="mysql:engine") 

395 """The mysql engine to use for the table. 

396 

397 For now this is a freeform string but it could be constrained to a list of 

398 known engines in the future. 

399 """ 

400 

401 mysql_charset: str | None = Field(None, alias="mysql:charset") 

402 """The mysql charset to use for the table. 

403 

404 For now this is a freeform string but it could be constrained to a list of 

405 known charsets in the future. 

406 """ 

407 

408 @model_validator(mode="before") 

409 @classmethod 

410 def create_constraints(cls, values: dict[str, Any]) -> dict[str, Any]: 

411 """Create constraints from the ``constraints`` field.""" 

412 if "constraints" in values: 

413 new_constraints: list[Constraint] = [] 

414 for item in values["constraints"]: 

415 if item["@type"] == "ForeignKey": 

416 new_constraints.append(ForeignKeyConstraint(**item)) 

417 elif item["@type"] == "Unique": 

418 new_constraints.append(UniqueConstraint(**item)) 

419 elif item["@type"] == "Check": 

420 new_constraints.append(CheckConstraint(**item)) 

421 else: 

422 raise ValueError(f"Unknown constraint type: {item['@type']}") 

423 values["constraints"] = new_constraints 

424 return values 

425 

426 @field_validator("columns", mode="after") 

427 @classmethod 

428 def check_unique_column_names(cls, columns: list[Column]) -> list[Column]: 

429 """Check that column names are unique.""" 

430 if len(columns) != len(set(column.name for column in columns)): 

431 raise ValueError("Column names must be unique") 

432 return columns 

433 

434 

435class SchemaVersion(BaseModel): 

436 """The version of the schema.""" 

437 

438 current: str 

439 """The current version of the schema.""" 

440 

441 compatible: list[str] = Field(default_factory=list) 

442 """The compatible versions of the schema.""" 

443 

444 read_compatible: list[str] = Field(default_factory=list) 

445 """The read compatible versions of the schema.""" 

446 

447 

448class SchemaIdVisitor: 

449 """Visitor to build a Schema object's map of IDs to objects. 

450 

451 Duplicates are added to a set when they are encountered, which can be 

452 accessed via the `duplicates` attribute. The presence of duplicates will 

453 not throw an error. Only the first object with a given ID will be added to 

454 the map, but this should not matter, since a ValidationError will be thrown 

455 by the `model_validator` method if any duplicates are found in the schema. 

456 

457 This class is intended for internal use only. 

458 """ 

459 

460 def __init__(self) -> None: 

461 """Create a new SchemaVisitor.""" 

462 self.schema: "Schema" | None = None 

463 self.duplicates: set[str] = set() 

464 

465 def add(self, obj: BaseObject) -> None: 

466 """Add an object to the ID map.""" 

467 if hasattr(obj, "id"): 

468 obj_id = getattr(obj, "id") 

469 if self.schema is not None: 

470 if obj_id in self.schema.id_map: 

471 self.duplicates.add(obj_id) 

472 else: 

473 self.schema.id_map[obj_id] = obj 

474 

475 def visit_schema(self, schema: "Schema") -> None: 

476 """Visit the schema object that was added during initialization. 

477 

478 This will set an internal variable pointing to the schema object. 

479 """ 

480 self.schema = schema 

481 self.duplicates.clear() 

482 self.add(self.schema) 

483 for table in self.schema.tables: 

484 self.visit_table(table) 

485 

486 def visit_table(self, table: Table) -> None: 

487 """Visit a table object.""" 

488 self.add(table) 

489 for column in table.columns: 

490 self.visit_column(column) 

491 for constraint in table.constraints: 

492 self.visit_constraint(constraint) 

493 

494 def visit_column(self, column: Column) -> None: 

495 """Visit a column object.""" 

496 self.add(column) 

497 

498 def visit_constraint(self, constraint: Constraint) -> None: 

499 """Visit a constraint object.""" 

500 self.add(constraint) 

501 

502 

503class Schema(BaseObject): 

504 """The database schema containing the tables.""" 

505 

506 version: SchemaVersion | str | None = None 

507 """The version of the schema.""" 

508 

509 tables: Sequence[Table] 

510 """The tables in the schema.""" 

511 

512 id_map: dict[str, Any] = Field(default_factory=dict, exclude=True) 

513 """Map of IDs to objects.""" 

514 

515 @field_validator("tables", mode="after") 

516 @classmethod 

517 def check_unique_table_names(cls, tables: list[Table]) -> list[Table]: 

518 """Check that table names are unique.""" 

519 if len(tables) != len(set(table.name for table in tables)): 

520 raise ValueError("Table names must be unique") 

521 return tables 

522 

523 def _create_id_map(self: Schema) -> Schema: 

524 """Create a map of IDs to objects. 

525 

526 This method should not be called by users. It is called automatically 

527 by the ``model_post_init()`` method. If the ID map is already 

528 populated, this method will return immediately. 

529 """ 

530 if len(self.id_map): 

531 logger.debug("Ignoring call to create_id_map() - ID map was already populated") 

532 return self 

533 visitor: SchemaIdVisitor = SchemaIdVisitor() 

534 visitor.visit_schema(self) 

535 logger.debug(f"Created schema ID map with {len(self.id_map.keys())} objects") 

536 if len(visitor.duplicates): 

537 raise ValueError( 

538 "Duplicate IDs found in schema:\n " + "\n ".join(visitor.duplicates) + "\n" 

539 ) 

540 return self 

541 

542 def model_post_init(self, ctx: Any) -> None: 

543 """Post-initialization hook for the model.""" 

544 self._create_id_map() 

545 

546 def __getitem__(self, id: str) -> BaseObject: 

547 """Get an object by its ID.""" 

548 if id not in self: 

549 raise KeyError(f"Object with ID '{id}' not found in schema") 

550 return self.id_map[id] 

551 

552 def __contains__(self, id: str) -> bool: 

553 """Check if an object with the given ID is in the schema.""" 

554 return id in self.id_map