Coverage for python/lsst/daf/butler/core/ddl.py: 54%

233 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-14 19:21 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21"""Classes for representing SQL data-definition language (DDL) in Python. 

22 

23This include "CREATE TABLE" etc. 

24 

25This provides an extra layer on top of SQLAlchemy's classes for these concepts, 

26because we need a level of indirection between logical tables and the actual 

27SQL, and SQLAlchemy's DDL classes always map 1-1 to SQL. 

28 

29We've opted for the rather more obscure "ddl" as the name of this module 

30instead of "schema" because the latter is too overloaded; in most SQL 

31databases, a "schema" is also another term for a namespace. 

32""" 

33from __future__ import annotations 

34 

35from lsst import sphgeom 

36 

37__all__ = ( 

38 "TableSpec", 

39 "FieldSpec", 

40 "ForeignKeySpec", 

41 "IndexSpec", 

42 "Base64Bytes", 

43 "Base64Region", 

44 "AstropyTimeNsecTai", 

45 "GUID", 

46) 

47 

48import logging 

49import uuid 

50from base64 import b64decode, b64encode 

51from collections.abc import Callable, Iterable 

52from dataclasses import dataclass 

53from math import ceil 

54from typing import TYPE_CHECKING, Any 

55 

56import astropy.time 

57import sqlalchemy 

58from lsst.sphgeom import Region 

59from lsst.utils.iteration import ensure_iterable 

60from sqlalchemy.dialects.postgresql import UUID 

61 

62from . import time_utils 

63from .config import Config 

64from .exceptions import ValidationError 

65from .named import NamedValueSet 

66from .utils import stripIfNotNone 

67 

68if TYPE_CHECKING: 

69 from .timespan import TimespanDatabaseRepresentation 

70 

71 

72_LOG = logging.getLogger(__name__) 

73 

74 

75class SchemaValidationError(ValidationError): 

76 """Exceptions that indicate problems in Registry schema configuration.""" 

77 

78 @classmethod 

79 def translate(cls, caught: type[Exception], message: str) -> Callable: 

80 """Return decorator to re-raise exceptions as `SchemaValidationError`. 

81 

82 Decorated functions must be class or instance methods, with a 

83 ``config`` parameter as their first argument. This will be passed 

84 to ``message.format()`` as a keyword argument, along with ``err``, 

85 the original exception. 

86 

87 Parameters 

88 ---------- 

89 caught : `type` (`Exception` subclass) 

90 The type of exception to catch. 

91 message : `str` 

92 A `str.format` string that may contain named placeholders for 

93 ``config``, ``err``, or any keyword-only argument accepted by 

94 the decorated function. 

95 """ 

96 

97 def decorate(func: Callable) -> Callable: 

98 def decorated(self: Any, config: Config, *args: Any, **kwargs: Any) -> Any: 

99 try: 

100 return func(self, config, *args, **kwargs) 

101 except caught as err: 

102 raise cls(message.format(config=str(config), err=err)) 

103 

104 return decorated 

105 

106 return decorate 

107 

108 

109class Base64Bytes(sqlalchemy.TypeDecorator): 

110 """A SQLAlchemy custom type for Python `bytes`. 

111 

112 Maps Python `bytes` to a base64-encoded `sqlalchemy.Text` field. 

113 """ 

114 

115 impl = sqlalchemy.Text 

116 

117 cache_ok = True 

118 

119 def __init__(self, nbytes: int | None = None, *args: Any, **kwargs: Any): 

120 if nbytes is not None: 

121 length = 4 * ceil(nbytes / 3) if self.impl == sqlalchemy.String else None 

122 else: 

123 length = None 

124 super().__init__(*args, length=length, **kwargs) 

125 self.nbytes = nbytes 

126 

127 def process_bind_param(self, value: bytes | None, dialect: sqlalchemy.engine.Dialect) -> str | None: 

