Coverage for python/lsst/daf/butler/ddl.py: 55%

230 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-16 10:44 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27"""Classes for representing SQL data-definition language (DDL) in Python. 

28 

29This include "CREATE TABLE" etc. 

30 

31This provides an extra layer on top of SQLAlchemy's classes for these concepts, 

32because we need a level of indirection between logical tables and the actual 

33SQL, and SQLAlchemy's DDL classes always map 1-1 to SQL. 

34 

35We've opted for the rather more obscure "ddl" as the name of this module 

36instead of "schema" because the latter is too overloaded; in most SQL 

37databases, a "schema" is also another term for a namespace. 

38""" 

39from __future__ import annotations 

40 

41from lsst import sphgeom 

42 

43__all__ = ( 

44 "TableSpec", 

45 "FieldSpec", 

46 "ForeignKeySpec", 

47 "IndexSpec", 

48 "Base64Bytes", 

49 "Base64Region", 

50 "AstropyTimeNsecTai", 

51 "GUID", 

52) 

53 

54import logging 

55import uuid 

56from base64 import b64decode, b64encode 

57from collections.abc import Callable, Iterable 

58from dataclasses import dataclass 

59from math import ceil 

60from typing import TYPE_CHECKING, Any 

61 

62import astropy.time 

63import sqlalchemy 

64from lsst.sphgeom import Region 

65from lsst.utils.iteration import ensure_iterable 

66from sqlalchemy.dialects.postgresql import UUID 

67 

68from . import time_utils 

69from ._config import Config 

70from ._exceptions import ValidationError 

71from ._named import NamedValueSet 

72from .utils import stripIfNotNone 

73 

74if TYPE_CHECKING: 

75 from .timespan_database_representation import TimespanDatabaseRepresentation 

76 

77 

78_LOG = logging.getLogger(__name__) 

79 

80 

81class SchemaValidationError(ValidationError): 

82 """Exceptions that indicate problems in Registry schema configuration.""" 

83 

84 @classmethod 

85 def translate(cls, caught: type[Exception], message: str) -> Callable: 

86 """Return decorator to re-raise exceptions as `SchemaValidationError`. 

87 

88 Decorated functions must be class or instance methods, with a 

89 ``config`` parameter as their first argument. This will be passed 

90 to ``message.format()`` as a keyword argument, along with ``err``, 

91 the original exception. 

92 

93 Parameters 

94 ---------- 

95 caught : `type` (`Exception` subclass) 

96 The type of exception to catch. 

97 message : `str` 

98 A `str.format` string that may contain named placeholders for 

99 ``config``, ``err``, or any keyword-only argument accepted by 

100 the decorated function. 

101 """ 

102 

103 def decorate(func: Callable) -> Callable: 

104 def decorated(self: Any, config: Config, *args: Any, **kwargs: Any) -> Any: 

105 try: 

106 return func(self, config, *args, **kwargs) 

107 except caught as err: 

108 raise cls(message.format(config=str(config), err=err)) from err 

109 

110 return decorated 

111 

112 return decorate 

113 

114 

115class Base64Bytes(sqlalchemy.TypeDecorator): 

116 """A SQLAlchemy custom type for Python `bytes`. 

117 

118 Maps Python `bytes` to a base64-encoded `sqlalchemy.Text` field. 

119 

120 Parameters 

121 ---------- 

122 nbytes : `int` or `None`, optional 

123 Number of bytes. 

124 *args : `typing.Any` 

125 Parameters passed to base class constructor. 

126 **kwargs : `typing.Any` 

127 Keyword parameters passed to base class constructor. 

128 """ 

129 

130 impl = sqlalchemy.Text 

131 

132 cache_ok = True 

133 

134 def __init__(self, nbytes: int | None = None, *args: Any, **kwargs: Any): 

135 if nbytes is not None: 

136 length = 4 * ceil(nbytes / 3) if self.impl is sqlalchemy.String else None 

