Coverage for python/felis/datamodel.py: 53%

300 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-20 02:40 -0700

1# This file is part of felis. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24import logging 

25import re 

26from collections.abc import Mapping, Sequence 

27from enum import StrEnum, auto 

28from typing import Annotated, Any, Literal, TypeAlias 

29 

30from astropy import units as units # type: ignore 

31from astropy.io.votable import ucd # type: ignore 

32from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator, model_validator 

33from sqlalchemy import dialects 

34from sqlalchemy import types as sqa_types 

35from sqlalchemy.engine import create_mock_engine 

36from sqlalchemy.engine.interfaces import Dialect 

37from sqlalchemy.types import TypeEngine 

38 

39from .db.sqltypes import get_type_func 

40from .types import FelisType 

41 

42logger = logging.getLogger(__name__) 

43 

44__all__ = ( 

45 "BaseObject", 

46 "Column", 

47 "CheckConstraint", 

48 "Constraint", 

49 "DescriptionStr", 

50 "ForeignKeyConstraint", 

51 "Index", 

52 "Schema", 

53 "SchemaVersion", 

54 "Table", 

55 "UniqueConstraint", 

56) 

57 

58CONFIG = ConfigDict( 

59 populate_by_name=True, # Populate attributes by name. 

60 extra="forbid", # Do not allow extra fields. 

61 str_strip_whitespace=True, # Strip whitespace from string fields. 

62) 

63"""Pydantic model configuration as described in: 

64https://docs.pydantic.dev/2.0/api/config/#pydantic.config.ConfigDict 

65""" 

66 

67DESCR_MIN_LENGTH = 3 

68"""Minimum length for a description field.""" 

69 

70DescriptionStr: TypeAlias = Annotated[str, Field(min_length=DESCR_MIN_LENGTH)] 

71"""Define a type for a description string, which must be three or more 

72characters long. Stripping of whitespace is done globally on all str fields.""" 

73 

74 

75class BaseObject(BaseModel): 

76 """Base class for all Felis objects.""" 

77 

78 model_config = CONFIG 

79 """Pydantic model configuration.""" 

80 

81 name: str 

82 """The name of the database object. 

83 

84 All Felis database objects must have a name. 

85 """ 

86 

87 id: str = Field(alias="@id") 

88 """The unique identifier of the database object. 

89 

90 All Felis database objects must have a unique identifier. 

91 """ 

92 

93 description: DescriptionStr | None = None 

94 """A description of the database object.""" 

95 

96 @model_validator(mode="after") # type: ignore[arg-type] 

97 @classmethod 

98 def check_description(cls, object: BaseObject, info: ValidationInfo) -> BaseObject: 

99 """Check that the description is present if required.""" 

100 context = info.context 

101 if not context or not context.get("require_description", False): 

102 return object 

103 if object.description is None or object.description == "": 

104 raise ValueError("Description is required and must be non-empty") 

105 if len(object.description) < DESCR_MIN_LENGTH: 

106 raise ValueError(f"Description must be at least {DESCR_MIN_LENGTH} characters long") 

107 return object 

108 

109 

110class DataType(StrEnum): 

111 """`Enum` representing the data types supported by Felis.""" 

112 

113 boolean = auto() 

114 byte = auto() 

115 short = auto() 

116 int = auto() 

117 long = auto() 

118 float = auto() 

119 double = auto() 

120 char = auto() 

121 string = auto() 

122 unicode = auto() 

123 text = auto() 

124 binary = auto() 

125 timestamp = auto() 

126 

127 

128_DIALECTS = { 

129 "mysql": create_mock_engine("mysql://", executor=None).dialect, 

130 "postgresql": create_mock_engine("postgresql://", executor=None).dialect, 

131} 

132"""Dictionary of dialect names to SQLAlchemy dialects.""" 

133 

134_DIALECT_MODULES = {"mysql": getattr(dialects, "mysql"), "postgresql": getattr(dialects, "postgresql")} 