128 # 'value' is native `bytes`. We want to encode that to base64 `bytes` 

129 # and then ASCII `str`, because `str` is what SQLAlchemy expects for 

130 # String fields. 

131 if value is None: 

132 return None 

133 if not isinstance(value, bytes): 

134 raise TypeError( 

135 f"Base64Bytes fields require 'bytes' values; got '{value}' with type {type(value)}." 

136 ) 

137 return b64encode(value).decode("ascii") 

138 

139 def process_result_value(self, value: str | None, dialect: sqlalchemy.engine.Dialect) -> bytes | None: 

140 # 'value' is a `str` that must be ASCII because it's base64-encoded. 

141 # We want to transform that to base64-encoded `bytes` and then 

142 # native `bytes`. 

143 return b64decode(value.encode("ascii")) if value is not None else None 

144 

145 @property 

146 def python_type(self) -> type[bytes]: 

147 return bytes 

148 

149 

150# create an alias, for use below to disambiguate between the built in 

151# sqlachemy type 

152LocalBase64Bytes = Base64Bytes 

153 

154 

155class Base64Region(Base64Bytes): 

156 """A SQLAlchemy custom type for Python `sphgeom.Region`. 

157 

158 Maps Python `sphgeom.Region` to a base64-encoded `sqlalchemy.String`. 

159 """ 

160 

161 cache_ok = True # have to be set explicitly in each class 

162 

163 def process_bind_param(self, value: Region | None, dialect: sqlalchemy.engine.Dialect) -> str | None: 

164 if value is None: 

165 return None 

166 return super().process_bind_param(value.encode(), dialect) 

167 

168 def process_result_value(self, value: str | None, dialect: sqlalchemy.engine.Dialect) -> Region | None: 

169 if value is None: 

170 return None 

171 return Region.decode(super().process_result_value(value, dialect)) 

172 

173 @property 

174 def python_type(self) -> type[sphgeom.Region]: 

175 return sphgeom.Region 

176 

177 

178class AstropyTimeNsecTai(sqlalchemy.TypeDecorator): 

179 """A SQLAlchemy custom type for Python `astropy.time.Time`. 

180 

181 Maps Python `astropy.time.Time` to a number of nanoseconds since Unix 

182 epoch in TAI scale. 

183 """ 

184 

185 impl = sqlalchemy.BigInteger 

186 

187 cache_ok = True 

188 

189 def process_bind_param( 

190 self, value: astropy.time.Time | None, dialect: sqlalchemy.engine.Dialect 

191 ) -> int | None: 

192 if value is None: 

193 return None 

194 if not isinstance(value, astropy.time.Time): 

195 raise TypeError(f"Unsupported type: {type(value)}, expected astropy.time.Time") 

196 value = time_utils.TimeConverter().astropy_to_nsec(value) 

197 return value 

198 

199 def process_result_value( 

200 self, value: int | None, dialect: sqlalchemy.engine.Dialect 

201 ) -> astropy.time.Time | None: 

202 # value is nanoseconds since epoch, or None 

203 if value is None: 

204 return None 

205 value = time_utils.TimeConverter().nsec_to_astropy(value) 

206 return value 

207 

208 

209# TODO: sqlalchemy 2 has internal support for UUID: 

210# https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.Uuid 

211class GUID(sqlalchemy.TypeDecorator): 

212 """Platform-independent GUID type. 

213 

214 Uses PostgreSQL's UUID type, otherwise uses CHAR(32), storing as 

215 stringified hex values. 

216 """ 

217 

218 impl = sqlalchemy.CHAR 

219 

220 cache_ok = True 

221 

222 def load_dialect_impl(self, dialect: sqlalchemy.Dialect) -> sqlalchemy.types.TypeEngine: 

223 if dialect.name == "postgresql": 

224 return dialect.type_descriptor(UUID()) 

225 else: 

226 return dialect.type_descriptor(sqlalchemy.CHAR(32)) 