137 else: 

138 length = None 

139 super().__init__(*args, length=length, **kwargs) 

140 self.nbytes = nbytes 

141 

142 def process_bind_param(self, value: bytes | None, dialect: sqlalchemy.engine.Dialect) -> str | None: 

143 # 'value' is native `bytes`. We want to encode that to base64 `bytes` 

144 # and then ASCII `str`, because `str` is what SQLAlchemy expects for 

145 # String fields. 

146 if value is None: 

147 return None 

148 if not isinstance(value, bytes): 

149 raise TypeError( 

150 f"Base64Bytes fields require 'bytes' values; got '{value}' with type {type(value)}." 

151 ) 

152 return b64encode(value).decode("ascii") 

153 

154 def process_result_value(self, value: str | None, dialect: sqlalchemy.engine.Dialect) -> bytes | None: 

155 # 'value' is a `str` that must be ASCII because it's base64-encoded. 

156 # We want to transform that to base64-encoded `bytes` and then 

157 # native `bytes`. 

158 return b64decode(value.encode("ascii")) if value is not None else None 

159 

160 @property 

161 def python_type(self) -> type[bytes]: 

162 return bytes 

163 

164 

165# create an alias, for use below to disambiguate between the built in 

166# sqlachemy type 

167LocalBase64Bytes = Base64Bytes 

168 

169 

170class Base64Region(Base64Bytes): 

171 """A SQLAlchemy custom type for Python `sphgeom.Region`. 

172 

173 Maps Python `sphgeom.Region` to a base64-encoded `sqlalchemy.String`. 

174 """ 

175 

176 cache_ok = True # have to be set explicitly in each class 

177 

178 def process_bind_param(self, value: Region | None, dialect: sqlalchemy.engine.Dialect) -> str | None: 

179 if value is None: 

180 return None 

181 return super().process_bind_param(value.encode(), dialect) 

182 

183 def process_result_value(self, value: str | None, dialect: sqlalchemy.engine.Dialect) -> Region | None: 

184 if value is None: 

185 return None 

186 return Region.decode(super().process_result_value(value, dialect)) 

187 

188 @property 

189 def python_type(self) -> type[sphgeom.Region]: 

190 return sphgeom.Region 

191 

192 

193class AstropyTimeNsecTai(sqlalchemy.TypeDecorator): 

194 """A SQLAlchemy custom type for Python `astropy.time.Time`. 

195 

196 Maps Python `astropy.time.Time` to a number of nanoseconds since Unix 

197 epoch in TAI scale. 

198 """ 

199 

200 impl = sqlalchemy.BigInteger 

201 

202 cache_ok = True 

203 

204 def process_bind_param( 

205 self, value: astropy.time.Time | None, dialect: sqlalchemy.engine.Dialect 

206 ) -> int | None: 

207 if value is None: 

208 return None 

209 if not isinstance(value, astropy.time.Time): 

210 raise TypeError(f"Unsupported type: {type(value)}, expected astropy.time.Time") 

211 value = time_utils.TimeConverter().astropy_to_nsec(value) 

212 return value 

213 

214 def process_result_value( 

215 self, value: int | None, dialect: sqlalchemy.engine.Dialect 

216 ) -> astropy.time.Time | None: 

217 # value is nanoseconds since epoch, or None 

218 if value is None: 

219 return None 

220 value = time_utils.TimeConverter().nsec_to_astropy(value) 

221 return value 

222 

223 

224# TODO: sqlalchemy 2 has internal support for UUID: 

225# https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.Uuid 

226class GUID(sqlalchemy.TypeDecorator): 

227 """Platform-independent GUID type. 

228 

229 Uses PostgreSQL's UUID type, otherwise uses CHAR(32), storing as 

230 stringified hex values. 

231 """ 

232 

233 impl = sqlalchemy.CHAR 

234 

235 cache_ok = True 

236 

