Coverage for python/lsst/daf/butler/core/ddl.py: 49%
221 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-12 09:01 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-12 09:01 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21"""Classes for representing SQL data-definition language (DDL) in Python.
23This include "CREATE TABLE" etc.
25This provides an extra layer on top of SQLAlchemy's classes for these concepts,
26because we need a level of indirection between logical tables and the actual
27SQL, and SQLAlchemy's DDL classes always map 1-1 to SQL.
29We've opted for the rather more obscure "ddl" as the name of this module
30instead of "schema" because the latter is too overloaded; in most SQL
31databases, a "schema" is also another term for a namespace.
32"""
33from __future__ import annotations
35from lsst import sphgeom
37__all__ = (
38 "TableSpec",
39 "FieldSpec",
40 "ForeignKeySpec",
41 "Base64Bytes",
42 "Base64Region",
43 "AstropyTimeNsecTai",
44 "GUID",
45)
47import logging
48import uuid
49from base64 import b64decode, b64encode
50from dataclasses import dataclass
51from math import ceil
52from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, Set, Tuple, Type, Union
54import astropy.time
55import sqlalchemy
56from lsst.sphgeom import Region
57from lsst.utils.iteration import ensure_iterable
58from sqlalchemy.dialects.postgresql import UUID
60from . import time_utils
61from .config import Config
62from .exceptions import ValidationError
63from .named import NamedValueSet
64from .utils import stripIfNotNone
66if TYPE_CHECKING: 66 ↛ 67line 66 didn't jump to line 67, because the condition on line 66 was never true
67 from .timespan import TimespanDatabaseRepresentation
70_LOG = logging.getLogger(__name__)
73class SchemaValidationError(ValidationError):
74 """Exceptions that indicate problems in Registry schema configuration."""
76 @classmethod
77 def translate(cls, caught: Type[Exception], message: str) -> Callable:
78 """Return decorator to re-raise exceptions as `SchemaValidationError`.
80 Decorated functions must be class or instance methods, with a
81 ``config`` parameter as their first argument. This will be passed
82 to ``message.format()`` as a keyword argument, along with ``err``,
83 the original exception.
85 Parameters
86 ----------
87 caught : `type` (`Exception` subclass)
88 The type of exception to catch.
89 message : `str`
90 A `str.format` string that may contain named placeholders for
91 ``config``, ``err``, or any keyword-only argument accepted by
92 the decorated function.
93 """
95 def decorate(func: Callable) -> Callable:
96 def decorated(self: Any, config: Config, *args: Any, **kwargs: Any) -> Any:
97 try:
98 return func(self, config, *args, **kwargs)
99 except caught as err:
100 raise cls(message.format(config=str(config), err=err))
102 return decorated
104 return decorate
107class Base64Bytes(sqlalchemy.TypeDecorator):
108 """A SQLAlchemy custom type for Python `bytes`.
110 Maps Python `bytes` to a base64-encoded `sqlalchemy.Text` field.
111 """
113 impl = sqlalchemy.Text
115 cache_ok = True
117 def __init__(self, nbytes: int | None = None, *args: Any, **kwargs: Any):
118 if nbytes is not None:
119 length = 4 * ceil(nbytes / 3) if self.impl == sqlalchemy.String else None
120 else:
121 length = None
122 super().__init__(*args, length=length, **kwargs)
123 self.nbytes = nbytes
125 def process_bind_param(self, value: Optional[bytes], dialect: sqlalchemy.engine.Dialect) -> Optional[str]:
126 # 'value' is native `bytes`. We want to encode that to base64 `bytes`
127 # and then ASCII `str`, because `str` is what SQLAlchemy expects for
128 # String fields.
129 if value is None:
130 return None
131 if not isinstance(value, bytes):
132 raise TypeError(
133 f"Base64Bytes fields require 'bytes' values; got '{value}' with type {type(value)}."
134 )
135 return b64encode(value).decode("ascii")
137 def process_result_value(
138 self, value: Optional[str], dialect: sqlalchemy.engine.Dialect
139 ) -> Optional[bytes]:
140 # 'value' is a `str` that must be ASCII because it's base64-encoded.
141 # We want to transform that to base64-encoded `bytes` and then
142 # native `bytes`.
143 return b64decode(value.encode("ascii")) if value is not None else None
145 @property
146 def python_type(self) -> Type[bytes]:
147 return bytes
150# create an alias, for use below to disambiguate between the built in
151# sqlachemy type
152LocalBase64Bytes = Base64Bytes
155class Base64Region(Base64Bytes):
156 """A SQLAlchemy custom type for Python `sphgeom.Region`.
158 Maps Python `sphgeom.Region` to a base64-encoded `sqlalchemy.String`.
159 """
161 cache_ok = True # have to be set explicitly in each class
163 def process_bind_param(
164 self, value: Optional[Region], dialect: sqlalchemy.engine.Dialect
165 ) -> Optional[str]:
166 if value is None:
167 return None
168 return super().process_bind_param(value.encode(), dialect)
170 def process_result_value(
171 self, value: Optional[str], dialect: sqlalchemy.engine.Dialect
172 ) -> Optional[Region]:
173 if value is None:
174 return None
175 return Region.decode(super().process_result_value(value, dialect))
177 @property
178 def python_type(self) -> Type[sphgeom.Region]:
179 return sphgeom.Region
182class AstropyTimeNsecTai(sqlalchemy.TypeDecorator):
183 """A SQLAlchemy custom type for Python `astropy.time.Time`.
185 Maps Python `astropy.time.Time` to a number of nanoseconds since Unix
186 epoch in TAI scale.
187 """
189 impl = sqlalchemy.BigInteger
191 cache_ok = True
193 def process_bind_param(
194 self, value: Optional[astropy.time.Time], dialect: sqlalchemy.engine.Dialect
195 ) -> Optional[int]:
196 if value is None:
197 return None
198 if not isinstance(value, astropy.time.Time):
199 raise TypeError(f"Unsupported type: {type(value)}, expected astropy.time.Time")
200 value = time_utils.TimeConverter().astropy_to_nsec(value)
201 return value
203 def process_result_value(
204 self, value: Optional[int], dialect: sqlalchemy.engine.Dialect
205 ) -> Optional[astropy.time.Time]:
206 # value is nanoseconds since epoch, or None
207 if value is None:
208 return None
209 value = time_utils.TimeConverter().nsec_to_astropy(value)
210 return value
213class GUID(sqlalchemy.TypeDecorator):
214 """Platform-independent GUID type.
216 Uses PostgreSQL's UUID type, otherwise uses CHAR(32), storing as
217 stringified hex values.
218 """
220 impl = sqlalchemy.CHAR
222 cache_ok = True
224 def load_dialect_impl(self, dialect: sqlalchemy.Dialect) -> sqlalchemy.TypeEngine:
225 if dialect.name == "postgresql":
226 return dialect.type_descriptor(UUID())
227 else:
228 return dialect.type_descriptor(sqlalchemy.CHAR(32))
230 def process_bind_param(self, value: Any, dialect: sqlalchemy.Dialect) -> Optional[str]:
231 if value is None:
232 return value
234 # Coerce input to UUID type, in general having UUID on input is the
235 # only thing that we want but there is code right now that uses ints.
236 if isinstance(value, int):
237 value = uuid.UUID(int=value)
238 elif isinstance(value, bytes):
239 value = uuid.UUID(bytes=value)
240 elif isinstance(value, str):
241 # hexstring
242 value = uuid.UUID(hex=value)
243 elif not isinstance(value, uuid.UUID):
244 raise TypeError(f"Unexpected type of a bind value: {type(value)}")
246 if dialect.name == "postgresql":
247 return str(value)
248 else:
249 return "%.32x" % value.int
251 def process_result_value(self, value: Optional[str], dialect: sqlalchemy.Dialect) -> Optional[uuid.UUID]:
252 if value is None:
253 return value
254 else:
255 return uuid.UUID(hex=value)
258VALID_CONFIG_COLUMN_TYPES = {
259 "string": sqlalchemy.String,
260 "int": sqlalchemy.BigInteger,
261 "float": sqlalchemy.Float,
262 "region": Base64Region,
263 "bool": sqlalchemy.Boolean,
264 "blob": sqlalchemy.LargeBinary,
265 "datetime": AstropyTimeNsecTai,
266 "hash": Base64Bytes,
267 "uuid": GUID,
268}
271@dataclass
272class FieldSpec:
273 """A data class for defining a column in a logical `Registry` table."""
275 name: str
276 """Name of the column."""
278 dtype: type
279 """Type of the column; usually a `type` subclass provided by SQLAlchemy
280 that defines both a Python type and a corresponding precise SQL type.
281 """
283 length: Optional[int] = None
284 """Length of the type in the database, for variable-length types."""
286 nbytes: Optional[int] = None
287 """Natural length used for hash and encoded-region columns, to be converted
288 into the post-encoding length.
289 """
291 primaryKey: bool = False
292 """Whether this field is (part of) its table's primary key."""
294 autoincrement: bool = False
295 """Whether the database should insert automatically incremented values when
296 no value is provided in an INSERT.
297 """
299 nullable: bool = True
300 """Whether this field is allowed to be NULL. If ``primaryKey`` is
301 `True`, during construction this value will be forced to `False`."""
303 default: Any = None
304 """A server-side default value for this field.
306 This is passed directly as the ``server_default`` argument to
307 `sqlalchemy.schema.Column`. It does _not_ go through SQLAlchemy's usual
308 type conversion or quoting for Python literals, and should hence be used
309 with care. See the SQLAlchemy documentation for more information.
310 """
312 doc: Optional[str] = None
313 """Documentation for this field."""
315 def __post_init__(self) -> None:
316 if self.primaryKey:
317 # Change the default to match primaryKey.
318 self.nullable = False
320 def __eq__(self, other: Any) -> bool:
321 if isinstance(other, FieldSpec):
322 return self.name == other.name
323 else:
324 return NotImplemented
326 def __hash__(self) -> int:
327 return hash(self.name)
329 @classmethod
330 @SchemaValidationError.translate(KeyError, "Missing key {err} in column config '{config}'.")
331 def fromConfig(cls, config: Config, **kwargs: Any) -> FieldSpec:
332 """Create a `FieldSpec` from a subset of a `SchemaConfig`.
334 Parameters
335 ----------
336 config: `Config`
337 Configuration describing the column. Nested configuration keys
338 correspond to `FieldSpec` attributes.
339 **kwargs
340 Additional keyword arguments that provide defaults for values
341 not present in config.
343 Returns
344 -------
345 spec: `FieldSpec`
346 Specification structure for the column.
348 Raises
349 ------
350 SchemaValidationError
351 Raised if configuration keys are missing or have invalid values.
352 """
353 dtype = VALID_CONFIG_COLUMN_TYPES.get(config["type"])
354 if dtype is None:
355 raise SchemaValidationError(f"Invalid field type string: '{config['type']}'.")
356 if not config["name"].islower():
357 raise SchemaValidationError(f"Column name '{config['name']}' is not all lowercase.")
358 self = cls(name=config["name"], dtype=dtype, **kwargs)
359 self.length = config.get("length", self.length)
360 self.nbytes = config.get("nbytes", self.nbytes)
361 if self.length is not None and self.nbytes is not None:
362 raise SchemaValidationError(f"Both length and nbytes provided for field '{self.name}'.")
363 self.primaryKey = config.get("primaryKey", self.primaryKey)
364 self.autoincrement = config.get("autoincrement", self.autoincrement)
365 self.nullable = config.get("nullable", False if self.primaryKey else self.nullable)
366 self.doc = stripIfNotNone(config.get("doc", None))
367 return self
369 @classmethod
370 def for_region(cls, name: str = "region", nullable: bool = True, nbytes: int = 2048) -> FieldSpec:
371 """Create a `FieldSpec` for a spatial region column.
373 Parameters
374 ----------
375 name : `str`, optional
376 Name for the field.
377 nullable : `bool`, optional
378 Whether NULL values are permitted.
379 nbytes : `int`, optional
380 Maximum number of bytes for serialized regions. The actual column
381 size will be larger to allow for base-64 encoding.
383 Returns
384 -------
385 spec : `FieldSpec`
386 Specification structure for a region column.
387 """
388 return cls(name, nullable=nullable, dtype=Base64Region, nbytes=nbytes)
390 def isStringType(self) -> bool:
391 """Indicate that this is a sqlalchemy.String field spec.
393 Returns
394 -------
395 isString : `bool`
396 The field refers to a `sqlalchemy.String` and not any other type.
397 This can return `False` even if the object was created with a
398 string type if it has been decided that it should be implemented
399 as a `sqlalchemy.Text` type.
400 """
401 if self.dtype == sqlalchemy.String:
402 # For short strings retain them as strings
403 if self.dtype == sqlalchemy.String and self.length and self.length <= 32:
404 return True
405 return False
407 def getSizedColumnType(self) -> sqlalchemy.types.TypeEngine:
408 """Return a sized version of the column type.
410 Utilizes either (or neither) of ``self.length`` and ``self.nbytes``.
412 Returns
413 -------
414 dtype : `sqlalchemy.types.TypeEngine`
415 A SQLAlchemy column type object.
416 """
417 if self.length is not None:
418 # Last chance check that we are only looking at possible String
419 if self.dtype == sqlalchemy.String and not self.isStringType():
420 return sqlalchemy.Text
421 return self.dtype(length=self.length)
422 if self.nbytes is not None:
423 return self.dtype(nbytes=self.nbytes)
424 return self.dtype
426 def getPythonType(self) -> type:
427 """Return the Python type associated with this field's (SQL) dtype.
429 Returns
430 -------
431 type : `type`
432 Python type associated with this field's (SQL) `dtype`.
433 """
434 # to construct these objects, nbytes keyword is needed
435 if issubclass(self.dtype, LocalBase64Bytes):
436 # satisfy mypy for something that must be true
437 assert self.nbytes is not None
438 return self.dtype(nbytes=self.nbytes).python_type
439 else:
440 return self.dtype().python_type # type: ignore
443@dataclass
444class ForeignKeySpec:
445 """Definition of a foreign key constraint in a logical `Registry` table."""
447 table: str
448 """Name of the target table."""
450 source: Tuple[str, ...]
451 """Tuple of source table column names."""
453 target: Tuple[str, ...]
454 """Tuple of target table column names."""
456 onDelete: Optional[str] = None
457 """SQL clause indicating how to handle deletes to the target table.
459 If not `None` (which indicates that a constraint violation exception should
460 be raised), should be either "SET NULL" or "CASCADE".
461 """
463 addIndex: bool = True
464 """If `True`, create an index on the columns of this foreign key in the
465 source table.
466 """
468 @classmethod
469 @SchemaValidationError.translate(KeyError, "Missing key {err} in foreignKey config '{config}'.")
470 def fromConfig(cls, config: Config) -> ForeignKeySpec:
471 """Create a `ForeignKeySpec` from a subset of a `SchemaConfig`.
473 Parameters
474 ----------
475 config: `Config`
476 Configuration describing the constraint. Nested configuration keys
477 correspond to `ForeignKeySpec` attributes.
479 Returns
480 -------
481 spec: `ForeignKeySpec`
482 Specification structure for the constraint.
484 Raises
485 ------
486 SchemaValidationError
487 Raised if configuration keys are missing or have invalid values.
488 """
489 return cls(
490 table=config["table"],
491 source=tuple(ensure_iterable(config["source"])),
492 target=tuple(ensure_iterable(config["target"])),
493 onDelete=config.get("onDelete", None),
494 )
497@dataclass
498class TableSpec:
499 """A data class used to define a table or table-like query interface.
501 Parameters
502 ----------
503 fields : `Iterable` [ `FieldSpec` ]
504 Specifications for the columns in this table.
505 unique : `Iterable` [ `tuple` [ `str` ] ], optional
506 Non-primary-key unique constraints for the table.
507 indexes: `Iterable` [ `tuple` [ `str` ] ], optional
508 Indexes for the table.
509 foreignKeys : `Iterable` [ `ForeignKeySpec` ], optional
510 Foreign key constraints for the table.
511 exclusion : `Iterable` [ `tuple` [ `str` or `type` ] ]
512 Special constraints that prohibit overlaps between timespans over rows
513 where other columns are equal. These take the same form as unique
514 constraints, but each tuple may contain a single
515 `TimespanDatabaseRepresentation` subclass representing a timespan
516 column.
517 recycleIds : `bool`, optional
518 If `True`, allow databases that might normally recycle autoincrement
519 IDs to do so (usually better for performance) on any autoincrement
520 field in this table.
521 doc : `str`, optional
522 Documentation for the table.
523 """
525 def __init__(
526 self,
527 fields: Iterable[FieldSpec],
528 *,
529 unique: Iterable[Tuple[str, ...]] = (),
530 indexes: Iterable[Tuple[str, ...]] = (),
531 foreignKeys: Iterable[ForeignKeySpec] = (),
532 exclusion: Iterable[Tuple[Union[str, Type[TimespanDatabaseRepresentation]], ...]] = (),
533 recycleIds: bool = True,
534 doc: Optional[str] = None,
535 ):
536 self.fields = NamedValueSet(fields)
537 self.unique = set(unique)
538 self.indexes = set(indexes)
539 self.foreignKeys = list(foreignKeys)
540 self.exclusion = set(exclusion)
541 self.recycleIds = recycleIds
542 self.doc = doc
544 fields: NamedValueSet[FieldSpec]
545 """Specifications for the columns in this table."""
547 unique: Set[Tuple[str, ...]]
548 """Non-primary-key unique constraints for the table."""
550 indexes: Set[Tuple[str, ...]]
551 """Indexes for the table."""
553 foreignKeys: List[ForeignKeySpec]
554 """Foreign key constraints for the table."""
556 exclusion: Set[Tuple[Union[str, Type[TimespanDatabaseRepresentation]], ...]]
557 """Exclusion constraints for the table.
559 Exclusion constraints behave mostly like unique constraints, but may
560 contain a database-native Timespan column that is restricted to not overlap
561 across rows (for identical combinations of any non-Timespan columns in the
562 constraint).
563 """
565 recycleIds: bool = True
566 """If `True`, allow databases that might normally recycle autoincrement IDs
567 to do so (usually better for performance) on any autoincrement field in
568 this table.
569 """
571 doc: Optional[str] = None
572 """Documentation for the table."""
574 @classmethod
575 @SchemaValidationError.translate(KeyError, "Missing key {err} in table config '{config}'.")
576 def fromConfig(cls, config: Config) -> TableSpec:
577 """Create a `ForeignKeySpec` from a subset of a `SchemaConfig`.
579 Parameters
580 ----------
581 config: `Config`
582 Configuration describing the constraint. Nested configuration keys
583 correspond to `TableSpec` attributes.
585 Returns
586 -------
587 spec: `TableSpec`
588 Specification structure for the table.
590 Raises
591 ------
592 SchemaValidationError
593 Raised if configuration keys are missing or have invalid values.
594 """
595 return cls(
596 fields=NamedValueSet(FieldSpec.fromConfig(c) for c in config["columns"]),
597 unique={tuple(u) for u in config.get("unique", ())},
598 foreignKeys=[ForeignKeySpec.fromConfig(c) for c in config.get("foreignKeys", ())],
599 doc=stripIfNotNone(config.get("doc")),
600 )