Coverage for python / felis / datamodel.py: 32%
580 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:37 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:37 +0000
1"""Define Pydantic data models for Felis."""
3# This file is part of felis.
4#
5# Developed for the LSST Data Management System.
6# This product includes software developed by the LSST Project
7# (https://www.lsst.org).
8# See the COPYRIGHT file at the top-level directory of this distribution
9# for details of code ownership.
10#
11# This program is free software: you can redistribute it and/or modify
12# it under the terms of the GNU General Public License as published by
13# the Free Software Foundation, either version 3 of the License, or
14# (at your option) any later version.
15#
16# This program is distributed in the hope that it will be useful,
17# but WITHOUT ANY WARRANTY; without even the implied warranty of
18# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19# GNU General Public License for more details.
20#
21# You should have received a copy of the GNU General Public License
22# along with this program. If not, see <https://www.gnu.org/licenses/>.
24from __future__ import annotations
26import json
27import logging
28import sys
29from collections.abc import Sequence
30from enum import StrEnum, auto
31from typing import IO, Annotated, Any, Generic, Literal, TypeAlias, TypeVar
33import yaml
34from astropy import units as units # type: ignore
35from astropy.io.votable import ucd # type: ignore
36from lsst.resources import ResourcePath, ResourcePathExpression
37from pydantic import (
38 BaseModel,
39 ConfigDict,
40 Field,
41 PrivateAttr,
42 ValidationError,
43 ValidationInfo,
44 field_serializer,
45 field_validator,
46 model_validator,
47)
48from pydantic_core import InitErrorDetails
50from .db._dialects import get_supported_dialects, string_to_typeengine
51from .db._sqltypes import get_type_func
52from .types import Boolean, Byte, Char, Double, FelisType, Float, Int, Long, Short, String, Text, Unicode
54logger = logging.getLogger(__name__)
56__all__ = (
57 "BaseObject",
58 "CheckConstraint",
59 "Column",
60 "Constraint",
61 "DataType",
62 "ForeignKeyConstraint",
63 "Index",
64 "Schema",
65 "SchemaVersion",
66 "Table",
67 "UniqueConstraint",
68)
70CONFIG = ConfigDict(
71 populate_by_name=True, # Populate attributes by name.
72 extra="forbid", # Do not allow extra fields.
73 str_strip_whitespace=True, # Strip whitespace from string fields.
74 use_enum_values=False, # Do not use enum values during serialization.
75)
76"""Pydantic model configuration as described in:
77https://docs.pydantic.dev/2.0/api/config/#pydantic.config.ConfigDict
78"""
80DESCR_MIN_LENGTH = 3
81"""Minimum length for a description field."""
83DescriptionStr: TypeAlias = Annotated[str, Field(min_length=DESCR_MIN_LENGTH)]
84"""Type for a description, which must be three or more characters long."""
87class BaseObject(BaseModel):
88 """Base model.
90 All classes representing objects in the Felis data model should inherit
91 from this class.
92 """
94 model_config = CONFIG
95 """Pydantic model configuration."""
97 name: str
98 """Name of the database object."""
100 id: str = Field(alias="@id")
101 """Unique identifier of the database object."""
103 description: DescriptionStr | None = None
104 """Description of the database object."""
106 votable_utype: str | None = Field(None, alias="votable:utype")
107 """VOTable utype (usage-specific or unique type) of the object."""
109 @model_validator(mode="after")
110 def check_description(self, info: ValidationInfo) -> BaseObject:
111 """Check that the description is present if required.
113 Parameters
114 ----------
115 info
116 Validation context used to determine if the check is enabled.
118 Returns
119 -------
120 `BaseObject`
121 The object being validated.
122 """
123 context = info.context
124 if not context or not context.get("check_description", False):
125 return self
126 if self.description is None or self.description == "":
127 raise ValueError("Description is required and must be non-empty")
128 if len(self.description) < DESCR_MIN_LENGTH:
129 raise ValueError(f"Description must be at least {DESCR_MIN_LENGTH} characters long")
130 return self
133class DataType(StrEnum):
134 """``Enum`` representing the data types supported by Felis."""
136 boolean = auto()
137 byte = auto()
138 short = auto()
139 int = auto()
140 long = auto()
141 float = auto()
142 double = auto()
143 char = auto()
144 string = auto()
145 unicode = auto()
146 text = auto()
147 binary = auto()
148 timestamp = auto()
151def validate_ivoa_ucd(ivoa_ucd: str) -> str:
152 """Validate IVOA UCD values.
154 Parameters
155 ----------
156 ivoa_ucd
157 IVOA UCD value to check.
159 Returns
160 -------
161 `str`
162 The IVOA UCD value if it is valid.
164 Raises
165 ------
166 ValueError
167 If the IVOA UCD value is invalid.
168 """
169 if ivoa_ucd is not None:
170 try:
171 ucd.parse_ucd(ivoa_ucd, check_controlled_vocabulary=True, has_colon=";" in ivoa_ucd)
172 except ValueError as e:
173 raise ValueError(f"Invalid IVOA UCD: {e}")
174 return ivoa_ucd
177class Column(BaseObject):
178 """Column model."""
180 datatype: DataType
181 """Datatype of the column."""
183 length: int | None = Field(None, gt=0)
184 """Length of the column."""
186 precision: int | None = Field(None, ge=0)
187 """The numerical precision of the column.
189 For timestamps, this is the number of fractional digits retained in the
190 seconds field.
191 """
193 nullable: bool = True
194 """Whether the column can be ``NULL``."""
196 value: str | int | float | bool | None = None
197 """Default value of the column."""
199 autoincrement: bool | None = None
200 """Whether the column is autoincremented."""
202 ivoa_ucd: str | None = Field(None, alias="ivoa:ucd")
203 """IVOA UCD of the column."""
205 fits_tunit: str | None = Field(None, alias="fits:tunit")
206 """FITS TUNIT of the column."""
208 ivoa_unit: str | None = Field(None, alias="ivoa:unit")
209 """IVOA unit of the column."""
211 tap_column_index: int | None = Field(None, alias="tap:column_index")
212 """TAP_SCHEMA column index of the column."""
214 tap_principal: int | None = Field(0, alias="tap:principal", ge=0, le=1)
215 """Whether this is a TAP_SCHEMA principal column."""
217 votable_arraysize: int | str | None = Field(None, alias="votable:arraysize")
218 """VOTable arraysize of the column."""
220 tap_std: int | None = Field(0, alias="tap:std", ge=0, le=1)
221 """TAP_SCHEMA indication that this column is defined by an IVOA standard.
222 """
224 votable_xtype: str | None = Field(None, alias="votable:xtype")
225 """VOTable xtype (extended type) of the column."""
227 votable_datatype: str | None = Field(None, alias="votable:datatype")
228 """VOTable datatype of the column."""
230 mysql_datatype: str | None = Field(None, alias="mysql:datatype")
231 """MySQL datatype override on the column."""
233 postgresql_datatype: str | None = Field(None, alias="postgresql:datatype")
234 """PostgreSQL datatype override on the column."""
236 @model_validator(mode="after")
237 def check_value(self) -> Column:
238 """Check that the default value is valid.
240 Returns
241 -------
242 `Column`
243 The column being validated.
244 """
245 if (value := self.value) is not None:
246 if value is not None and self.autoincrement is True:
247 raise ValueError("Column cannot have both a default value and be autoincremented")
248 felis_type = FelisType.felis_type(self.datatype)
249 if felis_type.is_numeric:
250 if felis_type in (Byte, Short, Int, Long) and not isinstance(value, int):
251 raise ValueError("Default value must be an int for integer type columns")
252 elif felis_type in (Float, Double) and not isinstance(value, float):
253 raise ValueError("Default value must be a decimal number for float and double columns")
254 elif felis_type in (String, Char, Unicode, Text):
255 if not isinstance(value, str):
256 raise ValueError("Default value must be a string for string columns")
257 if not len(value):
258 raise ValueError("Default value must be a non-empty string for string columns")
259 elif felis_type is Boolean and not isinstance(value, bool):
260 raise ValueError("Default value must be a boolean for boolean columns")
261 return self
263 @field_validator("ivoa_ucd")
264 @classmethod
265 def check_ivoa_ucd(cls, ivoa_ucd: str) -> str:
266 """Check that IVOA UCD values are valid.
268 Parameters
269 ----------
270 ivoa_ucd
271 IVOA UCD value to check.
273 Returns
274 -------
275 `str`
276 The IVOA UCD value if it is valid.
277 """
278 return validate_ivoa_ucd(ivoa_ucd)
280 @model_validator(mode="after")
281 def check_units(self) -> Column:
282 """Check that the ``fits:tunit`` or ``ivoa:unit`` field has valid
283 units according to astropy. Only one may be provided.
285 Returns
286 -------
287 `Column`
288 The column being validated.
290 Raises
291 ------
292 ValueError
293 Raised if both FITS and IVOA units are provided, or if the unit is
294 invalid.
295 """
296 fits_unit = self.fits_tunit
297 ivoa_unit = self.ivoa_unit
299 if fits_unit and ivoa_unit:
300 raise ValueError("Column cannot have both FITS and IVOA units")
301 unit = fits_unit or ivoa_unit
303 if unit is not None:
304 try:
305 units.Unit(unit)
306 except ValueError as e:
307 raise ValueError(f"Invalid unit: {e}")
309 return self
311 @model_validator(mode="before")
312 @classmethod
313 def check_length(cls, values: dict[str, Any]) -> dict[str, Any]:
314 """Check that a valid length is provided for sized types.
316 Parameters
317 ----------
318 values
319 Values of the column.
321 Returns
322 -------
323 `dict` [ `str`, `Any` ]
324 The values of the column.
326 Raises
327 ------
328 ValueError
329 Raised if a length is not provided for a sized type.
330 """
331 datatype = values.get("datatype")
332 if datatype is None:
333 # Skip this validation if datatype is not provided
334 return values
335 length = values.get("length")
336 felis_type = FelisType.felis_type(datatype)
337 if felis_type.is_sized and length is None:
338 raise ValueError(
339 f"Length must be provided for type '{datatype}'"
340 + (f" in column '{values['@id']}'" if "@id" in values else "")
341 )
342 elif not felis_type.is_sized and length is not None:
343 logger.warning(
344 f"The datatype '{datatype}' does not support a specified length"
345 + (f" in column '{values['@id']}'" if "@id" in values else "")
346 )
347 return values
349 @model_validator(mode="after")
350 def check_redundant_datatypes(self, info: ValidationInfo) -> Column:
351 """Check for redundant datatypes on columns.
353 Parameters
354 ----------
355 info
356 Validation context used to determine if the check is enabled.
358 Returns
359 -------
360 `Column`
361 The column being validated.
363 Raises
364 ------
365 ValueError
366 Raised if a datatype override is redundant.
367 """
368 context = info.context
369 if not context or not context.get("check_redundant_datatypes", False):
370 return self
371 if all(
372 getattr(self, f"{dialect}:datatype", None) is not None
373 for dialect in get_supported_dialects().keys()
374 ):
375 return self
377 datatype = self.datatype
378 length: int | None = self.length or None
380 datatype_func = get_type_func(datatype)
381 felis_type = FelisType.felis_type(datatype)
382 if felis_type.is_sized:
383 datatype_obj = datatype_func(length)
384 else:
385 datatype_obj = datatype_func()
387 for dialect_name, dialect in get_supported_dialects().items():
388 db_annotation = f"{dialect_name}_datatype"
389 if datatype_string := self.model_dump().get(db_annotation):
390 db_datatype_obj = string_to_typeengine(datatype_string, dialect, length)
391 if datatype_obj.compile(dialect) == db_datatype_obj.compile(dialect):
392 raise ValueError(
393 "'{}: {}' is a redundant override of 'datatype: {}' in column '{}'{}".format(
394 db_annotation,
395 datatype_string,
396 self.datatype,
397 self.id,
398 "" if length is None else f" with length {length}",
399 )
400 )
401 else:
402 logger.debug(
403 f"Type override of 'datatype: {self.datatype}' "
404 f"with '{db_annotation}: {datatype_string}' in column '{self.id}' "
405 f"compiled to '{datatype_obj.compile(dialect)}' and "
406 f"'{db_datatype_obj.compile(dialect)}'"
407 )
408 return self
410 @model_validator(mode="after")
411 def check_precision(self) -> Column:
412 """Check that precision is only valid for timestamp columns.
414 Returns
415 -------
416 `Column`
417 The column being validated.
418 """
419 if self.precision is not None and self.datatype != "timestamp":
420 raise ValueError("Precision is only valid for timestamp columns")
421 return self
423 @model_validator(mode="before")
424 @classmethod
425 def check_votable_arraysize(cls, values: dict[str, Any], info: ValidationInfo) -> dict[str, Any]:
426 """Set the default value for the ``votable_arraysize`` field, which
427 corresponds to ``arraysize`` in the IVOA VOTable standard.
429 Parameters
430 ----------
431 values
432 Values of the column.
433 info
434 Validation context used to determine if the check is enabled.
436 Returns
437 -------
438 `dict` [ `str`, `Any` ]
439 The values of the column.
441 Notes
442 -----
443 Following the IVOA VOTable standard, an ``arraysize`` of 1 should not
444 be used.
445 """
446 if values.get("name", None) is None or values.get("datatype", None) is None:
447 # Skip bad column data that will not validate
448 return values
449 context = info.context if info.context else {}
450 arraysize = values.get("votable:arraysize", None)
451 if arraysize is None:
452 length = values.get("length", None)
453 datatype = values.get("datatype")
454 if length is not None and length > 1:
455 # Following the IVOA standard, arraysize of 1 is disallowed
456 if datatype == "char":
457 arraysize = str(length)
458 elif datatype in ("string", "unicode", "binary"):
459 if context.get("force_unbounded_arraysize", False):
460 arraysize = "*"
461 logger.debug(
462 f"Forced VOTable's 'arraysize' to '*' on column '{values['name']}' with datatype "
463 + f"'{values['datatype']}' and length '{length}'"
464 )
465 else:
466 arraysize = f"{length}*"
467 elif datatype in ("timestamp", "text"):
468 arraysize = "*"
469 if arraysize is not None:
470 values["votable:arraysize"] = arraysize
471 logger.debug(
472 f"Set default 'votable:arraysize' to '{arraysize}' on column '{values['name']}'"
473 + f" with datatype '{values['datatype']}' and length '{values.get('length', None)}'"
474 )
475 else:
476 logger.debug(f"Using existing 'votable:arraysize' of '{arraysize}' on column '{values['name']}'")
477 if isinstance(values["votable:arraysize"], int):
478 logger.warning(
479 f"Usage of an integer value for 'votable:arraysize' in column '{values['name']}' is "
480 + "deprecated"
481 )
482 values["votable:arraysize"] = str(arraysize)
483 return values
485 @field_serializer("datatype")
486 def serialize_datatype(self, value: DataType) -> str:
487 """Convert `DataType` to string when serializing to JSON/YAML.
489 Parameters
490 ----------
491 value
492 The `DataType` value to serialize.
494 Returns
495 -------
496 `str`
497 The serialized `DataType` value.
498 """
499 return str(value)
501 @field_validator("datatype", mode="before")
502 @classmethod
503 def deserialize_datatype(cls, value: str) -> DataType:
504 """Convert string back into `DataType` when loading from JSON/YAML.
506 Parameters
507 ----------
508 value
509 The string value to deserialize.
511 Returns
512 -------
513 `DataType`
514 The deserialized `DataType` value.
515 """
516 return DataType(value)
518 @model_validator(mode="after")
519 def check_votable_xtype(self) -> Column:
520 """Set the default value for the ``votable_xtype`` field, which
521 corresponds to an Extended Datatype or ``xtype`` in the IVOA VOTable
522 standard.
524 Returns
525 -------
526 `Column`
527 The column being validated.
529 Notes
530 -----
531 This is currently only set automatically for the Felis ``timestamp``
532 datatype.
533 """
534 if self.datatype == DataType.timestamp and self.votable_xtype is None:
535 self.votable_xtype = "timestamp"
536 return self
539class Constraint(BaseObject):
540 """Table constraint model."""
542 deferrable: bool = False
543 """Whether this constraint will be declared as deferrable."""
545 initially: Literal["IMMEDIATE", "DEFERRED"] | None = None
546 """Value for ``INITIALLY`` clause; only used if `deferrable` is
547 `True`."""
549 @model_validator(mode="after")
550 def check_deferrable(self) -> Constraint:
551 """Check that the ``INITIALLY`` clause is only used if `deferrable` is
552 `True`.
554 Returns
555 -------
556 `Constraint`
557 The constraint being validated.
558 """
559 if self.initially is not None and not self.deferrable:
560 raise ValueError("INITIALLY clause can only be used if deferrable is True")
561 return self
564class CheckConstraint(Constraint):
565 """Table check constraint model."""
567 type: Literal["Check"] = Field("Check", alias="@type")
568 """Type of the constraint."""
570 expression: str
571 """Expression for the check constraint."""
573 @field_serializer("type")
574 def serialize_type(self, value: str) -> str:
575 """Ensure '@type' is included in serialized output.
577 Parameters
578 ----------
579 value
580 The value to serialize.
582 Returns
583 -------
584 `str`
585 The serialized value.
586 """
587 return value
590class UniqueConstraint(Constraint):
591 """Table unique constraint model."""
593 type: Literal["Unique"] = Field("Unique", alias="@type")
594 """Type of the constraint."""
596 columns: list[str]
597 """Columns in the unique constraint."""
599 @field_serializer("type")
600 def serialize_type(self, value: str) -> str:
601 """Ensure '@type' is included in serialized output.
603 Parameters
604 ----------
605 value
606 The value to serialize.
608 Returns
609 -------
610 `str`
611 The serialized value.
612 """
613 return value
616class ForeignKeyConstraint(Constraint):
617 """Table foreign key constraint model.
619 This constraint is used to define a foreign key relationship between two
620 tables in the schema. There must be at least one column in the
621 `columns` list, and at least one column in the `referenced_columns` list
622 or a validation error will be raised.
624 Notes
625 -----
626 These relationships will be reflected in the TAP_SCHEMA ``keys`` and
627 ``key_columns`` data.
628 """
630 type: Literal["ForeignKey"] = Field("ForeignKey", alias="@type")
631 """Type of the constraint."""
633 columns: list[str] = Field(min_length=1)
634 """The columns comprising the foreign key."""
636 referenced_columns: list[str] = Field(alias="referencedColumns", min_length=1)
637 """The columns referenced by the foreign key."""
639 on_delete: Literal["CASCADE", "SET NULL", "SET DEFAULT", "RESTRICT", "NO ACTION"] | None = None
640 """Action to take when the referenced row is deleted."""
642 on_update: Literal["CASCADE", "SET NULL", "SET DEFAULT", "RESTRICT", "NO ACTION"] | None = None
643 """Action to take when the referenced row is updated."""
645 @field_serializer("type")
646 def serialize_type(self, value: str) -> str:
647 """Ensure '@type' is included in serialized output.
649 Parameters
650 ----------
651 value
652 The value to serialize.
654 Returns
655 -------
656 `str`
657 The serialized value.
658 """
659 return value
661 @model_validator(mode="after")
662 def check_column_lengths(self) -> ForeignKeyConstraint:
663 """Check that the `columns` and `referenced_columns` lists have the
664 same length.
666 Returns
667 -------
668 `ForeignKeyConstraint`
669 The foreign key constraint being validated.
671 Raises
672 ------
673 ValueError
674 Raised if the `columns` and `referenced_columns` lists do not have
675 the same length.
676 """
677 if len(self.columns) != len(self.referenced_columns):
678 raise ValueError(
679 "Columns and referencedColumns must have the same length for a ForeignKey constraint"
680 )
681 return self
684_ConstraintType = Annotated[
685 CheckConstraint | ForeignKeyConstraint | UniqueConstraint, Field(discriminator="type")
686]
687"""Type alias for a constraint type."""
690class Index(BaseObject):
691 """Table index model.
693 An index can be defined on either columns or expressions, but not both.
694 """
696 columns: list[str] | None = None
697 """Columns in the index."""
699 expressions: list[str] | None = None
700 """Expressions in the index."""
702 @model_validator(mode="before")
703 @classmethod
704 def check_columns_or_expressions(cls, values: dict[str, Any]) -> dict[str, Any]:
705 """Check that columns or expressions are specified, but not both.
707 Parameters
708 ----------
709 values
710 Values of the index.
712 Returns
713 -------
714 `dict` [ `str`, `Any` ]
715 The values of the index.
717 Raises
718 ------
719 ValueError
720 Raised if both columns and expressions are specified, or if neither
721 are specified.
722 """
723 if "columns" in values and "expressions" in values:
724 raise ValueError("Defining columns and expressions is not valid")
725 elif "columns" not in values and "expressions" not in values:
726 raise ValueError("Must define columns or expressions")
727 return values
730ColumnRef: TypeAlias = str
731"""Type alias for a column reference."""
734class ColumnGroup(BaseObject):
735 """Column group model."""
737 columns: list[ColumnRef | Column] = Field(..., min_length=1)
738 """Columns in the group."""
740 ivoa_ucd: str | None = Field(None, alias="ivoa:ucd")
741 """IVOA UCD of the column."""
743 table: Table | None = Field(None, exclude=True)
744 """Reference to the parent table."""
746 @field_validator("ivoa_ucd")
747 @classmethod
748 def check_ivoa_ucd(cls, ivoa_ucd: str) -> str:
749 """Check that IVOA UCD values are valid.
751 Parameters
752 ----------
753 ivoa_ucd
754 IVOA UCD value to check.
756 Returns
757 -------
758 `str`
759 The IVOA UCD value if it is valid.
760 """
761 return validate_ivoa_ucd(ivoa_ucd)
763 @model_validator(mode="after")
764 def check_unique_columns(self) -> ColumnGroup:
765 """Check that the columns list contains unique items.
767 Returns
768 -------
769 `ColumnGroup`
770 The column group being validated.
771 """
772 column_ids = [col if isinstance(col, str) else col.id for col in self.columns]
773 if len(column_ids) != len(set(column_ids)):
774 raise ValueError("Columns in the group must be unique")
775 return self
777 def _dereference_columns(self) -> None:
778 """Dereference ColumnRef to Column objects."""
779 if self.table is None:
780 raise ValueError("ColumnGroup must have a reference to its parent table")
782 dereferenced_columns: list[ColumnRef | Column] = []
783 for col in self.columns:
784 if isinstance(col, str):
785 # Dereference ColumnRef to Column object
786 try:
787 col_obj = self.table._find_column_by_id(col)
788 except KeyError as e:
789 raise ValueError(f"Column '{col}' not found in table '{self.table.name}'") from e
790 dereferenced_columns.append(col_obj)
791 else:
792 dereferenced_columns.append(col)
794 self.columns = dereferenced_columns
796 @field_serializer("columns")
797 def serialize_columns(self, columns: list[ColumnRef | Column]) -> list[str]:
798 """Serialize columns as their IDs.
800 Parameters
801 ----------
802 columns
803 The columns to serialize.
805 Returns
806 -------
807 `list` [ `str` ]
808 The serialized column IDs.
809 """
810 return [col if isinstance(col, str) else col.id for col in columns]
813class Table(BaseObject):
814 """Table model."""
816 primary_key: str | list[str] | None = Field(None, alias="primaryKey")
817 """Primary key of the table."""
819 tap_table_index: int | None = Field(None, alias="tap:table_index")
820 """IVOA TAP_SCHEMA table index of the table."""
822 mysql_engine: str | None = Field("MyISAM", alias="mysql:engine")
823 """MySQL engine to use for the table."""
825 mysql_charset: str | None = Field(None, alias="mysql:charset")
826 """MySQL charset to use for the table."""
828 columns: Sequence[Column]
829 """Columns in the table."""
831 column_groups: list[ColumnGroup] = Field(default_factory=list, alias="columnGroups")
832 """Column groups in the table."""
834 constraints: list[_ConstraintType] = Field(default_factory=list)
835 """Constraints on the table."""
837 indexes: list[Index] = Field(default_factory=list)
838 """Indexes on the table."""
840 @field_validator("columns", mode="after")
841 @classmethod
842 def check_unique_column_names(cls, columns: list[Column]) -> list[Column]:
843 """Check that column names are unique.
845 Parameters
846 ----------
847 columns
848 The columns to check.
850 Returns
851 -------
852 `list` [ `Column` ]
853 The columns if they are unique.
855 Raises
856 ------
857 ValueError
858 Raised if column names are not unique.
859 """
860 if len(columns) != len(set(column.name for column in columns)):
861 raise ValueError("Column names must be unique")
862 return columns
864 @model_validator(mode="after")
865 def check_tap_table_index(self, info: ValidationInfo) -> Table:
866 """Check that the table has a TAP table index.
868 Parameters
869 ----------
870 info
871 Validation context used to determine if the check is enabled.
873 Returns
874 -------
875 `Table`
876 The table being validated.
878 Raises
879 ------
880 ValueError
881 Raised If the table is missing a TAP table index.
882 """
883 context = info.context
884 if not context or not context.get("check_tap_table_indexes", False):
885 return self
886 if self.tap_table_index is None:
887 raise ValueError("Table is missing a TAP table index")
888 return self
890 @model_validator(mode="after")
891 def check_tap_principal(self, info: ValidationInfo) -> Table:
892 """Check that at least one column is flagged as 'principal' for TAP
893 purposes.
895 Parameters
896 ----------
897 info
898 Validation context used to determine if the check is enabled.
900 Returns
901 -------
902 `Table`
903 The table being validated.
905 Raises
906 ------
907 ValueError
908 Raised if the table is missing a column flagged as 'principal'.
909 """
910 context = info.context
911 if not context or not context.get("check_tap_principal", False):
912 return self
913 for col in self.columns:
914 if col.tap_principal == 1:
915 return self
916 raise ValueError(f"Table '{self.name}' is missing at least one column designated as 'tap:principal'")
918 def _find_column_by_id(self, id: str) -> Column:
919 """Find a column by ID.
921 Parameters
922 ----------
923 id
924 The ID of the column to find.
926 Returns
927 -------
928 `Column`
929 The column with the given ID.
931 Raises
932 ------
933 ValueError
934 Raised if the column is not found.
935 """
936 for column in self.columns:
937 if column.id == id:
938 return column
939 raise KeyError(f"Column '{id}' not found in table '{self.name}'")
941 @model_validator(mode="after")
942 def dereference_column_groups(self: Table) -> Table:
943 """Dereference columns in column groups.
945 Returns
946 -------
947 `Table`
948 The table with dereferenced column groups.
949 """
950 for group in self.column_groups:
951 group.table = self
952 group._dereference_columns()
953 return self
956class SchemaVersion(BaseModel):
957 """Schema version model."""
959 current: str
960 """The current version of the schema."""
962 compatible: list[str] = Field(default_factory=list)
963 """The compatible versions of the schema."""
965 read_compatible: list[str] = Field(default_factory=list)
966 """The read compatible versions of the schema."""
969class SchemaIdVisitor:
970 """Visit a schema and build the map of IDs to objects.
972 Notes
973 -----
974 Duplicates are added to a set when they are encountered, which can be
975 accessed via the ``duplicates`` attribute. The presence of duplicates will
976 not throw an error. Only the first object with a given ID will be added to
977 the map, but this should not matter, since a ``ValidationError`` will be
978 thrown by the ``model_validator`` method if any duplicates are found in the
979 schema.
980 """
982 def __init__(self) -> None:
983 """Create a new SchemaVisitor."""
984 self.schema: Schema | None = None
985 self.duplicates: set[str] = set()
987 def add(self, obj: BaseObject) -> None:
988 """Add an object to the ID map.
990 Parameters
991 ----------
992 obj
993 The object to add to the ID map.
994 """
995 if hasattr(obj, "id"):
996 obj_id = getattr(obj, "id")
997 if self.schema is not None:
998 if obj_id in self.schema._id_map:
999 self.duplicates.add(obj_id)
1000 else:
1001 self.schema._id_map[obj_id] = obj
1003 def visit_schema(self, schema: Schema) -> None:
1004 """Visit the objects in a schema and build the ID map.
1006 Parameters
1007 ----------
1008 schema
1009 The schema object to visit.
1011 Notes
1012 -----
1013 This will set an internal variable pointing to the schema object.
1014 """
1015 self.schema = schema
1016 self.duplicates.clear()
1017 self.add(self.schema)
1018 for table in self.schema.tables:
1019 self.visit_table(table)
1021 def visit_table(self, table: Table) -> None:
1022 """Visit a table object.
1024 Parameters
1025 ----------
1026 table
1027 The table object to visit.
1028 """
1029 self.add(table)
1030 for column in table.columns:
1031 self.visit_column(column)
1032 for constraint in table.constraints:
1033 self.visit_constraint(constraint)
1035 def visit_column(self, column: Column) -> None:
1036 """Visit a column object.
1038 Parameters
1039 ----------
1040 column
1041 The column object to visit.
1042 """
1043 self.add(column)
1045 def visit_constraint(self, constraint: Constraint) -> None:
1046 """Visit a constraint object.
1048 Parameters
1049 ----------
1050 constraint
1051 The constraint object to visit.
1052 """
1053 self.add(constraint)
1056T = TypeVar("T", bound=BaseObject)
1059def _strip_ids(data: Any) -> Any:
1060 """Recursively strip '@id' fields from a dictionary or list.
1062 Parameters
1063 ----------
1064 data
1065 The data to strip IDs from, which can be a dictionary, list, or any
1066 other type. Other types will be returned unchanged.
1067 """
1068 if isinstance(data, dict):
1069 data.pop("@id", None)
1070 for k, v in data.items():
1071 data[k] = _strip_ids(v)
1072 return data
1073 elif isinstance(data, list):
1074 return [_strip_ids(item) for item in data]
1075 else:
1076 return data
1079def _append_error(
1080 errors: list[InitErrorDetails],
1081 loc: tuple,
1082 input_value: Any,
1083 error_message: str,
1084 error_type: str = "value_error",
1085) -> None:
1086 """Append an error to the errors list.
1088 Parameters
1089 ----------
1090 errors : list[InitErrorDetails]
1091 The list of errors to append to.
1092 loc : tuple
1093 The location of the error in the schema.
1094 input_value : Any
1095 The input value that caused the error.
1096 error_message : str
1097 The error message to include in the context.
1098 """
1099 errors.append(
1100 {
1101 "type": error_type,
1102 "loc": loc,
1103 "input": input_value,
1104 "ctx": {"error": error_message},
1105 }
1106 )
1109class Schema(BaseObject, Generic[T]):
1110 """Database schema model.
1112 This represents a database schema, which contains one or more tables.
1113 """
1115 version: SchemaVersion | str | None = None
1116 """The version of the schema."""
1118 tables: Sequence[Table]
1119 """The tables in the schema."""
1121 _id_map: dict[str, Any] = PrivateAttr(default_factory=dict)
1122 """Map of IDs to objects."""
1124 @model_validator(mode="before")
1125 @classmethod
1126 def generate_ids(cls, values: dict[str, Any], info: ValidationInfo) -> dict[str, Any]:
1127 """Generate IDs for objects that do not have them.
1129 Parameters
1130 ----------
1131 values
1132 The values of the schema.
1133 info
1134 Validation context used to determine if ID generation is enabled.
1136 Returns
1137 -------
1138 `dict` [ `str`, `Any` ]
1139 The values of the schema with generated IDs.
1140 """
1141 context = info.context
1142 if not context or not context.get("id_generation", False):
1143 logger.debug("Skipping ID generation")
1144 return values
1145 schema_name = values["name"]
1146 if "@id" not in values:
1147 values["@id"] = f"#{schema_name}"
1148 logger.debug(f"Generated ID '{values['@id']}' for schema '{schema_name}'")
1149 if "tables" in values:
1150 for table in values["tables"]:
1151 if "@id" not in table:
1152 table["@id"] = f"#{table['name']}"
1153 logger.debug(f"Generated ID '{table['@id']}' for table '{table['name']}'")
1154 if "columns" in table:
1155 for column in table["columns"]:
1156 if "@id" not in column:
1157 column["@id"] = f"#{table['name']}.{column['name']}"
1158 logger.debug(f"Generated ID '{column['@id']}' for column '{column['name']}'")
1159 if "columnGroups" in table:
1160 for column_group in table["columnGroups"]:
1161 if "@id" not in column_group:
1162 column_group["@id"] = f"#{table['name']}.{column_group['name']}"
1163 logger.debug(
1164 f"Generated ID '{column_group['@id']}' for column group "
1165 f"'{column_group['name']}'"
1166 )
1167 if "constraints" in table:
1168 for constraint in table["constraints"]:
1169 if "@id" not in constraint:
1170 constraint["@id"] = f"#{constraint['name']}"
1171 logger.debug(
1172 f"Generated ID '{constraint['@id']}' for constraint '{constraint['name']}'"
1173 )
1174 if "indexes" in table:
1175 for index in table["indexes"]:
1176 if "@id" not in index:
1177 index["@id"] = f"#{index['name']}"
1178 logger.debug(f"Generated ID '{index['@id']}' for index '{index['name']}'")
1179 return values
1181 @field_validator("tables", mode="after")
1182 @classmethod
1183 def check_unique_table_names(cls, tables: list[Table]) -> list[Table]:
1184 """Check that table names are unique.
1186 Parameters
1187 ----------
1188 tables
1189 The tables to check.
1191 Returns
1192 -------
1193 `list` [ `Table` ]
1194 The tables if they are unique.
1196 Raises
1197 ------
1198 ValueError
1199 Raised if table names are not unique.
1200 """
1201 if len(tables) != len(set(table.name for table in tables)):
1202 raise ValueError("Table names must be unique")
1203 return tables
1205 @model_validator(mode="after")
1206 def check_tap_table_indexes(self, info: ValidationInfo) -> Schema:
1207 """Check that the TAP table indexes are unique.
1209 Parameters
1210 ----------
1211 info
1212 The validation context used to determine if the check is enabled.
1214 Returns
1215 -------
1216 `Schema`
1217 The schema being validated.
1218 """
1219 context = info.context
1220 if not context or not context.get("check_tap_table_indexes", False):
1221 return self
1222 table_indicies = set()
1223 for table in self.tables:
1224 table_index = table.tap_table_index
1225 if table_index is not None:
1226 if table_index in table_indicies:
1227 raise ValueError(f"Duplicate 'tap:table_index' value {table_index} found in schema")
1228 table_indicies.add(table_index)
1229 return self
1231 @model_validator(mode="after")
1232 def check_unique_constraint_names(self: Schema) -> Schema:
1233 """Check for duplicate constraint names in the schema.
1235 Returns
1236 -------
1237 `Schema`
1238 The schema being validated.
1240 Raises
1241 ------
1242 ValueError
1243 Raised if duplicate constraint names are found in the schema.
1244 """
1245 constraint_names = set()
1246 duplicate_names = []
1248 for table in self.tables:
1249 for constraint in table.constraints:
1250 constraint_name = constraint.name
1251 if constraint_name in constraint_names:
1252 duplicate_names.append(constraint_name)
1253 else:
1254 constraint_names.add(constraint_name)
1256 if duplicate_names:
1257 raise ValueError(f"Duplicate constraint names found in schema: {duplicate_names}")
1259 return self
1261 @model_validator(mode="after")
1262 def check_unique_index_names(self: Schema) -> Schema:
1263 """Check for duplicate index names in the schema.
1265 Returns
1266 -------
1267 `Schema`
1268 The schema being validated.
1270 Raises
1271 ------
1272 ValueError
1273 Raised if duplicate index names are found in the schema.
1274 """
1275 index_names = set()
1276 duplicate_names = []
1278 for table in self.tables:
1279 for index in table.indexes:
1280 index_name = index.name
1281 if index_name in index_names:
1282 duplicate_names.append(index_name)
1283 else:
1284 index_names.add(index_name)
1286 if duplicate_names:
1287 raise ValueError(f"Duplicate index names found in schema: {duplicate_names}")
1289 return self
1291 @model_validator(mode="after")
1292 def create_id_map(self: Schema) -> Schema:
1293 """Create a map of IDs to objects.
1295 Returns
1296 -------
1297 `Schema`
1298 The schema with the ID map created.
1300 Raises
1301 ------
1302 ValueError
1303 Raised if duplicate identifiers are found in the schema.
1304 """
1305 if self._id_map:
1306 logger.debug("Ignoring call to create_id_map() - ID map was already populated")
1307 return self
1308 visitor: SchemaIdVisitor = SchemaIdVisitor()
1309 visitor.visit_schema(self)
1310 if len(visitor.duplicates):
1311 raise ValueError(
1312 "Duplicate IDs found in schema:\n " + "\n ".join(visitor.duplicates) + "\n"
1313 )
1314 logger.debug("Created ID map with %d entries", len(self._id_map))
1315 return self
1317 def _validate_column_id(
1318 self: Schema,
1319 column_id: str,
1320 loc: tuple,
1321 errors: list[InitErrorDetails],
1322 ) -> None:
1323 """Validate a column ID from a constraint and append errors if invalid.
1325 Parameters
1326 ----------
1327 schema : Schema
1328 The schema being validated.
1329 column_id : str
1330 The column ID to validate.
1331 loc : tuple
1332 The location of the error in the schema.
1333 errors : list[InitErrorDetails]
1334 The list of errors to append to.
1335 """
1336 if column_id not in self:
1337 _append_error(
1338 errors,
1339 loc,
1340 column_id,
1341 f"Column ID '{column_id}' not found in schema",
1342 )
1343 elif not isinstance(self[column_id], Column):
1344 _append_error(
1345 errors,
1346 loc,
1347 column_id,
1348 f"ID '{column_id}' does not refer to a Column object",
1349 )
1351 def _validate_foreign_key_column(
1352 self: Schema,
1353 column_id: str,
1354 table: Table,
1355 loc: tuple,
1356 errors: list[InitErrorDetails],
1357 ) -> None:
1358 """Validate a foreign key column ID from a constraint and append errors
1359 if invalid.
1361 Parameters
1362 ----------
1363 schema : Schema
1364 The schema being validated.
1365 column_id : str
1366 The foreign key column ID to validate.
1367 loc : tuple
1368 The location of the error in the schema.
1369 errors : list[InitErrorDetails]
1370 The list of errors to append to.
1371 """
1372 try:
1373 table._find_column_by_id(column_id)
1374 except KeyError:
1375 _append_error(
1376 errors,
1377 loc,
1378 column_id,
1379 f"Column '{column_id}' not found in table '{table.name}'",
1380 )
1382 @model_validator(mode="after")
1383 def check_constraints(self: Schema) -> Schema:
1384 """Check constraint objects for validity. This needs to be deferred
1385 until after the schema is fully loaded and the ID map is created.
1387 Raises
1388 ------
1389 pydantic.ValidationError
1390 Raised if any constraints are invalid.
1392 Returns
1393 -------
1394 `Schema`
1395 The schema being validated.
1396 """
1397 errors: list[InitErrorDetails] = []
1399 for table_index, table in enumerate(self.tables):
1400 for constraint_index, constraint in enumerate(table.constraints):
1401 column_ids: list[str] = []
1402 referenced_column_ids: list[str] = []
1404 if isinstance(constraint, ForeignKeyConstraint):
1405 column_ids += constraint.columns
1406 referenced_column_ids += constraint.referenced_columns
1407 elif isinstance(constraint, UniqueConstraint):
1408 column_ids += constraint.columns
1409 # No extra checks are required on CheckConstraint objects.
1411 # Validate the foreign key columns
1412 for column_id in column_ids:
1413 self._validate_column_id(
1414 column_id,
1415 (
1416 "tables",
1417 table_index,
1418 "constraints",
1419 constraint_index,
1420 "columns",
1421 column_id,
1422 ),
1423 errors,
1424 )
1425 # Check that the foreign key column is within the source
1426 # table.
1427 self._validate_foreign_key_column(
1428 column_id,
1429 table,
1430 (
1431 "tables",
1432 table_index,
1433 "constraints",
1434 constraint_index,
1435 "columns",
1436 column_id,
1437 ),
1438 errors,
1439 )
1441 # Validate the primary key (reference) columns
1442 for referenced_column_id in referenced_column_ids:
1443 self._validate_column_id(
1444 referenced_column_id,
1445 (
1446 "tables",
1447 table_index,
1448 "constraints",
1449 constraint_index,
1450 "referenced_columns",
1451 referenced_column_id,
1452 ),
1453 errors,
1454 )
1456 if errors:
1457 raise ValidationError.from_exception_data("Schema validation failed", errors)
1459 return self
1461 def __getitem__(self, id: str) -> BaseObject:
1462 """Get an object by its ID.
1464 Parameters
1465 ----------
1466 id
1467 The ID of the object to get.
1469 Raises
1470 ------
1471 KeyError
1472 Raised if the object with the given ID is not found in the schema.
1473 """
1474 if id not in self:
1475 raise KeyError(f"Object with ID '{id}' not found in schema")
1476 return self._id_map[id]
1478 def __contains__(self, id: str) -> bool:
1479 """Check if an object with the given ID is in the schema.
1481 Parameters
1482 ----------
1483 id
1484 The ID of the object to check.
1485 """
1486 return id in self._id_map
1488 def find_object_by_id(self, id: str, obj_type: type[T]) -> T:
1489 """Find an object with the given type by its ID.
1491 Parameters
1492 ----------
1493 id
1494 The ID of the object to find.
1495 obj_type
1496 The type of the object to find.
1498 Returns
1499 -------
1500 BaseObject
1501 The object with the given ID and type.
1503 Raises
1504 ------
1505 KeyError
1506 If the object with the given ID is not found in the schema.
1507 TypeError
1508 If the object that is found does not have the right type.
1510 Notes
1511 -----
1512 The actual return type is the user-specified argument ``T``, which is
1513 expected to be a subclass of `BaseObject`.
1514 """
1515 obj = self[id]
1516 if not isinstance(obj, obj_type):
1517 raise TypeError(f"Object with ID '{id}' is not of type '{obj_type.__name__}'")
1518 return obj
1520 def get_table_by_column(self, column: Column) -> Table:
1521 """Find the table that contains a column.
1523 Parameters
1524 ----------
1525 column
1526 The column to find.
1528 Returns
1529 -------
1530 `Table`
1531 The table that contains the column.
1533 Raises
1534 ------
1535 ValueError
1536 If the column is not found in any table.
1537 """
1538 for table in self.tables:
1539 if column in table.columns:
1540 return table
1541 raise ValueError(f"Column '{column.name}' not found in any table")
1543 @classmethod
1544 def from_uri(cls, resource_path: ResourcePathExpression, context: dict[str, Any] = {}) -> Schema:
1545 """Load a `Schema` from a string representing a ``ResourcePath``.
1547 Parameters
1548 ----------
1549 resource_path
1550 The ``ResourcePath`` pointing to a YAML file.
1551 context
1552 Pydantic context to be used in validation.
1554 Returns
1555 -------
1556 `str`
1557 The ID of the object.
1559 Raises
1560 ------
1561 yaml.YAMLError
1562 Raised if there is an error loading the YAML data.
1563 ValueError
1564 Raised if there is an error reading the resource.
1565 pydantic.ValidationError
1566 Raised if the schema fails validation.
1567 """
1568 logger.debug(f"Loading schema from: '{resource_path}'")
1569 try:
1570 rp_stream = ResourcePath(resource_path).read()
1571 except Exception as e:
1572 raise ValueError(f"Error reading resource from '{resource_path}' : {e}") from e
1573 yaml_data = yaml.safe_load(rp_stream)
1574 return Schema.model_validate(yaml_data, context=context)
1576 @classmethod
1577 def from_stream(cls, source: IO[str], context: dict[str, Any] = {}) -> Schema:
1578 """Load a `Schema` from a file stream which should contain YAML data.
1580 Parameters
1581 ----------
1582 source
1583 The file stream to read from.
1584 context
1585 Pydantic context to be used in validation.
1587 Returns
1588 -------
1589 `Schema`
1590 The Felis schema loaded from the stream.
1592 Raises
1593 ------
1594 yaml.YAMLError
1595 Raised if there is an error loading the YAML file.
1596 pydantic.ValidationError
1597 Raised if the schema fails validation.
1598 """
1599 logger.debug("Loading schema from: '%s'", source)
1600 yaml_data = yaml.safe_load(source)
1601 return Schema.model_validate(yaml_data, context=context)
1603 def _model_dump(self, strip_ids: bool = False) -> dict[str, Any]:
1604 """Dump the schema as a dictionary with some default arguments
1605 applied.
1607 Parameters
1608 ----------
1609 strip_ids
1610 Whether to strip the IDs from the dumped data. Defaults to `False`.
1612 Returns
1613 -------
1614 `dict` [ `str`, `Any` ]
1615 The dumped schema data as a dictionary.
1616 """
1617 data = self.model_dump(by_alias=True, exclude_none=True, exclude_defaults=True)
1618 if strip_ids:
1619 data = _strip_ids(data)
1620 return data
1622 def dump_yaml(self, stream: IO[str] = sys.stdout, strip_ids: bool = False) -> None:
1623 """Pretty print the schema as YAML.
1625 Parameters
1626 ----------
1627 stream
1628 The stream to write the YAML data to.
1629 strip_ids
1630 Whether to strip the IDs from the dumped data. Defaults to `False`.
1631 """
1632 data = self._model_dump(strip_ids=strip_ids)
1633 yaml.safe_dump(
1634 data,
1635 stream,
1636 default_flow_style=False,
1637 sort_keys=False,
1638 )
1640 def dump_json(self, stream: IO[str] = sys.stdout, strip_ids: bool = False) -> None:
1641 """Pretty print the schema as JSON.
1643 Parameters
1644 ----------
1645 stream
1646 The stream to write the JSON data to.
1647 strip_ids
1648 Whether to strip the IDs from the dumped data. Defaults to `False`.
1649 """
1650 data = self._model_dump(strip_ids=strip_ids)
1651 json.dump(
1652 data,
1653 stream,
1654 indent=4,
1655 sort_keys=False,
1656 )