237 def load_dialect_impl(self, dialect: sqlalchemy.Dialect) -> sqlalchemy.types.TypeEngine: 

238 if dialect.name == "postgresql": 

239 return dialect.type_descriptor(UUID()) 

240 else: 

241 return dialect.type_descriptor(sqlalchemy.CHAR(32)) 

242 

243 def process_bind_param(self, value: Any, dialect: sqlalchemy.Dialect) -> str | None: 

244 if value is None: 

245 return value 

246 

247 # Coerce input to UUID type, in general having UUID on input is the 

248 # only thing that we want but there is code right now that uses ints. 

249 if isinstance(value, int): 

250 value = uuid.UUID(int=value) 

251 elif isinstance(value, bytes): 

252 value = uuid.UUID(bytes=value) 

253 elif isinstance(value, str): 

254 # hexstring 

255 value = uuid.UUID(hex=value) 

256 elif not isinstance(value, uuid.UUID): 

257 raise TypeError(f"Unexpected type of a bind value: {type(value)}") 

258 

259 if dialect.name == "postgresql": 

260 return str(value) 

261 else: 

262 return "%.32x" % value.int 

263 

264 def process_result_value( 

265 self, value: str | uuid.UUID | None, dialect: sqlalchemy.Dialect 

266 ) -> uuid.UUID | None: 

267 if value is None or isinstance(value, uuid.UUID): 

268 # sqlalchemy 2 converts to UUID internally 

269 return value 

270 else: 

271 return uuid.UUID(hex=value) 

272 

273 

274VALID_CONFIG_COLUMN_TYPES = { 

275 "string": sqlalchemy.String, 

276 "int": sqlalchemy.BigInteger, 

277 "float": sqlalchemy.Float, 

278 "region": Base64Region, 

279 "bool": sqlalchemy.Boolean, 

280 "blob": sqlalchemy.LargeBinary, 

281 "datetime": AstropyTimeNsecTai, 

282 "hash": Base64Bytes, 

283 "uuid": GUID, 

284} 

285 

286 

287@dataclass 

288class FieldSpec: 

289 """A data class for defining a column in a logical `Registry` table.""" 

290 

291 name: str 

292 """Name of the column.""" 

293 

294 dtype: type 

295 """Type of the column; usually a `type` subclass provided by SQLAlchemy 

296 that defines both a Python type and a corresponding precise SQL type. 

297 """ 

298 

299 length: int | None = None 

300 """Length of the type in the database, for variable-length types.""" 

301 

302 nbytes: int | None = None 

303 """Natural length used for hash and encoded-region columns, to be converted 

304 into the post-encoding length. 

305 """ 

306 

307 primaryKey: bool = False 

308 """Whether this field is (part of) its table's primary key.""" 

309 

310 autoincrement: bool = False 

311 """Whether the database should insert automatically incremented values when 

312 no value is provided in an INSERT. 

313 """ 

314 

315 nullable: bool = True 

316 """Whether this field is allowed to be NULL. If ``primaryKey`` is 

317 `True`, during construction this value will be forced to `False`.""" 

318 

319 default: Any = None 

320 """A server-side default value for this field. 

321 

322 This is passed directly as the ``server_default`` argument to 

323 `sqlalchemy.schema.Column`. It does _not_ go through SQLAlchemy's usual 

324 type conversion or quoting for Python literals, and should hence be used 

325 with care. See the SQLAlchemy documentation for more information. 

326 """ 

327 

328 doc: str | None = None 

329 """Documentation for this field.""" 

330 

331 def __post_init__(self) -> None: 

332 if self.primaryKey: 

333 # Change the default to match primaryKey. 

334 self.nullable = False 

335 

336 def __eq__(self, other: Any) -> bool: 

337 if isinstance(other, FieldSpec): 

338 return self.name == other.name 

339 else: 

340 return NotImplemented 

341 

342 def __hash__(self) -> int: 

343 return hash(self.name) 

344 

345 @classmethod 

