Coverage for python/lsst/daf/butler/core/ddl.py: 44%

179 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-01 19:55 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21"""Classes for representing SQL data-definition language (DDL) in Python. 

22 

23This include "CREATE TABLE" etc. 

24 

25This provides an extra layer on top of SQLAlchemy's classes for these concepts, 

26because we need a level of indirection between logical tables and the actual 

27SQL, and SQLAlchemy's DDL classes always map 1-1 to SQL. 

28 

29We've opted for the rather more obscure "ddl" as the name of this module 

30instead of "schema" because the latter is too overloaded; in most SQL 

31databases, a "schema" is also another term for a namespace. 

32""" 

33from __future__ import annotations 

34 

35__all__ = ("TableSpec", "FieldSpec", "ForeignKeySpec", "Base64Bytes", "Base64Region", 

36 "AstropyTimeNsecTai", "GUID") 

37 

38from base64 import b64encode, b64decode 

39import logging 

40from math import ceil 

41from dataclasses import dataclass 

42from typing import Any, Callable, Iterable, List, Optional, Set, Tuple, Type, TYPE_CHECKING, Union 

43import uuid 

44 

45import sqlalchemy 

46from sqlalchemy.dialects.postgresql import UUID 

47import astropy.time 

48 

49from lsst.sphgeom import Region 

50from .config import Config 

51from .exceptions import ValidationError 

52from . import time_utils 

53from .utils import iterable, stripIfNotNone 

54from .named import NamedValueSet 

55 

56if TYPE_CHECKING: 56 ↛ 57line 56 didn't jump to line 57, because the condition on line 56 was never true

57 from .timespan import TimespanDatabaseRepresentation 

58 

59 

60_LOG = logging.getLogger(__name__) 

61 

62 

63class SchemaValidationError(ValidationError): 

64 """Exceptions that indicate problems in Registry schema configuration.""" 

65 

66 @classmethod 

67 def translate(cls, caught: Type[Exception], message: str) -> Callable: 

68 """Return decorator to re-raise exceptions as `SchemaValidationError`. 

69 

70 Decorated functions must be class or instance methods, with a 

71 ``config`` parameter as their first argument. This will be passed 

72 to ``message.format()`` as a keyword argument, along with ``err``, 

73 the original exception. 

74 

75 Parameters 

76 ---------- 

77 caught : `type` (`Exception` subclass) 

78 The type of exception to catch. 

79 message : `str` 

80 A `str.format` string that may contain named placeholders for 

81 ``config``, ``err``, or any keyword-only argument accepted by 

82 the decorated function. 

83 """ 

84 def decorate(func: Callable) -> Callable: 

85 def decorated(self: Any, config: Config, *args: Any, **kwargs: Any) -> Any: 

86 try: 

87 return func(self, config, *args, **kwargs) 

88 except caught as err: 

89 raise cls(message.format(config=str(config), err=err)) 

90 return decorated 

91 return decorate 

92 

93 

94class Base64Bytes(sqlalchemy.TypeDecorator): 

95 """A SQLAlchemy custom type for Python `bytes`. 

96 

97 Maps Python `bytes` to a base64-encoded `sqlalchemy.Text` field. 

98 """ 

99 

100 impl = sqlalchemy.Text 

101 

102 cache_ok = True 

103 

104 def __init__(self, nbytes: int, *args: Any, **kwargs: Any): 

105 length = 4*ceil(nbytes/3) if self.impl == sqlalchemy.String else None 

106 super().__init__(*args, length=length, **kwargs) 

107 self.nbytes = nbytes 

108 

109 def process_bind_param(self, value: Optional[bytes], dialect: sqlalchemy.engine.Dialect 

110 ) -> Optional[str]: 

111 # 'value' is native `bytes`. We want to encode that to base64 `bytes` 

112 # and then ASCII `str`, because `str` is what SQLAlchemy expects for 

113 # String fields. 

114 if value is None: 

115 return None 

116 if not isinstance(value, bytes): 