227 

228 def process_bind_param(self, value: Any, dialect: sqlalchemy.Dialect) -> str | None: 

229 if value is None: 

230 return value 

231 

232 # Coerce input to UUID type, in general having UUID on input is the 

233 # only thing that we want but there is code right now that uses ints. 

234 if isinstance(value, int): 

235 value = uuid.UUID(int=value) 

236 elif isinstance(value, bytes): 

237 value = uuid.UUID(bytes=value) 

238 elif isinstance(value, str): 

239 # hexstring 

240 value = uuid.UUID(hex=value) 

241 elif not isinstance(value, uuid.UUID): 

242 raise TypeError(f"Unexpected type of a bind value: {type(value)}") 

243 

244 if dialect.name == "postgresql": 

245 return str(value) 

246 else: 

247 return "%.32x" % value.int 

248 

249 def process_result_value( 

250 self, value: str | uuid.UUID | None, dialect: sqlalchemy.Dialect 

251 ) -> uuid.UUID | None: 

252 if value is None: 

253 return value 

254 elif isinstance(value, uuid.UUID): 

255 # sqlalchemy 2 converts to UUID internally 

256 return value 

257 else: 

258 return uuid.UUID(hex=value) 

259 

260 

261VALID_CONFIG_COLUMN_TYPES = { 

262 "string": sqlalchemy.String, 

263 "int": sqlalchemy.BigInteger, 

264 "float": sqlalchemy.Float, 

265 "region": Base64Region, 

266 "bool": sqlalchemy.Boolean, 

267 "blob": sqlalchemy.LargeBinary, 

268 "datetime": AstropyTimeNsecTai, 

269 "hash": Base64Bytes, 

270 "uuid": GUID, 

271} 

272 

273 

274@dataclass 

275class FieldSpec: 

276 """A data class for defining a column in a logical `Registry` table.""" 

277 

278 name: str 

279 """Name of the column.""" 

280 

281 dtype: type 

282 """Type of the column; usually a `type` subclass provided by SQLAlchemy 

283 that defines both a Python type and a corresponding precise SQL type. 

284 """ 

285 

286 length: int | None = None 

287 """Length of the type in the database, for variable-length types.""" 

288 

289 nbytes: int | None = None 

290 """Natural length used for hash and encoded-region columns, to be converted 

291 into the post-encoding length. 

292 """ 

293 

294 primaryKey: bool = False 

295 """Whether this field is (part of) its table's primary key.""" 

296 

297 autoincrement: bool = False 

298 """Whether the database should insert automatically incremented values when 

299 no value is provided in an INSERT. 

300 """ 

301 

302 nullable: bool = True 

303 """Whether this field is allowed to be NULL. If ``primaryKey`` is 

304 `True`, during construction this value will be forced to `False`.""" 

305 

306 default: Any = None 

307 """A server-side default value for this field. 

308 

309 This is passed directly as the ``server_default`` argument to 

310 `sqlalchemy.schema.Column`. It does _not_ go through SQLAlchemy's usual 

311 type conversion or quoting for Python literals, and should hence be used 

312 with care. See the SQLAlchemy documentation for more information. 

313 """ 

314 

315 doc: str | None = None 

316 """Documentation for this field.""" 

317 

318 def __post_init__(self) -> None: 

319 if self.primaryKey: 

320 # Change the default to match primaryKey. 

321 self.nullable = False 

322 

323 def __eq__(self, other: Any) -> bool: 

324 if isinstance(other, FieldSpec): 

325 return self.name == other.name 

326 else: 

327 return NotImplemented 

328 

329 def __hash__(self) -> int: 

330 return hash(self.name) 

331 

332 @classmethod 

333 @SchemaValidationError.translate(KeyError, "Missing key {err} in column config '{config}'.") 

334 def fromConfig(cls, config: Config, **kwargs: Any) -> FieldSpec: 