346 @SchemaValidationError.translate(KeyError, "Missing key {err} in column config '{config}'.") 

347 def fromConfig(cls, config: Config, **kwargs: Any) -> FieldSpec: 

348 """Create a `FieldSpec` from a subset of a `SchemaConfig`. 

349 

350 Parameters 

351 ---------- 

352 config : `Config` 

353 Configuration describing the column. Nested configuration keys 

354 correspond to `FieldSpec` attributes. 

355 **kwargs 

356 Additional keyword arguments that provide defaults for values 

357 not present in config. 

358 

359 Returns 

360 ------- 

361 spec: `FieldSpec` 

362 Specification structure for the column. 

363 

364 Raises 

365 ------ 

366 SchemaValidationError 

367 Raised if configuration keys are missing or have invalid values. 

368 """ 

369 dtype = VALID_CONFIG_COLUMN_TYPES.get(config["type"]) 

370 if dtype is None: 

371 raise SchemaValidationError(f"Invalid field type string: '{config['type']}'.") 

372 if not config["name"].islower(): 

373 raise SchemaValidationError(f"Column name '{config['name']}' is not all lowercase.") 

374 self = cls(name=config["name"], dtype=dtype, **kwargs) 

375 self.length = config.get("length", self.length) 

376 self.nbytes = config.get("nbytes", self.nbytes) 

377 if self.length is not None and self.nbytes is not None: 

378 raise SchemaValidationError(f"Both length and nbytes provided for field '{self.name}'.") 

379 self.primaryKey = config.get("primaryKey", self.primaryKey) 

380 self.autoincrement = config.get("autoincrement", self.autoincrement) 

381 self.nullable = config.get("nullable", False if self.primaryKey else self.nullable) 

382 self.doc = stripIfNotNone(config.get("doc", None)) 

383 return self 

384 

385 @classmethod 

386 def for_region(cls, name: str = "region", nullable: bool = True, nbytes: int = 2048) -> FieldSpec: 

387 """Create a `FieldSpec` for a spatial region column. 

388 

389 Parameters 

390 ---------- 

391 name : `str`, optional 

392 Name for the field. 

393 nullable : `bool`, optional 

394 Whether NULL values are permitted. 

395 nbytes : `int`, optional 

396 Maximum number of bytes for serialized regions. The actual column 

397 size will be larger to allow for base-64 encoding. 

398 

399 Returns 

400 ------- 

401 spec : `FieldSpec` 

402 Specification structure for a region column. 

403 """ 

404 return cls(name, nullable=nullable, dtype=Base64Region, nbytes=nbytes) 

405 

406 def isStringType(self) -> bool: 

407 """Indicate that this is a sqlalchemy.String field spec. 

408 

409 Returns 

410 ------- 

411 isString : `bool` 

412 The field refers to a `sqlalchemy.String` and not any other type. 

413 This can return `False` even if the object was created with a 

414 string type if it has been decided that it should be implemented 

415 as a `sqlalchemy.Text` type. 

416 """ 

417 # For short strings retain them as strings 

418 if self.dtype is sqlalchemy.String and self.length and self.length <= 32: 

419 return True 

420 return False 

421 

422 def getSizedColumnType(self) -> sqlalchemy.types.TypeEngine | type: 

423 """Return a sized version of the column type. 

424 

425 Utilizes either (or neither) of ``self.length`` and ``self.nbytes``. 

426 

427 Returns 

428 ------- 

429 dtype : `sqlalchemy.types.TypeEngine` 

430 A SQLAlchemy column type object. 

431 """ 

432 if self.length is not None: 

433 # Last chance check that we are only looking at possible String 

434 if self.dtype is sqlalchemy.String and not self.isStringType(): 

435 return sqlalchemy.Text 

436 return self.dtype(length=self.length) 

437 if self.nbytes is not None: 

438 return self.dtype(nbytes=self.nbytes) 

439 return self.dtype 

440 

