Coverage for python / felis / datamodel.py: 31%
719 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 08:42 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 08:42 +0000
1"""Define Pydantic data models for Felis."""
3# This file is part of felis.
4#
5# Developed for the LSST Data Management System.
6# This product includes software developed by the LSST Project
7# (https://www.lsst.org).
8# See the COPYRIGHT file at the top-level directory of this distribution
9# for details of code ownership.
10#
11# This program is free software: you can redistribute it and/or modify
12# it under the terms of the GNU General Public License as published by
13# the Free Software Foundation, either version 3 of the License, or
14# (at your option) any later version.
15#
16# This program is distributed in the hope that it will be useful,
17# but WITHOUT ANY WARRANTY; without even the implied warranty of
18# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19# GNU General Public License for more details.
20#
21# You should have received a copy of the GNU General Public License
22# along with this program. If not, see <https://www.gnu.org/licenses/>.
24from __future__ import annotations
26import json
27import logging
28import sys
29from collections.abc import Sequence
30from enum import StrEnum, auto
31from typing import IO, Annotated, Any, Generic, Literal, TypeAlias, TypeVar
33import yaml
34from astropy import units as units # type: ignore
35from astropy.io.votable import ucd # type: ignore
36from lsst.resources import ResourcePath, ResourcePathExpression
37from pydantic import (
38 BaseModel,
39 ConfigDict,
40 Field,
41 PrivateAttr,
42 ValidationError,
43 ValidationInfo,
44 field_serializer,
45 field_validator,
46 model_validator,
47)
48from pydantic_core import InitErrorDetails
50from .db._dialects import get_supported_dialects, string_to_typeengine
51from .db._sqltypes import get_type_func
52from .types import Boolean, Byte, Char, Double, FelisType, Float, Int, Long, Short, String, Text, Unicode
54logger = logging.getLogger(__name__)
56__all__ = (
57 "BaseObject",
58 "CheckConstraint",
59 "Column",
60 "ColumnOverrides",
61 "ColumnResourceRef",
62 "Constraint",
63 "DataType",
64 "ForeignKeyConstraint",
65 "Index",
66 "Resource",
67 "Schema",
68 "SchemaVersion",
69 "Table",
70 "UniqueConstraint",
71)
73CONFIG = ConfigDict(
74 populate_by_name=True, # Populate attributes by name.
75 extra="forbid", # Do not allow extra fields.
76 str_strip_whitespace=True, # Strip whitespace from string fields.
77 use_enum_values=False, # Do not use enum values during serialization.
78)
79"""Pydantic model configuration as described in:
80https://docs.pydantic.dev/2.0/api/config/#pydantic.config.ConfigDict
81"""
83DESCR_MIN_LENGTH = 3
84"""Minimum length for a description field."""
86DescriptionStr: TypeAlias = Annotated[str, Field(min_length=DESCR_MIN_LENGTH)]
87"""Type for a description, which must be three or more characters long."""
90class BaseObject(BaseModel):
91 """Base model.
93 All classes representing objects in the Felis data model should inherit
94 from this class.
95 """
97 model_config = CONFIG
98 """Pydantic model configuration."""
100 name: str
101 """Name of the database object."""
103 id: str = Field(alias="@id")
104 """Unique identifier of the database object."""
106 description: DescriptionStr | None = None
107 """Description of the database object."""
109 votable_utype: str | None = Field(None, alias="votable:utype")
110 """VOTable utype (usage-specific or unique type) of the object."""
112 @model_validator(mode="after")
113 def check_description(self, info: ValidationInfo) -> BaseObject:
114 """Check that the description is present if required.
116 Parameters
117 ----------
118 info
119 Validation context used to determine if the check is enabled.
121 Returns
122 -------
123 `BaseObject`
124 The object being validated.
125 """
126 context = info.context
127 if not context or not context.get("check_description", False):
128 return self
129 if self.description is None or self.description == "":
130 raise ValueError("Description is required and must be non-empty")
131 if len(self.description) < DESCR_MIN_LENGTH:
132 raise ValueError(f"Description must be at least {DESCR_MIN_LENGTH} characters long")
133 return self
136class DataType(StrEnum):
137 """``Enum`` representing the data types supported by Felis."""
139 boolean = auto()
140 byte = auto()
141 short = auto()
142 int = auto()
143 long = auto()
144 float = auto()
145 double = auto()
146 char = auto()
147 string = auto()
148 unicode = auto()
149 text = auto()
150 binary = auto()
151 timestamp = auto()
154def validate_ivoa_ucd(ivoa_ucd: str) -> str:
155 """Validate IVOA UCD values.
157 Parameters
158 ----------
159 ivoa_ucd
160 IVOA UCD value to check.
162 Returns
163 -------
164 `str`
165 The IVOA UCD value if it is valid.
167 Raises
168 ------
169 ValueError
170 If the IVOA UCD value is invalid.
171 """
172 if ivoa_ucd is not None:
173 try:
174 ucd.parse_ucd(ivoa_ucd, check_controlled_vocabulary=True, has_colon=";" in ivoa_ucd)
175 except ValueError as e:
176 raise ValueError(f"Invalid IVOA UCD: {e}")
177 return ivoa_ucd
180class Column(BaseObject):
181 """Column model."""
183 datatype: DataType
184 """Datatype of the column."""
186 length: int | None = Field(None, gt=0)
187 """Length of the column."""
189 precision: int | None = Field(None, ge=0)
190 """The numerical precision of the column.
192 For timestamps, this is the number of fractional digits retained in the
193 seconds field.
194 """
196 nullable: bool = True
197 """Whether the column can be ``NULL``."""
199 value: str | int | float | bool | None = None
200 """Default value of the column."""
202 autoincrement: bool | None = None
203 """Whether the column is autoincremented."""
205 ivoa_ucd: str | None = Field(None, alias="ivoa:ucd")
206 """IVOA UCD of the column."""
208 fits_tunit: str | None = Field(None, alias="fits:tunit")
209 """FITS TUNIT of the column."""
211 ivoa_unit: str | None = Field(None, alias="ivoa:unit")
212 """IVOA unit of the column."""
214 tap_column_index: int | None = Field(None, alias="tap:column_index")
215 """TAP_SCHEMA column index of the column."""
217 tap_principal: int | None = Field(0, alias="tap:principal", ge=0, le=1)
218 """Whether this is a TAP_SCHEMA principal column."""
220 votable_arraysize: int | str | None = Field(None, alias="votable:arraysize")
221 """VOTable arraysize of the column."""
223 tap_std: int | None = Field(0, alias="tap:std", ge=0, le=1)
224 """TAP_SCHEMA indication that this column is defined by an IVOA standard.
225 """
227 votable_xtype: str | None = Field(None, alias="votable:xtype")
228 """VOTable xtype (extended type) of the column."""
230 votable_datatype: str | None = Field(None, alias="votable:datatype")
231 """VOTable datatype of the column."""
233 mysql_datatype: str | None = Field(None, alias="mysql:datatype")
234 """MySQL datatype override on the column."""
236 postgresql_datatype: str | None = Field(None, alias="postgresql:datatype")
237 """PostgreSQL datatype override on the column."""
239 _is_resource_ref: bool = PrivateAttr(False)
240 """Whether this column is a resource reference column."""
242 @model_validator(mode="after")
243 def check_value(self) -> Column:
244 """Check that the default value is valid.
246 Returns
247 -------
248 `Column`
249 The column being validated.
250 """
251 if (value := self.value) is not None:
252 if value is not None and self.autoincrement is True:
253 raise ValueError("Column cannot have both a default value and be autoincremented")
254 felis_type = FelisType.felis_type(self.datatype)
255 if felis_type.is_numeric:
256 if felis_type in (Byte, Short, Int, Long) and not isinstance(value, int):
257 raise ValueError("Default value must be an int for integer type columns")
258 elif felis_type in (Float, Double) and not isinstance(value, float):
259 raise ValueError("Default value must be a decimal number for float and double columns")
260 elif felis_type in (String, Char, Unicode, Text):
261 if not isinstance(value, str):
262 raise ValueError("Default value must be a string for string columns")
263 if not len(value):
264 raise ValueError("Default value must be a non-empty string for string columns")
265 elif felis_type is Boolean and not isinstance(value, bool):
266 raise ValueError("Default value must be a boolean for boolean columns")
267 return self
269 @field_validator("ivoa_ucd")
270 @classmethod
271 def check_ivoa_ucd(cls, ivoa_ucd: str) -> str:
272 """Check that IVOA UCD values are valid.
274 Parameters
275 ----------
276 ivoa_ucd
277 IVOA UCD value to check.
279 Returns
280 -------
281 `str`
282 The IVOA UCD value if it is valid.
283 """
284 return validate_ivoa_ucd(ivoa_ucd)
286 @model_validator(mode="after")
287 def check_units(self) -> Column:
288 """Check that the ``fits:tunit`` or ``ivoa:unit`` field has valid
289 units according to astropy. Only one may be provided.
291 Returns
292 -------
293 `Column`
294 The column being validated.
296 Raises
297 ------
298 ValueError
299 Raised if both FITS and IVOA units are provided, or if the unit is
300 invalid.
301 """
302 fits_unit = self.fits_tunit
303 ivoa_unit = self.ivoa_unit
305 if fits_unit and ivoa_unit:
306 raise ValueError("Column cannot have both FITS and IVOA units")
307 unit = fits_unit or ivoa_unit
309 if unit is not None:
310 try:
311 units.Unit(unit)
312 except ValueError as e:
313 raise ValueError(f"Invalid unit: {e}")
315 return self
317 @model_validator(mode="before")
318 @classmethod
319 def check_length(cls, values: dict[str, Any]) -> dict[str, Any]:
320 """Check that a valid length is provided for sized types.
322 Parameters
323 ----------
324 values
325 Values of the column.
327 Returns
328 -------
329 `dict` [ `str`, `Any` ]
330 The values of the column.
332 Raises
333 ------
334 ValueError
335 Raised if a length is not provided for a sized type.
336 """
337 datatype = values.get("datatype")
338 if datatype is None:
339 # Skip this validation if datatype is not provided
340 return values
341 length = values.get("length")
342 felis_type = FelisType.felis_type(datatype)
343 if felis_type.is_sized and length is None:
344 raise ValueError(
345 f"Length must be provided for type '{datatype}'"
346 + (f" in column '{values['@id']}'" if "@id" in values else "")
347 )
348 elif not felis_type.is_sized and length is not None:
349 logger.warning(
350 f"The datatype '{datatype}' does not support a specified length"
351 + (f" in column '{values['@id']}'" if "@id" in values else "")
352 )
353 return values
355 @model_validator(mode="after")
356 def check_redundant_datatypes(self, info: ValidationInfo) -> Column:
357 """Check for redundant datatypes on columns.
359 Parameters
360 ----------
361 info
362 Validation context used to determine if the check is enabled.
364 Returns
365 -------
366 `Column`
367 The column being validated.
369 Raises
370 ------
371 ValueError
372 Raised if a datatype override is redundant.
373 """
374 context = info.context
375 if not context or not context.get("check_redundant_datatypes", False):
376 return self
377 if all(
378 getattr(self, f"{dialect}:datatype", None) is not None
379 for dialect in get_supported_dialects().keys()
380 ):
381 return self
383 datatype = self.datatype
384 length: int | None = self.length or None
386 datatype_func = get_type_func(datatype)
387 felis_type = FelisType.felis_type(datatype)
388 if felis_type.is_sized:
389 datatype_obj = datatype_func(length)
390 else:
391 datatype_obj = datatype_func()
393 for dialect_name, dialect in get_supported_dialects().items():
394 db_annotation = f"{dialect_name}_datatype"
395 if datatype_string := self.model_dump().get(db_annotation):
396 db_datatype_obj = string_to_typeengine(datatype_string, dialect, length)
397 if datatype_obj.compile(dialect) == db_datatype_obj.compile(dialect):
398 raise ValueError(
399 "'{}: {}' is a redundant override of 'datatype: {}' in column '{}'{}".format(
400 db_annotation,
401 datatype_string,
402 self.datatype,
403 self.id,
404 "" if length is None else f" with length {length}",
405 )
406 )
407 else:
408 logger.debug(
409 f"Type override of 'datatype: {self.datatype}' "
410 f"with '{db_annotation}: {datatype_string}' in column '{self.id}' "
411 f"compiled to '{datatype_obj.compile(dialect)}' and "
412 f"'{db_datatype_obj.compile(dialect)}'"
413 )
414 return self
416 @model_validator(mode="after")
417 def check_precision(self) -> Column:
418 """Check that precision is only valid for timestamp columns.
420 Returns
421 -------
422 `Column`
423 The column being validated.
424 """
425 if self.precision is not None and self.datatype != "timestamp":
426 raise ValueError("Precision is only valid for timestamp columns")
427 return self
429 @model_validator(mode="before")
430 @classmethod
431 def check_votable_arraysize(cls, values: dict[str, Any], info: ValidationInfo) -> dict[str, Any]:
432 """Set the default value for the ``votable_arraysize`` field, which
433 corresponds to ``arraysize`` in the IVOA VOTable standard.
435 Parameters
436 ----------
437 values
438 Values of the column.
439 info
440 Validation context used to determine if the check is enabled.
442 Returns
443 -------
444 `dict` [ `str`, `Any` ]
445 The values of the column.
447 Notes
448 -----
449 Following the IVOA VOTable standard, an ``arraysize`` of 1 should not
450 be used.
451 """
452 if values.get("name", None) is None or values.get("datatype", None) is None:
453 # Skip bad column data that will not validate
454 return values
455 context = info.context if info.context else {}
456 arraysize = values.get("votable:arraysize", None)
457 if arraysize is None:
458 length = values.get("length", None)
459 datatype = values.get("datatype")
460 if length is not None and length > 1:
461 # Following the IVOA standard, arraysize of 1 is disallowed
462 if datatype == "char":
463 arraysize = str(length)
464 elif datatype in ("string", "unicode", "binary"):
465 if context.get("force_unbounded_arraysize", False):
466 arraysize = "*"
467 logger.debug(
468 f"Forced VOTable's 'arraysize' to '*' on column '{values['name']}' with datatype "
469 + f"'{values['datatype']}' and length '{length}'"
470 )
471 else:
472 arraysize = f"{length}*"
473 elif datatype in ("timestamp", "text"):
474 arraysize = "*"
475 if arraysize is not None:
476 values["votable:arraysize"] = arraysize
477 logger.debug(
478 f"Set default 'votable:arraysize' to '{arraysize}' on column '{values['name']}'"
479 + f" with datatype '{values['datatype']}' and length '{values.get('length', None)}'"
480 )
481 else:
482 logger.debug(f"Using existing 'votable:arraysize' of '{arraysize}' on column '{values['name']}'")
483 if isinstance(values["votable:arraysize"], int):
484 logger.warning(
485 f"Usage of an integer value for 'votable:arraysize' in column '{values['name']}' is "
486 + "deprecated"
487 )
488 values["votable:arraysize"] = str(arraysize)
489 return values
491 @field_serializer("datatype")
492 def serialize_datatype(self, value: DataType) -> str:
493 """Convert `DataType` to string when serializing to JSON/YAML.
495 Parameters
496 ----------
497 value
498 The `DataType` value to serialize.
500 Returns
501 -------
502 `str`
503 The serialized `DataType` value.
504 """
505 return str(value)
507 @field_validator("datatype", mode="before")
508 @classmethod
509 def deserialize_datatype(cls, value: str) -> DataType:
510 """Convert string back into `DataType` when loading from JSON/YAML.
512 Parameters
513 ----------
514 value
515 The string value to deserialize.
517 Returns
518 -------
519 `DataType`
520 The deserialized `DataType` value.
521 """
522 return DataType(value)
524 @model_validator(mode="after")
525 def check_votable_xtype(self) -> Column:
526 """Set the default value for the ``votable_xtype`` field, which
527 corresponds to an Extended Datatype or ``xtype`` in the IVOA VOTable
528 standard.
530 Returns
531 -------
532 `Column`
533 The column being validated.
535 Notes
536 -----
537 This is currently only set automatically for the Felis ``timestamp``
538 datatype.
539 """
540 if self.datatype == DataType.timestamp and self.votable_xtype is None:
541 self.votable_xtype = "timestamp"
542 return self
544 def _update_from_overrides(self, overrides: ColumnOverrides) -> None:
545 """Update the column attributes from the given overrides.
547 Parameters
548 ----------
549 overrides
550 The column overrides to apply or `None` to skip applying overrides.
552 Notes
553 -----
554 Using ``model_fields_set`` allows updating only the fields that are
555 explicitly set in the `overrides` object. This prevents overwriting
556 existing column attributes which were not explicitly provided.
557 """
558 if overrides.model_fields_set:
559 logger.debug("Applying overrides to column '%s': %s", self.id, overrides.model_fields_set)
560 for field in overrides.model_fields_set:
561 setattr(self, field, getattr(overrides, field))
564class Constraint(BaseObject):
565 """Table constraint model."""
567 deferrable: bool = False
568 """Whether this constraint will be declared as deferrable."""
570 initially: Literal["IMMEDIATE", "DEFERRED"] | None = None
571 """Value for ``INITIALLY`` clause; only used if `deferrable` is
572 `True`."""
574 @model_validator(mode="after")
575 def check_deferrable(self) -> Constraint:
576 """Check that the ``INITIALLY`` clause is only used if `deferrable` is
577 `True`.
579 Returns
580 -------
581 `Constraint`
582 The constraint being validated.
583 """
584 if self.initially is not None and not self.deferrable:
585 raise ValueError("INITIALLY clause can only be used if deferrable is True")
586 return self
589class CheckConstraint(Constraint):
590 """Table check constraint model."""
592 type: Literal["Check"] = Field("Check", alias="@type")
593 """Type of the constraint."""
595 expression: str
596 """Expression for the check constraint."""
598 @field_serializer("type")
599 def serialize_type(self, value: str) -> str:
600 """Ensure '@type' is included in serialized output.
602 Parameters
603 ----------
604 value
605 The value to serialize.
607 Returns
608 -------
609 `str`
610 The serialized value.
611 """
612 return value
615class UniqueConstraint(Constraint):
616 """Table unique constraint model."""
618 type: Literal["Unique"] = Field("Unique", alias="@type")
619 """Type of the constraint."""
621 columns: list[str]
622 """Columns in the unique constraint."""
624 @field_serializer("type")
625 def serialize_type(self, value: str) -> str:
626 """Ensure '@type' is included in serialized output.
628 Parameters
629 ----------
630 value
631 The value to serialize.
633 Returns
634 -------
635 `str`
636 The serialized value.
637 """
638 return value
641class ForeignKeyConstraint(Constraint):
642 """Table foreign key constraint model.
644 This constraint is used to define a foreign key relationship between two
645 tables in the schema. There must be at least one column in the
646 `columns` list, and at least one column in the `referenced_columns` list
647 or a validation error will be raised.
649 Notes
650 -----
651 These relationships will be reflected in the TAP_SCHEMA ``keys`` and
652 ``key_columns`` data.
653 """
655 type: Literal["ForeignKey"] = Field("ForeignKey", alias="@type")
656 """Type of the constraint."""
658 columns: list[str] = Field(min_length=1)
659 """The columns comprising the foreign key."""
661 referenced_columns: list[str] = Field(alias="referencedColumns", min_length=1)
662 """The columns referenced by the foreign key."""
664 on_delete: Literal["CASCADE", "SET NULL", "SET DEFAULT", "RESTRICT", "NO ACTION"] | None = None
665 """Action to take when the referenced row is deleted."""
667 on_update: Literal["CASCADE", "SET NULL", "SET DEFAULT", "RESTRICT", "NO ACTION"] | None = None
668 """Action to take when the referenced row is updated."""
670 @field_serializer("type")
671 def serialize_type(self, value: str) -> str:
672 """Ensure '@type' is included in serialized output.
674 Parameters
675 ----------
676 value
677 The value to serialize.
679 Returns
680 -------
681 `str`
682 The serialized value.
683 """
684 return value
686 @model_validator(mode="after")
687 def check_column_lengths(self) -> ForeignKeyConstraint:
688 """Check that the `columns` and `referenced_columns` lists have the
689 same length.
691 Returns
692 -------
693 `ForeignKeyConstraint`
694 The foreign key constraint being validated.
696 Raises
697 ------
698 ValueError
699 Raised if the `columns` and `referenced_columns` lists do not have
700 the same length.
701 """
702 if len(self.columns) != len(self.referenced_columns):
703 raise ValueError(
704 "Columns and referencedColumns must have the same length for a ForeignKey constraint"
705 )
706 return self
709_ConstraintType = Annotated[
710 CheckConstraint | ForeignKeyConstraint | UniqueConstraint, Field(discriminator="type")
711]
712"""Type alias for a constraint type."""
715class Index(BaseObject):
716 """Table index model.
718 An index can be defined on either columns or expressions, but not both.
719 """
721 columns: list[str] | None = None
722 """Columns in the index."""
724 expressions: list[str] | None = None
725 """Expressions in the index."""
727 @model_validator(mode="before")
728 @classmethod
729 def check_columns_or_expressions(cls, values: dict[str, Any]) -> dict[str, Any]:
730 """Check that columns or expressions are specified, but not both.
732 Parameters
733 ----------
734 values
735 Values of the index.
737 Returns
738 -------
739 `dict` [ `str`, `Any` ]
740 The values of the index.
742 Raises
743 ------
744 ValueError
745 Raised if both columns and expressions are specified, or if neither
746 are specified.
747 """
748 if "columns" in values and "expressions" in values:
749 raise ValueError("Defining columns and expressions is not valid")
750 elif "columns" not in values and "expressions" not in values:
751 raise ValueError("Must define columns or expressions")
752 return values
755ColumnRef: TypeAlias = str
756"""Type alias for a column reference."""
759class ColumnGroup(BaseObject):
760 """Column group model."""
762 columns: list[ColumnRef | Column] = Field(..., min_length=1)
763 """Columns in the group."""
765 ivoa_ucd: str | None = Field(None, alias="ivoa:ucd")
766 """IVOA UCD of the column."""
768 table: Table | None = Field(None, exclude=True)
769 """Reference to the parent table."""
771 @field_validator("ivoa_ucd")
772 @classmethod
773 def check_ivoa_ucd(cls, ivoa_ucd: str) -> str:
774 """Check that IVOA UCD values are valid.
776 Parameters
777 ----------
778 ivoa_ucd
779 IVOA UCD value to check.
781 Returns
782 -------
783 `str`
784 The IVOA UCD value if it is valid.
785 """
786 return validate_ivoa_ucd(ivoa_ucd)
788 @model_validator(mode="after")
789 def check_unique_columns(self) -> ColumnGroup:
790 """Check that the columns list contains unique items.
792 Returns
793 -------
794 `ColumnGroup`
795 The column group being validated.
796 """
797 column_ids = [col if isinstance(col, str) else col.id for col in self.columns]
798 if len(column_ids) != len(set(column_ids)):
799 raise ValueError("Columns in the group must be unique")
800 return self
802 def _dereference_columns(self) -> None:
803 """Dereference ColumnRef to Column objects."""
804 if self.table is None:
805 raise ValueError("ColumnGroup must have a reference to its parent table")
807 dereferenced_columns: list[ColumnRef | Column] = []
808 for col in self.columns:
809 if isinstance(col, str):
810 # Dereference ColumnRef to Column object
811 try:
812 col_obj = self.table._find_column_by_id(col)
813 except KeyError as e:
814 raise ValueError(f"Column '{col}' not found in table '{self.table.name}'") from e
815 dereferenced_columns.append(col_obj)
816 else:
817 dereferenced_columns.append(col)
819 self.columns = dereferenced_columns
821 @field_serializer("columns")
822 def serialize_columns(self, columns: list[ColumnRef | Column]) -> list[str]:
823 """Serialize columns as their IDs.
825 Parameters
826 ----------
827 columns
828 The columns to serialize.
830 Returns
831 -------
832 `list` [ `str` ]
833 The serialized column IDs.
834 """
835 return [col if isinstance(col, str) else col.id for col in columns]
838class ColumnOverrides(BaseModel):
839 """Allowed overrides for a referenced column.
841 Notes
842 -----
843 All of these fields are optional. Values of None may be explicitly set to
844 override the corresponding attribute in the referenced column but only
845 for certain fields (see validation in `_check_non_nullable_overrides`).
846 """
848 model_config = CONFIG.copy()
850 datatype: DataType | None = None
851 """New datatype for the column."""
853 length: int | None = None
854 """New length for the column."""
856 description: str | None = None
857 """New description for the column."""
859 nullable: bool | None = None
860 """New nullable flag for the column."""
862 tap_principal: int | None = Field(default=None, alias="tap:principal")
863 """Override for the TAP_SCHEMA 'principal' flag."""
865 tap_column_index: int | None = Field(default=None, alias="tap:column_index")
866 """Override for the TAP_SCHEMA column index."""
868 @model_validator(mode="before")
869 @classmethod
870 def _check_non_nullable_overrides(cls, data: Any) -> Any:
871 """Check that certain fields are not overridden to null."""
872 if not isinstance(data, dict):
873 return data
874 non_nullable_fields = ("datatype", "length", "nullable", "tap_principal")
875 for name in non_nullable_fields:
876 if name in data and data[name] is None:
877 raise ValueError(f"The '{name}' field cannot be overridden to null")
878 return data
880 @field_serializer("datatype")
881 def serialize_datatype(self, value: DataType | None) -> str | None:
882 """Convert `DataType` to string when serializing to JSON/YAML.
884 Parameters
885 ----------
886 value
887 The `DataType` value to serialize, or None.
889 Returns
890 -------
891 `str` | None
892 The serialized `DataType` value, or None if the input was None.
893 """
894 if value is None:
895 return None
896 return str(value)
898 @field_validator("datatype", mode="before")
899 @classmethod
900 def deserialize_datatype(cls, value: str | None) -> DataType | None:
901 """Convert string back into `DataType` when loading from JSON/YAML.
903 Parameters
904 ----------
905 value
906 The string value to deserialize, or None.
908 Returns
909 -------
910 `DataType` | None
911 The deserialized `DataType` value, or None if the input was None.
912 """
913 if value is None:
914 return None
915 return DataType(value)
918class ColumnResourceRef(BaseModel):
919 """A column which is dervived from an external resource."""
921 ref_name: str | None = None
922 """Name of the referenced column in the resource
923 (if different from the key)."""
925 overrides: ColumnOverrides | None = None
926 """Optional overrides of the referenced column's attributes."""
929# Type aliases for the nested mapping structure of referenced columns
930ResourceColumnMap: TypeAlias = dict[str, ColumnResourceRef | None]
931ResourceTableMap: TypeAlias = dict[str, ResourceColumnMap]
932ResourceMap: TypeAlias = dict[str, ResourceTableMap]
935class Table(BaseObject):
936 """Table model."""
938 primary_key: str | list[str] | None = Field(None, alias="primaryKey")
939 """Primary key of the table."""
941 tap_table_index: int | None = Field(None, alias="tap:table_index")
942 """IVOA TAP_SCHEMA table index of the table."""
944 mysql_engine: str | None = Field("MyISAM", alias="mysql:engine")
945 """MySQL engine to use for the table."""
947 mysql_charset: str | None = Field(None, alias="mysql:charset")
948 """MySQL charset to use for the table."""
950 columns: list[Column] = Field(default_factory=list)
951 """Columns in the table."""
953 column_refs: ResourceMap = Field(default_factory=dict, alias="columnRefs")
954 """Referenced columns from external resources."""
956 column_groups: list[ColumnGroup] = Field(default_factory=list, alias="columnGroups")
957 """Column groups in the table."""
959 constraints: list[_ConstraintType] = Field(default_factory=list)
960 """Constraints on the table."""
962 indexes: list[Index] = Field(default_factory=list)
963 """Indexes on the table."""
965 @field_validator("columns", mode="after")
966 @classmethod
967 def check_unique_column_names(cls, columns: list[Column]) -> list[Column]:
968 """Check that column names are unique.
970 Parameters
971 ----------
972 columns
973 The columns to check.
975 Returns
976 -------
977 `list` [ `Column` ]
978 The columns if they are unique.
980 Raises
981 ------
982 ValueError
983 Raised if column names are not unique.
984 """
985 if len(columns) != len(set(column.name for column in columns)):
986 raise ValueError("Column names must be unique")
987 return columns
989 @model_validator(mode="after")
990 def check_tap_table_index(self, info: ValidationInfo) -> Table:
991 """Check that the table has a TAP table index.
993 Parameters
994 ----------
995 info
996 Validation context used to determine if the check is enabled.
998 Returns
999 -------
1000 `Table`
1001 The table being validated.
1003 Raises
1004 ------
1005 ValueError
1006 Raised If the table is missing a TAP table index.
1007 """
1008 context = info.context
1009 if not context or not context.get("check_tap_table_indexes", False):
1010 return self
1011 if self.tap_table_index is None:
1012 raise ValueError("Table is missing a TAP table index")
1013 return self
1015 @model_validator(mode="after")
1016 def check_tap_principal(self, info: ValidationInfo) -> Table:
1017 """Check that at least one column is flagged as 'principal' for TAP
1018 purposes.
1020 Parameters
1021 ----------
1022 info
1023 Validation context used to determine if the check is enabled.
1025 Returns
1026 -------
1027 `Table`
1028 The table being validated.
1030 Raises
1031 ------
1032 ValueError
1033 Raised if the table is missing a column flagged as 'principal'.
1034 """
1035 context = info.context
1036 if not context or not context.get("check_tap_principal", False):
1037 return self
1038 for col in self.columns:
1039 if col.tap_principal == 1:
1040 return self
1041 raise ValueError(f"Table '{self.name}' is missing at least one column designated as 'tap:principal'")
1043 def _find_column_by_id(self, id: str) -> Column:
1044 """Find a column by ID.
1046 Parameters
1047 ----------
1048 id
1049 The ID of the column to find.
1051 Returns
1052 -------
1053 `Column`
1054 The column with the given ID.
1056 Raises
1057 ------
1058 ValueError
1059 Raised if the column is not found.
1060 """
1061 for column in self.columns:
1062 if column.id == id:
1063 return column
1064 raise KeyError(f"Column '{id}' not found in table '{self.name}'")
1066 def _find_column_by_name(self, name: str) -> Column:
1067 for column in self.columns:
1068 if column.name == name:
1069 return column
1070 raise KeyError(f"Column '{name}' not found in table '{self.name}'")
1072 @model_validator(mode="after")
1073 def dereference_column_groups(self: Table) -> Table:
1074 """Dereference columns in column groups.
1076 Returns
1077 -------
1078 `Table`
1079 The table with dereferenced column groups.
1080 """
1081 for group in self.column_groups:
1082 group.table = self
1083 group._dereference_columns()
1084 return self
1086 @field_serializer("columns")
1087 def _serialize_columns(self, columns: list[Column]) -> list[dict[str, Any]]:
1088 """Serialize only non-resource columns."""
1089 return [
1090 col.model_dump(
1091 by_alias=True,
1092 exclude_none=True,
1093 exclude_defaults=True,
1094 )
1095 for col in columns
1096 if not col._is_resource_ref
1097 ]
1100class SchemaVersion(BaseModel):
1101 """Schema version model."""
1103 current: str
1104 """The current version of the schema."""
1106 compatible: list[str] = Field(default_factory=list)
1107 """The compatible versions of the schema."""
1109 read_compatible: list[str] = Field(default_factory=list)
1110 """The read compatible versions of the schema."""
1113class SchemaIdVisitor:
1114 """Visit a schema and build the map of IDs to objects.
1116 Notes
1117 -----
1118 Duplicates are added to a set when they are encountered, which can be
1119 accessed via the ``duplicates`` attribute. The presence of duplicates will
1120 not throw an error. Only the first object with a given ID will be added to
1121 the map, but this should not matter, since a ``ValidationError`` will be
1122 thrown by the ``model_validator`` method if any duplicates are found in the
1123 schema.
1124 """
1126 def __init__(self) -> None:
1127 """Create a new SchemaVisitor."""
1128 self.schema: Schema | None = None
1129 self.duplicates: set[str] = set()
1131 def add(self, obj: BaseObject) -> None:
1132 """Add an object to the ID map.
1134 Parameters
1135 ----------
1136 obj
1137 The object to add to the ID map.
1138 """
1139 if hasattr(obj, "id"):
1140 obj_id = getattr(obj, "id")
1141 if self.schema is not None:
1142 if obj_id in self.schema._id_map:
1143 self.duplicates.add(obj_id)
1144 else:
1145 self.schema._id_map[obj_id] = obj
1147 def visit_schema(self, schema: Schema) -> None:
1148 """Visit the objects in a schema and build the ID map.
1150 Parameters
1151 ----------
1152 schema
1153 The schema object to visit.
1155 Notes
1156 -----
1157 This will set an internal variable pointing to the schema object.
1158 """
1159 self.schema = schema
1160 self.duplicates.clear()
1161 self.add(self.schema)
1162 for table in self.schema.tables:
1163 self.visit_table(table)
1165 def visit_table(self, table: Table) -> None:
1166 """Visit a table object.
1168 Parameters
1169 ----------
1170 table
1171 The table object to visit.
1172 """
1173 self.add(table)
1174 for column in table.columns:
1175 self.visit_column(column)
1176 for constraint in table.constraints:
1177 self.visit_constraint(constraint)
1179 def visit_column(self, column: Column) -> None:
1180 """Visit a column object.
1182 Parameters
1183 ----------
1184 column
1185 The column object to visit.
1186 """
1187 self.add(column)
1189 def visit_constraint(self, constraint: Constraint) -> None:
1190 """Visit a constraint object.
1192 Parameters
1193 ----------
1194 constraint
1195 The constraint object to visit.
1196 """
1197 self.add(constraint)
1200T = TypeVar("T", bound=BaseObject)
1203def _strip_ids(data: Any) -> Any:
1204 """Recursively strip '@id' fields from a dictionary or list.
1206 Parameters
1207 ----------
1208 data
1209 The data to strip IDs from, which can be a dictionary, list, or any
1210 other type. Other types will be returned unchanged.
1211 """
1212 if isinstance(data, dict):
1213 data.pop("@id", None)
1214 for k, v in data.items():
1215 data[k] = _strip_ids(v)
1216 return data
1217 elif isinstance(data, list):
1218 return [_strip_ids(item) for item in data]
1219 else:
1220 return data
1223def _append_error(
1224 errors: list[InitErrorDetails],
1225 loc: tuple,
1226 input_value: Any,
1227 error_message: str,
1228 error_type: str = "value_error",
1229) -> None:
1230 """Append an error to the errors list.
1232 Parameters
1233 ----------
1234 errors : list[InitErrorDetails]
1235 The list of errors to append to.
1236 loc : tuple
1237 The location of the error in the schema.
1238 input_value : Any
1239 The input value that caused the error.
1240 error_message : str
1241 The error message to include in the context.
1242 """
1243 errors.append(
1244 {
1245 "type": error_type,
1246 "loc": loc,
1247 "input": input_value,
1248 "ctx": {"error": error_message},
1249 }
1250 )
1253class Resource(BaseModel):
1254 """A resource definition referencing an external schema."""
1256 uri: str = Field(..., description="Resource URI or path")
1257 """URI of the schema resource which may be a local path, ``resource://``,
1258 or remote URL."""
1261class Schema(BaseObject, Generic[T]):
1262 """Database schema model.
1264 This represents a database schema, which contains one or more tables.
1265 """
1267 version: SchemaVersion | str | None = None
1268 """The version of the schema."""
1270 tables: Sequence[Table]
1271 """The tables in the schema."""
1273 resources: dict[str, Resource] = Field(default_factory=dict)
1274 """External resources referenced by this schema."""
1276 _id_map: dict[str, Any] = PrivateAttr(default_factory=dict)
1277 """Map of IDs to objects."""
1279 _resource_map: dict[str, Schema] = PrivateAttr(default_factory=dict)
1280 """Map of resource names to loaded schemas."""
1282 @model_validator(mode="after")
1283 def _load_resources(self: Schema, info: ValidationInfo) -> Schema:
1284 """Load external resources referenced by this schema into an internal
1285 mapping of resource names to their `Schema` objects.
1287 Returns
1288 -------
1289 `Schema`
1290 The schema being validated.
1292 Raises
1293 ------
1294 ValueError
1295 Raised if a resource cannot be loaded.
1296 """
1297 if info.context:
1298 context = info.context.copy()
1299 # Ignore this flag for loading the resources themselves
1300 context.pop("dereference_resources", None)
1301 else:
1302 context = {}
1304 for resource_name, resource in self.resources.items():
1305 uri = resource.uri
1306 try:
1307 loaded_schema = Schema.from_uri(uri, context=context)
1308 self._resource_map[resource_name] = loaded_schema
1309 logger.debug(f"Loaded resource '{resource_name}' from URI '{uri}'")
1310 except Exception as e:
1311 raise ValueError(f"Failed to load resource '{resource_name}' from URI '{uri}': {e}") from e
1312 return self
1314 def _find_table_by_name(self, name: str) -> Table:
1315 """Find a table by name.
1317 Parameters
1318 ----------
1319 name
1320 The name of the table to find.
1322 Returns
1323 -------
1324 `Table`
1325 The table with the given name.
1327 Raises
1328 ------
1329 KeyError
1330 Raised if the table is not found.
1331 """
1332 for table in self.tables:
1333 if table.name == name:
1334 return table
1335 raise KeyError(f"Table '{name}' not found in schema '{self.name}'")
1337 @model_validator(mode="after")
1338 def _dereference_resource_columns(self: Schema, info: ValidationInfo) -> Schema:
1339 """Dereference columns from external resources and add them to the
1340 tables in this schema.
1341 """
1342 context = info.context
1343 column_ref_index_increment: int | None = None
1344 dereference_resources = False
1345 if context is not None:
1346 dereference_resources = context.get("dereference_resources", False)
1347 column_ref_index_increment = context.get("column_ref_index_increment", None)
1349 for table in self.tables:
1350 if column_refs := table.column_refs:
1351 for resource_name, tables in column_refs.items():
1352 resource_schema = self._resource_map.get(resource_name)
1353 if resource_schema is None:
1354 raise ValueError(f"Schema resource '{resource_name}' was not found in resources")
1355 self._process_column_refs(
1356 table,
1357 tables,
1358 resource_schema,
1359 dereference_resources,
1360 column_ref_index_increment,
1361 )
1362 if dereference_resources and len(table.column_refs) > 0:
1363 # Clear column refs in table if fully dereferencing
1364 logger.debug(
1365 f"Clearing columnRefs in table '{table.name}' after dereferencing resource columns"
1366 )
1367 table.column_refs = {}
1368 return self
1370 @classmethod
1371 def _process_column_refs(
1372 cls,
1373 table: Table,
1374 ref_tables: ResourceTableMap,
1375 resource_schema: Schema,
1376 dereference_resources: bool = False,
1377 column_ref_index_increment: int | None = None,
1378 ) -> None:
1379 """Process column references from an external resource and add them
1380 to the given table as columns.
1381 """
1382 current_column_index = column_ref_index_increment if column_ref_index_increment is not None else -1
1384 for table_name, columns in ref_tables.items():
1385 try:
1386 resource_table = resource_schema._find_table_by_name(table_name)
1387 except KeyError as e:
1388 raise ValueError(
1389 f"Table '{table_name}' not found in resource '{resource_schema.name}'"
1390 ) from e
1391 for local_column_name, column_ref in columns.items():
1392 if column_ref is not None and column_ref.ref_name is not None:
1393 # Use specified ref_name
1394 ref_column_name = column_ref.ref_name
1395 else:
1396 # Use the local column name if no ref_name
1397 # specified
1398 ref_column_name = local_column_name
1400 # Check if referenced column exists in resource
1401 try:
1402 base_column = resource_table._find_column_by_name(ref_column_name)
1403 except KeyError:
1404 # The ref_name is specified but column is not
1405 # found
1406 if column_ref is not None and column_ref.ref_name is not None:
1407 raise ValueError(
1408 f"Column '{ref_column_name}' not found in table '{table_name}' "
1409 f"from resource '{resource_schema.name}'"
1410 )
1411 # The ref_name is not specified and the local
1412 # column name is not found
1413 raise ValueError(
1414 f"Column '{local_column_name}' not found in table '{table_name}' "
1415 f"from resource '{resource_schema.name}' and no ref_name provided"
1416 )
1418 # Create a copy of the base column
1419 column_copy = base_column.model_copy()
1421 # Set the local name (key from the mapping)
1422 column_copy.name = local_column_name
1424 if not dereference_resources:
1425 # Flag the column as a resource reference so it will not be
1426 # written out during serialization
1427 column_copy._is_resource_ref = True
1429 # Apply overrides to the referenced column definition
1430 overrides = column_ref.overrides if column_ref is not None else None
1431 if overrides is not None:
1432 column_copy._update_from_overrides(overrides)
1434 # Manually set the ID of the copied column as ID generation has
1435 # already occurred by now
1436 column_copy.id = f"{table.id}.{local_column_name}"
1438 # Apply automatic assignment of 'tap:column_index', if enabled
1439 if column_ref_index_increment is not None:
1440 if (not overrides) or (overrides.tap_column_index is None):
1441 column_copy.tap_column_index = current_column_index
1442 current_column_index += column_ref_index_increment
1443 logger.debug(
1444 f"Automatically assigned 'tap:column_index' {column_copy.tap_column_index} to "
1445 f"column '{local_column_name}' in table '{table_name}' from resource "
1446 f"'{resource_schema.name}'"
1447 )
1448 else:
1449 # Skip automatic assignment of 'tap:column_index' if it
1450 # is already overridden
1451 logger.debug(
1452 f"Skipping automatic assignment of 'tap:column_index' for column "
1453 f"'{local_column_name}' in table '{table_name}' from resource "
1454 f"'{resource_schema.name}' as it is already overridden to "
1455 f"{column_copy.tap_column_index}"
1456 )
1457 table.columns.append(column_copy)
1458 logger.debug(
1459 f"Dereferenced column '{local_column_name}' from table '{table_name}' "
1460 f"in resource '{resource_schema.name}' into table '{table.name}'"
1461 )
1463 @model_validator(mode="before")
1464 @classmethod
1465 def generate_ids(cls, values: dict[str, Any], info: ValidationInfo) -> dict[str, Any]:
1466 """Generate IDs for objects that do not have them.
1468 Parameters
1469 ----------
1470 values
1471 The values of the schema.
1472 info
1473 Validation context used to determine if ID generation is enabled.
1475 Returns
1476 -------
1477 `dict` [ `str`, `Any` ]
1478 The values of the schema with generated IDs.
1479 """
1480 context = info.context
1481 if not context or not context.get("id_generation", False):
1482 logger.debug("Skipping ID generation")
1483 return values
1484 schema_name = values["name"]
1485 if "@id" not in values:
1486 values["@id"] = f"#{schema_name}"
1487 logger.debug(f"Generated ID '{values['@id']}' for schema '{schema_name}'")
1488 if "tables" in values:
1489 for table in values["tables"]:
1490 if "@id" not in table:
1491 table["@id"] = f"#{table['name']}"
1492 logger.debug(f"Generated ID '{table['@id']}' for table '{table['name']}'")
1493 if "columns" in table:
1494 for column in table["columns"]:
1495 if "@id" not in column:
1496 column["@id"] = f"#{table['name']}.{column['name']}"
1497 logger.debug(f"Generated ID '{column['@id']}' for column '{column['name']}'")
1498 if "columnGroups" in table:
1499 for column_group in table["columnGroups"]:
1500 if "@id" not in column_group:
1501 column_group["@id"] = f"#{table['name']}.{column_group['name']}"
1502 logger.debug(
1503 f"Generated ID '{column_group['@id']}' for column group "
1504 f"'{column_group['name']}'"
1505 )
1506 if "constraints" in table:
1507 for constraint in table["constraints"]:
1508 if "@id" not in constraint:
1509 constraint["@id"] = f"#{constraint['name']}"
1510 logger.debug(
1511 f"Generated ID '{constraint['@id']}' for constraint '{constraint['name']}'"
1512 )
1513 if "indexes" in table:
1514 for index in table["indexes"]:
1515 if "@id" not in index:
1516 index["@id"] = f"#{index['name']}"
1517 logger.debug(f"Generated ID '{index['@id']}' for index '{index['name']}'")
1518 return values
1520 @field_validator("tables", mode="after")
1521 @classmethod
1522 def check_unique_table_names(cls, tables: list[Table]) -> list[Table]:
1523 """Check that table names are unique.
1525 Parameters
1526 ----------
1527 tables
1528 The tables to check.
1530 Returns
1531 -------
1532 `list` [ `Table` ]
1533 The tables if they are unique.
1535 Raises
1536 ------
1537 ValueError
1538 Raised if table names are not unique.
1539 """
1540 if len(tables) != len(set(table.name for table in tables)):
1541 raise ValueError("Table names must be unique")
1542 return tables
1544 @model_validator(mode="after")
1545 def check_tap_table_indexes(self, info: ValidationInfo) -> Schema:
1546 """Check that the TAP table indexes are unique.
1548 Parameters
1549 ----------
1550 info
1551 The validation context used to determine if the check is enabled.
1553 Returns
1554 -------
1555 `Schema`
1556 The schema being validated.
1557 """
1558 context = info.context
1559 if not context or not context.get("check_tap_table_indexes", False):
1560 return self
1561 table_indicies = set()
1562 for table in self.tables:
1563 table_index = table.tap_table_index
1564 if table_index is not None:
1565 if table_index in table_indicies:
1566 raise ValueError(f"Duplicate 'tap:table_index' value {table_index} found in schema")
1567 table_indicies.add(table_index)
1568 return self
1570 @model_validator(mode="after")
1571 def check_unique_constraint_names(self: Schema) -> Schema:
1572 """Check for duplicate constraint names in the schema.
1574 Returns
1575 -------
1576 `Schema`
1577 The schema being validated.
1579 Raises
1580 ------
1581 ValueError
1582 Raised if duplicate constraint names are found in the schema.
1583 """
1584 constraint_names = set()
1585 duplicate_names = []
1587 for table in self.tables:
1588 for constraint in table.constraints:
1589 constraint_name = constraint.name
1590 if constraint_name in constraint_names:
1591 duplicate_names.append(constraint_name)
1592 else:
1593 constraint_names.add(constraint_name)
1595 if duplicate_names:
1596 raise ValueError(f"Duplicate constraint names found in schema: {duplicate_names}")
1598 return self
1600 @model_validator(mode="after")
1601 def check_unique_index_names(self: Schema) -> Schema:
1602 """Check for duplicate index names in the schema.
1604 Returns
1605 -------
1606 `Schema`
1607 The schema being validated.
1609 Raises
1610 ------
1611 ValueError
1612 Raised if duplicate index names are found in the schema.
1613 """
1614 index_names = set()
1615 duplicate_names = []
1617 for table in self.tables:
1618 for index in table.indexes:
1619 index_name = index.name
1620 if index_name in index_names:
1621 duplicate_names.append(index_name)
1622 else:
1623 index_names.add(index_name)
1625 if duplicate_names:
1626 raise ValueError(f"Duplicate index names found in schema: {duplicate_names}")
1628 return self
1630 @model_validator(mode="after")
1631 def create_id_map(self: Schema) -> Schema:
1632 """Create a map of IDs to objects.
1634 Returns
1635 -------
1636 `Schema`
1637 The schema with the ID map created.
1639 Raises
1640 ------
1641 ValueError
1642 Raised if duplicate identifiers are found in the schema.
1643 """
1644 if self._id_map:
1645 logger.debug("Ignoring call to create_id_map() - ID map was already populated")
1646 return self
1647 visitor: SchemaIdVisitor = SchemaIdVisitor()
1648 visitor.visit_schema(self)
1649 if len(visitor.duplicates):
1650 raise ValueError(
1651 "Duplicate IDs found in schema:\n " + "\n ".join(visitor.duplicates) + "\n"
1652 )
1653 logger.debug("Created ID map with %d entries", len(self._id_map))
1654 return self
1656 def _validate_column_id(
1657 self: Schema,
1658 column_id: str,
1659 loc: tuple,
1660 errors: list[InitErrorDetails],
1661 ) -> None:
1662 """Validate a column ID from a constraint and append errors if invalid.
1664 Parameters
1665 ----------
1666 schema : Schema
1667 The schema being validated.
1668 column_id : str
1669 The column ID to validate.
1670 loc : tuple
1671 The location of the error in the schema.
1672 errors : list[InitErrorDetails]
1673 The list of errors to append to.
1674 """
1675 if column_id not in self:
1676 _append_error(
1677 errors,
1678 loc,
1679 column_id,
1680 f"Column ID '{column_id}' not found in schema",
1681 )
1682 elif not isinstance(self[column_id], Column):
1683 _append_error(
1684 errors,
1685 loc,
1686 column_id,
1687 f"ID '{column_id}' does not refer to a Column object",
1688 )
1690 def _validate_foreign_key_column(
1691 self: Schema,
1692 column_id: str,
1693 table: Table,
1694 loc: tuple,
1695 errors: list[InitErrorDetails],
1696 ) -> None:
1697 """Validate a foreign key column ID from a constraint and append errors
1698 if invalid.
1700 Parameters
1701 ----------
1702 schema : Schema
1703 The schema being validated.
1704 column_id : str
1705 The foreign key column ID to validate.
1706 loc : tuple
1707 The location of the error in the schema.
1708 errors : list[InitErrorDetails]
1709 The list of errors to append to.
1710 """
1711 try:
1712 table._find_column_by_id(column_id)
1713 except KeyError:
1714 _append_error(
1715 errors,
1716 loc,
1717 column_id,
1718 f"Column '{column_id}' not found in table '{table.name}'",
1719 )
1721 @model_validator(mode="after")
1722 def check_constraints(self: Schema) -> Schema:
1723 """Check constraint objects for validity. This needs to be deferred
1724 until after the schema is fully loaded and the ID map is created.
1726 Raises
1727 ------
1728 pydantic.ValidationError
1729 Raised if any constraints are invalid.
1731 Returns
1732 -------
1733 `Schema`
1734 The schema being validated.
1735 """
1736 errors: list[InitErrorDetails] = []
1738 for table_index, table in enumerate(self.tables):
1739 for constraint_index, constraint in enumerate(table.constraints):
1740 column_ids: list[str] = []
1741 referenced_column_ids: list[str] = []
1743 if isinstance(constraint, ForeignKeyConstraint):
1744 column_ids += constraint.columns
1745 referenced_column_ids += constraint.referenced_columns
1746 elif isinstance(constraint, UniqueConstraint):
1747 column_ids += constraint.columns
1748 # No extra checks are required on CheckConstraint objects.
1750 # Validate the foreign key columns
1751 for column_id in column_ids:
1752 self._validate_column_id(
1753 column_id,
1754 (
1755 "tables",
1756 table_index,
1757 "constraints",
1758 constraint_index,
1759 "columns",
1760 column_id,
1761 ),
1762 errors,
1763 )
1764 # Check that the foreign key column is within the source
1765 # table.
1766 self._validate_foreign_key_column(
1767 column_id,
1768 table,
1769 (
1770 "tables",
1771 table_index,
1772 "constraints",
1773 constraint_index,
1774 "columns",
1775 column_id,
1776 ),
1777 errors,
1778 )
1780 # Validate the primary key (reference) columns
1781 for referenced_column_id in referenced_column_ids:
1782 self._validate_column_id(
1783 referenced_column_id,
1784 (
1785 "tables",
1786 table_index,
1787 "constraints",
1788 constraint_index,
1789 "referenced_columns",
1790 referenced_column_id,
1791 ),
1792 errors,
1793 )
1795 if errors:
1796 raise ValidationError.from_exception_data("Schema validation failed", errors)
1798 return self
1800 def __getitem__(self, id: str) -> BaseObject:
1801 """Get an object by its ID.
1803 Parameters
1804 ----------
1805 id
1806 The ID of the object to get.
1808 Raises
1809 ------
1810 KeyError
1811 Raised if the object with the given ID is not found in the schema.
1812 """
1813 if id not in self:
1814 raise KeyError(f"Object with ID '{id}' not found in schema")
1815 return self._id_map[id]
1817 def __contains__(self, id: str) -> bool:
1818 """Check if an object with the given ID is in the schema.
1820 Parameters
1821 ----------
1822 id
1823 The ID of the object to check.
1824 """
1825 return id in self._id_map
1827 def find_object_by_id(self, id: str, obj_type: type[T]) -> T:
1828 """Find an object with the given type by its ID.
1830 Parameters
1831 ----------
1832 id
1833 The ID of the object to find.
1834 obj_type
1835 The type of the object to find.
1837 Returns
1838 -------
1839 BaseObject
1840 The object with the given ID and type.
1842 Raises
1843 ------
1844 KeyError
1845 If the object with the given ID is not found in the schema.
1846 TypeError
1847 If the object that is found does not have the right type.
1849 Notes
1850 -----
1851 The actual return type is the user-specified argument ``T``, which is
1852 expected to be a subclass of `BaseObject`.
1853 """
1854 obj = self[id]
1855 if not isinstance(obj, obj_type):
1856 raise TypeError(f"Object with ID '{id}' is not of type '{obj_type.__name__}'")
1857 return obj
1859 def get_table_by_column(self, column: Column) -> Table:
1860 """Find the table that contains a column.
1862 Parameters
1863 ----------
1864 column
1865 The column to find.
1867 Returns
1868 -------
1869 `Table`
1870 The table that contains the column.
1872 Raises
1873 ------
1874 ValueError
1875 If the column is not found in any table.
1876 """
1877 for table in self.tables:
1878 if column in table.columns:
1879 return table
1880 raise ValueError(f"Column '{column.name}' not found in any table")
1882 @classmethod
1883 def from_uri(cls, resource_path: ResourcePathExpression, context: dict[str, Any] = {}) -> Schema:
1884 """Load a `Schema` from a string representing a ``ResourcePath``.
1886 Parameters
1887 ----------
1888 resource_path
1889 The ``ResourcePath`` pointing to a YAML file.
1890 context
1891 Pydantic context to be used in validation.
1893 Returns
1894 -------
1895 `str`
1896 The ID of the object.
1898 Raises
1899 ------
1900 yaml.YAMLError
1901 Raised if there is an error loading the YAML data.
1902 ValueError
1903 Raised if there is an error reading the resource.
1904 pydantic.ValidationError
1905 Raised if the schema fails validation.
1906 """
1907 try:
1908 rp_stream = ResourcePath(resource_path).read()
1909 except Exception as e:
1910 raise ValueError(f"Error reading resource from '{resource_path}' : {e}") from e
1911 yaml_data = yaml.safe_load(rp_stream)
1912 return Schema.model_validate(yaml_data, context=context)
1914 @classmethod
1915 def from_stream(cls, source: IO[str], context: dict[str, Any] = {}) -> Schema:
1916 """Load a `Schema` from a file stream which should contain YAML data.
1918 Parameters
1919 ----------
1920 source
1921 The file stream to read from.
1922 context
1923 Pydantic context to be used in validation.
1925 Returns
1926 -------
1927 `Schema`
1928 The Felis schema loaded from the stream.
1930 Raises
1931 ------
1932 yaml.YAMLError
1933 Raised if there is an error loading the YAML file.
1934 pydantic.ValidationError
1935 Raised if the schema fails validation.
1936 """
1937 logger.debug("Loading schema from: '%s'", source)
1938 yaml_data = yaml.safe_load(source)
1939 return Schema.model_validate(yaml_data, context=context)
1941 def _model_dump(self, strip_ids: bool = False) -> dict[str, Any]:
1942 """Dump the schema as a dictionary with some default arguments
1943 applied.
1945 Parameters
1946 ----------
1947 strip_ids
1948 Whether to strip the IDs from the dumped data. Defaults to `False`.
1950 Returns
1951 -------
1952 `dict` [ `str`, `Any` ]
1953 The dumped schema data as a dictionary.
1954 """
1955 data = self.model_dump(by_alias=True, exclude_none=True, exclude_defaults=True)
1956 if strip_ids:
1957 data = _strip_ids(data)
1958 return data
1960 def dump_yaml(self, stream: IO[str] = sys.stdout, strip_ids: bool = False) -> None:
1961 """Pretty print the schema as YAML.
1963 Parameters
1964 ----------
1965 stream
1966 The stream to write the YAML data to.
1967 strip_ids
1968 Whether to strip the IDs from the dumped data. Defaults to `False`.
1969 """
1970 data = self._model_dump(strip_ids=strip_ids)
1971 yaml.safe_dump(
1972 data,
1973 stream,
1974 default_flow_style=False,
1975 sort_keys=False,
1976 )
1978 def dump_json(self, stream: IO[str] = sys.stdout, strip_ids: bool = False) -> None:
1979 """Pretty print the schema as JSON.
1981 Parameters
1982 ----------
1983 stream
1984 The stream to write the JSON data to.
1985 strip_ids
1986 Whether to strip the IDs from the dumped data. Defaults to `False`.
1987 """
1988 data = self._model_dump(strip_ids=strip_ids)
1989 json.dump(
1990 data,
1991 stream,
1992 indent=4,
1993 sort_keys=False,
1994 )