335 """Create a `FieldSpec` from a subset of a `SchemaConfig`. 

336 

337 Parameters 

338 ---------- 

339 config: `Config` 

340 Configuration describing the column. Nested configuration keys 

341 correspond to `FieldSpec` attributes. 

342 **kwargs 

343 Additional keyword arguments that provide defaults for values 

344 not present in config. 

345 

346 Returns 

347 ------- 

348 spec: `FieldSpec` 

349 Specification structure for the column. 

350 

351 Raises 

352 ------ 

353 SchemaValidationError 

354 Raised if configuration keys are missing or have invalid values. 

355 """ 

356 dtype = VALID_CONFIG_COLUMN_TYPES.get(config["type"]) 

357 if dtype is None: 

358 raise SchemaValidationError(f"Invalid field type string: '{config['type']}'.") 

359 if not config["name"].islower(): 

360 raise SchemaValidationError(f"Column name '{config['name']}' is not all lowercase.") 

361 self = cls(name=config["name"], dtype=dtype, **kwargs) 

362 self.length = config.get("length", self.length) 

363 self.nbytes = config.get("nbytes", self.nbytes) 

364 if self.length is not None and self.nbytes is not None: 

365 raise SchemaValidationError(f"Both length and nbytes provided for field '{self.name}'.") 

366 self.primaryKey = config.get("primaryKey", self.primaryKey) 

367 self.autoincrement = config.get("autoincrement", self.autoincrement) 

368 self.nullable = config.get("nullable", False if self.primaryKey else self.nullable) 

369 self.doc = stripIfNotNone(config.get("doc", None)) 

370 return self 

371 

372 @classmethod 

373 def for_region(cls, name: str = "region", nullable: bool = True, nbytes: int = 2048) -> FieldSpec: 

374 """Create a `FieldSpec` for a spatial region column. 

375 

376 Parameters 

377 ---------- 

378 name : `str`, optional 

379 Name for the field. 

380 nullable : `bool`, optional 

381 Whether NULL values are permitted. 

382 nbytes : `int`, optional 

383 Maximum number of bytes for serialized regions. The actual column 

384 size will be larger to allow for base-64 encoding. 

385 

386 Returns 

387 ------- 

388 spec : `FieldSpec` 

389 Specification structure for a region column. 

390 """ 

391 return cls(name, nullable=nullable, dtype=Base64Region, nbytes=nbytes) 

392 

393 def isStringType(self) -> bool: 

394 """Indicate that this is a sqlalchemy.String field spec. 

395 

396 Returns 

397 ------- 

398 isString : `bool` 

399 The field refers to a `sqlalchemy.String` and not any other type. 

400 This can return `False` even if the object was created with a 

401 string type if it has been decided that it should be implemented 

402 as a `sqlalchemy.Text` type. 

403 """ 

404 if self.dtype == sqlalchemy.String: 

405 # For short strings retain them as strings 

406 if self.dtype == sqlalchemy.String and self.length and self.length <= 32: 

407 return True 

408 return False 

409 

410 def getSizedColumnType(self) -> sqlalchemy.types.TypeEngine | type: 

411 """Return a sized version of the column type. 

412 

413 Utilizes either (or neither) of ``self.length`` and ``self.nbytes``. 

414 

415 Returns 

416 ------- 

417 dtype : `sqlalchemy.types.TypeEngine` 

418 A SQLAlchemy column type object. 

419 """ 

420 if self.length is not None: 

421 # Last chance check that we are only looking at possible String 

422 if self.dtype == sqlalchemy.String and not self.isStringType(): 

423 return sqlalchemy.Text 

424 return self.dtype(length=self.length) 

425 if self.nbytes is not None: 

426 return self.dtype(nbytes=self.nbytes) 

427 return self.dtype 

428 

429 def getPythonType(self) -> type: 

