Coverage for python/lsst/daf/butler/registry/databases/postgresql.py: 31%
204 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-25 10:48 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-25 10:48 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29from ... import ddl, time_utils
31__all__ = ["PostgresqlDatabase"]
33from collections.abc import Iterable, Iterator, Mapping
34from contextlib import closing, contextmanager
35from typing import Any
37import psycopg2
38import sqlalchemy
39import sqlalchemy.dialects.postgresql
40from sqlalchemy import sql
42from ..._named import NamedValueAbstractSet
43from ..._timespan import Timespan
44from ...timespan_database_representation import TimespanDatabaseRepresentation
45from ..interfaces import Database
46from ..nameShrinker import NameShrinker
49class PostgresqlDatabase(Database):
50 """An implementation of the `Database` interface for PostgreSQL.
52 Parameters
53 ----------
54 engine : `sqlalchemy.engine.Engine`
55 Engine to use for this connection.
56 origin : `int`
57 An integer ID that should be used as the default for any datasets,
58 quanta, or other entities that use a (autoincrement, origin) compound
59 primary key.
60 namespace : `str`, optional
61 The namespace (schema) this database is associated with. If `None`,
62 the default schema for the connection is used (which may be `None`).
63 writeable : `bool`, optional
64 If `True`, allow write operations on the database, including
65 ``CREATE TABLE``.
67 Notes
68 -----
69 This currently requires the psycopg2 driver to be used as the backend for
70 SQLAlchemy. Running the tests for this class requires the
71 ``testing.postgresql`` be installed, which we assume indicates that a
72 PostgreSQL server is installed and can be run locally in userspace.
74 Some functionality provided by this class (and used by `Registry`) requires
75 the ``btree_gist`` PostgreSQL server extension to be installed an enabled
76 on the database being connected to; this is checked at connection time.
77 """
79 def __init__(
80 self,
81 *,
82 engine: sqlalchemy.engine.Engine,
83 origin: int,
84 namespace: str | None = None,
85 writeable: bool = True,
86 ):
87 with engine.connect() as connection:
88 # `Any` to make mypy ignore the line below, can't use type: ignore
89 dbapi: Any = connection.connection
90 try:
91 dsn = dbapi.get_dsn_parameters()
92 except (AttributeError, KeyError) as err:
93 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err
94 if namespace is None:
95 query = sql.select(sql.func.current_schema())
96 namespace = connection.execute(query).scalar()
97 query_text = "SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';"
98 if not connection.execute(sqlalchemy.text(query_text)).scalar():
99 raise RuntimeError(
100 "The Butler PostgreSQL backend requires the btree_gist extension. "
101 "As extensions are enabled per-database, this may require an administrator to run "
102 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is "
103 " initialized."
104 )
105 self._init(
106 engine=engine,
107 origin=origin,
108 namespace=namespace,
109 writeable=writeable,
110 dbname=dsn.get("dbname"),
111 metadata=None,
112 )
114 def _init(
115 self,
116 *,
117 engine: sqlalchemy.engine.Engine,
118 origin: int,
119 namespace: str | None = None,
120 writeable: bool = True,
121 dbname: str,
122 metadata: sqlalchemy.schema.MetaData | None,
123 ) -> None:
124 # Initialization logic shared between ``__init__`` and ``clone``.
125 super().__init__(origin=origin, engine=engine, namespace=namespace, metadata=metadata)
126 self._writeable = writeable
127 self.dbname = dbname
128 self._shrinker = NameShrinker(self.dialect.max_identifier_length)
130 def clone(self) -> PostgresqlDatabase:
131 clone = self.__new__(type(self))
132 clone._init(
133 origin=self.origin,
134 engine=self._engine,
135 namespace=self.namespace,
136 writeable=self._writeable,
137 dbname=self.dbname,
138 metadata=self._metadata,
139 )
140 return clone
142 @classmethod
143 def makeEngine(
144 cls, uri: str | sqlalchemy.engine.URL, *, writeable: bool = True
145 ) -> sqlalchemy.engine.Engine:
146 return sqlalchemy.engine.create_engine(uri, pool_size=1)
148 @classmethod
149 def fromEngine(
150 cls,
151 engine: sqlalchemy.engine.Engine,
152 *,
153 origin: int,
154 namespace: str | None = None,
155 writeable: bool = True,
156 ) -> Database:
157 return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable)
159 @contextmanager
160 def _transaction(
161 self,
162 *,
163 interrupting: bool = False,
164 savepoint: bool = False,
165 lock: Iterable[sqlalchemy.schema.Table] = (),
166 for_temp_tables: bool = False,
167 ) -> Iterator[tuple[bool, sqlalchemy.engine.Connection]]:
168 with super()._transaction(interrupting=interrupting, savepoint=savepoint, lock=lock) as (
169 is_new,
170 connection,
171 ):
172 if is_new:
173 # pgbouncer with transaction-level pooling (which we aim to
174 # support) says that SET cannot be used, except for a list of
175 # "Startup parameters" that includes "timezone" (see
176 # https://www.pgbouncer.org/features.html#fnref:0). But I
177 # don't see "timezone" in PostgreSQL's list of parameters
178 # passed when creating a new connection
179 # (https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS).
180 # Given that the pgbouncer docs say, "PgBouncer detects their
181 # changes and so it can guarantee they remain consistent for
182 # the client", I assume we can use "SET TIMESPAN" and pgbouncer
183 # will take care of clients that share connections being set
184 # consistently. And if that assumption is wrong, we should
185 # still probably be okay, since all clients should be Butler
186 # clients, and they'll all be setting the same thing.
187 #
188 # The "SET TRANSACTION READ ONLY" should also be safe, because
189 # it only ever acts on the current transaction; I think it's
190 # not included in pgbouncer's declaration that SET is
191 # incompatible with transaction-level pooling because
192 # PostgreSQL actually considers SET TRANSACTION to be a
193 # fundamentally different statement from SET (they have their
194 # own distinct doc pages, at least).
195 if not (self.isWriteable() or for_temp_tables):
196 # PostgreSQL permits writing to temporary tables inside
197 # read-only transactions, but it doesn't permit creating
198 # them.
199 with closing(connection.connection.cursor()) as cursor:
200 cursor.execute("SET TRANSACTION READ ONLY")
201 cursor.execute("SET TIME ZONE 0")
202 else:
203 with closing(connection.connection.cursor()) as cursor:
204 # Make timestamps UTC, because we didn't use TIMESTAMPZ
205 # for the column type. When we can tolerate a schema
206 # change, we should change that type and remove this
207 # line.
208 cursor.execute("SET TIME ZONE 0")
209 yield is_new, connection
211 @contextmanager
212 def temporary_table(
213 self, spec: ddl.TableSpec, name: str | None = None
214 ) -> Iterator[sqlalchemy.schema.Table]:
215 # Docstring inherited.
216 with self.transaction(for_temp_tables=True), super().temporary_table(spec, name) as table:
217 yield table
219 def _lockTables(
220 self, connection: sqlalchemy.engine.Connection, tables: Iterable[sqlalchemy.schema.Table] = ()
221 ) -> None:
222 # Docstring inherited.
223 for table in tables:
224 connection.execute(sqlalchemy.text(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE"))
226 def isWriteable(self) -> bool:
227 return self._writeable
229 def __str__(self) -> str:
230 return f"PostgreSQL@{self.dbname}:{self.namespace}"
232 def shrinkDatabaseEntityName(self, original: str) -> str:
233 return self._shrinker.shrink(original)
235 def expandDatabaseEntityName(self, shrunk: str) -> str:
236 return self._shrinker.expand(shrunk)
238 def _convertExclusionConstraintSpec(
239 self,
240 table: str,
241 spec: tuple[str | type[TimespanDatabaseRepresentation], ...],
242 metadata: sqlalchemy.MetaData,
243 ) -> sqlalchemy.schema.Constraint:
244 # Docstring inherited.
245 args: list[tuple[sqlalchemy.schema.Column, str]] = []
246 names = ["excl"]
247 for item in spec:
248 if isinstance(item, str):
249 args.append((sqlalchemy.schema.Column(item), "="))
250 names.append(item)
251 elif issubclass(item, TimespanDatabaseRepresentation):
252 assert item is self.getTimespanRepresentation()
253 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&"))
254 names.append(TimespanDatabaseRepresentation.NAME)
255 return sqlalchemy.dialects.postgresql.ExcludeConstraint(
256 *args,
257 name=self.shrinkDatabaseEntityName("_".join(names)),
258 )
260 def _make_temporary_table(
261 self,
262 connection: sqlalchemy.engine.Connection,
263 spec: ddl.TableSpec,
264 name: str | None = None,
265 **kwargs: Any,
266 ) -> sqlalchemy.schema.Table:
267 # Docstring inherited
268 # Adding ON COMMIT DROP here is really quite defensive: we already
269 # manually drop the table at the end of the temporary_table context
270 # manager, and that will usually happen first. But this will guarantee
271 # that we drop the table at the end of the transaction even if the
272 # connection lasts longer, and that's good citizenship when connections
273 # may be multiplexed by e.g. pgbouncer.
274 return super()._make_temporary_table(connection, spec, name, postgresql_on_commit="DROP", **kwargs)
276 @classmethod
277 def getTimespanRepresentation(cls) -> type[TimespanDatabaseRepresentation]:
278 # Docstring inherited.
279 return _RangeTimespanRepresentation
281 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None:
282 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.")
283 if not rows:
284 return
285 # This uses special support for UPSERT in PostgreSQL backend:
286 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert
287 query = sqlalchemy.dialects.postgresql.dml.insert(table)
288 # In the SET clause assign all columns using special `excluded`
289 # pseudo-table. If some column in the table does not appear in the
290 # INSERT list this will set it to NULL.
291 excluded = query.excluded
292 data = {
293 column.name: getattr(excluded, column.name)
294 for column in table.columns
295 if column.name not in table.primary_key
296 }
297 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data)
298 with self._transaction() as (_, connection):
299 connection.execute(query, rows)
301 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict, primary_key_only: bool = False) -> int:
302 # Docstring inherited.
303 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.")
304 if not rows:
305 return 0
306 # Like `replace`, this uses UPSERT.
307 base_insert = sqlalchemy.dialects.postgresql.dml.insert(table)
308 if primary_key_only:
309 query = base_insert.on_conflict_do_nothing(constraint=table.primary_key)
310 else:
311 query = base_insert.on_conflict_do_nothing()
312 with self._transaction() as (_, connection):
313 return connection.execute(query, rows).rowcount
315 def constant_rows(
316 self,
317 fields: NamedValueAbstractSet[ddl.FieldSpec],
318 *rows: dict,
319 name: str | None = None,
320 ) -> sqlalchemy.sql.FromClause:
321 # Docstring inherited.
322 return super().constant_rows(fields, *rows, name=name)
325class _RangeTimespanType(sqlalchemy.TypeDecorator):
326 """A single-column `Timespan` representation usable only with
327 PostgreSQL.
329 This type should be able to take advantage of PostgreSQL's built-in
330 range operators, and the indexing and EXCLUSION table constraints built
331 off of them.
332 """
334 impl = sqlalchemy.dialects.postgresql.INT8RANGE
336 cache_ok = True
338 def process_bind_param(
339 self, value: Timespan | None, dialect: sqlalchemy.engine.Dialect
340 ) -> psycopg2.extras.NumericRange | None:
341 if value is None:
342 return None
343 if not isinstance(value, Timespan):
344 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.")
345 if value.isEmpty():
346 return psycopg2.extras.NumericRange(empty=True)
347 else:
348 converter = time_utils.TimeConverter()
349 assert value._nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__."
350 assert value._nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__."
351 lower = None if value._nsec[0] == converter.min_nsec else value._nsec[0]
352 upper = None if value._nsec[1] == converter.max_nsec else value._nsec[1]
353 return psycopg2.extras.NumericRange(lower=lower, upper=upper)
355 def process_result_value(
356 self, value: psycopg2.extras.NumericRange | None, dialect: sqlalchemy.engine.Dialect
357 ) -> Timespan | None:
358 if value is None:
359 return None
360 if value.isempty:
361 return Timespan.makeEmpty()
362 converter = time_utils.TimeConverter()
363 begin_nsec = converter.min_nsec if value.lower is None else value.lower
364 end_nsec = converter.max_nsec if value.upper is None else value.upper
365 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec))
368class _RangeTimespanRepresentation(TimespanDatabaseRepresentation):
369 """An implementation of `TimespanDatabaseRepresentation` that uses
370 `_RangeTimespanType` to store a timespan in a single
371 PostgreSQL-specific field.
373 Parameters
374 ----------
375 column : `sqlalchemy.sql.ColumnElement`
376 SQLAlchemy object representing the column.
377 """
379 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str):
380 self.column = column
381 self._name = name
383 __slots__ = ("column", "_name")
385 @classmethod
386 def makeFieldSpecs(
387 cls, nullable: bool, name: str | None = None, **kwargs: Any
388 ) -> tuple[ddl.FieldSpec, ...]:
389 # Docstring inherited.
390 if name is None:
391 name = cls.NAME
392 return (
393 ddl.FieldSpec(
394 name,
395 dtype=_RangeTimespanType,
396 nullable=nullable,
397 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")),
398 **kwargs,
399 ),
400 )
402 @classmethod
403 def getFieldNames(cls, name: str | None = None) -> tuple[str, ...]:
404 # Docstring inherited.
405 if name is None:
406 name = cls.NAME
407 return (name,)
409 @classmethod
410 def update(
411 cls, extent: Timespan | None, name: str | None = None, result: dict[str, Any] | None = None
412 ) -> dict[str, Any]:
413 # Docstring inherited.
414 if name is None:
415 name = cls.NAME
416 if result is None:
417 result = {}
418 result[name] = extent
419 return result
421 @classmethod
422 def extract(cls, mapping: Mapping[str, Any], name: str | None = None) -> Timespan | None:
423 # Docstring inherited.
424 if name is None:
425 name = cls.NAME
426 return mapping[name]
428 @classmethod
429 def fromLiteral(cls, timespan: Timespan | None) -> _RangeTimespanRepresentation:
430 # Docstring inherited.
431 if timespan is None:
432 return cls(column=sqlalchemy.sql.null(), name=cls.NAME)
433 return cls(
434 column=sqlalchemy.sql.cast(
435 sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType), type_=_RangeTimespanType
436 ),
437 name=cls.NAME,
438 )
440 @classmethod
441 def from_columns(
442 cls, columns: sqlalchemy.sql.ColumnCollection, name: str | None = None
443 ) -> _RangeTimespanRepresentation:
444 # Docstring inherited.
445 if name is None:
446 name = cls.NAME
447 return cls(columns[name], name)
449 @property
450 def name(self) -> str:
451 # Docstring inherited.
452 return self._name
454 def isNull(self) -> sqlalchemy.sql.ColumnElement:
455 # Docstring inherited.
456 return self.column.is_(None)
458 def isEmpty(self) -> sqlalchemy.sql.ColumnElement:
459 # Docstring inherited
460 return sqlalchemy.sql.func.isempty(self.column)
462 def __lt__(
463 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement
464 ) -> sqlalchemy.sql.ColumnElement:
465 # Docstring inherited.
466 if isinstance(other, sqlalchemy.sql.ColumnElement):
467 return sqlalchemy.sql.and_(
468 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)),
469 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)),
470 sqlalchemy.sql.func.upper(self.column) <= other,
471 )
472 else:
473 return self.column << other.column
475 def __gt__(
476 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement
477 ) -> sqlalchemy.sql.ColumnElement:
478 # Docstring inherited.
479 if isinstance(other, sqlalchemy.sql.ColumnElement):
480 return sqlalchemy.sql.and_(
481 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)),
482 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)),
483 sqlalchemy.sql.func.lower(self.column) > other,
484 )
485 else:
486 return self.column >> other.column
488 def overlaps(
489 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement
490 ) -> sqlalchemy.sql.ColumnElement:
491 # Docstring inherited.
492 if not isinstance(other, _RangeTimespanRepresentation):
493 return self.contains(other)
494 return self.column.overlaps(other.column)
496 def contains(
497 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement
498 ) -> sqlalchemy.sql.ColumnElement:
499 # Docstring inherited
500 if isinstance(other, _RangeTimespanRepresentation):
501 return self.column.contains(other.column)
502 else:
503 return self.column.contains(other)
505 def lower(self) -> sqlalchemy.sql.ColumnElement:
506 # Docstring inherited.
507 return sqlalchemy.sql.functions.coalesce(
508 sqlalchemy.sql.func.lower(self.column), sqlalchemy.sql.literal(0)
509 )
511 def upper(self) -> sqlalchemy.sql.ColumnElement:
512 # Docstring inherited.
513 return sqlalchemy.sql.functions.coalesce(
514 sqlalchemy.sql.func.upper(self.column), sqlalchemy.sql.literal(0)
515 )
517 def flatten(self, name: str | None = None) -> tuple[sqlalchemy.sql.ColumnElement]:
518 # Docstring inherited.
519 if name is None:
520 return (self.column,)
521 else:
522 return (self.column.label(name),)