Coverage for python/lsst/daf/butler/registry/databases/postgresql.py: 31%

204 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-02-01 11:19 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29from ... import ddl, time_utils 

30 

31__all__ = ["PostgresqlDatabase"] 

32 

33from collections.abc import Iterable, Iterator, Mapping 

34from contextlib import closing, contextmanager 

35from typing import Any 

36 

37import psycopg2 

38import sqlalchemy 

39import sqlalchemy.dialects.postgresql 

40from sqlalchemy import sql 

41 

42from ..._named import NamedValueAbstractSet 

43from ..._timespan import Timespan 

44from ...timespan_database_representation import TimespanDatabaseRepresentation 

45from ..interfaces import Database 

46from ..nameShrinker import NameShrinker 

47 

48 

49class PostgresqlDatabase(Database): 

50 """An implementation of the `Database` interface for PostgreSQL. 

51 

52 Parameters 

53 ---------- 

54 engine : `sqlalchemy.engine.Engine` 

55 Engine to use for this connection. 

56 origin : `int` 

57 An integer ID that should be used as the default for any datasets, 

58 quanta, or other entities that use a (autoincrement, origin) compound 

59 primary key. 

60 namespace : `str`, optional 

61 The namespace (schema) this database is associated with. If `None`, 

62 the default schema for the connection is used (which may be `None`). 

63 writeable : `bool`, optional 

64 If `True`, allow write operations on the database, including 

65 ``CREATE TABLE``. 

66 

67 Notes 

68 ----- 

69 This currently requires the psycopg2 driver to be used as the backend for 

70 SQLAlchemy. Running the tests for this class requires the 

71 ``testing.postgresql`` be installed, which we assume indicates that a 

72 PostgreSQL server is installed and can be run locally in userspace. 

73 

74 Some functionality provided by this class (and used by `Registry`) requires 

75 the ``btree_gist`` PostgreSQL server extension to be installed an enabled 

76 on the database being connected to; this is checked at connection time. 

77 """ 

78 

79 def __init__( 

80 self, 

81 *, 

82 engine: sqlalchemy.engine.Engine, 

83 origin: int, 

84 namespace: str | None = None, 

85 writeable: bool = True, 

86 ): 

87 with engine.connect() as connection: 

88 # `Any` to make mypy ignore the line below, can't use type: ignore 

89 dbapi: Any = connection.connection 

90 try: 

91 dsn = dbapi.get_dsn_parameters() 

92 except (AttributeError, KeyError) as err: 

93 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err 

94 if namespace is None: 

95 query = sql.select(sql.func.current_schema()) 

96 namespace = connection.execute(query).scalar() 

97 query_text = "SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';" 

98 if not connection.execute(sqlalchemy.text(query_text)).scalar(): 

99 raise RuntimeError( 

100 "The Butler PostgreSQL backend requires the btree_gist extension. " 

101 "As extensions are enabled per-database, this may require an administrator to run " 

102 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is " 

103 " initialized." 

104 ) 

105 self._init( 

106 engine=engine, 

107 origin=origin, 

108 namespace=namespace, 

109 writeable=writeable, 

110 dbname=dsn.get("dbname"), 

111 metadata=None, 

112 ) 

113 

114 def _init( 

115 self, 

116 *, 

117 engine: sqlalchemy.engine.Engine, 

118 origin: int, 

119 namespace: str | None = None, 

120 writeable: bool = True, 

121 dbname: str, 

122 metadata: sqlalchemy.schema.MetaData | None, 

123 ) -> None: 

124 # Initialization logic shared between ``__init__`` and ``clone``. 

125 super().__init__(origin=origin, engine=engine, namespace=namespace, metadata=metadata) 

126 self._writeable = writeable 

127 self.dbname = dbname 

128 self._shrinker = NameShrinker(self.dialect.max_identifier_length) 

129 

130 def clone(self) -> PostgresqlDatabase: 

131 clone = self.__new__(type(self)) 

132 clone._init( 

133 origin=self.origin, 

134 engine=self._engine, 

135 namespace=self.namespace, 

136 writeable=self._writeable, 

137 dbname=self.dbname, 

138 metadata=self._metadata, 

139 ) 

140 return clone 

141 

142 @classmethod 

143 def makeEngine( 

144 cls, uri: str | sqlalchemy.engine.URL, *, writeable: bool = True 

145 ) -> sqlalchemy.engine.Engine: 

146 return sqlalchemy.engine.create_engine(uri, pool_size=1) 

147 

148 @classmethod 

149 def fromEngine( 

150 cls, 

151 engine: sqlalchemy.engine.Engine, 

152 *, 

153 origin: int, 

154 namespace: str | None = None, 

155 writeable: bool = True, 

156 ) -> Database: 

157 return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable) 

158 