117 raise TypeError( 

118 f"Base64Bytes fields require 'bytes' values; got '{value}' with type {type(value)}." 

119 ) 

120 return b64encode(value).decode("ascii") 

121 

122 def process_result_value(self, value: Optional[str], dialect: sqlalchemy.engine.Dialect 

123 ) -> Optional[bytes]: 

124 # 'value' is a `str` that must be ASCII because it's base64-encoded. 

125 # We want to transform that to base64-encoded `bytes` and then 

126 # native `bytes`. 

127 return b64decode(value.encode("ascii")) if value is not None else None 

128 

129 

130class Base64Region(Base64Bytes): 

131 """A SQLAlchemy custom type for Python `sphgeom.Region`. 

132 

133 Maps Python `sphgeom.Region` to a base64-encoded `sqlalchemy.String`. 

134 """ 

135 

136 def process_bind_param(self, value: Optional[Region], dialect: sqlalchemy.engine.Dialect 

137 ) -> Optional[str]: 

138 if value is None: 

139 return None 

140 return super().process_bind_param(value.encode(), dialect) 

141 

142 def process_result_value(self, value: Optional[str], dialect: sqlalchemy.engine.Dialect 

143 ) -> Optional[Region]: 

144 if value is None: 

145 return None 

146 return Region.decode(super().process_result_value(value, dialect)) 

147 

148 

149class AstropyTimeNsecTai(sqlalchemy.TypeDecorator): 

150 """A SQLAlchemy custom type for Python `astropy.time.Time`. 

151 

152 Maps Python `astropy.time.Time` to a number of nanoseconds since Unix 

153 epoch in TAI scale. 

154 """ 

155 

156 impl = sqlalchemy.BigInteger 

157 

158 cache_ok = True 

159 

160 def process_bind_param(self, value: Optional[astropy.time.Time], dialect: sqlalchemy.engine.Dialect 

161 ) -> Optional[int]: 

162 if value is None: 

163 return None 

164 if not isinstance(value, astropy.time.Time): 

165 raise TypeError(f"Unsupported type: {type(value)}, expected astropy.time.Time") 

166 value = time_utils.TimeConverter().astropy_to_nsec(value) 

167 return value 

168 

169 def process_result_value(self, value: Optional[int], dialect: sqlalchemy.engine.Dialect 

170 ) -> Optional[astropy.time.Time]: 

171 # value is nanoseconds since epoch, or None 

172 if value is None: 

173 return None 

174 value = time_utils.TimeConverter().nsec_to_astropy(value) 

175 return value 

176 

177 

178class GUID(sqlalchemy.TypeDecorator): 

179 """Platform-independent GUID type. 

180 

181 Uses PostgreSQL's UUID type, otherwise uses CHAR(32), storing as 

182 stringified hex values. 

183 """ 

184 

185 impl = sqlalchemy.CHAR 

186 

187 cache_ok = True 

188 

189 def load_dialect_impl(self, dialect: sqlalchemy.Dialect) -> sqlalchemy.TypeEngine: 

190 if dialect.name == 'postgresql': 

191 return dialect.type_descriptor(UUID()) 

192 else: 

193 return dialect.type_descriptor(sqlalchemy.CHAR(32)) 

194 

195 def process_bind_param(self, value: Any, dialect: sqlalchemy.Dialect) -> Optional[str]: 

196 if value is None: 

197 return value 

198 

199 # Coerce input to UUID type, in general having UUID on input is the 

200 # only thing that we want but there is code right now that uses ints. 

201 if isinstance(value, int): 

202 value = uuid.UUID(int=value) 

203 elif isinstance(value, bytes): 

204 value = uuid.UUID(bytes=value) 

205 elif isinstance(value, str): 

206 # hexstring 

207 value = uuid.UUID(hex=value) 

208 elif not isinstance(value, uuid.UUID): 

209 raise TypeError(f"Unexpected type of a bind value: {type(value)}") 

210 

211 if dialect.name == 'postgresql': 

212 return str(value) 

213 else: 

