Coverage for python/felis/datamodel.py: 53%

301 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-05-01 15:16 -0700

1# This file is part of felis. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24import logging 

25import re 

26from collections.abc import Mapping, Sequence 

27from enum import StrEnum, auto 

28from typing import Annotated, Any, Literal, TypeAlias 

29 

30from astropy import units as units # type: ignore 

31from astropy.io.votable import ucd # type: ignore 

32from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator, model_validator 

33from sqlalchemy import dialects 

34from sqlalchemy import types as sqa_types 

35from sqlalchemy.engine import create_mock_engine 

36from sqlalchemy.engine.interfaces import Dialect 

37from sqlalchemy.types import TypeEngine 

38 

39from .db.sqltypes import get_type_func 

40from .types import FelisType 

41 

42logger = logging.getLogger(__name__) 

43 

44__all__ = ( 

45 "BaseObject", 

46 "Column", 

47 "CheckConstraint", 

48 "Constraint", 

49 "DescriptionStr", 

50 "ForeignKeyConstraint", 

51 "Index", 

52 "Schema", 

53 "SchemaVersion", 

54 "Table", 

55 "UniqueConstraint", 

56) 

57 

58CONFIG = ConfigDict( 

59 populate_by_name=True, # Populate attributes by name. 

60 extra="forbid", # Do not allow extra fields. 

61 str_strip_whitespace=True, # Strip whitespace from string fields. 

62) 

63"""Pydantic model configuration as described in: 

64https://docs.pydantic.dev/2.0/api/config/#pydantic.config.ConfigDict 

65""" 

66 

67DESCR_MIN_LENGTH = 3 

68"""Minimum length for a description field.""" 

69 

70DescriptionStr: TypeAlias = Annotated[str, Field(min_length=DESCR_MIN_LENGTH)] 

71"""Define a type for a description string, which must be three or more 

72characters long. Stripping of whitespace is done globally on all str fields.""" 

73 

74 

75class BaseObject(BaseModel): 

76 """Base class for all Felis objects.""" 

77 

78 model_config = CONFIG 

79 """Pydantic model configuration.""" 

80 

81 name: str 

82 """The name of the database object. 

83 

84 All Felis database objects must have a name. 

85 """ 

86 

87 id: str = Field(alias="@id") 

88 """The unique identifier of the database object. 

89 

90 All Felis database objects must have a unique identifier. 

91 """ 

92 

93 description: DescriptionStr | None = None 

94 """A description of the database object.""" 

95 

96 @model_validator(mode="after") # type: ignore[arg-type] 

97 @classmethod 

98 def check_description(cls, object: BaseObject, info: ValidationInfo) -> BaseObject: 

99 """Check that the description is present if required.""" 

100 context = info.context 

101 if not context or not context.get("require_description", False): 

102 return object 

103 if object.description is None or object.description == "": 

104 raise ValueError("Description is required and must be non-empty") 

105 if len(object.description) < DESCR_MIN_LENGTH: 

106 raise ValueError(f"Description must be at least {DESCR_MIN_LENGTH} characters long") 

107 return object 

108 

109 

110class DataType(StrEnum): 

111 """`Enum` representing the data types supported by Felis.""" 

112 

113 boolean = auto() 

114 byte = auto() 

115 short = auto() 

116 int = auto() 

117 long = auto() 

118 float = auto() 

119 double = auto() 

120 char = auto() 

121 string = auto() 

122 unicode = auto() 

123 text = auto() 

124 binary = auto() 

125 timestamp = auto() 

126 

127 

128_DIALECTS = { 

129 "mysql": create_mock_engine("mysql://", executor=None).dialect, 

130 "postgresql": create_mock_engine("postgresql://", executor=None).dialect, 

131} 

132"""Dictionary of dialect names to SQLAlchemy dialects.""" 

133 

134_DIALECT_MODULES = {"mysql": getattr(dialects, "mysql"), "postgresql": getattr(dialects, "postgresql")} 