159 @contextmanager 

160 def _transaction( 

161 self, 

162 *, 

163 interrupting: bool = False, 

164 savepoint: bool = False, 

165 lock: Iterable[sqlalchemy.schema.Table] = (), 

166 for_temp_tables: bool = False, 

167 ) -> Iterator[tuple[bool, sqlalchemy.engine.Connection]]: 

168 with super()._transaction(interrupting=interrupting, savepoint=savepoint, lock=lock) as ( 

169 is_new, 

170 connection, 

171 ): 

172 if is_new: 

173 # pgbouncer with transaction-level pooling (which we aim to 

174 # support) says that SET cannot be used, except for a list of 

175 # "Startup parameters" that includes "timezone" (see 

176 # https://www.pgbouncer.org/features.html#fnref:0). But I 

177 # don't see "timezone" in PostgreSQL's list of parameters 

178 # passed when creating a new connection 

179 # (https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS). 

180 # Given that the pgbouncer docs say, "PgBouncer detects their 

181 # changes and so it can guarantee they remain consistent for 

182 # the client", I assume we can use "SET TIMESPAN" and pgbouncer 

183 # will take care of clients that share connections being set 

184 # consistently. And if that assumption is wrong, we should 

185 # still probably be okay, since all clients should be Butler 

186 # clients, and they'll all be setting the same thing. 

187 # 

188 # The "SET TRANSACTION READ ONLY" should also be safe, because 

189 # it only ever acts on the current transaction; I think it's 

190 # not included in pgbouncer's declaration that SET is 

191 # incompatible with transaction-level pooling because 

192 # PostgreSQL actually considers SET TRANSACTION to be a 

193 # fundamentally different statement from SET (they have their 

194 # own distinct doc pages, at least). 

195 if not (self.isWriteable() or for_temp_tables): 

196 # PostgreSQL permits writing to temporary tables inside 

197 # read-only transactions, but it doesn't permit creating 

198 # them. 

199 with closing(connection.connection.cursor()) as cursor: 

200 cursor.execute("SET TRANSACTION READ ONLY") 

201 cursor.execute("SET TIME ZONE 0") 

202 else: 

203 with closing(connection.connection.cursor()) as cursor: 

204 # Make timestamps UTC, because we didn't use TIMESTAMPZ 

205 # for the column type. When we can tolerate a schema 

206 # change, we should change that type and remove this 

207 # line. 

208 cursor.execute("SET TIME ZONE 0") 

209 yield is_new, connection 

210 

211 @contextmanager 

212 def temporary_table( 

213 self, spec: ddl.TableSpec, name: str | None = None 

214 ) -> Iterator[sqlalchemy.schema.Table]: 

215 # Docstring inherited. 

216 with self.transaction(for_temp_tables=True), super().temporary_table(spec, name) as table: 

217 yield table 

218 

219 def _lockTables( 

220 self, connection: sqlalchemy.engine.Connection, tables: Iterable[sqlalchemy.schema.Table] = () 

221 ) -> None: 

222 # Docstring inherited. 

223 for table in tables: 

224 connection.execute(sqlalchemy.text(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE")) 

225 

226 def isWriteable(self) -> bool: 

227 return self._writeable 

228 

229 def __str__(self) -> str: 

230 return f"PostgreSQL@{self.dbname}:{self.namespace}" 

231 

232 def shrinkDatabaseEntityName(self, original: str) -> str: 

233 return self._shrinker.shrink(original) 

234 

235 def expandDatabaseEntityName(self, shrunk: str) -> str: 

236 return self._shrinker.expand(shrunk) 

237 

238 def _convertExclusionConstraintSpec( 

239 self, 

240 table: str, 

241 spec: tuple[str | type[TimespanDatabaseRepresentation], ...], 

242 metadata: sqlalchemy.MetaData, 

243 ) -> sqlalchemy.schema.Constraint: 

244 # Docstring inherited. 

245 args: list[tuple[sqlalchemy.schema.Column, str]] = [] 

246 names = ["excl"] 

247 for item in spec: 

248 if isinstance(item, str): 

249 args.append((sqlalchemy.schema.Column(item), "=")) 

250 names.append(item) 

251 elif issubclass(item, TimespanDatabaseRepresentation): 

252 assert item is self.getTimespanRepresentation() 

253 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&")) 

254 names.append(TimespanDatabaseRepresentation.NAME) 

255 return sqlalchemy.dialects.postgresql.ExcludeConstraint( 

256 *args, 

257 name=self.shrinkDatabaseEntityName("_".join(names)), 

258 ) 

259 

260 def _make_temporary_table( 

261 self, 

262 connection: sqlalchemy.engine.Connection, 

263 spec: ddl.TableSpec, 

264 name: str | None = None, 

265 **kwargs: Any, 

266 ) -> sqlalchemy.schema.Table: 

267 # Docstring inherited 

268 # Adding ON COMMIT DROP here is really quite defensive: we already 

269 # manually drop the table at the end of the temporary_table context 

270 # manager, and that will usually happen first. But this will guarantee 

271 # that we drop the table at the end of the transaction even if the 

272 # connection lasts longer, and that's good citizenship when connections 

273 # may be multiplexed by e.g. pgbouncer. 

274 return super()._make_temporary_table(connection, spec, name, postgresql_on_commit="DROP", **kwargs) 

275 

276 @classmethod 

277 def getTimespanRepresentation(cls) -> type[TimespanDatabaseRepresentation]: 

278 # Docstring inherited. 

279 return _RangeTimespanRepresentation 

280 

281 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None: 

282 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.") 

283 if not rows: 

284 return 

285 # This uses special support for UPSERT in PostgreSQL backend: 

286 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert 

287 query = sqlalchemy.dialects.postgresql.dml.insert(table) 

288 # In the SET clause assign all columns using special `excluded` 

289 # pseudo-table. If some column in the table does not appear in the 

290 # INSERT list this will set it to NULL. 

291 excluded = query.excluded 

292 data = { 

293 column.name: getattr(excluded, column.name) 

294 for column in table.columns 

295 if column.name not in table.primary_key 

296 } 

297 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data) 

298 with self._transaction() as (_, connection): 

299 connection.execute(query, rows) 

300 

301 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict, primary_key_only: bool = False) -> int: 