214 return "%.32x" % value.int 

215 

216 def process_result_value(self, value: Optional[str], dialect: sqlalchemy.Dialect) -> Optional[uuid.UUID]: 

217 if value is None: 

218 return value 

219 else: 

220 return uuid.UUID(hex=value) 

221 

222 

223VALID_CONFIG_COLUMN_TYPES = { 

224 "string": sqlalchemy.String, 

225 "int": sqlalchemy.BigInteger, 

226 "float": sqlalchemy.Float, 

227 "region": Base64Region, 

228 "bool": sqlalchemy.Boolean, 

229 "blob": sqlalchemy.LargeBinary, 

230 "datetime": AstropyTimeNsecTai, 

231 "hash": Base64Bytes, 

232 "uuid": GUID, 

233} 

234 

235 

236@dataclass 

237class FieldSpec: 

238 """A data class for defining a column in a logical `Registry` table.""" 

239 

240 name: str 

241 """Name of the column.""" 

242 

243 dtype: type 

244 """Type of the column; usually a `type` subclass provided by SQLAlchemy 

245 that defines both a Python type and a corresponding precise SQL type. 

246 """ 

247 

248 length: Optional[int] = None 

249 """Length of the type in the database, for variable-length types.""" 

250 

251 nbytes: Optional[int] = None 

252 """Natural length used for hash and encoded-region columns, to be converted 

253 into the post-encoding length. 

254 """ 

255 

256 primaryKey: bool = False 

257 """Whether this field is (part of) its table's primary key.""" 

258 

259 autoincrement: bool = False 

260 """Whether the database should insert automatically incremented values when 

261 no value is provided in an INSERT. 

262 """ 

263 

264 nullable: bool = True 

265 """Whether this field is allowed to be NULL.""" 

266 

267 default: Any = None 

268 """A server-side default value for this field. 

269 

270 This is passed directly as the ``server_default`` argument to 

271 `sqlalchemy.schema.Column`. It does _not_ go through SQLAlchemy's usual 

272 type conversion or quoting for Python literals, and should hence be used 

273 with care. See the SQLAlchemy documentation for more information. 

274 """ 

275 

276 doc: Optional[str] = None 

277 """Documentation for this field.""" 

278 

279 def __eq__(self, other: Any) -> bool: 

280 if isinstance(other, FieldSpec): 

281 return self.name == other.name 

282 else: 

283 return NotImplemented 

284 

285 def __hash__(self) -> int: 

286 return hash(self.name) 

287 

288 @classmethod 

289 @SchemaValidationError.translate(KeyError, "Missing key {err} in column config '{config}'.") 

290 def fromConfig(cls, config: Config, **kwargs: Any) -> FieldSpec: 

291 """Create a `FieldSpec` from a subset of a `SchemaConfig`. 

292 

293 Parameters 

294 ---------- 

295 config: `Config` 

296 Configuration describing the column. Nested configuration keys 

297 correspond to `FieldSpec` attributes. 

298 **kwargs 

299 Additional keyword arguments that provide defaults for values 

300 not present in config. 

301 

302 Returns 

303 ------- 

304 spec: `FieldSpec` 

305 Specification structure for the column. 

306 

307 Raises 

308 ------ 

309 SchemaValidationError 

310 Raised if configuration keys are missing or have invalid values. 

311 """ 

312 dtype = VALID_CONFIG_COLUMN_TYPES.get(config["type"]) 

313 if dtype is None: 

314 raise SchemaValidationError(f"Invalid field type string: '{config['type']}'.") 

315 if not config["name"].islower(): 

316 raise SchemaValidationError(f"Column name '{config['name']}' is not all lowercase.") 

317 self = cls(name=config["name"], dtype=dtype, **kwargs) 

318 self.length = config.get("length", self.length) 

319 self.nbytes = config.get("nbytes", self.nbytes) 

320 if self.length is not None and self.nbytes is not None: 

321 raise SchemaValidationError(f"Both length and nbytes provided for field '{self.name}'.") 