135"""Dictionary of dialect names to SQLAlchemy dialect modules.""" 

136 

137_DATATYPE_REGEXP = re.compile(r"(\w+)(\((.*)\))?") 

138"""Regular expression to match data types in the form "type(length)""" 

139 

140 

141def string_to_typeengine( 

142 type_string: str, dialect: Dialect | None = None, length: int | None = None 

143) -> TypeEngine: 

144 match = _DATATYPE_REGEXP.search(type_string) 

145 if not match: 

146 raise ValueError(f"Invalid type string: {type_string}") 

147 

148 type_name, _, params = match.groups() 

149 if dialect is None: 

150 type_class = getattr(sqa_types, type_name.upper(), None) 

151 else: 

152 try: 

153 dialect_module = _DIALECT_MODULES[dialect.name] 

154 except KeyError: 

155 raise ValueError(f"Unsupported dialect: {dialect}") 

156 type_class = getattr(dialect_module, type_name.upper(), None) 

157 

158 if not type_class: 

159 raise ValueError(f"Unsupported type: {type_class}") 

160 

161 if params: 

162 params = [int(param) if param.isdigit() else param for param in params.split(",")] 

163 type_obj = type_class(*params) 

164 else: 

165 type_obj = type_class() 

166 

167 if hasattr(type_obj, "length") and getattr(type_obj, "length") is None and length is not None: 

168 type_obj.length = length 

169 

170 return type_obj 

171 

172 

173class Column(BaseObject): 

174 """A column in a table.""" 

175 

176 datatype: DataType 

177 """The datatype of the column.""" 

178 

179 length: int | None = None 

180 """The length of the column.""" 

181 

182 nullable: bool | None = None 

183 """Whether the column can be ``NULL``. 

184 

185 If `None`, this value was not set explicitly in the YAML data. In this 

186 case, it will be set to `False` for columns with numeric types and `True` 

187 otherwise. 

188 """ 

189 

190 value: Any = None 

191 """The default value of the column.""" 

192 

193 autoincrement: bool | None = None 

194 """Whether the column is autoincremented.""" 

195 

196 mysql_datatype: str | None = Field(None, alias="mysql:datatype") 

197 """The MySQL datatype of the column.""" 

198 

199 postgresql_datatype: str | None = Field(None, alias="postgresql:datatype") 

200 """The PostgreSQL datatype of the column.""" 

201 

202 ivoa_ucd: str | None = Field(None, alias="ivoa:ucd") 

203 """The IVOA UCD of the column.""" 

204 

205 fits_tunit: str | None = Field(None, alias="fits:tunit") 

206 """The FITS TUNIT of the column.""" 

207 

208 ivoa_unit: str | None = Field(None, alias="ivoa:unit") 

209 """The IVOA unit of the column.""" 

210 

211 tap_column_index: int | None = Field(None, alias="tap:column_index") 

212 """The TAP_SCHEMA column index of the column.""" 

213 

214 tap_principal: int | None = Field(0, alias="tap:principal", ge=0, le=1) 

215 """Whether this is a TAP_SCHEMA principal column; can be either 0 or 1. 

216 """ 

217 

218 votable_arraysize: int | Literal["*"] | None = Field(None, alias="votable:arraysize") 

219 """The VOTable arraysize of the column.""" 

220 

221 tap_std: int | None = Field(0, alias="tap:std", ge=0, le=1) 

222 """TAP_SCHEMA indication that this column is defined by an IVOA standard. 

223 """ 

224 

225 votable_utype: str | None = Field(None, alias="votable:utype") 

226 """The VOTable utype (usage-specific or unique type) of the column.""" 

227 

228 votable_xtype: str | None = Field(None, alias="votable:xtype") 

229 """The VOTable xtype (extended type) of the column.""" 

230 

231 @field_validator("ivoa_ucd") 

232 @classmethod 

233 def check_ivoa_ucd(cls, ivoa_ucd: str) -> str: 

234 """Check that IVOA UCD values are valid.""" 

