Coverage for python / felis / datamodel.py: 31%
723 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-07 08:14 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-07 08:14 +0000
1"""Define Pydantic data models for Felis."""
3# This file is part of felis.
4#
5# Developed for the LSST Data Management System.
6# This product includes software developed by the LSST Project
7# (https://www.lsst.org).
8# See the COPYRIGHT file at the top-level directory of this distribution
9# for details of code ownership.
10#
11# This program is free software: you can redistribute it and/or modify
12# it under the terms of the GNU General Public License as published by
13# the Free Software Foundation, either version 3 of the License, or
14# (at your option) any later version.
15#
16# This program is distributed in the hope that it will be useful,
17# but WITHOUT ANY WARRANTY; without even the implied warranty of
18# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19# GNU General Public License for more details.
20#
21# You should have received a copy of the GNU General Public License
22# along with this program. If not, see <https://www.gnu.org/licenses/>.
24from __future__ import annotations
26import json
27import logging
28import sys
29from collections.abc import Sequence
30from enum import StrEnum, auto
31from operator import itemgetter
32from typing import IO, Annotated, Any, Generic, Literal, TypeAlias, TypeVar
34import yaml
35from astropy import units as units # type: ignore
36from astropy.io.votable import ucd # type: ignore
37from lsst.resources import ResourcePath, ResourcePathExpression
38from pydantic import (
39 BaseModel,
40 ConfigDict,
41 Field,
42 PrivateAttr,
43 ValidationError,
44 ValidationInfo,
45 field_serializer,
46 field_validator,
47 model_validator,
48)
49from pydantic_core import InitErrorDetails
51from .db._dialects import get_supported_dialects, string_to_typeengine
52from .db._sqltypes import get_type_func
53from .types import Boolean, Byte, Char, Double, FelisType, Float, Int, Long, Short, String, Text, Unicode
55logger = logging.getLogger(__name__)
57__all__ = (
58 "BaseObject",
59 "CheckConstraint",
60 "Column",
61 "ColumnOverrides",
62 "ColumnResourceRef",
63 "Constraint",
64 "DataType",
65 "ForeignKeyConstraint",
66 "Index",
67 "Resource",
68 "Schema",
69 "SchemaVersion",
70 "Table",
71 "UniqueConstraint",
72)
74CONFIG = ConfigDict(
75 populate_by_name=True, # Populate attributes by name.
76 extra="forbid", # Do not allow extra fields.
77 str_strip_whitespace=True, # Strip whitespace from string fields.
78 use_enum_values=False, # Do not use enum values during serialization.
79)
80"""Pydantic model configuration as described in:
81https://docs.pydantic.dev/2.0/api/config/#pydantic.config.ConfigDict
82"""
84DESCR_MIN_LENGTH = 3
85"""Minimum length for a description field."""
87DescriptionStr: TypeAlias = Annotated[str, Field(min_length=DESCR_MIN_LENGTH)]
88"""Type for a description, which must be three or more characters long."""
91class BaseObject(BaseModel):
92 """Base model.
94 All classes representing objects in the Felis data model should inherit
95 from this class.
96 """
98 model_config = CONFIG
99 """Pydantic model configuration."""
101 name: str
102 """Name of the database object."""
104 id: str = Field(alias="@id")
105 """Unique identifier of the database object."""
107 description: DescriptionStr | None = None
108 """Description of the database object."""
110 votable_utype: str | None = Field(None, alias="votable:utype")
111 """VOTable utype (usage-specific or unique type) of the object."""
113 @model_validator(mode="after")
114 def check_description(self, info: ValidationInfo) -> BaseObject:
115 """Check that the description is present if required.
117 Parameters
118 ----------
119 info
120 Validation context used to determine if the check is enabled.
122 Returns
123 -------
124 `BaseObject`
125 The object being validated.
126 """
127 context = info.context
128 if not context or not context.get("check_description", False):
129 return self
130 if self.description is None or self.description == "":
131 raise ValueError("Description is required and must be non-empty")
132 if len(self.description) < DESCR_MIN_LENGTH:
133 raise ValueError(f"Description must be at least {DESCR_MIN_LENGTH} characters long")
134 return self
137class DataType(StrEnum):
138 """``Enum`` representing the data types supported by Felis."""
140 boolean = auto()
141 byte = auto()
142 short = auto()
143 int = auto()
144 long = auto()
145 float = auto()
146 double = auto()
147 char = auto()
148 string = auto()
149 unicode = auto()
150 text = auto()
151 binary = auto()
152 timestamp = auto()
155def validate_ivoa_ucd(ivoa_ucd: str) -> str:
156 """Validate IVOA UCD values.
158 Parameters
159 ----------
160 ivoa_ucd
161 IVOA UCD value to check.
163 Returns
164 -------
165 `str`
166 The IVOA UCD value if it is valid.
168 Raises
169 ------
170 ValueError
171 If the IVOA UCD value is invalid.
172 """
173 if ivoa_ucd is not None:
174 try:
175 ucd.parse_ucd(ivoa_ucd, check_controlled_vocabulary=True, has_colon=";" in ivoa_ucd)
176 except ValueError as e:
177 raise ValueError(f"Invalid IVOA UCD: {e}")
178 return ivoa_ucd
181class Column(BaseObject):
182 """Column model."""
184 datatype: DataType
185 """Datatype of the column."""
187 length: int | None = Field(None, gt=0)
188 """Length of the column."""
190 precision: int | None = Field(None, ge=0)
191 """The numerical precision of the column.
193 For timestamps, this is the number of fractional digits retained in the
194 seconds field.
195 """
197 nullable: bool = True
198 """Whether the column can be ``NULL``."""
200 value: str | int | float | bool | None = None
201 """Default value of the column."""
203 autoincrement: bool | None = None
204 """Whether the column is autoincremented."""
206 ivoa_ucd: str | None = Field(None, alias="ivoa:ucd")
207 """IVOA UCD of the column."""
209 fits_tunit: str | None = Field(None, alias="fits:tunit")
210 """FITS TUNIT of the column."""
212 ivoa_unit: str | None = Field(None, alias="ivoa:unit")
213 """IVOA unit of the column."""
215 tap_column_index: int | None = Field(None, alias="tap:column_index")
216 """TAP_SCHEMA column index of the column."""
218 tap_principal: int | None = Field(0, alias="tap:principal", ge=0, le=1)
219 """Whether this is a TAP_SCHEMA principal column."""
221 votable_arraysize: int | str | None = Field(None, alias="votable:arraysize")
222 """VOTable arraysize of the column."""
224 tap_std: int | None = Field(0, alias="tap:std", ge=0, le=1)
225 """TAP_SCHEMA indication that this column is defined by an IVOA standard.
226 """
228 votable_xtype: str | None = Field(None, alias="votable:xtype")
229 """VOTable xtype (extended type) of the column."""
231 votable_datatype: str | None = Field(None, alias="votable:datatype")
232 """VOTable datatype of the column."""
234 mysql_datatype: str | None = Field(None, alias="mysql:datatype")
235 """MySQL datatype override on the column."""
237 postgresql_datatype: str | None = Field(None, alias="postgresql:datatype")
238 """PostgreSQL datatype override on the column."""
240 _is_resource_ref: bool = PrivateAttr(False)
241 """Whether this column is a resource reference column."""
243 @model_validator(mode="after")
244 def check_value(self) -> Column:
245 """Check that the default value is valid.
247 Returns
248 -------
249 `Column`
250 The column being validated.
251 """
252 if (value := self.value) is not None:
253 if value is not None and self.autoincrement is True:
254 raise ValueError("Column cannot have both a default value and be autoincremented")
255 felis_type = FelisType.felis_type(self.datatype)
256 if felis_type.is_numeric:
257 if felis_type in (Byte, Short, Int, Long) and not isinstance(value, int):
258 raise ValueError("Default value must be an int for integer type columns")
259 elif felis_type in (Float, Double) and not isinstance(value, float):
260 raise ValueError("Default value must be a decimal number for float and double columns")
261 elif felis_type in (String, Char, Unicode, Text):
262 if not isinstance(value, str):
263 raise ValueError("Default value must be a string for string columns")
264 if not len(value):
265 raise ValueError("Default value must be a non-empty string for string columns")
266 elif felis_type is Boolean and not isinstance(value, bool):
267 raise ValueError("Default value must be a boolean for boolean columns")
268 return self
270 @field_validator("ivoa_ucd")
271 @classmethod
272 def check_ivoa_ucd(cls, ivoa_ucd: str) -> str:
273 """Check that IVOA UCD values are valid.
275 Parameters
276 ----------
277 ivoa_ucd
278 IVOA UCD value to check.
280 Returns
281 -------
282 `str`
283 The IVOA UCD value if it is valid.
284 """
285 return validate_ivoa_ucd(ivoa_ucd)
287 @model_validator(mode="after")
288 def check_units(self) -> Column:
289 """Check that the ``fits:tunit`` or ``ivoa:unit`` field has valid
290 units according to astropy. Only one may be provided.
292 Returns
293 -------
294 `Column`
295 The column being validated.
297 Raises
298 ------
299 ValueError
300 Raised if both FITS and IVOA units are provided, or if the unit is
301 invalid.
302 """
303 fits_unit = self.fits_tunit
304 ivoa_unit = self.ivoa_unit
306 if fits_unit and ivoa_unit:
307 raise ValueError("Column cannot have both FITS and IVOA units")
308 unit = fits_unit or ivoa_unit
310 if unit is not None:
311 try:
312 units.Unit(unit)
313 except ValueError as e:
314 raise ValueError(f"Invalid unit: {e}")
316 return self
318 @model_validator(mode="before")
319 @classmethod
320 def check_length(cls, values: dict[str, Any]) -> dict[str, Any]:
321 """Check that a valid length is provided for sized types.
323 Parameters
324 ----------
325 values
326 Values of the column.
328 Returns
329 -------
330 `dict` [ `str`, `Any` ]
331 The values of the column.
333 Raises
334 ------
335 ValueError
336 Raised if a length is not provided for a sized type.
337 """
338 datatype = values.get("datatype")
339 if datatype is None:
340 # Skip this validation if datatype is not provided
341 return values
342 length = values.get("length")
343 felis_type = FelisType.felis_type(datatype)
344 if felis_type.is_sized and length is None:
345 raise ValueError(
346 f"Length must be provided for type '{datatype}'"
347 + (f" in column '{values['@id']}'" if "@id" in values else "")
348 )
349 elif not felis_type.is_sized and length is not None:
350 logger.warning(
351 f"The datatype '{datatype}' does not support a specified length"
352 + (f" in column '{values['@id']}'" if "@id" in values else "")
353 )
354 return values
356 @model_validator(mode="after")
357 def check_redundant_datatypes(self, info: ValidationInfo) -> Column:
358 """Check for redundant datatypes on columns.
360 Parameters
361 ----------
362 info
363 Validation context used to determine if the check is enabled.
365 Returns
366 -------
367 `Column`
368 The column being validated.
370 Raises
371 ------
372 ValueError
373 Raised if a datatype override is redundant.
374 """
375 context = info.context
376 if not context or not context.get("check_redundant_datatypes", False):
377 return self
378 if all(
379 getattr(self, f"{dialect}:datatype", None) is not None
380 for dialect in get_supported_dialects().keys()
381 ):
382 return self
384 datatype = self.datatype
385 length: int | None = self.length or None
387 datatype_func = get_type_func(datatype)
388 felis_type = FelisType.felis_type(datatype)
389 if felis_type.is_sized:
390 datatype_obj = datatype_func(length)
391 else:
392 datatype_obj = datatype_func()
394 for dialect_name, dialect in get_supported_dialects().items():
395 db_annotation = f"{dialect_name}_datatype"
396 if datatype_string := self.model_dump().get(db_annotation):
397 db_datatype_obj = string_to_typeengine(datatype_string, dialect, length)
398 if datatype_obj.compile(dialect) == db_datatype_obj.compile(dialect):
399 raise ValueError(
400 "'{}: {}' is a redundant override of 'datatype: {}' in column '{}'{}".format(
401 db_annotation,
402 datatype_string,
403 self.datatype,
404 self.id,
405 "" if length is None else f" with length {length}",
406 )
407 )
408 else:
409 logger.debug(
410 f"Type override of 'datatype: {self.datatype}' "
411 f"with '{db_annotation}: {datatype_string}' in column '{self.id}' "
412 f"compiled to '{datatype_obj.compile(dialect)}' and "
413 f"'{db_datatype_obj.compile(dialect)}'"
414 )
415 return self
417 @model_validator(mode="after")
418 def check_precision(self) -> Column:
419 """Check that precision is only valid for timestamp columns.
421 Returns
422 -------
423 `Column`
424 The column being validated.
425 """
426 if self.precision is not None and self.datatype != "timestamp":
427 raise ValueError("Precision is only valid for timestamp columns")
428 return self
430 @model_validator(mode="before")
431 @classmethod
432 def check_votable_arraysize(cls, values: dict[str, Any], info: ValidationInfo) -> dict[str, Any]:
433 """Set the default value for the ``votable_arraysize`` field, which
434 corresponds to ``arraysize`` in the IVOA VOTable standard.
436 Parameters
437 ----------
438 values
439 Values of the column.
440 info
441 Validation context used to determine if the check is enabled.
443 Returns
444 -------
445 `dict` [ `str`, `Any` ]
446 The values of the column.
448 Notes
449 -----
450 Following the IVOA VOTable standard, an ``arraysize`` of 1 should not
451 be used.
452 """
453 if values.get("name", None) is None or values.get("datatype", None) is None:
454 # Skip bad column data that will not validate
455 return values
456 context = info.context if info.context else {}
457 arraysize = values.get("votable:arraysize", None)
458 if arraysize is None:
459 length = values.get("length", None)
460 datatype = values.get("datatype")
461 if length is not None and length > 1:
462 # Following the IVOA standard, arraysize of 1 is disallowed
463 if datatype == "char":
464 arraysize = str(length)
465 elif datatype in ("string", "unicode", "binary"):
466 if context.get("force_unbounded_arraysize", False):
467 arraysize = "*"
468 logger.debug(
469 f"Forced VOTable's 'arraysize' to '*' on column '{values['name']}' with datatype "
470 + f"'{values['datatype']}' and length '{length}'"
471 )
472 else:
473 arraysize = f"{length}*"
474 elif datatype in ("timestamp", "text"):
475 arraysize = "*"
476 if arraysize is not None:
477 values["votable:arraysize"] = arraysize
478 logger.debug(
479 f"Set default 'votable:arraysize' to '{arraysize}' on column '{values['name']}'"
480 + f" with datatype '{values['datatype']}' and length '{values.get('length', None)}'"
481 )
482 else:
483 logger.debug(f"Using existing 'votable:arraysize' of '{arraysize}' on column '{values['name']}'")
484 if isinstance(values["votable:arraysize"], int):
485 logger.warning(
486 f"Usage of an integer value for 'votable:arraysize' in column '{values['name']}' is "
487 + "deprecated"
488 )
489 values["votable:arraysize"] = str(arraysize)
490 return values
492 @field_serializer("datatype")
493 def serialize_datatype(self, value: DataType) -> str:
494 """Convert `DataType` to string when serializing to JSON/YAML.
496 Parameters
497 ----------
498 value
499 The `DataType` value to serialize.
501 Returns
502 -------
503 `str`
504 The serialized `DataType` value.
505 """
506 return str(value)
508 @field_validator("datatype", mode="before")
509 @classmethod
510 def deserialize_datatype(cls, value: str) -> DataType:
511 """Convert string back into `DataType` when loading from JSON/YAML.
513 Parameters
514 ----------
515 value
516 The string value to deserialize.
518 Returns
519 -------
520 `DataType`
521 The deserialized `DataType` value.
522 """
523 return DataType(value)
525 @model_validator(mode="after")
526 def check_votable_xtype(self) -> Column:
527 """Set the default value for the ``votable_xtype`` field, which
528 corresponds to an Extended Datatype or ``xtype`` in the IVOA VOTable
529 standard.
531 Returns
532 -------
533 `Column`
534 The column being validated.
536 Notes
537 -----
538 This is currently only set automatically for the Felis ``timestamp``
539 datatype.
540 """
541 if self.datatype == DataType.timestamp and self.votable_xtype is None:
542 self.votable_xtype = "timestamp"
543 return self
545 def _update_from_overrides(self, overrides: ColumnOverrides) -> None:
546 """Update the column attributes from the given overrides.
548 Parameters
549 ----------
550 overrides
551 The column overrides to apply or `None` to skip applying overrides.
553 Notes
554 -----
555 Using ``model_fields_set`` allows updating only the fields that are
556 explicitly set in the `overrides` object. This prevents overwriting
557 existing column attributes which were not explicitly provided.
558 """
559 if overrides.model_fields_set:
560 logger.debug("Applying overrides to column '%s': %s", self.id, overrides.model_fields_set)
561 for field in overrides.model_fields_set:
562 setattr(self, field, getattr(overrides, field))
565class Constraint(BaseObject):
566 """Table constraint model."""
568 deferrable: bool = False
569 """Whether this constraint will be declared as deferrable."""
571 initially: Literal["IMMEDIATE", "DEFERRED"] | None = None
572 """Value for ``INITIALLY`` clause; only used if `deferrable` is
573 `True`."""
575 @model_validator(mode="after")
576 def check_deferrable(self) -> Constraint:
577 """Check that the ``INITIALLY`` clause is only used if `deferrable` is
578 `True`.
580 Returns
581 -------
582 `Constraint`
583 The constraint being validated.
584 """
585 if self.initially is not None and not self.deferrable:
586 raise ValueError("INITIALLY clause can only be used if deferrable is True")
587 return self
590class CheckConstraint(Constraint):
591 """Table check constraint model."""
593 type: Literal["Check"] = Field("Check", alias="@type")
594 """Type of the constraint."""
596 expression: str
597 """Expression for the check constraint."""
599 @field_serializer("type")
600 def serialize_type(self, value: str) -> str:
601 """Ensure '@type' is included in serialized output.
603 Parameters
604 ----------
605 value
606 The value to serialize.
608 Returns
609 -------
610 `str`
611 The serialized value.
612 """
613 return value
616class UniqueConstraint(Constraint):
617 """Table unique constraint model."""
619 type: Literal["Unique"] = Field("Unique", alias="@type")
620 """Type of the constraint."""
622 columns: list[str]
623 """Columns in the unique constraint."""
625 @field_serializer("type")
626 def serialize_type(self, value: str) -> str:
627 """Ensure '@type' is included in serialized output.
629 Parameters
630 ----------
631 value
632 The value to serialize.
634 Returns
635 -------
636 `str`
637 The serialized value.
638 """
639 return value
642class ForeignKeyConstraint(Constraint):
643 """Table foreign key constraint model.
645 This constraint is used to define a foreign key relationship between two
646 tables in the schema. There must be at least one column in the
647 `columns` list, and at least one column in the `referenced_columns` list
648 or a validation error will be raised.
650 Notes
651 -----
652 These relationships will be reflected in the TAP_SCHEMA ``keys`` and
653 ``key_columns`` data.
654 """
656 type: Literal["ForeignKey"] = Field("ForeignKey", alias="@type")
657 """Type of the constraint."""
659 columns: list[str] = Field(min_length=1)
660 """The columns comprising the foreign key."""
662 referenced_columns: list[str] = Field(alias="referencedColumns", min_length=1)
663 """The columns referenced by the foreign key."""
665 on_delete: Literal["CASCADE", "SET NULL", "SET DEFAULT", "RESTRICT", "NO ACTION"] | None = None
666 """Action to take when the referenced row is deleted."""
668 on_update: Literal["CASCADE", "SET NULL", "SET DEFAULT", "RESTRICT", "NO ACTION"] | None = None
669 """Action to take when the referenced row is updated."""
671 @field_serializer("type")
672 def serialize_type(self, value: str) -> str:
673 """Ensure '@type' is included in serialized output.
675 Parameters
676 ----------
677 value
678 The value to serialize.
680 Returns
681 -------
682 `str`
683 The serialized value.
684 """
685 return value
687 @model_validator(mode="after")
688 def check_column_lengths(self) -> ForeignKeyConstraint:
689 """Check that the `columns` and `referenced_columns` lists have the
690 same length.
692 Returns
693 -------
694 `ForeignKeyConstraint`
695 The foreign key constraint being validated.
697 Raises
698 ------
699 ValueError
700 Raised if the `columns` and `referenced_columns` lists do not have
701 the same length.
702 """
703 if len(self.columns) != len(self.referenced_columns):
704 raise ValueError(
705 "Columns and referencedColumns must have the same length for a ForeignKey constraint"
706 )
707 return self
710_ConstraintType = Annotated[
711 CheckConstraint | ForeignKeyConstraint | UniqueConstraint, Field(discriminator="type")
712]
713"""Type alias for a constraint type."""
716class Index(BaseObject):
717 """Table index model.
719 An index can be defined on either columns or expressions, but not both.
720 """
722 columns: list[str] | None = None
723 """Columns in the index."""
725 expressions: list[str] | None = None
726 """Expressions in the index."""
728 @model_validator(mode="before")
729 @classmethod
730 def check_columns_or_expressions(cls, values: dict[str, Any]) -> dict[str, Any]:
731 """Check that columns or expressions are specified, but not both.
733 Parameters
734 ----------
735 values
736 Values of the index.
738 Returns
739 -------
740 `dict` [ `str`, `Any` ]
741 The values of the index.
743 Raises
744 ------
745 ValueError
746 Raised if both columns and expressions are specified, or if neither
747 are specified.
748 """
749 if "columns" in values and "expressions" in values:
750 raise ValueError("Defining columns and expressions is not valid")
751 elif "columns" not in values and "expressions" not in values:
752 raise ValueError("Must define columns or expressions")
753 return values
756ColumnRef: TypeAlias = str
757"""Type alias for a column reference."""
760class ColumnGroup(BaseObject):
761 """Column group model."""
763 columns: list[ColumnRef | Column] = Field(..., min_length=1)
764 """Columns in the group."""
766 ivoa_ucd: str | None = Field(None, alias="ivoa:ucd")
767 """IVOA UCD of the column."""
769 table: Table | None = Field(None, exclude=True)
770 """Reference to the parent table."""
772 @field_validator("ivoa_ucd")
773 @classmethod
774 def check_ivoa_ucd(cls, ivoa_ucd: str) -> str:
775 """Check that IVOA UCD values are valid.
777 Parameters
778 ----------
779 ivoa_ucd
780 IVOA UCD value to check.
782 Returns
783 -------
784 `str`
785 The IVOA UCD value if it is valid.
786 """
787 return validate_ivoa_ucd(ivoa_ucd)
789 @model_validator(mode="after")
790 def check_unique_columns(self) -> ColumnGroup:
791 """Check that the columns list contains unique items.
793 Returns
794 -------
795 `ColumnGroup`
796 The column group being validated.
797 """
798 column_ids = [col if isinstance(col, str) else col.id for col in self.columns]
799 if len(column_ids) != len(set(column_ids)):
800 raise ValueError("Columns in the group must be unique")
801 return self
803 def _dereference_columns(self) -> None:
804 """Dereference ColumnRef to Column objects."""
805 if self.table is None:
806 raise ValueError("ColumnGroup must have a reference to its parent table")
808 dereferenced_columns: list[ColumnRef | Column] = []
809 for col in self.columns:
810 if isinstance(col, str):
811 # Dereference ColumnRef to Column object
812 try:
813 col_obj = self.table._find_column_by_id(col)
814 except KeyError as e:
815 raise ValueError(f"Column '{col}' not found in table '{self.table.name}'") from e
816 dereferenced_columns.append(col_obj)
817 else:
818 dereferenced_columns.append(col)
820 self.columns = dereferenced_columns
822 @field_serializer("columns")
823 def serialize_columns(self, columns: list[ColumnRef | Column]) -> list[str]:
824 """Serialize columns as their IDs.
826 Parameters
827 ----------
828 columns
829 The columns to serialize.
831 Returns
832 -------
833 `list` [ `str` ]
834 The serialized column IDs.
835 """
836 return [col if isinstance(col, str) else col.id for col in columns]
839class ColumnOverrides(BaseModel):
840 """Allowed overrides for a referenced column.
842 Notes
843 -----
844 All of these fields are optional. Values of None may be explicitly set to
845 override the corresponding attribute in the referenced column but only
846 for certain fields (see validation in `_check_non_nullable_overrides`).
847 """
849 model_config = CONFIG.copy()
851 datatype: DataType | None = None
852 """New datatype for the column."""
854 length: int | None = None
855 """New length for the column."""
857 description: str | None = None
858 """New description for the column."""
860 nullable: bool | None = None
861 """New nullable flag for the column."""
863 tap_principal: int | None = Field(default=None, alias="tap:principal")
864 """Override for the TAP_SCHEMA 'principal' flag."""
866 tap_column_index: int | None = Field(default=None, alias="tap:column_index")
867 """Override for the TAP_SCHEMA column index."""
869 @model_validator(mode="before")
870 @classmethod
871 def _check_non_nullable_overrides(cls, data: Any) -> Any:
872 """Check that certain fields are not overridden to null."""
873 if not isinstance(data, dict):
874 return data
875 non_nullable_fields = ("datatype", "length", "nullable", "tap_principal")
876 for name in non_nullable_fields:
877 if name in data and data[name] is None:
878 raise ValueError(f"The '{name}' field cannot be overridden to null")
879 return data
881 @field_serializer("datatype")
882 def serialize_datatype(self, value: DataType | None) -> str | None:
883 """Convert `DataType` to string when serializing to JSON/YAML.
885 Parameters
886 ----------
887 value
888 The `DataType` value to serialize, or None.
890 Returns
891 -------
892 `str` | None
893 The serialized `DataType` value, or None if the input was None.
894 """
895 if value is None:
896 return None
897 return str(value)
899 @field_validator("datatype", mode="before")
900 @classmethod
901 def deserialize_datatype(cls, value: str | None) -> DataType | None:
902 """Convert string back into `DataType` when loading from JSON/YAML.
904 Parameters
905 ----------
906 value
907 The string value to deserialize, or None.
909 Returns
910 -------
911 `DataType` | None
912 The deserialized `DataType` value, or None if the input was None.
913 """
914 if value is None:
915 return None
916 return DataType(value)
919class ColumnResourceRef(BaseModel):
920 """A column which is dervived from an external resource."""
922 ref_name: str | None = None
923 """Name of the referenced column in the resource
924 (if different from the key)."""
926 overrides: ColumnOverrides | None = None
927 """Optional overrides of the referenced column's attributes."""
930# Type aliases for the nested mapping structure of referenced columns
931ResourceColumnMap: TypeAlias = dict[str, ColumnResourceRef | None]
932ResourceTableMap: TypeAlias = dict[str, ResourceColumnMap]
933ResourceMap: TypeAlias = dict[str, ResourceTableMap]
936class Table(BaseObject):
937 """Table model."""
939 primary_key: str | list[str] | None = Field(None, alias="primaryKey")
940 """Primary key of the table."""
942 tap_table_index: int | None = Field(None, alias="tap:table_index")
943 """IVOA TAP_SCHEMA table index of the table."""
945 mysql_engine: str | None = Field("MyISAM", alias="mysql:engine")
946 """MySQL engine to use for the table."""
948 mysql_charset: str | None = Field(None, alias="mysql:charset")
949 """MySQL charset to use for the table."""
951 columns: list[Column] = Field(default_factory=list)
952 """Columns in the table."""
954 column_refs: ResourceMap = Field(default_factory=dict, alias="columnRefs")
955 """Referenced columns from external resources."""
957 column_groups: list[ColumnGroup] = Field(default_factory=list, alias="columnGroups")
958 """Column groups in the table."""
960 constraints: list[_ConstraintType] = Field(default_factory=list)
961 """Constraints on the table."""
963 indexes: list[Index] = Field(default_factory=list)
964 """Indexes on the table."""
966 @field_validator("columns", mode="after")
967 @classmethod
968 def check_unique_column_names(cls, columns: list[Column]) -> list[Column]:
969 """Check that column names are unique.
971 Parameters
972 ----------
973 columns
974 The columns to check.
976 Returns
977 -------
978 `list` [ `Column` ]
979 The columns if they are unique.
981 Raises
982 ------
983 ValueError
984 Raised if column names are not unique.
985 """
986 if len(columns) != len(set(column.name for column in columns)):
987 raise ValueError("Column names must be unique")
988 return columns
990 @model_validator(mode="after")
991 def check_tap_table_index(self, info: ValidationInfo) -> Table:
992 """Check that the table has a TAP table index.
994 Parameters
995 ----------
996 info
997 Validation context used to determine if the check is enabled.
999 Returns
1000 -------
1001 `Table`
1002 The table being validated.
1004 Raises
1005 ------
1006 ValueError
1007 Raised If the table is missing a TAP table index.
1008 """
1009 context = info.context
1010 if not context or not context.get("check_tap_table_indexes", False):
1011 return self
1012 if self.tap_table_index is None:
1013 raise ValueError("Table is missing a TAP table index")
1014 return self
1016 @model_validator(mode="after")
1017 def check_tap_principal(self, info: ValidationInfo) -> Table:
1018 """Check that at least one column is flagged as 'principal' for TAP
1019 purposes.
1021 Parameters
1022 ----------
1023 info
1024 Validation context used to determine if the check is enabled.
1026 Returns
1027 -------
1028 `Table`
1029 The table being validated.
1031 Raises
1032 ------
1033 ValueError
1034 Raised if the table is missing a column flagged as 'principal'.
1035 """
1036 context = info.context
1037 if not context or not context.get("check_tap_principal", False):
1038 return self
1039 for col in self.columns:
1040 if col.tap_principal == 1:
1041 return self
1042 raise ValueError(f"Table '{self.name}' is missing at least one column designated as 'tap:principal'")
1044 def _find_column_by_id(self, id: str) -> Column:
1045 """Find a column by ID.
1047 Parameters
1048 ----------
1049 id
1050 The ID of the column to find.
1052 Returns
1053 -------
1054 `Column`
1055 The column with the given ID.
1057 Raises
1058 ------
1059 ValueError
1060 Raised if the column is not found.
1061 """
1062 for column in self.columns:
1063 if column.id == id:
1064 return column
1065 raise KeyError(f"Column '{id}' not found in table '{self.name}'")
1067 def _find_column_by_name(self, name: str) -> Column:
1068 for column in self.columns:
1069 if column.name == name:
1070 return column
1071 raise KeyError(f"Column '{name}' not found in table '{self.name}'")
1073 @model_validator(mode="after")
1074 def dereference_column_groups(self: Table) -> Table:
1075 """Dereference columns in column groups.
1077 Returns
1078 -------
1079 `Table`
1080 The table with dereferenced column groups.
1081 """
1082 for group in self.column_groups:
1083 group.table = self
1084 group._dereference_columns()
1085 return self
1087 @field_serializer("columns")
1088 def _serialize_columns(self, columns: list[Column]) -> list[dict[str, Any]]:
1089 """Serialize only non-resource columns."""
1090 return [
1091 col.model_dump(
1092 by_alias=True,
1093 exclude_none=True,
1094 exclude_defaults=True,
1095 )
1096 for col in columns
1097 if not col._is_resource_ref
1098 ]
1101class SchemaVersion(BaseModel):
1102 """Schema version model."""
1104 current: str
1105 """The current version of the schema."""
1107 compatible: list[str] = Field(default_factory=list)
1108 """The compatible versions of the schema."""
1110 read_compatible: list[str] = Field(default_factory=list)
1111 """The read compatible versions of the schema."""
1114class SchemaIdVisitor:
1115 """Visit a schema and build the map of IDs to objects.
1117 Notes
1118 -----
1119 Duplicates are added to a set when they are encountered, which can be
1120 accessed via the ``duplicates`` attribute. The presence of duplicates will
1121 not throw an error. Only the first object with a given ID will be added to
1122 the map, but this should not matter, since a ``ValidationError`` will be
1123 thrown by the ``model_validator`` method if any duplicates are found in the
1124 schema.
1125 """
1127 def __init__(self) -> None:
1128 """Create a new SchemaVisitor."""
1129 self.schema: Schema | None = None
1130 self.duplicates: set[str] = set()
1132 def add(self, obj: BaseObject) -> None:
1133 """Add an object to the ID map.
1135 Parameters
1136 ----------
1137 obj
1138 The object to add to the ID map.
1139 """
1140 if hasattr(obj, "id"):
1141 obj_id = getattr(obj, "id")
1142 if self.schema is not None:
1143 if obj_id in self.schema._id_map:
1144 self.duplicates.add(obj_id)
1145 else:
1146 self.schema._id_map[obj_id] = obj
1148 def visit_schema(self, schema: Schema) -> None:
1149 """Visit the objects in a schema and build the ID map.
1151 Parameters
1152 ----------
1153 schema
1154 The schema object to visit.
1156 Notes
1157 -----
1158 This will set an internal variable pointing to the schema object.
1159 """
1160 self.schema = schema
1161 self.duplicates.clear()
1162 self.add(self.schema)
1163 for table in self.schema.tables:
1164 self.visit_table(table)
1166 def visit_table(self, table: Table) -> None:
1167 """Visit a table object.
1169 Parameters
1170 ----------
1171 table
1172 The table object to visit.
1173 """
1174 self.add(table)
1175 for column in table.columns:
1176 self.visit_column(column)
1177 for constraint in table.constraints:
1178 self.visit_constraint(constraint)
1180 def visit_column(self, column: Column) -> None:
1181 """Visit a column object.
1183 Parameters
1184 ----------
1185 column
1186 The column object to visit.
1187 """
1188 self.add(column)
1190 def visit_constraint(self, constraint: Constraint) -> None:
1191 """Visit a constraint object.
1193 Parameters
1194 ----------
1195 constraint
1196 The constraint object to visit.
1197 """
1198 self.add(constraint)
1201T = TypeVar("T", bound=BaseObject)
1204def _strip_ids(data: Any) -> Any:
1205 """Recursively strip '@id' fields from a dictionary or list.
1207 Parameters
1208 ----------
1209 data
1210 The data to strip IDs from, which can be a dictionary, list, or any
1211 other type. Other types will be returned unchanged.
1212 """
1213 if isinstance(data, dict):
1214 data.pop("@id", None)
1215 for k, v in data.items():
1216 data[k] = _strip_ids(v)
1217 return data
1218 elif isinstance(data, list):
1219 return [_strip_ids(item) for item in data]
1220 else:
1221 return data
1224def _append_error(
1225 errors: list[InitErrorDetails],
1226 loc: tuple,
1227 input_value: Any,
1228 error_message: str,
1229 error_type: str = "value_error",
1230) -> None:
1231 """Append an error to the errors list.
1233 Parameters
1234 ----------
1235 errors : list[InitErrorDetails]
1236 The list of errors to append to.
1237 loc : tuple
1238 The location of the error in the schema.
1239 input_value : Any
1240 The input value that caused the error.
1241 error_message : str
1242 The error message to include in the context.
1243 """
1244 errors.append(
1245 {
1246 "type": error_type,
1247 "loc": loc,
1248 "input": input_value,
1249 "ctx": {"error": error_message},
1250 }
1251 )
1254class Resource(BaseModel):
1255 """A resource definition referencing an external schema."""
1257 uri: str = Field(..., description="Resource URI or path")
1258 """URI of the schema resource which may be a local path, ``resource://``,
1259 or remote URL."""
1262class Schema(BaseObject, Generic[T]):
1263 """Database schema model.
1265 This represents a database schema, which contains one or more tables.
1266 """
1268 version: SchemaVersion | str | None = None
1269 """The version of the schema."""
1271 tables: Sequence[Table]
1272 """The tables in the schema."""
1274 resources: dict[str, Resource] = Field(default_factory=dict)
1275 """External resources referenced by this schema."""
1277 _id_map: dict[str, Any] = PrivateAttr(default_factory=dict)
1278 """Map of IDs to objects."""
1280 _resource_map: dict[str, Schema] = PrivateAttr(default_factory=dict)
1281 """Map of resource names to loaded schemas."""
1283 @model_validator(mode="after")
1284 def _load_resources(self: Schema, info: ValidationInfo) -> Schema:
1285 """Load external resources referenced by this schema into an internal
1286 mapping of resource names to their `Schema` objects.
1288 Returns
1289 -------
1290 `Schema`
1291 The schema being validated.
1293 Raises
1294 ------
1295 ValueError
1296 Raised if a resource cannot be loaded.
1297 """
1298 if info.context:
1299 context = info.context.copy()
1300 # Ignore this flag for loading the resources themselves
1301 context.pop("dereference_resources", None)
1302 else:
1303 context = {}
1305 for resource_name, resource in self.resources.items():
1306 uri = resource.uri
1307 try:
1308 loaded_schema = Schema.from_uri(uri, context=context)
1309 self._resource_map[resource_name] = loaded_schema
1310 logger.debug(f"Loaded resource '{resource_name}' from URI '{uri}'")
1311 except Exception as e:
1312 raise ValueError(f"Failed to load resource '{resource_name}' from URI '{uri}': {e}") from e
1313 return self
1315 def _find_table_by_name(self, name: str) -> Table:
1316 """Find a table by name.
1318 Parameters
1319 ----------
1320 name
1321 The name of the table to find.
1323 Returns
1324 -------
1325 `Table`
1326 The table with the given name.
1328 Raises
1329 ------
1330 KeyError
1331 Raised if the table is not found.
1332 """
1333 for table in self.tables:
1334 if table.name == name:
1335 return table
1336 raise KeyError(f"Table '{name}' not found in schema '{self.name}'")
1338 @model_validator(mode="after")
1339 def _dereference_resource_columns(self: Schema, info: ValidationInfo) -> Schema:
1340 """Dereference columns from external resources and add them to the
1341 tables in this schema.
1342 """
1343 context = info.context
1344 column_ref_index_increment: int | None = None
1345 dereference_resources = False
1346 if context is not None:
1347 dereference_resources = context.get("dereference_resources", False)
1348 column_ref_index_increment = context.get("column_ref_index_increment", None)
1350 for table in self.tables:
1351 if column_refs := table.column_refs:
1352 for resource_name, tables in column_refs.items():
1353 resource_schema = self._resource_map.get(resource_name)
1354 if resource_schema is None:
1355 raise ValueError(f"Schema resource '{resource_name}' was not found in resources")
1356 self._process_column_refs(
1357 table,
1358 tables,
1359 resource_schema,
1360 dereference_resources,
1361 column_ref_index_increment,
1362 )
1363 if dereference_resources and len(table.column_refs) > 0:
1364 # Clear column refs in table if fully dereferencing
1365 logger.debug(
1366 f"Clearing columnRefs in table '{table.name}' after dereferencing resource columns"
1367 )
1368 table.column_refs = {}
1369 return self
1371 @classmethod
1372 def _process_column_refs(
1373 cls,
1374 table: Table,
1375 ref_tables: ResourceTableMap,
1376 resource_schema: Schema,
1377 dereference_resources: bool = False,
1378 column_ref_index_increment: int | None = None,
1379 ) -> None:
1380 """Process column references from an external resource and add them
1381 to the given table as columns.
1382 """
1383 current_column_index = column_ref_index_increment if column_ref_index_increment is not None else -1
1385 for table_name, columns in ref_tables.items():
1386 try:
1387 resource_table = resource_schema._find_table_by_name(table_name)
1388 except KeyError as e:
1389 raise ValueError(
1390 f"Table '{table_name}' not found in resource '{resource_schema.name}'"
1391 ) from e
1392 for local_column_name, column_ref in columns.items():
1393 if column_ref is not None and column_ref.ref_name is not None:
1394 # Use specified ref_name
1395 ref_column_name = column_ref.ref_name
1396 else:
1397 # Use the local column name if no ref_name
1398 # specified
1399 ref_column_name = local_column_name
1401 # Check if referenced column exists in resource
1402 try:
1403 base_column = resource_table._find_column_by_name(ref_column_name)
1404 except KeyError:
1405 # The ref_name is specified but column is not
1406 # found
1407 if column_ref is not None and column_ref.ref_name is not None:
1408 raise ValueError(
1409 f"Column '{ref_column_name}' not found in table '{table_name}' "
1410 f"from resource '{resource_schema.name}'"
1411 )
1412 # The ref_name is not specified and the local
1413 # column name is not found
1414 raise ValueError(
1415 f"Column '{local_column_name}' not found in table '{table_name}' "
1416 f"from resource '{resource_schema.name}' and no ref_name provided"
1417 )
1419 # Create a copy of the base column
1420 column_copy = base_column.model_copy()
1422 # Set the local name (key from the mapping)
1423 column_copy.name = local_column_name
1425 if not dereference_resources:
1426 # Flag the column as a resource reference so it will not be
1427 # written out during serialization
1428 column_copy._is_resource_ref = True
1430 # Apply overrides to the referenced column definition
1431 overrides = column_ref.overrides if column_ref is not None else None
1432 if overrides is not None:
1433 column_copy._update_from_overrides(overrides)
1435 # Manually set the ID of the copied column as ID generation has
1436 # already occurred by now
1437 column_copy.id = f"{table.id}.{local_column_name}"
1439 # Apply automatic assignment of 'tap:column_index', if enabled
1440 if column_ref_index_increment is not None:
1441 if (not overrides) or (overrides.tap_column_index is None):
1442 column_copy.tap_column_index = current_column_index
1443 current_column_index += column_ref_index_increment
1444 logger.debug(
1445 f"Automatically assigned 'tap:column_index' {column_copy.tap_column_index} to "
1446 f"column '{local_column_name}' in table '{table_name}' from resource "
1447 f"'{resource_schema.name}'"
1448 )
1449 else:
1450 # Skip automatic assignment of 'tap:column_index' if it
1451 # is already overridden
1452 logger.debug(
1453 f"Skipping automatic assignment of 'tap:column_index' for column "
1454 f"'{local_column_name}' in table '{table_name}' from resource "
1455 f"'{resource_schema.name}' as it is already overridden to "
1456 f"{column_copy.tap_column_index}"
1457 )
1458 table.columns.append(column_copy)
1459 logger.debug(
1460 f"Dereferenced column '{local_column_name}' from table '{table_name}' "
1461 f"in resource '{resource_schema.name}' into table '{table.name}'"
1462 )
1464 @model_validator(mode="before")
1465 @classmethod
1466 def generate_ids(cls, values: dict[str, Any], info: ValidationInfo) -> dict[str, Any]:
1467 """Generate IDs for objects that do not have them.
1469 Parameters
1470 ----------
1471 values
1472 The values of the schema.
1473 info
1474 Validation context used to determine if ID generation is enabled.
1476 Returns
1477 -------
1478 `dict` [ `str`, `Any` ]
1479 The values of the schema with generated IDs.
1480 """
1481 context = info.context
1482 if not context or not context.get("id_generation", False):
1483 logger.debug("Skipping ID generation")
1484 return values
1485 schema_name = values["name"]
1486 if "@id" not in values:
1487 values["@id"] = f"#{schema_name}"
1488 logger.debug(f"Generated ID '{values['@id']}' for schema '{schema_name}'")
1489 if "tables" in values:
1490 for table in values["tables"]:
1491 if "@id" not in table:
1492 table["@id"] = f"#{table['name']}"
1493 logger.debug(f"Generated ID '{table['@id']}' for table '{table['name']}'")
1494 if "columns" in table:
1495 for column in table["columns"]:
1496 if "@id" not in column:
1497 column["@id"] = f"#{table['name']}.{column['name']}"
1498 logger.debug(f"Generated ID '{column['@id']}' for column '{column['name']}'")
1499 if "columnGroups" in table:
1500 for column_group in table["columnGroups"]:
1501 if "@id" not in column_group:
1502 column_group["@id"] = f"#{table['name']}.{column_group['name']}"
1503 logger.debug(
1504 f"Generated ID '{column_group['@id']}' for column group "
1505 f"'{column_group['name']}'"
1506 )
1507 if "constraints" in table:
1508 for constraint in table["constraints"]:
1509 if "@id" not in constraint:
1510 constraint["@id"] = f"#{constraint['name']}"
1511 logger.debug(
1512 f"Generated ID '{constraint['@id']}' for constraint '{constraint['name']}'"
1513 )
1514 if "indexes" in table:
1515 for index in table["indexes"]:
1516 if "@id" not in index:
1517 index["@id"] = f"#{index['name']}"
1518 logger.debug(f"Generated ID '{index['@id']}' for index '{index['name']}'")
1519 return values
1521 @field_validator("tables", mode="after")
1522 @classmethod
1523 def check_unique_table_names(cls, tables: list[Table]) -> list[Table]:
1524 """Check that table names are unique.
1526 Parameters
1527 ----------
1528 tables
1529 The tables to check.
1531 Returns
1532 -------
1533 `list` [ `Table` ]
1534 The tables if they are unique.
1536 Raises
1537 ------
1538 ValueError
1539 Raised if table names are not unique.
1540 """
1541 if len(tables) != len(set(table.name for table in tables)):
1542 raise ValueError("Table names must be unique")
1543 return tables
1545 @model_validator(mode="after")
1546 def check_tap_table_indexes(self, info: ValidationInfo) -> Schema:
1547 """Check that the TAP table indexes are unique.
1549 Parameters
1550 ----------
1551 info
1552 The validation context used to determine if the check is enabled.
1554 Returns
1555 -------
1556 `Schema`
1557 The schema being validated.
1558 """
1559 context = info.context
1560 if not context or not context.get("check_tap_table_indexes", False):
1561 return self
1562 table_indicies = set()
1563 for table in self.tables:
1564 table_index = table.tap_table_index
1565 if table_index is not None:
1566 if table_index in table_indicies:
1567 raise ValueError(f"Duplicate 'tap:table_index' value {table_index} found in schema")
1568 table_indicies.add(table_index)
1569 return self
1571 @model_validator(mode="after")
1572 def check_unique_constraint_names(self: Schema) -> Schema:
1573 """Check for duplicate constraint names in the schema.
1575 Returns
1576 -------
1577 `Schema`
1578 The schema being validated.
1580 Raises
1581 ------
1582 ValueError
1583 Raised if duplicate constraint names are found in the schema.
1584 """
1585 constraint_names = set()
1586 duplicate_names = []
1588 for table in self.tables:
1589 for constraint in table.constraints:
1590 constraint_name = constraint.name
1591 if constraint_name in constraint_names:
1592 duplicate_names.append(constraint_name)
1593 else:
1594 constraint_names.add(constraint_name)
1596 if duplicate_names:
1597 raise ValueError(f"Duplicate constraint names found in schema: {duplicate_names}")
1599 return self
1601 @model_validator(mode="after")
1602 def check_unique_index_names(self: Schema) -> Schema:
1603 """Check for duplicate index names in the schema.
1605 Returns
1606 -------
1607 `Schema`
1608 The schema being validated.
1610 Raises
1611 ------
1612 ValueError
1613 Raised if duplicate index names are found in the schema.
1614 """
1615 index_names = set()
1616 duplicate_names = []
1618 for table in self.tables:
1619 for index in table.indexes:
1620 index_name = index.name
1621 if index_name in index_names:
1622 duplicate_names.append(index_name)
1623 else:
1624 index_names.add(index_name)
1626 if duplicate_names:
1627 raise ValueError(f"Duplicate index names found in schema: {duplicate_names}")
1629 return self
1631 @model_validator(mode="after")
1632 def create_id_map(self: Schema) -> Schema:
1633 """Create a map of IDs to objects.
1635 Returns
1636 -------
1637 `Schema`
1638 The schema with the ID map created.
1640 Raises
1641 ------
1642 ValueError
1643 Raised if duplicate identifiers are found in the schema.
1644 """
1645 if self._id_map:
1646 logger.debug("Ignoring call to create_id_map() - ID map was already populated")
1647 return self
1648 visitor: SchemaIdVisitor = SchemaIdVisitor()
1649 visitor.visit_schema(self)
1650 if len(visitor.duplicates):
1651 raise ValueError(
1652 "Duplicate IDs found in schema:\n " + "\n ".join(visitor.duplicates) + "\n"
1653 )
1654 logger.debug("Created ID map with %d entries", len(self._id_map))
1655 return self
1657 def _validate_column_id(
1658 self: Schema,
1659 column_id: str,
1660 loc: tuple,
1661 errors: list[InitErrorDetails],
1662 ) -> None:
1663 """Validate a column ID from a constraint and append errors if invalid.
1665 Parameters
1666 ----------
1667 schema : Schema
1668 The schema being validated.
1669 column_id : str
1670 The column ID to validate.
1671 loc : tuple
1672 The location of the error in the schema.
1673 errors : list[InitErrorDetails]
1674 The list of errors to append to.
1675 """
1676 if column_id not in self:
1677 _append_error(
1678 errors,
1679 loc,
1680 column_id,
1681 f"Column ID '{column_id}' not found in schema",
1682 )
1683 elif not isinstance(self[column_id], Column):
1684 _append_error(
1685 errors,
1686 loc,
1687 column_id,
1688 f"ID '{column_id}' does not refer to a Column object",
1689 )
1691 def _validate_foreign_key_column(
1692 self: Schema,
1693 column_id: str,
1694 table: Table,
1695 loc: tuple,
1696 errors: list[InitErrorDetails],
1697 ) -> None:
1698 """Validate a foreign key column ID from a constraint and append errors
1699 if invalid.
1701 Parameters
1702 ----------
1703 schema : Schema
1704 The schema being validated.
1705 column_id : str
1706 The foreign key column ID to validate.
1707 loc : tuple
1708 The location of the error in the schema.
1709 errors : list[InitErrorDetails]
1710 The list of errors to append to.
1711 """
1712 try:
1713 table._find_column_by_id(column_id)
1714 except KeyError:
1715 _append_error(
1716 errors,
1717 loc,
1718 column_id,
1719 f"Column '{column_id}' not found in table '{table.name}'",
1720 )
1722 @model_validator(mode="after")
1723 def check_constraints(self: Schema) -> Schema:
1724 """Check constraint objects for validity. This needs to be deferred
1725 until after the schema is fully loaded and the ID map is created.
1727 Raises
1728 ------
1729 pydantic.ValidationError
1730 Raised if any constraints are invalid.
1732 Returns
1733 -------
1734 `Schema`
1735 The schema being validated.
1736 """
1737 errors: list[InitErrorDetails] = []
1739 for table_index, table in enumerate(self.tables):
1740 for constraint_index, constraint in enumerate(table.constraints):
1741 column_ids: list[str] = []
1742 referenced_column_ids: list[str] = []
1744 if isinstance(constraint, ForeignKeyConstraint):
1745 column_ids += constraint.columns
1746 referenced_column_ids += constraint.referenced_columns
1747 elif isinstance(constraint, UniqueConstraint):
1748 column_ids += constraint.columns
1749 # No extra checks are required on CheckConstraint objects.
1751 # Validate the foreign key columns
1752 for column_id in column_ids:
1753 self._validate_column_id(
1754 column_id,
1755 (
1756 "tables",
1757 table_index,
1758 "constraints",
1759 constraint_index,
1760 "columns",
1761 column_id,
1762 ),
1763 errors,
1764 )
1765 # Check that the foreign key column is within the source
1766 # table.
1767 self._validate_foreign_key_column(
1768 column_id,
1769 table,
1770 (
1771 "tables",
1772 table_index,
1773 "constraints",
1774 constraint_index,
1775 "columns",
1776 column_id,
1777 ),
1778 errors,
1779 )
1781 # Validate the primary key (reference) columns
1782 for referenced_column_id in referenced_column_ids:
1783 self._validate_column_id(
1784 referenced_column_id,
1785 (
1786 "tables",
1787 table_index,
1788 "constraints",
1789 constraint_index,
1790 "referenced_columns",
1791 referenced_column_id,
1792 ),
1793 errors,
1794 )
1796 if errors:
1797 raise ValidationError.from_exception_data("Schema validation failed", errors)
1799 return self
1801 def __getitem__(self, id: str) -> BaseObject:
1802 """Get an object by its ID.
1804 Parameters
1805 ----------
1806 id
1807 The ID of the object to get.
1809 Raises
1810 ------
1811 KeyError
1812 Raised if the object with the given ID is not found in the schema.
1813 """
1814 if id not in self:
1815 raise KeyError(f"Object with ID '{id}' not found in schema")
1816 return self._id_map[id]
1818 def __contains__(self, id: str) -> bool:
1819 """Check if an object with the given ID is in the schema.
1821 Parameters
1822 ----------
1823 id
1824 The ID of the object to check.
1825 """
1826 return id in self._id_map
1828 def find_object_by_id(self, id: str, obj_type: type[T]) -> T:
1829 """Find an object with the given type by its ID.
1831 Parameters
1832 ----------
1833 id
1834 The ID of the object to find.
1835 obj_type
1836 The type of the object to find.
1838 Returns
1839 -------
1840 BaseObject
1841 The object with the given ID and type.
1843 Raises
1844 ------
1845 KeyError
1846 If the object with the given ID is not found in the schema.
1847 TypeError
1848 If the object that is found does not have the right type.
1850 Notes
1851 -----
1852 The actual return type is the user-specified argument ``T``, which is
1853 expected to be a subclass of `BaseObject`.
1854 """
1855 obj = self[id]
1856 if not isinstance(obj, obj_type):
1857 raise TypeError(f"Object with ID '{id}' is not of type '{obj_type.__name__}'")
1858 return obj
1860 def get_table_by_column(self, column: Column) -> Table:
1861 """Find the table that contains a column.
1863 Parameters
1864 ----------
1865 column
1866 The column to find.
1868 Returns
1869 -------
1870 `Table`
1871 The table that contains the column.
1873 Raises
1874 ------
1875 ValueError
1876 If the column is not found in any table.
1877 """
1878 for table in self.tables:
1879 if column in table.columns:
1880 return table
1881 raise ValueError(f"Column '{column.name}' not found in any table")
1883 @classmethod
1884 def from_uri(cls, resource_path: ResourcePathExpression, context: dict[str, Any] = {}) -> Schema:
1885 """Load a `Schema` from a string representing a ``ResourcePath``.
1887 Parameters
1888 ----------
1889 resource_path
1890 The ``ResourcePath`` pointing to a YAML file.
1891 context
1892 Pydantic context to be used in validation.
1894 Returns
1895 -------
1896 `str`
1897 The ID of the object.
1899 Raises
1900 ------
1901 yaml.YAMLError
1902 Raised if there is an error loading the YAML data.
1903 ValueError
1904 Raised if there is an error reading the resource.
1905 pydantic.ValidationError
1906 Raised if the schema fails validation.
1907 """
1908 try:
1909 rp_stream = ResourcePath(resource_path).read()
1910 except Exception as e:
1911 raise ValueError(f"Error reading resource from '{resource_path}' : {e}") from e
1912 yaml_data = yaml.safe_load(rp_stream)
1913 return Schema.model_validate(yaml_data, context=context)
1915 @classmethod
1916 def from_stream(cls, source: IO[str], context: dict[str, Any] = {}) -> Schema:
1917 """Load a `Schema` from a file stream which should contain YAML data.
1919 Parameters
1920 ----------
1921 source
1922 The file stream to read from.
1923 context
1924 Pydantic context to be used in validation.
1926 Returns
1927 -------
1928 `Schema`
1929 The Felis schema loaded from the stream.
1931 Raises
1932 ------
1933 yaml.YAMLError
1934 Raised if there is an error loading the YAML file.
1935 pydantic.ValidationError
1936 Raised if the schema fails validation.
1937 """
1938 logger.debug("Loading schema from: '%s'", source)
1939 yaml_data = yaml.safe_load(source)
1940 return Schema.model_validate(yaml_data, context=context)
1942 def _model_dump(self, strip_ids: bool = False, sort_columns: bool = False) -> dict[str, Any]:
1943 """Dump the schema as a dictionary with some default arguments
1944 applied.
1946 Parameters
1947 ----------
1948 strip_ids
1949 Whether to strip the IDs from the dumped data. Defaults to `False`.
1950 sort_columns
1951 Whether to sort columns alphabetically by name. Defaults to
1952 `False`.
1954 Returns
1955 -------
1956 `dict` [ `str`, `Any` ]
1957 The dumped schema data as a dictionary.
1958 """
1959 data = self.model_dump(by_alias=True, exclude_none=True, exclude_defaults=True)
1960 if strip_ids:
1961 data = _strip_ids(data)
1962 if sort_columns:
1963 for table in data.get("tables", []):
1964 table["columns"] = sorted(table.get("columns", []), key=itemgetter("name"))
1965 return data
1967 def dump_yaml(
1968 self, stream: IO[str] = sys.stdout, strip_ids: bool = False, sort_columns: bool = False
1969 ) -> None:
1970 """Pretty print the schema as YAML.
1972 Parameters
1973 ----------
1974 stream
1975 The stream to write the YAML data to.
1976 strip_ids
1977 Whether to strip the IDs from the dumped data. Defaults to `False`.
1978 sort_columns
1979 Whether to sort columns alphabetically by name. Defaults to
1980 `False`.
1981 """
1982 data = self._model_dump(strip_ids=strip_ids, sort_columns=sort_columns)
1983 yaml.safe_dump(
1984 data,
1985 stream,
1986 default_flow_style=False,
1987 sort_keys=False,
1988 )
1990 def dump_json(
1991 self, stream: IO[str] = sys.stdout, strip_ids: bool = False, sort_columns: bool = False
1992 ) -> None:
1993 """Pretty print the schema as JSON.
1995 Parameters
1996 ----------
1997 stream
1998 The stream to write the JSON data to.
1999 strip_ids
2000 Whether to strip the IDs from the dumped data. Defaults to `False`.
2001 sort_columns
2002 Whether to sort columns alphabetically by name. Defaults to
2003 `False`.
2004 """
2005 data = self._model_dump(strip_ids=strip_ids, sort_columns=sort_columns)
2006 json.dump(
2007 data,
2008 stream,
2009 indent=4,
2010 sort_keys=False,
2011 )