135"""Dictionary of dialect names to SQLAlchemy dialect modules.""" 

136 

137_DATATYPE_REGEXP = re.compile(r"(\w+)(\((.*)\))?") 

138"""Regular expression to match data types in the form "type(length)""" 

139 

140 

141def string_to_typeengine( 

142 type_string: str, dialect: Dialect | None = None, length: int | None = None 

143) -> TypeEngine: 

144 match = _DATATYPE_REGEXP.search(type_string) 

145 if not match: 

146 raise ValueError(f"Invalid type string: {type_string}") 

147 

148 type_name, _, params = match.groups() 

149 if dialect is None: 

150 type_class = getattr(sqa_types, type_name.upper(), None) 

151 else: 

152 try: 

153 dialect_module = _DIALECT_MODULES[dialect.name] 

154 except KeyError: 

155 raise ValueError(f"Unsupported dialect: {dialect}") 

156 type_class = getattr(dialect_module, type_name.upper(), None) 

157 

158 if not type_class: 

159 raise ValueError(f"Unsupported type: {type_class}") 

160 

161 if params: 

162 params = [int(param) if param.isdigit() else param for param in params.split(",")] 

163 type_obj = type_class(*params) 

164 else: 

165 type_obj = type_class() 

166 

167 if hasattr(type_obj, "length") and getattr(type_obj, "length") is None and length is not None: 

168 type_obj.length = length 

169 

170 return type_obj 

171 

172 

173class Column(BaseObject): 

174 """A column in a table.""" 

175 

176 datatype: DataType 

177 """The datatype of the column.""" 

178 

179 length: int | None = None 

180 """The length of the column.""" 

181 

182 nullable: bool | None = None 

183 """Whether the column can be ``NULL``. 

184 

185 If `None`, this value was not set explicitly in the YAML data. In this 

186 case, it will be set to `False` for columns with numeric types and `True` 

187 otherwise. 

188 """ 

189 

190 value: Any = None 

191 """The default value of the column.""" 

192 

193 autoincrement: bool | None = None 

194 """Whether the column is autoincremented.""" 

195 

196 mysql_datatype: str | None = Field(None, alias="mysql:datatype") 

197 """The MySQL datatype of the column.""" 

198 

199 postgresql_datatype: str | None = Field(None, alias="postgresql:datatype") 

200 """The PostgreSQL datatype of the column.""" 

201 

202 ivoa_ucd: str | None = Field(None, alias="ivoa:ucd") 

203 """The IVOA UCD of the column.""" 

204 

205 fits_tunit: str | None = Field(None, alias="fits:tunit") 

206 """The FITS TUNIT of the column.""" 

207 

208 ivoa_unit: str | None = Field(None, alias="ivoa:unit") 

209 """The IVOA unit of the column.""" 

210 

211 tap_column_index: int | None = Field(None, alias="tap:column_index") 

212 """The TAP_SCHEMA column index of the column.""" 

213 

214 tap_principal: int | None = Field(0, alias="tap:principal", ge=0, le=1) 

215 """Whether this is a TAP_SCHEMA principal column; can be either 0 or 1. 

216 """ 

217 

218 votable_arraysize: int | Literal["*"] | None = Field(None, alias="votable:arraysize") 

219 """The VOTable arraysize of the column.""" 

220 

221 tap_std: int | None = Field(0, alias="tap:std", ge=0, le=1) 

222 """TAP_SCHEMA indication that this column is defined by an IVOA standard. 

223 """ 

224 

225 votable_utype: str | None = Field(None, alias="votable:utype") 

226 """The VOTable utype (usage-specific or unique type) of the column.""" 

227 

228 votable_xtype: str | None = Field(None, alias="votable:xtype") 

229 """The VOTable xtype (extended type) of the column.""" 

230 

231 @field_validator("ivoa_ucd") 

232 @classmethod 

233 def check_ivoa_ucd(cls, ivoa_ucd: str) -> str: 

234 """Check that IVOA UCD values are valid.""" 