441 def getPythonType(self) -> type: 

442 """Return the Python type associated with this field's (SQL) dtype. 

443 

444 Returns 

445 ------- 

446 type : `type` 

447 Python type associated with this field's (SQL) `dtype`. 

448 """ 

449 # to construct these objects, nbytes keyword is needed 

450 if issubclass(self.dtype, LocalBase64Bytes): 

451 # satisfy mypy for something that must be true 

452 assert self.nbytes is not None 

453 return self.dtype(nbytes=self.nbytes).python_type 

454 else: 

455 return self.dtype().python_type # type: ignore 

456 

457 

458@dataclass 

459class ForeignKeySpec: 

460 """Definition of a foreign key constraint in a logical `Registry` table.""" 

461 

462 table: str 

463 """Name of the target table.""" 

464 

465 source: tuple[str, ...] 

466 """Tuple of source table column names.""" 

467 

468 target: tuple[str, ...] 

469 """Tuple of target table column names.""" 

470 

471 onDelete: str | None = None 

472 """SQL clause indicating how to handle deletes to the target table. 

473 

474 If not `None` (which indicates that a constraint violation exception should 

475 be raised), should be either "SET NULL" or "CASCADE". 

476 """ 

477 

478 addIndex: bool = True 

479 """If `True`, create an index on the columns of this foreign key in the 

480 source table. 

481 """ 

482 

483 @classmethod 

484 @SchemaValidationError.translate(KeyError, "Missing key {err} in foreignKey config '{config}'.") 

485 def fromConfig(cls, config: Config) -> ForeignKeySpec: 

486 """Create a `ForeignKeySpec` from a subset of a `SchemaConfig`. 

487 

488 Parameters 

489 ---------- 

490 config : `Config` 

491 Configuration describing the constraint. Nested configuration keys 

492 correspond to `ForeignKeySpec` attributes. 

493 

494 Returns 

495 ------- 

496 spec: `ForeignKeySpec` 

497 Specification structure for the constraint. 

498 

499 Raises 

500 ------ 

501 SchemaValidationError 

502 Raised if configuration keys are missing or have invalid values. 

503 """ 

504 return cls( 

505 table=config["table"], 

506 source=tuple(ensure_iterable(config["source"])), 

507 target=tuple(ensure_iterable(config["target"])), 

508 onDelete=config.get("onDelete", None), 

509 ) 

510 

511 

512@dataclass(frozen=True) 

513class IndexSpec: 

514 """Specification of an index on table columns. 

515 

516 Parameters 

517 ---------- 

518 *columns : `str` 

519 Names of the columns to index. 

520 **kwargs : `Any` 

521 Additional keyword arguments to pass directly to 

522 `sqlalchemy.schema.Index` constructor. This could be used to provide 

523 backend-specific options, e.g. to create a ``GIST`` index in PostgreSQL 

524 one can pass ``postgresql_using="gist"``. 

525 """ 

526 

527 def __init__(self, *columns: str, **kwargs: Any): 

528 object.__setattr__(self, "columns", tuple(columns)) 

529 object.__setattr__(self, "kwargs", kwargs) 

530 

531 def __hash__(self) -> int: 

532 return hash(self.columns) 

533 

534 columns: tuple[str, ...] 

535 """Column names to include in the index (`Tuple` [ `str` ]).""" 

536 

537 kwargs: dict[str, Any] 

538 """Additional keyword arguments passed directly to 

539 `sqlalchemy.schema.Index` constructor (`dict` [ `str`, `Any` ]). 

540 """ 

541 

542 

543@dataclass 

544class TableSpec: 