302 # Docstring inherited. 

303 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.") 

304 if not rows: 

305 return 0 

306 # Like `replace`, this uses UPSERT. 

307 base_insert = sqlalchemy.dialects.postgresql.dml.insert(table) 

308 if primary_key_only: 

309 query = base_insert.on_conflict_do_nothing(constraint=table.primary_key) 

310 else: 

311 query = base_insert.on_conflict_do_nothing() 

312 with self._transaction() as (_, connection): 

313 return connection.execute(query, rows).rowcount 

314 

315 def constant_rows( 

316 self, 

317 fields: NamedValueAbstractSet[ddl.FieldSpec], 

318 *rows: dict, 

319 name: str | None = None, 

320 ) -> sqlalchemy.sql.FromClause: 

321 # Docstring inherited. 

322 return super().constant_rows(fields, *rows, name=name) 

323 

324 

325class _RangeTimespanType(sqlalchemy.TypeDecorator): 

326 """A single-column `Timespan` representation usable only with 

327 PostgreSQL. 

328 

329 This type should be able to take advantage of PostgreSQL's built-in 

330 range operators, and the indexing and EXCLUSION table constraints built 

331 off of them. 

332 """ 

333 

334 impl = sqlalchemy.dialects.postgresql.INT8RANGE 

335 

336 cache_ok = True 

337 

338 def process_bind_param( 

339 self, value: Timespan | None, dialect: sqlalchemy.engine.Dialect 

340 ) -> psycopg2.extras.NumericRange | None: 

341 if value is None: 

342 return None 

343 if not isinstance(value, Timespan): 

344 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.") 

345 if value.isEmpty(): 

346 return psycopg2.extras.NumericRange(empty=True) 

347 else: 

348 converter = time_utils.TimeConverter() 

349 assert value._nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__." 

350 assert value._nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__." 

351 lower = None if value._nsec[0] == converter.min_nsec else value._nsec[0] 

352 upper = None if value._nsec[1] == converter.max_nsec else value._nsec[1] 

353 return psycopg2.extras.NumericRange(lower=lower, upper=upper) 

354 

355 def process_result_value( 

356 self, value: psycopg2.extras.NumericRange | None, dialect: sqlalchemy.engine.Dialect 

357 ) -> Timespan | None: 

358 if value is None: 

359 return None 

360 if value.isempty: 

361 return Timespan.makeEmpty() 

362 converter = time_utils.TimeConverter() 

363 begin_nsec = converter.min_nsec if value.lower is None else value.lower 

364 end_nsec = converter.max_nsec if value.upper is None else value.upper 

365 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec)) 

366 

367 

368class _RangeTimespanRepresentation(TimespanDatabaseRepresentation): 

369 """An implementation of `TimespanDatabaseRepresentation` that uses 

370 `_RangeTimespanType` to store a timespan in a single 

371 PostgreSQL-specific field. 

372 

373 Parameters 

374 ---------- 

375 column : `sqlalchemy.sql.ColumnElement` 

376 SQLAlchemy object representing the column. 

377 """ 

378 

379 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str): 

380 self.column = column 

381 self._name = name 

382 

383 __slots__ = ("column", "_name") 

384 

385 @classmethod 