235 if ivoa_ucd is not None: 

236 try: 

237 ucd.parse_ucd(ivoa_ucd, check_controlled_vocabulary=True, has_colon=";" in ivoa_ucd) 

238 except ValueError as e: 

239 raise ValueError(f"Invalid IVOA UCD: {e}") 

240 return ivoa_ucd 

241 

242 @model_validator(mode="before") 

243 @classmethod 

244 def check_units(cls, values: dict[str, Any]) -> dict[str, Any]: 

245 """Check that units are valid.""" 

246 fits_unit = values.get("fits:tunit") 

247 ivoa_unit = values.get("ivoa:unit") 

248 

249 if fits_unit and ivoa_unit: 

250 raise ValueError("Column cannot have both FITS and IVOA units") 

251 unit = fits_unit or ivoa_unit 

252 

253 if unit is not None: 

254 try: 

255 units.Unit(unit) 

256 except ValueError as e: 

257 raise ValueError(f"Invalid unit: {e}") 

258 

259 return values 

260 

261 @model_validator(mode="after") # type: ignore[arg-type] 

262 @classmethod 

263 def validate_datatypes(cls, col: Column, info: ValidationInfo) -> Column: 

264 """Check for redundant datatypes on columns.""" 

265 context = info.context 

266 if not context or not context.get("check_redundant_datatypes", False): 

267 return col 

268 if all(getattr(col, f"{dialect}:datatype", None) is not None for dialect in _DIALECTS.keys()): 

269 return col 

270 

271 datatype = col.datatype 

272 length: int | None = col.length or None 

273 

274 datatype_func = get_type_func(datatype) 

275 felis_type = FelisType.felis_type(datatype) 

276 if felis_type.is_sized: 

277 if length is not None: 

278 datatype_obj = datatype_func(length) 

279 else: 

280 raise ValueError(f"Length must be provided for sized type '{datatype}' in column '{col.id}'") 

281 else: 

282 datatype_obj = datatype_func() 

283 

284 for dialect_name, dialect in _DIALECTS.items(): 

285 db_annotation = f"{dialect_name}_datatype" 

286 if datatype_string := col.model_dump().get(db_annotation): 

287 db_datatype_obj = string_to_typeengine(datatype_string, dialect, length) 

288 if datatype_obj.compile(dialect) == db_datatype_obj.compile(dialect): 

289 raise ValueError( 

290 "'{}: {}' is a redundant override of 'datatype: {}' in column '{}'{}".format( 

291 db_annotation, 

292 datatype_string, 

293 col.datatype, 

294 col.id, 

295 "" if length is None else f" with length {length}", 

296 ) 

297 ) 

298 else: 

299 logger.debug( 

300 "Type override of 'datatype: {}' with '{}: {}' in column '{}' " 

301 "compiled to '{}' and '{}'".format( 

302 col.datatype, 

303 db_annotation, 

304 datatype_string, 

305 col.id, 

306 datatype_obj.compile(dialect), 

307 db_datatype_obj.compile(dialect), 

308 ) 

309 ) 

310 return col 

311 

312 

313class Constraint(BaseObject): 

314 """A database table constraint.""" 

315 

316 deferrable: bool = False 

317 """If `True` then this constraint will be declared as deferrable.""" 

318 

319 initially: str | None = None 

320 """Value for ``INITIALLY`` clause, only used if ``deferrable`` is True.""" 

321 

322 annotations: Mapping[str, Any] = Field(default_factory=dict) 

323 """Additional annotations for this constraint.""" 

324 

325 type: str | None = Field(None, alias="@type") 

326 """The type of the constraint.""" 

327 

328 

329class CheckConstraint(Constraint): 

330 """A check constraint on a table.""" 

331 

332 expression: str 

333 """The expression for the check constraint.""" 

334 

335 

336class UniqueConstraint(Constraint): 

337 """A unique constraint on a table.""" 

338 

339 columns: list[str] 

340 """The columns in the unique constraint.""" 

341 

