Coverage for python/felis/datamodel.py: 48%
330 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-14 09:10 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-14 09:10 +0000
1# This file is part of felis.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24import logging
25import re
26from collections.abc import Mapping, Sequence
27from enum import StrEnum, auto
28from typing import Annotated, Any, Literal, TypeAlias
30from astropy import units as units # type: ignore
31from astropy.io.votable import ucd # type: ignore
32from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator, model_validator
33from sqlalchemy import dialects
34from sqlalchemy import types as sqa_types
35from sqlalchemy.engine import create_mock_engine
36from sqlalchemy.engine.interfaces import Dialect
37from sqlalchemy.types import TypeEngine
39from .db.sqltypes import get_type_func
40from .types import Boolean, Byte, Char, Double, FelisType, Float, Int, Long, Short, String, Text, Unicode
42logger = logging.getLogger(__name__)
44__all__ = (
45 "BaseObject",
46 "Column",
47 "CheckConstraint",
48 "Constraint",
49 "DescriptionStr",
50 "ForeignKeyConstraint",
51 "Index",
52 "Schema",
53 "SchemaVersion",
54 "Table",
55 "UniqueConstraint",
56)
58CONFIG = ConfigDict(
59 populate_by_name=True, # Populate attributes by name.
60 extra="forbid", # Do not allow extra fields.
61 str_strip_whitespace=True, # Strip whitespace from string fields.
62)
63"""Pydantic model configuration as described in:
64https://docs.pydantic.dev/2.0/api/config/#pydantic.config.ConfigDict
65"""
67DESCR_MIN_LENGTH = 3
68"""Minimum length for a description field."""
70DescriptionStr: TypeAlias = Annotated[str, Field(min_length=DESCR_MIN_LENGTH)]
71"""Define a type for a description string, which must be three or more
72characters long. Stripping of whitespace is done globally on all str fields."""
75class BaseObject(BaseModel):
76 """Base class for all Felis objects."""
78 model_config = CONFIG
79 """Pydantic model configuration."""
81 name: str
82 """The name of the database object.
84 All Felis database objects must have a name.
85 """
87 id: str = Field(alias="@id")
88 """The unique identifier of the database object.
90 All Felis database objects must have a unique identifier.
91 """
93 description: DescriptionStr | None = None
94 """A description of the database object."""
96 votable_utype: str | None = Field(None, alias="votable:utype")
97 """The VOTable utype (usage-specific or unique type) of the object."""
99 @model_validator(mode="after")
100 def check_description(self, info: ValidationInfo) -> BaseObject:
101 """Check that the description is present if required."""
102 context = info.context
103 if not context or not context.get("require_description", False):
104 return self
105 if self.description is None or self.description == "":
106 raise ValueError("Description is required and must be non-empty")
107 if len(self.description) < DESCR_MIN_LENGTH:
108 raise ValueError(f"Description must be at least {DESCR_MIN_LENGTH} characters long")
109 return self
112class DataType(StrEnum):
113 """`Enum` representing the data types supported by Felis."""
115 boolean = auto()
116 byte = auto()
117 short = auto()
118 int = auto()
119 long = auto()
120 float = auto()
121 double = auto()
122 char = auto()
123 string = auto()
124 unicode = auto()
125 text = auto()
126 binary = auto()
127 timestamp = auto()
130_DIALECTS = {
131 "mysql": create_mock_engine("mysql://", executor=None).dialect,
132 "postgresql": create_mock_engine("postgresql://", executor=None).dialect,
133}
134"""Dictionary of dialect names to SQLAlchemy dialects."""
136_DIALECT_MODULES = {"mysql": getattr(dialects, "mysql"), "postgresql": getattr(dialects, "postgresql")}
137"""Dictionary of dialect names to SQLAlchemy dialect modules."""
139_DATATYPE_REGEXP = re.compile(r"(\w+)(\((.*)\))?")
140"""Regular expression to match data types in the form "type(length)"""
143def string_to_typeengine(
144 type_string: str, dialect: Dialect | None = None, length: int | None = None
145) -> TypeEngine:
146 match = _DATATYPE_REGEXP.search(type_string)
147 if not match:
148 raise ValueError(f"Invalid type string: {type_string}")
150 type_name, _, params = match.groups()
151 if dialect is None:
152 type_class = getattr(sqa_types, type_name.upper(), None)
153 else:
154 try:
155 dialect_module = _DIALECT_MODULES[dialect.name]
156 except KeyError:
157 raise ValueError(f"Unsupported dialect: {dialect}")
158 type_class = getattr(dialect_module, type_name.upper(), None)
160 if not type_class:
161 raise ValueError(f"Unsupported type: {type_class}")
163 if params:
164 params = [int(param) if param.isdigit() else param for param in params.split(",")]
165 type_obj = type_class(*params)
166 else:
167 type_obj = type_class()
169 if hasattr(type_obj, "length") and getattr(type_obj, "length") is None and length is not None:
170 type_obj.length = length
172 return type_obj
175class Column(BaseObject):
176 """A column in a table."""
178 datatype: DataType
179 """The datatype of the column."""
181 length: int | None = Field(None, gt=0)
182 """The length of the column."""
184 nullable: bool = True
185 """Whether the column can be ``NULL``."""
187 value: str | int | float | bool | None = None
188 """The default value of the column."""
190 autoincrement: bool | None = None
191 """Whether the column is autoincremented."""
193 mysql_datatype: str | None = Field(None, alias="mysql:datatype")
194 """The MySQL datatype of the column."""
196 postgresql_datatype: str | None = Field(None, alias="postgresql:datatype")
197 """The PostgreSQL datatype of the column."""
199 ivoa_ucd: str | None = Field(None, alias="ivoa:ucd")
200 """The IVOA UCD of the column."""
202 fits_tunit: str | None = Field(None, alias="fits:tunit")
203 """The FITS TUNIT of the column."""
205 ivoa_unit: str | None = Field(None, alias="ivoa:unit")
206 """The IVOA unit of the column."""
208 tap_column_index: int | None = Field(None, alias="tap:column_index")
209 """The TAP_SCHEMA column index of the column."""
211 tap_principal: int | None = Field(0, alias="tap:principal", ge=0, le=1)
212 """Whether this is a TAP_SCHEMA principal column; can be either 0 or 1.
213 """
215 votable_arraysize: int | Literal["*"] | None = Field(None, alias="votable:arraysize")
216 """The VOTable arraysize of the column."""
218 tap_std: int | None = Field(0, alias="tap:std", ge=0, le=1)
219 """TAP_SCHEMA indication that this column is defined by an IVOA standard.
220 """
222 votable_xtype: str | None = Field(None, alias="votable:xtype")
223 """The VOTable xtype (extended type) of the column."""
225 votable_datatype: str | None = Field(None, alias="votable:datatype")
226 """The VOTable datatype of the column."""
228 @model_validator(mode="after")
229 def check_value(self) -> Column:
230 """Check that the default value is valid."""
231 if (value := self.value) is not None:
232 if value is not None and self.autoincrement is True:
233 raise ValueError("Column cannot have both a default value and be autoincremented")
234 felis_type = FelisType.felis_type(self.datatype)
235 if felis_type.is_numeric:
236 if felis_type in (Byte, Short, Int, Long) and not isinstance(value, int):
237 raise ValueError("Default value must be an int for integer type columns")
238 elif felis_type in (Float, Double) and not isinstance(value, float):
239 raise ValueError("Default value must be a decimal number for float and double columns")
240 elif felis_type in (String, Char, Unicode, Text):
241 if not isinstance(value, str):
242 raise ValueError("Default value must be a string for string columns")
243 if not len(value):
244 raise ValueError("Default value must be a non-empty string for string columns")
245 elif felis_type is Boolean and not isinstance(value, bool):
246 raise ValueError("Default value must be a boolean for boolean columns")
247 return self
249 @field_validator("ivoa_ucd")
250 @classmethod
251 def check_ivoa_ucd(cls, ivoa_ucd: str) -> str:
252 """Check that IVOA UCD values are valid."""
253 if ivoa_ucd is not None:
254 try:
255 ucd.parse_ucd(ivoa_ucd, check_controlled_vocabulary=True, has_colon=";" in ivoa_ucd)
256 except ValueError as e:
257 raise ValueError(f"Invalid IVOA UCD: {e}")
258 return ivoa_ucd
260 @model_validator(mode="before")
261 @classmethod
262 def check_units(cls, values: dict[str, Any]) -> dict[str, Any]:
263 """Check that units are valid."""
264 fits_unit = values.get("fits:tunit")
265 ivoa_unit = values.get("ivoa:unit")
267 if fits_unit and ivoa_unit:
268 raise ValueError("Column cannot have both FITS and IVOA units")
269 unit = fits_unit or ivoa_unit
271 if unit is not None:
272 try:
273 units.Unit(unit)
274 except ValueError as e:
275 raise ValueError(f"Invalid unit: {e}")
277 return values
279 @model_validator(mode="before")
280 @classmethod
281 def check_length(cls, values: dict[str, Any]) -> dict[str, Any]:
282 """Check that a valid length is provided for sized types."""
283 datatype = values.get("datatype")
284 if datatype is None:
285 # Skip this validation if datatype is not provided
286 return values
287 length = values.get("length")
288 felis_type = FelisType.felis_type(datatype)
289 if felis_type.is_sized and length is None:
290 raise ValueError(
291 f"Length must be provided for type '{datatype}'"
292 + (f" in column '{values['@id']}'" if "@id" in values else "")
293 )
294 elif not felis_type.is_sized and length is not None:
295 logger.warning(
296 f"The datatype '{datatype}' does not support a specified length"
297 + (f" in column '{values['@id']}'" if "@id" in values else "")
298 )
299 return values
301 @model_validator(mode="after")
302 def check_datatypes(self, info: ValidationInfo) -> Column:
303 """Check for redundant datatypes on columns."""
304 context = info.context
305 if not context or not context.get("check_redundant_datatypes", False):
306 return self
307 if all(getattr(self, f"{dialect}:datatype", None) is not None for dialect in _DIALECTS.keys()):
308 return self
310 datatype = self.datatype
311 length: int | None = self.length or None
313 datatype_func = get_type_func(datatype)
314 felis_type = FelisType.felis_type(datatype)
315 if felis_type.is_sized:
316 datatype_obj = datatype_func(length)
317 else:
318 datatype_obj = datatype_func()
320 for dialect_name, dialect in _DIALECTS.items():
321 db_annotation = f"{dialect_name}_datatype"
322 if datatype_string := self.model_dump().get(db_annotation):
323 db_datatype_obj = string_to_typeengine(datatype_string, dialect, length)
324 if datatype_obj.compile(dialect) == db_datatype_obj.compile(dialect):
325 raise ValueError(
326 "'{}: {}' is a redundant override of 'datatype: {}' in column '{}'{}".format(
327 db_annotation,
328 datatype_string,
329 self.datatype,
330 self.id,
331 "" if length is None else f" with length {length}",
332 )
333 )
334 else:
335 logger.debug(
336 f"Type override of 'datatype: {self.datatype}' "
337 f"with '{db_annotation}: {datatype_string}' in column '{self.id}' "
338 f"compiled to '{datatype_obj.compile(dialect)}' and "
339 f"'{db_datatype_obj.compile(dialect)}'"
340 )
341 return self
344class Constraint(BaseObject):
345 """A database table constraint."""
347 deferrable: bool = False
348 """If `True` then this constraint will be declared as deferrable."""
350 initially: str | None = None
351 """Value for ``INITIALLY`` clause, only used if ``deferrable`` is True."""
353 annotations: Mapping[str, Any] = Field(default_factory=dict)
354 """Additional annotations for this constraint."""
356 type: str | None = Field(None, alias="@type")
357 """The type of the constraint."""
360class CheckConstraint(Constraint):
361 """A check constraint on a table."""
363 expression: str
364 """The expression for the check constraint."""
367class UniqueConstraint(Constraint):
368 """A unique constraint on a table."""
370 columns: list[str]
371 """The columns in the unique constraint."""
374class Index(BaseObject):
375 """A database table index.
377 An index can be defined on either columns or expressions, but not both.
378 """
380 columns: list[str] | None = None
381 """The columns in the index."""
383 expressions: list[str] | None = None
384 """The expressions in the index."""
386 @model_validator(mode="before")
387 @classmethod
388 def check_columns_or_expressions(cls, values: dict[str, Any]) -> dict[str, Any]:
389 """Check that columns or expressions are specified, but not both."""
390 if "columns" in values and "expressions" in values:
391 raise ValueError("Defining columns and expressions is not valid")
392 elif "columns" not in values and "expressions" not in values:
393 raise ValueError("Must define columns or expressions")
394 return values
397class ForeignKeyConstraint(Constraint):
398 """A foreign key constraint on a table.
400 These will be reflected in the TAP_SCHEMA keys and key_columns data.
401 """
403 columns: list[str]
404 """The columns comprising the foreign key."""
406 referenced_columns: list[str] = Field(alias="referencedColumns")
407 """The columns referenced by the foreign key."""
410class Table(BaseObject):
411 """A database table."""
413 columns: Sequence[Column]
414 """The columns in the table."""
416 constraints: list[Constraint] = Field(default_factory=list)
417 """The constraints on the table."""
419 indexes: list[Index] = Field(default_factory=list)
420 """The indexes on the table."""
422 primary_key: str | list[str] | None = Field(None, alias="primaryKey")
423 """The primary key of the table."""
425 tap_table_index: int | None = Field(None, alias="tap:table_index")
426 """The IVOA TAP_SCHEMA table index of the table."""
428 mysql_engine: str | None = Field(None, alias="mysql:engine")
429 """The mysql engine to use for the table.
431 For now this is a freeform string but it could be constrained to a list of
432 known engines in the future.
433 """
435 mysql_charset: str | None = Field(None, alias="mysql:charset")
436 """The mysql charset to use for the table.
438 For now this is a freeform string but it could be constrained to a list of
439 known charsets in the future.
440 """
442 @model_validator(mode="before")
443 @classmethod
444 def create_constraints(cls, values: dict[str, Any]) -> dict[str, Any]:
445 """Create constraints from the ``constraints`` field."""
446 if "constraints" in values:
447 new_constraints: list[Constraint] = []
448 for item in values["constraints"]:
449 if item["@type"] == "ForeignKey":
450 new_constraints.append(ForeignKeyConstraint(**item))
451 elif item["@type"] == "Unique":
452 new_constraints.append(UniqueConstraint(**item))
453 elif item["@type"] == "Check":
454 new_constraints.append(CheckConstraint(**item))
455 else:
456 raise ValueError(f"Unknown constraint type: {item['@type']}")
457 values["constraints"] = new_constraints
458 return values
460 @field_validator("columns", mode="after")
461 @classmethod
462 def check_unique_column_names(cls, columns: list[Column]) -> list[Column]:
463 """Check that column names are unique."""
464 if len(columns) != len(set(column.name for column in columns)):
465 raise ValueError("Column names must be unique")
466 return columns
469class SchemaVersion(BaseModel):
470 """The version of the schema."""
472 current: str
473 """The current version of the schema."""
475 compatible: list[str] = Field(default_factory=list)
476 """The compatible versions of the schema."""
478 read_compatible: list[str] = Field(default_factory=list)
479 """The read compatible versions of the schema."""
482class SchemaIdVisitor:
483 """Visitor to build a Schema object's map of IDs to objects.
485 Duplicates are added to a set when they are encountered, which can be
486 accessed via the `duplicates` attribute. The presence of duplicates will
487 not throw an error. Only the first object with a given ID will be added to
488 the map, but this should not matter, since a ValidationError will be thrown
489 by the `model_validator` method if any duplicates are found in the schema.
491 This class is intended for internal use only.
492 """
494 def __init__(self) -> None:
495 """Create a new SchemaVisitor."""
496 self.schema: Schema | None = None
497 self.duplicates: set[str] = set()
499 def add(self, obj: BaseObject) -> None:
500 """Add an object to the ID map."""
501 if hasattr(obj, "id"):
502 obj_id = getattr(obj, "id")
503 if self.schema is not None:
504 if obj_id in self.schema.id_map:
505 self.duplicates.add(obj_id)
506 else:
507 self.schema.id_map[obj_id] = obj
509 def visit_schema(self, schema: Schema) -> None:
510 """Visit the schema object that was added during initialization.
512 This will set an internal variable pointing to the schema object.
513 """
514 self.schema = schema
515 self.duplicates.clear()
516 self.add(self.schema)
517 for table in self.schema.tables:
518 self.visit_table(table)
520 def visit_table(self, table: Table) -> None:
521 """Visit a table object."""
522 self.add(table)
523 for column in table.columns:
524 self.visit_column(column)
525 for constraint in table.constraints:
526 self.visit_constraint(constraint)
528 def visit_column(self, column: Column) -> None:
529 """Visit a column object."""
530 self.add(column)
532 def visit_constraint(self, constraint: Constraint) -> None:
533 """Visit a constraint object."""
534 self.add(constraint)
537class Schema(BaseObject):
538 """The database schema containing the tables."""
540 version: SchemaVersion | str | None = None
541 """The version of the schema."""
543 tables: Sequence[Table]
544 """The tables in the schema."""
546 id_map: dict[str, Any] = Field(default_factory=dict, exclude=True)
547 """Map of IDs to objects."""
549 @field_validator("tables", mode="after")
550 @classmethod
551 def check_unique_table_names(cls, tables: list[Table]) -> list[Table]:
552 """Check that table names are unique."""
553 if len(tables) != len(set(table.name for table in tables)):
554 raise ValueError("Table names must be unique")
555 return tables
557 def _create_id_map(self: Schema) -> Schema:
558 """Create a map of IDs to objects.
560 This method should not be called by users. It is called automatically
561 by the ``model_post_init()`` method. If the ID map is already
562 populated, this method will return immediately.
563 """
564 if len(self.id_map):
565 logger.debug("Ignoring call to create_id_map() - ID map was already populated")
566 return self
567 visitor: SchemaIdVisitor = SchemaIdVisitor()
568 visitor.visit_schema(self)
569 logger.debug(f"Created schema ID map with {len(self.id_map.keys())} objects")
570 if len(visitor.duplicates):
571 raise ValueError(
572 "Duplicate IDs found in schema:\n " + "\n ".join(visitor.duplicates) + "\n"
573 )
574 return self
576 def model_post_init(self, ctx: Any) -> None:
577 """Post-initialization hook for the model."""
578 self._create_id_map()
580 def __getitem__(self, id: str) -> BaseObject:
581 """Get an object by its ID."""
582 if id not in self:
583 raise KeyError(f"Object with ID '{id}' not found in schema")
584 return self.id_map[id]
586 def __contains__(self, id: str) -> bool:
587 """Check if an object with the given ID is in the schema."""
588 return id in self.id_map