Coverage for python/lsst/daf/butler/registry/databases/postgresql.py: 33%

222 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-19 03:43 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29from sqlalchemy.sql.expression import ColumnElement as ColumnElement 

30 

31from ... import ddl, time_utils 

32 

33__all__ = ["PostgresqlDatabase"] 

34 

35import re 

36from collections.abc import Callable, Iterable, Iterator, Mapping 

37from contextlib import closing, contextmanager 

38from typing import Any 

39 

40import psycopg2 

41import sqlalchemy 

42import sqlalchemy.dialects.postgresql 

43from sqlalchemy import sql 

44 

45from ..._named import NamedValueAbstractSet 

46from ..._timespan import Timespan 

47from ...name_shrinker import NameShrinker 

48from ...timespan_database_representation import TimespanDatabaseRepresentation 

49from ..interfaces import Database 

50 

51_SERVER_VERSION_REGEX = re.compile(r"(?P<major>\d+)\.(?P<minor>\d+)") 

52 

53 

54class PostgresqlDatabase(Database): 

55 """An implementation of the `Database` interface for PostgreSQL. 

56 

57 Parameters 

58 ---------- 

59 engine : `sqlalchemy.engine.Engine` 

60 Engine to use for this connection. 

61 origin : `int` 

62 An integer ID that should be used as the default for any datasets, 

63 quanta, or other entities that use a (autoincrement, origin) compound 

64 primary key. 

65 namespace : `str`, optional 

66 The namespace (schema) this database is associated with. If `None`, 

67 the default schema for the connection is used (which may be `None`). 

68 writeable : `bool`, optional 

69 If `True`, allow write operations on the database, including 

70 ``CREATE TABLE``. 

71 

72 Notes 

73 ----- 

74 This currently requires the psycopg2 driver to be used as the backend for 

75 SQLAlchemy. Running the tests for this class requires the 

76 ``testing.postgresql`` be installed, which we assume indicates that a 

77 PostgreSQL server is installed and can be run locally in userspace. 

78 

79 Some functionality provided by this class (and used by `Registry`) requires 

80 the ``btree_gist`` PostgreSQL server extension to be installed an enabled 

81 on the database being connected to; this is checked at connection time. 

82 """ 

83 

84 def __init__( 

85 self, 

86 *, 

87 engine: sqlalchemy.engine.Engine, 

88 origin: int, 

89 namespace: str | None = None, 

90 writeable: bool = True, 

91 ): 

92 with engine.connect() as connection: 

93 # `Any` to make mypy ignore the line below, can't use type: ignore 

94 dbapi: Any = connection.connection 

95 try: 

96 dsn = dbapi.get_dsn_parameters() 

97 except (AttributeError, KeyError) as err: 

98 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err 

99 if namespace is None: 

100 query = sql.select(sql.func.current_schema()) 

101 namespace = connection.execute(query).scalar() 

102 query_text = "SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';" 

103 if not connection.execute(sqlalchemy.text(query_text)).scalar(): 

104 raise RuntimeError( 

105 "The Butler PostgreSQL backend requires the btree_gist extension. " 

106 "As extensions are enabled per-database, this may require an administrator to run " 

107 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is " 

108 " initialized." 

109 ) 

110 raw_pg_version = connection.execute(sqlalchemy.text("SHOW server_version")).scalar() 

111 if raw_pg_version is not None and (m := _SERVER_VERSION_REGEX.search(raw_pg_version)): 

112 pg_version = (int(m.group("major")), int(m.group("minor"))) 

113 else: 

114 raise RuntimeError("Failed to get PostgreSQL server version.") 

115 self._init( 

116 engine=engine, 

117 origin=origin, 

118 namespace=namespace, 

119 writeable=writeable, 

120 dbname=dsn.get("dbname"), 

121 metadata=None, 

122 pg_version=pg_version, 

123 ) 

124 

125 def _init( 

126 self, 

127 *, 

128 engine: sqlalchemy.engine.Engine, 

129 origin: int, 

130 namespace: str | None = None, 

131 writeable: bool = True, 

132 dbname: str, 

133 metadata: sqlalchemy.schema.MetaData | None, 

134 pg_version: tuple[int, int], 

135 ) -> None: 