545 """A data class used to define a table or table-like query interface. 

546 

547 Parameters 

548 ---------- 

549 fields : `~collections.abc.Iterable` [ `FieldSpec` ] 

550 Specifications for the columns in this table. 

551 unique : `~collections.abc.Iterable` [ `tuple` [ `str` ] ], optional 

552 Non-primary-key unique constraints for the table. 

553 indexes : `~collections.abc.Iterable` [ `IndexSpec` ], optional 

554 Indexes for the table. 

555 foreignKeys : `~collections.abc.Iterable` [ `ForeignKeySpec` ], optional 

556 Foreign key constraints for the table. 

557 exclusion : `~collections.abc.Iterable` [ `tuple` [ `str` or `type` ] ] 

558 Special constraints that prohibit overlaps between timespans over rows 

559 where other columns are equal. These take the same form as unique 

560 constraints, but each tuple may contain a single 

561 `TimespanDatabaseRepresentation` subclass representing a timespan 

562 column. 

563 recycleIds : `bool`, optional 

564 If `True`, allow databases that might normally recycle autoincrement 

565 IDs to do so (usually better for performance) on any autoincrement 

566 field in this table. 

567 doc : `str`, optional 

568 Documentation for the table. 

569 """ 

570 

571 def __init__( 

572 self, 

573 fields: Iterable[FieldSpec], 

574 *, 

575 unique: Iterable[tuple[str, ...]] = (), 

576 indexes: Iterable[IndexSpec] = (), 

577 foreignKeys: Iterable[ForeignKeySpec] = (), 

578 exclusion: Iterable[tuple[str | type[TimespanDatabaseRepresentation], ...]] = (), 

579 recycleIds: bool = True, 

580 doc: str | None = None, 

581 ): 

582 self.fields = NamedValueSet(fields) 

583 self.unique = set(unique) 

584 self.indexes = set(indexes) 

585 self.foreignKeys = list(foreignKeys) 

586 self.exclusion = set(exclusion) 

587 self.recycleIds = recycleIds 

588 self.doc = doc 

589 

590 fields: NamedValueSet[FieldSpec] 

591 """Specifications for the columns in this table.""" 

592 

593 unique: set[tuple[str, ...]] 

594 """Non-primary-key unique constraints for the table.""" 

595 

596 indexes: set[IndexSpec] 

597 """Indexes for the table.""" 

598 

599 foreignKeys: list[ForeignKeySpec] 

600 """Foreign key constraints for the table.""" 

601 

602 exclusion: set[tuple[str | type[TimespanDatabaseRepresentation], ...]] 

603 """Exclusion constraints for the table. 

604 

605 Exclusion constraints behave mostly like unique constraints, but may 

606 contain a database-native Timespan column that is restricted to not overlap 

607 across rows (for identical combinations of any non-Timespan columns in the 

608 constraint). 

609 """ 

610 

611 recycleIds: bool = True 

612 """If `True`, allow databases that might normally recycle autoincrement IDs 

613 to do so (usually better for performance) on any autoincrement field in 

614 this table. 

615 """ 

616 

617 doc: str | None = None 

618 """Documentation for the table.""" 

619 

620 @classmethod 

621 @SchemaValidationError.translate(KeyError, "Missing key {err} in table config '{config}'.") 

622 def fromConfig(cls, config: Config) -> TableSpec: 

623 """Create a `ForeignKeySpec` from a subset of a `SchemaConfig`. 

624 

625 Parameters 

626 ---------- 

627 config : `Config` 

628 Configuration describing the constraint. Nested configuration keys 

629 correspond to `TableSpec` attributes. 

630 

631 Returns 

632 ------- 

633 spec: `TableSpec` 

634 Specification structure for the table. 

635 

636 Raises 

637 ------ 

638 SchemaValidationError 

639 Raised if configuration keys are missing or have invalid values. 

640 """ 

641 return cls( 

642 fields=NamedValueSet(FieldSpec.fromConfig(c) for c in config["columns"]), 

643 unique={tuple(u) for u in config.get("unique", ())}, 

644 foreignKeys=[ForeignKeySpec.fromConfig(c) for c in config.get("foreignKeys", ())], 

645 doc=stripIfNotNone(config.get("doc")), 

646 )