430 """Return the Python type associated with this field's (SQL) dtype. 

431 

432 Returns 

433 ------- 

434 type : `type` 

435 Python type associated with this field's (SQL) `dtype`. 

436 """ 

437 # to construct these objects, nbytes keyword is needed 

438 if issubclass(self.dtype, LocalBase64Bytes): 

439 # satisfy mypy for something that must be true 

440 assert self.nbytes is not None 

441 return self.dtype(nbytes=self.nbytes).python_type 

442 else: 

443 return self.dtype().python_type # type: ignore 

444 

445 

446@dataclass 

447class ForeignKeySpec: 

448 """Definition of a foreign key constraint in a logical `Registry` table.""" 

449 

450 table: str 

451 """Name of the target table.""" 

452 

453 source: tuple[str, ...] 

454 """Tuple of source table column names.""" 

455 

456 target: tuple[str, ...] 

457 """Tuple of target table column names.""" 

458 

459 onDelete: str | None = None 

460 """SQL clause indicating how to handle deletes to the target table. 

461 

462 If not `None` (which indicates that a constraint violation exception should 

463 be raised), should be either "SET NULL" or "CASCADE". 

464 """ 

465 

466 addIndex: bool = True 

467 """If `True`, create an index on the columns of this foreign key in the 

468 source table. 

469 """ 

470 

471 @classmethod 

472 @SchemaValidationError.translate(KeyError, "Missing key {err} in foreignKey config '{config}'.") 

473 def fromConfig(cls, config: Config) -> ForeignKeySpec: 

474 """Create a `ForeignKeySpec` from a subset of a `SchemaConfig`. 

475 

476 Parameters 

477 ---------- 

478 config: `Config` 

479 Configuration describing the constraint. Nested configuration keys 

480 correspond to `ForeignKeySpec` attributes. 

481 

482 Returns 

483 ------- 

484 spec: `ForeignKeySpec` 

485 Specification structure for the constraint. 

486 

487 Raises 

488 ------ 

489 SchemaValidationError 

490 Raised if configuration keys are missing or have invalid values. 

491 """ 

492 return cls( 

493 table=config["table"], 

494 source=tuple(ensure_iterable(config["source"])), 

495 target=tuple(ensure_iterable(config["target"])), 

496 onDelete=config.get("onDelete", None), 

497 ) 

498 

499 

500@dataclass(frozen=True) 

501class IndexSpec: 

502 """Specification of an index on table columns. 

503 

504 Parameters 

505 ---------- 

506 *columns : `str` 

507 Names of the columns to index. 

508 **kwargs: `Any` 

509 Additional keyword arguments to pass directly to 

510 `sqlalchemy.schema.Index` constructor. This could be used to provide 

511 backend-specific options, e.g. to create a ``GIST`` index in PostgreSQL 

512 one can pass ``postgresql_using="gist"``. 

513 """ 

514 

515 def __init__(self, *columns: str, **kwargs: Any): 

516 object.__setattr__(self, "columns", tuple(columns)) 

517 object.__setattr__(self, "kwargs", kwargs) 

518 

519 def __hash__(self) -> int: 

520 return hash(self.columns) 

521 

522 columns: tuple[str, ...] 

523 """Column names to include in the index (`Tuple` [ `str` ]).""" 

524 

525 kwargs: dict[str, Any] 

526 """Additional keyword arguments passed directly to 

527 `sqlalchemy.schema.Index` constructor (`dict` [ `str`, `Any` ]). 

528 """ 

529 

530 

531@dataclass 

532class TableSpec: 