136 # Initialization logic shared between ``__init__`` and ``clone``. 

137 super().__init__(origin=origin, engine=engine, namespace=namespace, metadata=metadata) 

138 self._writeable = writeable 

139 self.dbname = dbname 

140 self._pg_version = pg_version 

141 self._shrinker = NameShrinker(self.dialect.max_identifier_length) 

142 

143 def clone(self) -> PostgresqlDatabase: 

144 clone = self.__new__(type(self)) 

145 clone._init( 

146 origin=self.origin, 

147 engine=self._engine, 

148 namespace=self.namespace, 

149 writeable=self._writeable, 

150 dbname=self.dbname, 

151 metadata=self._metadata, 

152 pg_version=self._pg_version, 

153 ) 

154 return clone 

155 

156 @classmethod 

157 def makeEngine( 

158 cls, uri: str | sqlalchemy.engine.URL, *, writeable: bool = True 

159 ) -> sqlalchemy.engine.Engine: 

160 return sqlalchemy.engine.create_engine(uri, pool_size=1) 

161 

162 @classmethod 

163 def fromEngine( 

164 cls, 

165 engine: sqlalchemy.engine.Engine, 

166 *, 

167 origin: int, 

168 namespace: str | None = None, 

169 writeable: bool = True, 

170 ) -> Database: 

171 return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable) 

172 

173 @contextmanager 

174 def _transaction( 

175 self, 

176 *, 

177 interrupting: bool = False, 

178 savepoint: bool = False, 

179 lock: Iterable[sqlalchemy.schema.Table] = (), 

180 for_temp_tables: bool = False, 

181 ) -> Iterator[tuple[bool, sqlalchemy.engine.Connection]]: 

182 with super()._transaction(interrupting=interrupting, savepoint=savepoint, lock=lock) as ( 

183 is_new, 

184 connection, 

185 ): 

186 if is_new: 

187 # pgbouncer with transaction-level pooling (which we aim to 

188 # support) says that SET cannot be used, except for a list of 

189 # "Startup parameters" that includes "timezone" (see 

190 # https://www.pgbouncer.org/features.html#fnref:0). But I 

191 # don't see "timezone" in PostgreSQL's list of parameters 

192 # passed when creating a new connection 

193 # (https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS). 

194 # Given that the pgbouncer docs say, "PgBouncer detects their 

195 # changes and so it can guarantee they remain consistent for 

196 # the client", I assume we can use "SET TIMESPAN" and pgbouncer 

197 # will take care of clients that share connections being set 

198 # consistently. And if that assumption is wrong, we should 

199 # still probably be okay, since all clients should be Butler 

200 # clients, and they'll all be setting the same thing. 

201 # 

202 # The "SET TRANSACTION READ ONLY" should also be safe, because 

203 # it only ever acts on the current transaction; I think it's 

204 # not included in pgbouncer's declaration that SET is 

205 # incompatible with transaction-level pooling because 

206 # PostgreSQL actually considers SET TRANSACTION to be a 

207 # fundamentally different statement from SET (they have their 

208 # own distinct doc pages, at least). 

209 if not (self.isWriteable() or for_temp_tables): 

210 # PostgreSQL permits writing to temporary tables inside 

211 # read-only transactions, but it doesn't permit creating 

212 # them. 

213 with closing(connection.connection.cursor()) as cursor: 

214 cursor.execute("SET TRANSACTION READ ONLY") 

215 cursor.execute("SET TIME ZONE 0") 

216 else: 

217 with closing(connection.connection.cursor()) as cursor: 

218 # Make timestamps UTC, because we didn't use TIMESTAMPZ 

219 # for the column type. When we can tolerate a schema 

220 # change, we should change that type and remove this 

221 # line. 

222 cursor.execute("SET TIME ZONE 0") 

223 yield is_new, connection 

224 

225 @contextmanager 

226 def temporary_table( 

227 self, spec: ddl.TableSpec, name: str | None = None 

228 ) -> Iterator[sqlalchemy.schema.Table]: 

229 # Docstring inherited. 