322 self.primaryKey = config.get("primaryKey", self.primaryKey) 

323 self.autoincrement = config.get("autoincrement", self.autoincrement) 

324 self.nullable = config.get("nullable", False if self.primaryKey else self.nullable) 

325 self.doc = stripIfNotNone(config.get("doc", None)) 

326 return self 

327 

328 def isStringType(self) -> bool: 

329 """Indicate that this is a sqlalchemy.String field spec. 

330 

331 Returns 

332 ------- 

333 isString : `bool` 

334 The field refers to a `sqlalchemy.String` and not any other type. 

335 This can return `False` even if the object was created with a 

336 string type if it has been decided that it should be implemented 

337 as a `sqlalchemy.Text` type. 

338 """ 

339 if self.dtype == sqlalchemy.String: 

340 # For short strings retain them as strings 

341 if self.dtype == sqlalchemy.String and self.length and self.length <= 32: 

342 return True 

343 return False 

344 

345 def getSizedColumnType(self) -> sqlalchemy.types.TypeEngine: 

346 """Return a sized version of the column type. 

347 

348 Utilizes either (or neither) of ``self.length`` and ``self.nbytes``. 

349 

350 Returns 

351 ------- 

352 dtype : `sqlalchemy.types.TypeEngine` 

353 A SQLAlchemy column type object. 

354 """ 

355 if self.length is not None: 

356 # Last chance check that we are only looking at possible String 

357 if self.dtype == sqlalchemy.String and not self.isStringType(): 

358 return sqlalchemy.Text 

359 return self.dtype(length=self.length) 

360 if self.nbytes is not None: 

361 return self.dtype(nbytes=self.nbytes) 

362 return self.dtype 

363 

364 def getPythonType(self) -> type: 

365 """Return the Python type associated with this field's (SQL) dtype. 

366 

367 Returns 

368 ------- 

369 type : `type` 

370 Python type associated with this field's (SQL) `dtype`. 

371 """ 

372 return self.dtype().python_type 

373 

374 

375@dataclass 

376class ForeignKeySpec: 

377 """Definition of a foreign key constraint in a logical `Registry` table.""" 

378 

379 table: str 

380 """Name of the target table.""" 

381 

382 source: Tuple[str, ...] 

383 """Tuple of source table column names.""" 

384 

385 target: Tuple[str, ...] 

386 """Tuple of target table column names.""" 

387 

388 onDelete: Optional[str] = None 

389 """SQL clause indicating how to handle deletes to the target table. 

390 

391 If not `None` (which indicates that a constraint violation exception should 

392 be raised), should be either "SET NULL" or "CASCADE". 

393 """ 

394 

395 addIndex: bool = True 

396 """If `True`, create an index on the columns of this foreign key in the 

397 source table. 

398 """ 

399 

400 @classmethod 

401 @SchemaValidationError.translate(KeyError, "Missing key {err} in foreignKey config '{config}'.") 

402 def fromConfig(cls, config: Config) -> ForeignKeySpec: 

403 """Create a `ForeignKeySpec` from a subset of a `SchemaConfig`. 

404 

405 Parameters 

406 ---------- 

407 config: `Config` 

408 Configuration describing the constraint. Nested configuration keys 

409 correspond to `ForeignKeySpec` attributes. 

410 

411 Returns 

412 ------- 

413 spec: `ForeignKeySpec` 

414 Specification structure for the constraint. 

415 

416 Raises 

417 ------ 

418 SchemaValidationError 

419 Raised if configuration keys are missing or have invalid values. 

420 """ 

421 return cls(table=config["table"], 

422 source=tuple(iterable(config["source"])), 

423 target=tuple(iterable(config["target"])), 

424 onDelete=config.get("onDelete", None)) 

425 

426 

427@dataclass 

428class TableSpec: 