386 def makeFieldSpecs( 

387 cls, nullable: bool, name: str | None = None, **kwargs: Any 

388 ) -> tuple[ddl.FieldSpec, ...]: 

389 # Docstring inherited. 

390 if name is None: 

391 name = cls.NAME 

392 return ( 

393 ddl.FieldSpec( 

394 name, 

395 dtype=_RangeTimespanType, 

396 nullable=nullable, 

397 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")), 

398 **kwargs, 

399 ), 

400 ) 

401 

402 @classmethod 

403 def getFieldNames(cls, name: str | None = None) -> tuple[str, ...]: 

404 # Docstring inherited. 

405 if name is None: 

406 name = cls.NAME 

407 return (name,) 

408 

409 @classmethod 

410 def update( 

411 cls, extent: Timespan | None, name: str | None = None, result: dict[str, Any] | None = None 

412 ) -> dict[str, Any]: 

413 # Docstring inherited. 

414 if name is None: 

415 name = cls.NAME 

416 if result is None: 

417 result = {} 

418 result[name] = extent 

419 return result 

420 

421 @classmethod 

422 def extract(cls, mapping: Mapping[str, Any], name: str | None = None) -> Timespan | None: 

423 # Docstring inherited. 

424 if name is None: 

425 name = cls.NAME 

426 return mapping[name] 

427 

428 @classmethod 

429 def fromLiteral(cls, timespan: Timespan | None) -> _RangeTimespanRepresentation: 

430 # Docstring inherited. 

431 if timespan is None: 

432 return cls(column=sqlalchemy.sql.null(), name=cls.NAME) 

433 return cls( 

434 column=sqlalchemy.sql.cast( 

435 sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType), type_=_RangeTimespanType 

436 ), 

437 name=cls.NAME, 

438 ) 

439 

440 @classmethod 

441 def from_columns( 

442 cls, columns: sqlalchemy.sql.ColumnCollection, name: str | None = None 

443 ) -> _RangeTimespanRepresentation: 

444 # Docstring inherited. 

445 if name is None: 

446 name = cls.NAME 

447 return cls(columns[name], name) 

448 

449 @property 

450 def name(self) -> str: 

451 # Docstring inherited. 

452 return self._name 

453 

454 def isNull(self) -> sqlalchemy.sql.ColumnElement: 

455 # Docstring inherited. 

456 return self.column.is_(None) 

457 

458 def isEmpty(self) -> sqlalchemy.sql.ColumnElement: 

459 # Docstring inherited 

460 return sqlalchemy.sql.func.isempty(self.column) 

461 

462 def __lt__( 

463 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement 

464 ) -> sqlalchemy.sql.ColumnElement: 

465 # Docstring inherited. 

466 if isinstance(other, sqlalchemy.sql.ColumnElement): 

467 return sqlalchemy.sql.and_( 

468 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)), 

469 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

470 sqlalchemy.sql.func.upper(self.column) <= other, 

471 ) 

472 else: 

473 return self.column << other.column 

474 

475 def __gt__( 

476 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement 

477 ) -> sqlalchemy.sql.ColumnElement: 

478 # Docstring inherited. 

479 if isinstance(other, sqlalchemy.sql.ColumnElement): 

480 return sqlalchemy.sql.and_( 

481 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)), 

482 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

483 sqlalchemy.sql.func.lower(self.column) > other, 

484 ) 

485 else: 

486 return self.column >> other.column 

487 

488 def overlaps( 

489 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement 

490 ) -> sqlalchemy.sql.ColumnElement: 

491 # Docstring inherited. 

492 if not isinstance(other, _RangeTimespanRepresentation): 

493 return self.contains(other) 

494 return self.column.overlaps(other.column) 

495 

496 def contains( 

497 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement 

498 ) -> sqlalchemy.sql.ColumnElement: 

499 # Docstring inherited 

500 if isinstance(other, _RangeTimespanRepresentation): 

501 return self.column.contains(other.column) 

502 else: 

503 return self.column.contains(other) 

504 

505 def lower(self) -> sqlalchemy.sql.ColumnElement: 

506 # Docstring inherited. 

507 return sqlalchemy.sql.functions.coalesce( 

508 sqlalchemy.sql.func.lower(self.column), sqlalchemy.sql.literal(0) 

509 ) 

510 

511 def upper(self) -> sqlalchemy.sql.ColumnElement: 

512 # Docstring inherited. 

513 return sqlalchemy.sql.functions.coalesce( 

514 sqlalchemy.sql.func.upper(self.column), sqlalchemy.sql.literal(0) 

515 ) 

516 

517 def flatten(self, name: str | None = None) -> tuple[sqlalchemy.sql.ColumnElement]: 

518 # Docstring inherited. 

519 if name is None: 

520 return (self.column,) 

521 else: 

522 return (self.column.label(name),)