230 with self.transaction(for_temp_tables=True), super().temporary_table(spec, name) as table: 

231 yield table 

232 

233 def _lockTables( 

234 self, connection: sqlalchemy.engine.Connection, tables: Iterable[sqlalchemy.schema.Table] = () 

235 ) -> None: 

236 # Docstring inherited. 

237 for table in tables: 

238 connection.execute(sqlalchemy.text(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE")) 

239 

240 def isWriteable(self) -> bool: 

241 return self._writeable 

242 

243 def __str__(self) -> str: 

244 return f"PostgreSQL@{self.dbname}:{self.namespace}" 

245 

246 def shrinkDatabaseEntityName(self, original: str) -> str: 

247 return self._shrinker.shrink(original) 

248 

249 def expandDatabaseEntityName(self, shrunk: str) -> str: 

250 return self._shrinker.expand(shrunk) 

251 

252 def _convertExclusionConstraintSpec( 

253 self, 

254 table: str, 

255 spec: tuple[str | type[TimespanDatabaseRepresentation], ...], 

256 metadata: sqlalchemy.MetaData, 

257 ) -> sqlalchemy.schema.Constraint: 

258 # Docstring inherited. 

259 args: list[tuple[sqlalchemy.schema.Column, str]] = [] 

260 names = ["excl"] 

261 for item in spec: 

262 if isinstance(item, str): 

263 args.append((sqlalchemy.schema.Column(item), "=")) 

264 names.append(item) 

265 elif issubclass(item, TimespanDatabaseRepresentation): 

266 assert item is self.getTimespanRepresentation() 

267 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&")) 

268 names.append(TimespanDatabaseRepresentation.NAME) 

269 return sqlalchemy.dialects.postgresql.ExcludeConstraint( 

270 *args, 

271 name=self.shrinkDatabaseEntityName("_".join(names)), 

272 ) 

273 

274 def _make_temporary_table( 

275 self, 

276 connection: sqlalchemy.engine.Connection, 

277 spec: ddl.TableSpec, 

278 name: str | None = None, 

279 **kwargs: Any, 

280 ) -> sqlalchemy.schema.Table: 

281 # Docstring inherited 

282 # Adding ON COMMIT DROP here is really quite defensive: we already 

283 # manually drop the table at the end of the temporary_table context 

284 # manager, and that will usually happen first. But this will guarantee 

285 # that we drop the table at the end of the transaction even if the 

286 # connection lasts longer, and that's good citizenship when connections 

287 # may be multiplexed by e.g. pgbouncer. 

288 return super()._make_temporary_table(connection, spec, name, postgresql_on_commit="DROP", **kwargs) 

289 

290 @classmethod 

291 def getTimespanRepresentation(cls) -> type[TimespanDatabaseRepresentation]: 

292 # Docstring inherited. 

293 return _RangeTimespanRepresentation 

294 

295 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None: 

296 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.") 

297 if not rows: 

298 return 

299 # This uses special support for UPSERT in PostgreSQL backend: 

300 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert 

301 query = sqlalchemy.dialects.postgresql.dml.insert(table) 

302 # In the SET clause assign all columns using special `excluded` 

303 # pseudo-table. If some column in the table does not appear in the 

304 # INSERT list this will set it to NULL. 

305 excluded = query.excluded 

306 data = { 

307 column.name: getattr(excluded, column.name) 

308 for column in table.columns 

309 if column.name not in table.primary_key 

310 } 

311 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data) 

312 with self._transaction() as (_, connection): 

313 connection.execute(query, rows) 

314 

315 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict, primary_key_only: bool = False) -> int: 

316 # Docstring inherited. 

317 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.") 

318 if not rows: 

319 return 0 

320 # Like `replace`, this uses UPSERT. 

321 base_insert = sqlalchemy.dialects.postgresql.dml.insert(table) 

322 if primary_key_only: 

323 query = base_insert.on_conflict_do_nothing(constraint=table.primary_key) 

324 else: 

325 query = base_insert.on_conflict_do_nothing() 

326 with self._transaction() as (_, connection): 

327 return connection.execute(query, rows).rowcount 

328 