235 if ivoa_ucd is not None: 

236 try: 

237 ucd.parse_ucd(ivoa_ucd, check_controlled_vocabulary=True, has_colon=";" in ivoa_ucd) 

238 except ValueError as e: 

239 raise ValueError(f"Invalid IVOA UCD: {e}") 

240 return ivoa_ucd 

241 

242 @model_validator(mode="before") 

243 @classmethod 

244 def check_units(cls, values: dict[str, Any]) -> dict[str, Any]: 

245 """Check that units are valid.""" 

246 fits_unit = values.get("fits:tunit") 

247 ivoa_unit = values.get("ivoa:unit") 

248 

249 if fits_unit and ivoa_unit: 

250 raise ValueError("Column cannot have both FITS and IVOA units") 

251 unit = fits_unit or ivoa_unit 

252 

253 if unit is not None: 

254 try: 

255 units.Unit(unit) 

256 except ValueError as e: 

257 raise ValueError(f"Invalid unit: {e}") 

258 

259 return values 

260 

261 @model_validator(mode="after") # type: ignore[arg-type] 

262 @classmethod 

263 def validate_datatypes(cls, col: Column, info: ValidationInfo) -> Column: 

264 """Check for redundant datatypes on columns.""" 

265 context = info.context 

266 if not context or not context.get("check_redundant_datatypes", False): 

267 return col 

268 if all(getattr(col, f"{dialect}:datatype", None) is not None for dialect in _DIALECTS.keys()): 

269 return col 

270 

271 datatype = col.datatype 

272 length: int | None = col.length or None 

273 

274 datatype_func = get_type_func(datatype) 

275 felis_type = FelisType.felis_type(datatype) 

276 if felis_type.is_sized: 

277 if length is not None: 

278 datatype_obj = datatype_func(length) 

279 else: 

280 raise ValueError(f"Length must be provided for sized type '{datatype}' in column '{col.id}'") 

281 else: 

282 datatype_obj = datatype_func() 

283 

284 for dialect_name, dialect in _DIALECTS.items(): 

285 db_annotation = f"{dialect_name}_datatype" 

286 if datatype_string := col.model_dump().get(db_annotation): 

287 db_datatype_obj = string_to_typeengine(datatype_string, dialect, length) 

288 if datatype_obj.compile(dialect) == db_datatype_obj.compile(dialect): 

289 raise ValueError( 

290 "'{}: {}' is the same as 'datatype: {}' in column '{}'".format( 

291 db_annotation, datatype_string, col.datatype, col.id 

292 ) 

293 ) 

294 else: 

295 logger.debug( 

296 "Valid type override of 'datatype: {}' with '{}: {}' in column '{}'".format( 

297 col.datatype, db_annotation, datatype_string, col.id 

298 ) 

299 ) 

300 logger.debug( 

301 "Compiled datatype '{}' with {} compiled override '{}'".format( 

302 datatype_obj.compile(dialect), dialect_name, db_datatype_obj.compile(dialect) 

303 ) 

304 ) 

305 

306 return col 

307 

308 

309class Constraint(BaseObject): 

310 """A database table constraint.""" 

311 

312 deferrable: bool = False 

313 """If `True` then this constraint will be declared as deferrable.""" 

314 

315 initially: str | None = None 

316 """Value for ``INITIALLY`` clause, only used if ``deferrable`` is True.""" 

317 

318 annotations: Mapping[str, Any] = Field(default_factory=dict) 

319 """Additional annotations for this constraint.""" 

320 

321 type: str | None = Field(None, alias="@type") 

322 """The type of the constraint.""" 

323 

324 

325class CheckConstraint(Constraint): 

326 """A check constraint on a table.""" 

327 

328 expression: str 

329 """The expression for the check constraint.""" 

330 

331 

332class UniqueConstraint(Constraint): 

333 """A unique constraint on a table.""" 

334 

335 columns: list[str] 

336 """The columns in the unique constraint.""" 

337 

338 