429 """A data class used to define a table or table-like query interface. 

430 

431 Parameters 

432 ---------- 

433 fields : `Iterable` [ `FieldSpec` ] 

434 Specifications for the columns in this table. 

435 unique : `Iterable` [ `tuple` [ `str` ] ], optional 

436 Non-primary-key unique constraints for the table. 

437 indexes: `Iterable` [ `tuple` [ `str` ] ], optional 

438 Indexes for the table. 

439 foreignKeys : `Iterable` [ `ForeignKeySpec` ], optional 

440 Foreign key constraints for the table. 

441 exclusion : `Iterable` [ `tuple` [ `str` or `type` ] ] 

442 Special constraints that prohibit overlaps between timespans over rows 

443 where other columns are equal. These take the same form as unique 

444 constraints, but each tuple may contain a single 

445 `TimespanDatabaseRepresentation` subclass representing a timespan 

446 column. 

447 recycleIds : `bool`, optional 

448 If `True`, allow databases that might normally recycle autoincrement 

449 IDs to do so (usually better for performance) on any autoincrement 

450 field in this table. 

451 doc : `str`, optional 

452 Documentation for the table. 

453 """ 

454 

455 def __init__( 

456 self, fields: Iterable[FieldSpec], *, 

457 unique: Iterable[Tuple[str, ...]] = (), 

458 indexes: Iterable[Tuple[str, ...]] = (), 

459 foreignKeys: Iterable[ForeignKeySpec] = (), 

460 exclusion: Iterable[Tuple[Union[str, Type[TimespanDatabaseRepresentation]], ...]] = (), 

461 recycleIds: bool = True, 

462 doc: Optional[str] = None, 

463 ): 

464 self.fields = NamedValueSet(fields) 

465 self.unique = set(unique) 

466 self.indexes = set(indexes) 

467 self.foreignKeys = list(foreignKeys) 

468 self.exclusion = set(exclusion) 

469 self.recycleIds = recycleIds 

470 self.doc = doc 

471 

472 fields: NamedValueSet[FieldSpec] 

473 """Specifications for the columns in this table.""" 

474 

475 unique: Set[Tuple[str, ...]] 

476 """Non-primary-key unique constraints for the table.""" 

477 

478 indexes: Set[Tuple[str, ...]] 

479 """Indexes for the table.""" 

480 

481 foreignKeys: List[ForeignKeySpec] 

482 """Foreign key constraints for the table.""" 

483 

484 exclusion: Set[Tuple[Union[str, Type[TimespanDatabaseRepresentation]], ...]] 

485 """Exclusion constraints for the table. 

486 

487 Exclusion constraints behave mostly like unique constraints, but may 

488 contain a database-native Timespan column that is restricted to not overlap 

489 across rows (for identical combinations of any non-Timespan columns in the 

490 constraint). 

491 """ 

492 

493 recycleIds: bool = True 

494 """If `True`, allow databases that might normally recycle autoincrement IDs 

495 to do so (usually better for performance) on any autoincrement field in 

496 this table. 

497 """ 

498 

499 doc: Optional[str] = None 

500 """Documentation for the table.""" 

501 

502 @classmethod 

503 @SchemaValidationError.translate(KeyError, "Missing key {err} in table config '{config}'.") 

504 def fromConfig(cls, config: Config) -> TableSpec: 

505 """Create a `ForeignKeySpec` from a subset of a `SchemaConfig`. 

506 

507 Parameters 

508 ---------- 

509 config: `Config` 

510 Configuration describing the constraint. Nested configuration keys 

511 correspond to `TableSpec` attributes. 

512 

513 Returns 

514 ------- 

515 spec: `TableSpec` 

516 Specification structure for the table. 

517 

518 Raises 

519 ------ 

520 SchemaValidationError 

521 Raised if configuration keys are missing or have invalid values. 

522 """ 

523 return cls( 

524 fields=NamedValueSet(FieldSpec.fromConfig(c) for c in config["columns"]), 

525 unique={tuple(u) for u in config.get("unique", ())}, 

526 foreignKeys=[ForeignKeySpec.fromConfig(c) for c in config.get("foreignKeys", ())], 

527 doc=stripIfNotNone(config.get("doc")), 

528 )