329 def constant_rows( 

330 self, 

331 fields: NamedValueAbstractSet[ddl.FieldSpec], 

332 *rows: dict, 

333 name: str | None = None, 

334 ) -> sqlalchemy.sql.FromClause: 

335 # Docstring inherited. 

336 return super().constant_rows(fields, *rows, name=name) 

337 

338 @property 

339 def has_distinct_on(self) -> bool: 

340 # Docstring inherited. 

341 return True 

342 

343 @property 

344 def has_any_aggregate(self) -> bool: 

345 # Docstring inherited. 

346 return self._pg_version >= (16, 0) 

347 

348 def apply_any_aggregate(self, column: sqlalchemy.ColumnElement[Any]) -> sqlalchemy.ColumnElement[Any]: 

349 # Docstring inherited.x 

350 return sqlalchemy.func.any_value(column) 

351 

352 

353class _RangeTimespanType(sqlalchemy.TypeDecorator): 

354 """A single-column `Timespan` representation usable only with 

355 PostgreSQL. 

356 

357 This type should be able to take advantage of PostgreSQL's built-in 

358 range operators, and the indexing and EXCLUSION table constraints built 

359 off of them. 

360 """ 

361 

362 impl = sqlalchemy.dialects.postgresql.INT8RANGE 

363 

364 cache_ok = True 

365 

366 def process_bind_param( 

367 self, value: Timespan | None, dialect: sqlalchemy.engine.Dialect 

368 ) -> psycopg2.extras.NumericRange | None: 

369 if value is None: 

370 return None 

371 if not isinstance(value, Timespan): 

372 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.") 

373 if value.isEmpty(): 

374 return psycopg2.extras.NumericRange(empty=True) 

375 else: 

376 converter = time_utils.TimeConverter() 

377 assert value.nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__." 

378 assert value.nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__." 

379 lower = None if value.nsec[0] == converter.min_nsec else value.nsec[0] 

380 upper = None if value.nsec[1] == converter.max_nsec else value.nsec[1] 

381 return psycopg2.extras.NumericRange(lower=lower, upper=upper) 

382 

383 def process_result_value( 

384 self, value: psycopg2.extras.NumericRange | None, dialect: sqlalchemy.engine.Dialect 

385 ) -> Timespan | None: 

386 if value is None: 

387 return None 

388 if value.isempty: 

389 return Timespan.makeEmpty() 

390 converter = time_utils.TimeConverter() 

391 begin_nsec = converter.min_nsec if value.lower is None else value.lower 

392 end_nsec = converter.max_nsec if value.upper is None else value.upper 

393 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec)) 

394 

395 

396class _RangeTimespanRepresentation(TimespanDatabaseRepresentation): 

397 """An implementation of `TimespanDatabaseRepresentation` that uses 

398 `_RangeTimespanType` to store a timespan in a single 

399 PostgreSQL-specific field. 

400 

401 Parameters 

402 ---------- 

403 column : `sqlalchemy.sql.ColumnElement` 

404 SQLAlchemy object representing the column. 

405 """ 

406 

407 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str): 

408 self.column = column 

409 self._name = name 

410 

411 __slots__ = ("column", "_name") 

412 

413 @classmethod 

414 def makeFieldSpecs( 

415 cls, nullable: bool, name: str | None = None, **kwargs: Any 

416 ) -> tuple[ddl.FieldSpec, ...]: 

417 # Docstring inherited. 

418 if name is None: 

419 name = cls.NAME 

420 return ( 

421 ddl.FieldSpec( 

422 name, 

423 dtype=_RangeTimespanType, 

424 nullable=nullable, 

425 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")), 

426 **kwargs, 

427 ), 

428 ) 

429 

430 @classmethod 

431 def getFieldNames(cls, name: str | None = None) -> tuple[str, ...]: 

432 # Docstring inherited. 

433 if name is None: 

434 name = cls.NAME 

435 return (name,) 

436 

437 @classmethod 

438 def update( 

439 cls, extent: Timespan | None, name: str | None = None, result: dict[str, Any] | None = None 

440 ) -> dict[str, Any]: 

441 # Docstring inherited. 

442 if name is None: 

443 name = cls.NAME 

444 if result is None: 