339class Index(BaseObject): 

340 """A database table index. 

341 

342 An index can be defined on either columns or expressions, but not both. 

343 """ 

344 

345 columns: list[str] | None = None 

346 """The columns in the index.""" 

347 

348 expressions: list[str] | None = None 

349 """The expressions in the index.""" 

350 

351 @model_validator(mode="before") 

352 @classmethod 

353 def check_columns_or_expressions(cls, values: dict[str, Any]) -> dict[str, Any]: 

354 """Check that columns or expressions are specified, but not both.""" 

355 if "columns" in values and "expressions" in values: 

356 raise ValueError("Defining columns and expressions is not valid") 

357 elif "columns" not in values and "expressions" not in values: 

358 raise ValueError("Must define columns or expressions") 

359 return values 

360 

361 

362class ForeignKeyConstraint(Constraint): 

363 """A foreign key constraint on a table. 

364 

365 These will be reflected in the TAP_SCHEMA keys and key_columns data. 

366 """ 

367 

368 columns: list[str] 

369 """The columns comprising the foreign key.""" 

370 

371 referenced_columns: list[str] = Field(alias="referencedColumns") 

372 """The columns referenced by the foreign key.""" 

373 

374 

375class Table(BaseObject): 

376 """A database table.""" 

377 

378 columns: Sequence[Column] 

379 """The columns in the table.""" 

380 

381 constraints: list[Constraint] = Field(default_factory=list) 

382 """The constraints on the table.""" 

383 

384 indexes: list[Index] = Field(default_factory=list) 

385 """The indexes on the table.""" 

386 

387 primary_key: str | list[str] | None = Field(None, alias="primaryKey") 

388 """The primary key of the table.""" 

389 

390 tap_table_index: int | None = Field(None, alias="tap:table_index") 

391 """The IVOA TAP_SCHEMA table index of the table.""" 

392 

393 mysql_engine: str | None = Field(None, alias="mysql:engine") 

394 """The mysql engine to use for the table. 

395 

396 For now this is a freeform string but it could be constrained to a list of 

397 known engines in the future. 

398 """ 

399 

400 mysql_charset: str | None = Field(None, alias="mysql:charset") 

401 """The mysql charset to use for the table. 

402 

403 For now this is a freeform string but it could be constrained to a list of 

404 known charsets in the future. 

405 """ 

406 

407 @model_validator(mode="before") 

408 @classmethod 

409 def create_constraints(cls, values: dict[str, Any]) -> dict[str, Any]: 

410 """Create constraints from the ``constraints`` field.""" 

411 if "constraints" in values: 

412 new_constraints: list[Constraint] = [] 

413 for item in values["constraints"]: 

414 if item["@type"] == "ForeignKey": 

415 new_constraints.append(ForeignKeyConstraint(**item)) 

416 elif item["@type"] == "Unique": 

417 new_constraints.append(UniqueConstraint(**item)) 

418 elif item["@type"] == "Check": 

419 new_constraints.append(CheckConstraint(**item)) 

420 else: 

421 raise ValueError(f"Unknown constraint type: {item['@type']}") 

422 values["constraints"] = new_constraints 

423 return values 

424 

425 @field_validator("columns", mode="after") 

426 @classmethod 

427 def check_unique_column_names(cls, columns: list[Column]) -> list[Column]: 

428 """Check that column names are unique.""" 

429 if len(columns) != len(set(column.name for column in columns)): 

430 raise ValueError("Column names must be unique") 

431 return columns 

432 

433 

434class SchemaVersion(BaseModel): 

435 """The version of the schema.""" 

436 

437 current: str 

438 """The current version of the schema.""" 

439 

440 compatible: list[str] = Field(default_factory=list) 

441 """The compatible versions of the schema.""" 

442 

443 read_compatible: list[str] = Field(default_factory=list) 

444 """The read compatible versions of the schema.""" 

445 

446 

447class SchemaIdVisitor: 

