Coverage for python/felis/datamodel.py: 62%
214 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-23 10:44 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-23 10:44 +0000
1# This file is part of felis.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import logging
23from collections.abc import Mapping
24from enum import Enum
25from typing import Any, Literal
27from astropy import units as units # type: ignore
28from astropy.io.votable import ucd # type: ignore
29from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
31logger = logging.getLogger(__name__)
32# logger.setLevel(logging.DEBUG)
34__all__ = (
35 "BaseObject",
36 "Column",
37 "Constraint",
38 "CheckConstraint",
39 "UniqueConstraint",
40 "Index",
41 "ForeignKeyConstraint",
42 "Table",
43 "SchemaVersion",
44 "Schema",
45)
48class BaseObject(BaseModel):
49 """Base class for all Felis objects."""
51 model_config = ConfigDict(populate_by_name=True, extra="forbid", use_enum_values=True)
52 """Configuration for the `BaseModel` class.
54 Allow attributes to be populated by name and forbid extra attributes.
55 """
57 name: str
58 """The name of the database object.
60 All Felis database objects must have a name.
61 """
63 id: str = Field(alias="@id")
64 """The unique identifier of the database object.
66 All Felis database objects must have a unique identifier.
67 """
69 description: str | None = None
70 """A description of the database object.
72 The description is optional.
73 """
76class DataType(Enum):
77 """`Enum` representing the data types supported by Felis."""
79 BOOLEAN = "boolean"
80 BYTE = "byte"
81 SHORT = "short"
82 INT = "int"
83 LONG = "long"
84 FLOAT = "float"
85 DOUBLE = "double"
86 CHAR = "char"
87 STRING = "string"
88 UNICODE = "unicode"
89 TEXT = "text"
90 BINARY = "binary"
91 TIMESTAMP = "timestamp"
94class Column(BaseObject):
95 """A column in a table."""
97 datatype: DataType
98 """The datatype of the column."""
100 length: int | None = None
101 """The length of the column."""
103 nullable: bool = True
104 """Whether the column can be `NULL`."""
106 value: Any = None
107 """The default value of the column."""
109 autoincrement: bool | None = None
110 """Whether the column is autoincremented."""
112 mysql_datatype: str | None = Field(None, alias="mysql:datatype")
113 """The MySQL datatype of the column."""
115 ivoa_ucd: str | None = Field(None, alias="ivoa:ucd")
116 """The IVOA UCD of the column."""
118 fits_tunit: str | None = Field(None, alias="fits:tunit")
119 """The FITS TUNIT of the column."""
121 ivoa_unit: str | None = Field(None, alias="ivoa:unit")
122 """The IVOA unit of the column."""
124 tap_column_index: int | None = Field(None, alias="tap:column_index")
125 """The TAP_SCHEMA column index of the column."""
127 tap_principal: int | None = Field(0, alias="tap:principal", ge=0, le=1)
128 """Whether this is a TAP_SCHEMA principal column; can be either 0 or 1.
130 This could be a boolean instead of 0 or 1.
131 """
133 votable_arraysize: int | Literal["*"] | None = Field(None, alias="votable:arraysize")
134 """The VOTable arraysize of the column."""
136 tap_std: int | None = Field(0, alias="tap:std", ge=0, le=1)
137 """TAP_SCHEMA indication that this column is defined by an IVOA standard.
138 """
140 votable_utype: str | None = Field(None, alias="votable:utype")
141 """The VOTable utype (usage-specific or unique type) of the column."""
143 votable_xtype: str | None = Field(None, alias="votable:xtype")
144 """The VOTable xtype (extended type) of the column."""
146 @field_validator("ivoa_ucd")
147 @classmethod
148 def check_ivoa_ucd(cls, ivoa_ucd: str) -> str:
149 """Check that IVOA UCD values are valid."""
150 if ivoa_ucd is not None:
151 try:
152 ucd.parse_ucd(ivoa_ucd, check_controlled_vocabulary=True, has_colon=";" in ivoa_ucd)
153 except ValueError as e:
154 raise ValueError(f"Invalid IVOA UCD: {e}")
155 return ivoa_ucd
157 @model_validator(mode="before")
158 @classmethod
159 def check_units(cls, values: dict[str, Any]) -> dict[str, Any]:
160 """Check that units are valid."""
161 fits_unit = values.get("fits:tunit")
162 ivoa_unit = values.get("ivoa:unit")
164 if fits_unit and ivoa_unit:
165 raise ValueError("Column cannot have both FITS and IVOA units")
166 unit = fits_unit or ivoa_unit
168 if unit is not None:
169 try:
170 units.Unit(unit)
171 except ValueError as e:
172 raise ValueError(f"Invalid unit: {e}")
174 return values
177class Constraint(BaseObject):
178 """A database table constraint."""
180 deferrable: bool = False
181 """If `True` then this constraint will be declared as deferrable."""
183 initially: str | None = None
184 """Value for ``INITIALLY`` clause, only used if ``deferrable`` is True."""
186 annotations: Mapping[str, Any] = Field(default_factory=dict)
187 """Additional annotations for this constraint."""
189 type: str | None = Field(None, alias="@type")
190 """The type of the constraint."""
193class CheckConstraint(Constraint):
194 """A check constraint on a table."""
196 expression: str
197 """The expression for the check constraint."""
200class UniqueConstraint(Constraint):
201 """A unique constraint on a table."""
203 columns: list[str]
204 """The columns in the unique constraint."""
207class Index(BaseObject):
208 """A database table index.
210 An index can be defined on either columns or expressions, but not both.
211 """
213 columns: list[str] | None = None
214 """The columns in the index."""
216 expressions: list[str] | None = None
217 """The expressions in the index."""
219 @model_validator(mode="before")
220 @classmethod
221 def check_columns_or_expressions(cls, values: dict[str, Any]) -> dict[str, Any]:
222 """Check that columns or expressions are specified, but not both."""
223 if "columns" in values and "expressions" in values:
224 raise ValueError("Defining columns and expressions is not valid")
225 elif "columns" not in values and "expressions" not in values:
226 raise ValueError("Must define columns or expressions")
227 return values
230class ForeignKeyConstraint(Constraint):
231 """A foreign key constraint on a table.
233 These will be reflected in the TAP_SCHEMA keys and key_columns data.
234 """
236 columns: list[str]
237 """The columns comprising the foreign key."""
239 referenced_columns: list[str] = Field(alias="referencedColumns")
240 """The columns referenced by the foreign key."""
243class Table(BaseObject):
244 """A database table."""
246 columns: list[Column]
247 """The columns in the table."""
249 constraints: list[Constraint] = Field(default_factory=list)
250 """The constraints on the table."""
252 indexes: list[Index] = Field(default_factory=list)
253 """The indexes on the table."""
255 primaryKey: str | list[str] | None = None
256 """The primary key of the table."""
258 tap_table_index: int | None = Field(None, alias="tap:table_index")
259 """The IVOA TAP_SCHEMA table index of the table."""
261 mysql_engine: str | None = Field(None, alias="mysql:engine")
262 """The mysql engine to use for the table.
264 For now this is a freeform string but it could be constrained to a list of
265 known engines in the future.
266 """
268 mysql_charset: str | None = Field(None, alias="mysql:charset")
269 """The mysql charset to use for the table.
271 For now this is a freeform string but it could be constrained to a list of
272 known charsets in the future.
273 """
275 @model_validator(mode="before")
276 @classmethod
277 def create_constraints(cls, values: dict[str, Any]) -> dict[str, Any]:
278 """Create constraints from the ``constraints`` field."""
279 if "constraints" in values:
280 new_constraints: list[Constraint] = []
281 for item in values["constraints"]:
282 if item["@type"] == "ForeignKey":
283 new_constraints.append(ForeignKeyConstraint(**item))
284 elif item["@type"] == "Unique":
285 new_constraints.append(UniqueConstraint(**item))
286 elif item["@type"] == "Check":
287 new_constraints.append(CheckConstraint(**item))
288 else:
289 raise ValueError(f"Unknown constraint type: {item['@type']}")
290 values["constraints"] = new_constraints
291 return values
293 @field_validator("columns", mode="after")
294 @classmethod
295 def check_unique_column_names(cls, columns: list[Column]) -> list[Column]:
296 """Check that column names are unique."""
297 if len(columns) != len(set(column.name for column in columns)):
298 raise ValueError("Column names must be unique")
299 return columns
302class SchemaVersion(BaseModel):
303 """The version of the schema."""
305 current: str
306 """The current version of the schema."""
308 compatible: list[str] | None = None
309 """The compatible versions of the schema."""
311 read_compatible: list[str] | None = None
312 """The read compatible versions of the schema."""
315class SchemaVisitor:
316 """Visitor to build a Schema object's map of IDs to objects.
318 Duplicates are added to a set when they are encountered, which can be
319 accessed via the `duplicates` attribute. The presence of duplicates will
320 not throw an error. Only the first object with a given ID will be added to
321 the map, but this should not matter, since a ValidationError will be thrown
322 by the `model_validator` method if any duplicates are found in the schema.
324 This class is intended for internal use only.
325 """
327 def __init__(self) -> None:
328 """Create a new SchemaVisitor."""
329 self.schema: "Schema" | None = None
330 self.duplicates: set[str] = set()
332 def add(self, obj: BaseObject) -> None:
333 """Add an object to the ID map."""
334 if hasattr(obj, "id"):
335 obj_id = getattr(obj, "id")
336 if self.schema is not None:
337 if obj_id in self.schema.id_map:
338 self.duplicates.add(obj_id)
339 else:
340 self.schema.id_map[obj_id] = obj
342 def visit_schema(self, schema: "Schema") -> None:
343 """Visit the schema object that was added during initialization.
345 This will set an internal variable pointing to the schema object.
346 """
347 self.schema = schema
348 self.duplicates.clear()
349 self.add(self.schema)
350 for table in self.schema.tables:
351 self.visit_table(table)
353 def visit_table(self, table: Table) -> None:
354 """Visit a table object."""
355 self.add(table)
356 for column in table.columns:
357 self.visit_column(column)
358 for constraint in table.constraints:
359 self.visit_constraint(constraint)
361 def visit_column(self, column: Column) -> None:
362 """Visit a column object."""
363 self.add(column)
365 def visit_constraint(self, constraint: Constraint) -> None:
366 """Visit a constraint object."""
367 self.add(constraint)
370class Schema(BaseObject):
371 """The database schema."""
373 version: SchemaVersion | None = None
374 """The version of the schema."""
376 tables: list[Table]
377 """The tables in the schema."""
379 id_map: dict[str, Any] = Field(default_factory=dict, exclude=True)
380 """Map of IDs to objects."""
382 @field_validator("tables", mode="after")
383 @classmethod
384 def check_unique_table_names(cls, tables: list[Table]) -> list[Table]:
385 """Check that table names are unique."""
386 if len(tables) != len(set(table.name for table in tables)):
387 raise ValueError("Table names must be unique")
388 return tables
390 @model_validator(mode="after")
391 def create_id_map(self) -> "Schema":
392 """Create a map of IDs to objects."""
393 visitor: SchemaVisitor = SchemaVisitor()
394 visitor.visit_schema(self)
395 logger.debug(f"ID map contains {len(self.id_map.keys())} objects")
396 if len(visitor.duplicates):
397 raise ValueError(
398 "Duplicate IDs found in schema:\n " + "\n ".join(visitor.duplicates) + "\n"
399 )
400 return self
402 def get_object_by_id(self, id: str) -> BaseObject:
403 """Get an object by its unique "@id" field value.
405 An error will be thrown if the object is not found.
406 """
407 if id not in self.id_map:
408 raise ValueError(f"Object with ID {id} not found in schema")
409 return self.id_map[id]