445 result = {} 

446 result[name] = extent 

447 return result 

448 

449 @classmethod 

450 def extract(cls, mapping: Mapping[str, Any], name: str | None = None) -> Timespan | None: 

451 # Docstring inherited. 

452 if name is None: 

453 name = cls.NAME 

454 return mapping[name] 

455 

456 @classmethod 

457 def fromLiteral(cls, timespan: Timespan | None) -> _RangeTimespanRepresentation: 

458 # Docstring inherited. 

459 if timespan is None: 

460 return cls(column=sqlalchemy.sql.null(), name=cls.NAME) 

461 return cls( 

462 column=sqlalchemy.sql.cast( 

463 sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType), type_=_RangeTimespanType 

464 ), 

465 name=cls.NAME, 

466 ) 

467 

468 @classmethod 

469 def from_columns( 

470 cls, columns: sqlalchemy.sql.ColumnCollection, name: str | None = None 

471 ) -> _RangeTimespanRepresentation: 

472 # Docstring inherited. 

473 if name is None: 

474 name = cls.NAME 

475 return cls(columns[name], name) 

476 

477 @property 

478 def name(self) -> str: 

479 # Docstring inherited. 

480 return self._name 

481 

482 def isNull(self) -> sqlalchemy.sql.ColumnElement: 

483 # Docstring inherited. 

484 return self.column.is_(None) 

485 

486 def isEmpty(self) -> sqlalchemy.sql.ColumnElement: 

487 # Docstring inherited 

488 return sqlalchemy.sql.func.isempty(self.column) 

489 

490 def __lt__( 

491 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement 

492 ) -> sqlalchemy.sql.ColumnElement: 

493 # Docstring inherited. 

494 if isinstance(other, sqlalchemy.sql.ColumnElement): 

495 return sqlalchemy.sql.and_( 

496 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)), 

497 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

498 sqlalchemy.sql.func.upper(self.column) <= other, 

499 ) 

500 else: 

501 return self.column << other.column 

502 

503 def __gt__( 

504 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement 

505 ) -> sqlalchemy.sql.ColumnElement: 

506 # Docstring inherited. 

507 if isinstance(other, sqlalchemy.sql.ColumnElement): 

508 return sqlalchemy.sql.and_( 

509 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)), 

510 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

511 sqlalchemy.sql.func.lower(self.column) > other, 

512 ) 

513 else: 

514 return self.column >> other.column 

515 

516 def overlaps( 

517 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement 

518 ) -> sqlalchemy.sql.ColumnElement: 

519 # Docstring inherited. 

520 if not isinstance(other, _RangeTimespanRepresentation): 

521 return self.contains(other) 

522 return self.column.overlaps(other.column) 

523 

524 def contains( 

525 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement 

526 ) -> sqlalchemy.sql.ColumnElement: 

527 # Docstring inherited 

528 if isinstance(other, _RangeTimespanRepresentation): 

529 return self.column.contains(other.column) 

530 else: 

531 return self.column.contains(other) 

532 

533 def lower(self) -> sqlalchemy.sql.ColumnElement: 

534 # Docstring inherited. 

535 return sqlalchemy.sql.functions.coalesce( 

536 sqlalchemy.sql.func.lower(self.column), sqlalchemy.sql.literal(0) 

537 ) 

538 

539 def upper(self) -> sqlalchemy.sql.ColumnElement: 

540 # Docstring inherited. 

541 return sqlalchemy.sql.functions.coalesce( 

542 sqlalchemy.sql.func.upper(self.column), sqlalchemy.sql.literal(0) 

543 ) 

544 

545 def flatten(self, name: str | None = None) -> tuple[sqlalchemy.sql.ColumnElement]: 

546 # Docstring inherited. 

547 if name is None: 

548 return (self.column,) 

549 else: 

550 return (self.column.label(name),) 

551 

552 def apply_any_aggregate( 

553 self, func: Callable[[ColumnElement[Any]], ColumnElement[Any]] 

554 ) -> TimespanDatabaseRepresentation: 

555 # Docstring inherited. 

556 return _RangeTimespanRepresentation(func(self.column), self.name)