448 """Visitor to build a Schema object's map of IDs to objects. 

449 

450 Duplicates are added to a set when they are encountered, which can be 

451 accessed via the `duplicates` attribute. The presence of duplicates will 

452 not throw an error. Only the first object with a given ID will be added to 

453 the map, but this should not matter, since a ValidationError will be thrown 

454 by the `model_validator` method if any duplicates are found in the schema. 

455 

456 This class is intended for internal use only. 

457 """ 

458 

459 def __init__(self) -> None: 

460 """Create a new SchemaVisitor.""" 

461 self.schema: "Schema" | None = None 

462 self.duplicates: set[str] = set() 

463 

464 def add(self, obj: BaseObject) -> None: 

465 """Add an object to the ID map.""" 

466 if hasattr(obj, "id"): 

467 obj_id = getattr(obj, "id") 

468 if self.schema is not None: 

469 if obj_id in self.schema.id_map: 

470 self.duplicates.add(obj_id) 

471 else: 

472 self.schema.id_map[obj_id] = obj 

473 

474 def visit_schema(self, schema: "Schema") -> None: 

475 """Visit the schema object that was added during initialization. 

476 

477 This will set an internal variable pointing to the schema object. 

478 """ 

479 self.schema = schema 

480 self.duplicates.clear() 

481 self.add(self.schema) 

482 for table in self.schema.tables: 

483 self.visit_table(table) 

484 

485 def visit_table(self, table: Table) -> None: 

486 """Visit a table object.""" 

487 self.add(table) 

488 for column in table.columns: 

489 self.visit_column(column) 

490 for constraint in table.constraints: 

491 self.visit_constraint(constraint) 

492 

493 def visit_column(self, column: Column) -> None: 

494 """Visit a column object.""" 

495 self.add(column) 

496 

497 def visit_constraint(self, constraint: Constraint) -> None: 

498 """Visit a constraint object.""" 

499 self.add(constraint) 

500 

501 

502class Schema(BaseObject): 

503 """The database schema containing the tables.""" 

504 

505 version: SchemaVersion | str | None = None 

506 """The version of the schema.""" 

507 

508 tables: Sequence[Table] 

509 """The tables in the schema.""" 

510 

511 id_map: dict[str, Any] = Field(default_factory=dict, exclude=True) 

512 """Map of IDs to objects.""" 

513 

514 @field_validator("tables", mode="after") 

515 @classmethod 

516 def check_unique_table_names(cls, tables: list[Table]) -> list[Table]: 

517 """Check that table names are unique.""" 

518 if len(tables) != len(set(table.name for table in tables)): 

519 raise ValueError("Table names must be unique") 

520 return tables 

521 

522 def _create_id_map(self: Schema) -> Schema: 

523 """Create a map of IDs to objects. 

524 

525 This method should not be called by users. It is called automatically 

526 by the ``model_post_init()`` method. If the ID map is already 

527 populated, this method will return immediately. 

528 """ 

529 if len(self.id_map): 

530 logger.debug("Ignoring call to create_id_map() - ID map was already populated") 

531 return self 

532 visitor: SchemaIdVisitor = SchemaIdVisitor() 

533 visitor.visit_schema(self) 

534 logger.debug(f"Created schema ID map with {len(self.id_map.keys())} objects") 

535 if len(visitor.duplicates): 

536 raise ValueError( 

537 "Duplicate IDs found in schema:\n " + "\n ".join(visitor.duplicates) + "\n" 

538 ) 

539 return self 

540 

541 def model_post_init(self, ctx: Any) -> None: 

542 """Post-initialization hook for the model.""" 

543 self._create_id_map() 

544 

545 def __getitem__(self, id: str) -> BaseObject: 

546 """Get an object by its ID.""" 

547 if id not in self: 

548 raise KeyError(f"Object with ID '{id}' not found in schema") 

549 return self.id_map[id] 

550 

551 def __contains__(self, id: str) -> bool: 

552 """Check if an object with the given ID is in the schema.""" 

553 return id in self.id_map