342 

343class Index(BaseObject): 

344 """A database table index. 

345 

346 An index can be defined on either columns or expressions, but not both. 

347 """ 

348 

349 columns: list[str] | None = None 

350 """The columns in the index.""" 

351 

352 expressions: list[str] | None = None 

353 """The expressions in the index.""" 

354 

355 @model_validator(mode="before") 

356 @classmethod 

357 def check_columns_or_expressions(cls, values: dict[str, Any]) -> dict[str, Any]: 

358 """Check that columns or expressions are specified, but not both.""" 

359 if "columns" in values and "expressions" in values: 

360 raise ValueError("Defining columns and expressions is not valid") 

361 elif "columns" not in values and "expressions" not in values: 

362 raise ValueError("Must define columns or expressions") 

363 return values 

364 

365 

366class ForeignKeyConstraint(Constraint): 

367 """A foreign key constraint on a table. 

368 

369 These will be reflected in the TAP_SCHEMA keys and key_columns data. 

370 """ 

371 

372 columns: list[str] 

373 """The columns comprising the foreign key.""" 

374 

375 referenced_columns: list[str] = Field(alias="referencedColumns") 

376 """The columns referenced by the foreign key.""" 

377 

378 

379class Table(BaseObject): 

380 """A database table.""" 

381 

382 columns: Sequence[Column] 

383 """The columns in the table.""" 

384 

385 constraints: list[Constraint] = Field(default_factory=list) 

386 """The constraints on the table.""" 

387 

388 indexes: list[Index] = Field(default_factory=list) 

389 """The indexes on the table.""" 

390 

391 primary_key: str | list[str] | None = Field(None, alias="primaryKey") 

392 """The primary key of the table.""" 

393 

394 tap_table_index: int | None = Field(None, alias="tap:table_index") 

395 """The IVOA TAP_SCHEMA table index of the table.""" 

396 

397 mysql_engine: str | None = Field(None, alias="mysql:engine") 

398 """The mysql engine to use for the table. 

399 

400 For now this is a freeform string but it could be constrained to a list of 

401 known engines in the future. 

402 """ 

403 

404 mysql_charset: str | None = Field(None, alias="mysql:charset") 

405 """The mysql charset to use for the table. 

406 

407 For now this is a freeform string but it could be constrained to a list of 

408 known charsets in the future. 

409 """ 

410 

411 @model_validator(mode="before") 

412 @classmethod 

413 def create_constraints(cls, values: dict[str, Any]) -> dict[str, Any]: 

414 """Create constraints from the ``constraints`` field.""" 

415 if "constraints" in values: 

416 new_constraints: list[Constraint] = [] 

417 for item in values["constraints"]: 

418 if item["@type"] == "ForeignKey": 

419 new_constraints.append(ForeignKeyConstraint(**item)) 

420 elif item["@type"] == "Unique": 

421 new_constraints.append(UniqueConstraint(**item)) 

422 elif item["@type"] == "Check": 

423 new_constraints.append(CheckConstraint(**item)) 

424 else: 

425 raise ValueError(f"Unknown constraint type: {item['@type']}") 

426 values["constraints"] = new_constraints 

427 return values 

428 

429 @field_validator("columns", mode="after") 

430 @classmethod 

431 def check_unique_column_names(cls, columns: list[Column]) -> list[Column]: 

432 """Check that column names are unique.""" 

433 if len(columns) != len(set(column.name for column in columns)): 

434 raise ValueError("Column names must be unique") 

435 return columns 

436 

437 

438class SchemaVersion(BaseModel): 

439 """The version of the schema.""" 

440 

441 current: str 

442 """The current version of the schema.""" 

443 

444 compatible: list[str] = Field(default_factory=list) 

445 """The compatible versions of the schema.""" 

446 

447 read_compatible: list[str] = Field(default_factory=list) 

448 """The read compatible versions of the schema.""" 

449 

450 

451class SchemaIdVisitor: 