533 """A data class used to define a table or table-like query interface. 

534 

535 Parameters 

536 ---------- 

537 fields : `~collections.abc.Iterable` [ `FieldSpec` ] 

538 Specifications for the columns in this table. 

539 unique : `~collections.abc.Iterable` [ `tuple` [ `str` ] ], optional 

540 Non-primary-key unique constraints for the table. 

541 indexes: `~collections.abc.Iterable` [ `IndexSpec` ], optional 

542 Indexes for the table. 

543 foreignKeys : `~collections.abc.Iterable` [ `ForeignKeySpec` ], optional 

544 Foreign key constraints for the table. 

545 exclusion : `~collections.abc.Iterable` [ `tuple` [ `str` or `type` ] ] 

546 Special constraints that prohibit overlaps between timespans over rows 

547 where other columns are equal. These take the same form as unique 

548 constraints, but each tuple may contain a single 

549 `TimespanDatabaseRepresentation` subclass representing a timespan 

550 column. 

551 recycleIds : `bool`, optional 

552 If `True`, allow databases that might normally recycle autoincrement 

553 IDs to do so (usually better for performance) on any autoincrement 

554 field in this table. 

555 doc : `str`, optional 

556 Documentation for the table. 

557 """ 

558 

559 def __init__( 

560 self, 

561 fields: Iterable[FieldSpec], 

562 *, 

563 unique: Iterable[tuple[str, ...]] = (), 

564 indexes: Iterable[IndexSpec] = (), 

565 foreignKeys: Iterable[ForeignKeySpec] = (), 

566 exclusion: Iterable[tuple[str | type[TimespanDatabaseRepresentation], ...]] = (), 

567 recycleIds: bool = True, 

568 doc: str | None = None, 

569 ): 

570 self.fields = NamedValueSet(fields) 

571 self.unique = set(unique) 

572 self.indexes = set(indexes) 

573 self.foreignKeys = list(foreignKeys) 

574 self.exclusion = set(exclusion) 

575 self.recycleIds = recycleIds 

576 self.doc = doc 

577 

578 fields: NamedValueSet[FieldSpec] 

579 """Specifications for the columns in this table.""" 

580 

581 unique: set[tuple[str, ...]] 

582 """Non-primary-key unique constraints for the table.""" 

583 

584 indexes: set[IndexSpec] 

585 """Indexes for the table.""" 

586 

587 foreignKeys: list[ForeignKeySpec] 

588 """Foreign key constraints for the table.""" 

589 

590 exclusion: set[tuple[str | type[TimespanDatabaseRepresentation], ...]] 

591 """Exclusion constraints for the table. 

592 

593 Exclusion constraints behave mostly like unique constraints, but may 

594 contain a database-native Timespan column that is restricted to not overlap 

595 across rows (for identical combinations of any non-Timespan columns in the 

596 constraint). 

597 """ 

598 

599 recycleIds: bool = True 

600 """If `True`, allow databases that might normally recycle autoincrement IDs 

601 to do so (usually better for performance) on any autoincrement field in 

602 this table. 

603 """ 

604 

605 doc: str | None = None 

606 """Documentation for the table.""" 

607 

608 @classmethod 

609 @SchemaValidationError.translate(KeyError, "Missing key {err} in table config '{config}'.") 

610 def fromConfig(cls, config: Config) -> TableSpec: 

611 """Create a `ForeignKeySpec` from a subset of a `SchemaConfig`. 

612 

613 Parameters 

614 ---------- 

615 config: `Config` 

616 Configuration describing the constraint. Nested configuration keys 

617 correspond to `TableSpec` attributes. 

618 

619 Returns 

620 ------- 

621 spec: `TableSpec` 

622 Specification structure for the table. 

623 

624 Raises 

625 ------ 

626 SchemaValidationError 

627 Raised if configuration keys are missing or have invalid values. 

628 """ 

629 return cls( 

630 fields=NamedValueSet(FieldSpec.fromConfig(c) for c in config["columns"]), 

631 unique={tuple(u) for u in config.get("unique", ())}, 

632 foreignKeys=[ForeignKeySpec.fromConfig(c) for c in config.get("foreignKeys", ())], 

633 doc=stripIfNotNone(config.get("doc")), 

634 )