452 """Visitor to build a Schema object's map of IDs to objects. 

453 

454 Duplicates are added to a set when they are encountered, which can be 

455 accessed via the `duplicates` attribute. The presence of duplicates will 

456 not throw an error. Only the first object with a given ID will be added to 

457 the map, but this should not matter, since a ValidationError will be thrown 

458 by the `model_validator` method if any duplicates are found in the schema. 

459 

460 This class is intended for internal use only. 

461 """ 

462 

463 def __init__(self) -> None: 

464 """Create a new SchemaVisitor.""" 

465 self.schema: "Schema" | None = None 

466 self.duplicates: set[str] = set() 

467 

468 def add(self, obj: BaseObject) -> None: 

469 """Add an object to the ID map.""" 

470 if hasattr(obj, "id"): 

471 obj_id = getattr(obj, "id") 

472 if self.schema is not None: 

473 if obj_id in self.schema.id_map: 

474 self.duplicates.add(obj_id) 

475 else: 

476 self.schema.id_map[obj_id] = obj 

477 

478 def visit_schema(self, schema: "Schema") -> None: 

479 """Visit the schema object that was added during initialization. 

480 

481 This will set an internal variable pointing to the schema object. 

482 """ 

483 self.schema = schema 

484 self.duplicates.clear() 

485 self.add(self.schema) 

486 for table in self.schema.tables: 

487 self.visit_table(table) 

488 

489 def visit_table(self, table: Table) -> None: 

490 """Visit a table object.""" 

491 self.add(table) 

492 for column in table.columns: 

493 self.visit_column(column) 

494 for constraint in table.constraints: 

495 self.visit_constraint(constraint) 

496 

497 def visit_column(self, column: Column) -> None: 

498 """Visit a column object.""" 

499 self.add(column) 

500 

501 def visit_constraint(self, constraint: Constraint) -> None: 

502 """Visit a constraint object.""" 

503 self.add(constraint) 

504 

505 

506class Schema(BaseObject): 

507 """The database schema containing the tables.""" 

508 

509 version: SchemaVersion | str | None = None 

510 """The version of the schema.""" 

511 

512 tables: Sequence[Table] 

513 """The tables in the schema.""" 

514 

515 id_map: dict[str, Any] = Field(default_factory=dict, exclude=True) 

516 """Map of IDs to objects.""" 

517 

518 @field_validator("tables", mode="after") 

519 @classmethod 

520 def check_unique_table_names(cls, tables: list[Table]) -> list[Table]: 

521 """Check that table names are unique.""" 

522 if len(tables) != len(set(table.name for table in tables)): 

523 raise ValueError("Table names must be unique") 

524 return tables 

525 

526 def _create_id_map(self: Schema) -> Schema: 

527 """Create a map of IDs to objects. 

528 

529 This method should not be called by users. It is called automatically 

530 by the ``model_post_init()`` method. If the ID map is already 

531 populated, this method will return immediately. 

532 """ 

533 if len(self.id_map): 

534 logger.debug("Ignoring call to create_id_map() - ID map was already populated") 

535 return self 

536 visitor: SchemaIdVisitor = SchemaIdVisitor() 

537 visitor.visit_schema(self) 

538 logger.debug(f"Created schema ID map with {len(self.id_map.keys())} objects") 

539 if len(visitor.duplicates): 

540 raise ValueError( 

541 "Duplicate IDs found in schema:\n " + "\n ".join(visitor.duplicates) + "\n" 

542 ) 

543 return self 

544 

545 def model_post_init(self, ctx: Any) -> None: 

546 """Post-initialization hook for the model.""" 

547 self._create_id_map() 

548 

549 def __getitem__(self, id: str) -> BaseObject: 

550 """Get an object by its ID.""" 

551 if id not in self: 

552 raise KeyError(f"Object with ID '{id}' not found in schema") 

553 return self.id_map[id] 

554 

555 def __contains__(self, id: str) -> bool: 

556 """Check if an object with the given ID is in the schema.""" 

557 return id in self.id_map