Coverage for python/lsst/daf/butler/registry/databases/postgresql.py: 31%
199 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-16 10:43 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-16 10:43 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29from ... import ddl, time_utils
31__all__ = ["PostgresqlDatabase"]
33from collections.abc import Iterable, Iterator, Mapping
34from contextlib import closing, contextmanager
35from typing import Any
37import psycopg2
38import sqlalchemy
39import sqlalchemy.dialects.postgresql
40from sqlalchemy import sql
42from ..._named import NamedValueAbstractSet
43from ..._timespan import Timespan
44from ...timespan_database_representation import TimespanDatabaseRepresentation
45from ..interfaces import Database
46from ..nameShrinker import NameShrinker
49class PostgresqlDatabase(Database):
50 """An implementation of the `Database` interface for PostgreSQL.
52 Parameters
53 ----------
54 engine : `sqlalchemy.engine.Engine`
55 Engine to use for this connection.
56 origin : `int`
57 An integer ID that should be used as the default for any datasets,
58 quanta, or other entities that use a (autoincrement, origin) compound
59 primary key.
60 namespace : `str`, optional
61 The namespace (schema) this database is associated with. If `None`,
62 the default schema for the connection is used (which may be `None`).
63 writeable : `bool`, optional
64 If `True`, allow write operations on the database, including
65 ``CREATE TABLE``.
67 Notes
68 -----
69 This currently requires the psycopg2 driver to be used as the backend for
70 SQLAlchemy. Running the tests for this class requires the
71 ``testing.postgresql`` be installed, which we assume indicates that a
72 PostgreSQL server is installed and can be run locally in userspace.
74 Some functionality provided by this class (and used by `Registry`) requires
75 the ``btree_gist`` PostgreSQL server extension to be installed an enabled
76 on the database being connected to; this is checked at connection time.
77 """
79 def __init__(
80 self,
81 *,
82 engine: sqlalchemy.engine.Engine,
83 origin: int,
84 namespace: str | None = None,
85 writeable: bool = True,
86 ):
87 super().__init__(origin=origin, engine=engine, namespace=namespace)
88 with engine.connect() as connection:
89 # `Any` to make mypy ignore the line below, can't use type: ignore
90 dbapi: Any = connection.connection
91 try:
92 dsn = dbapi.get_dsn_parameters()
93 except (AttributeError, KeyError) as err:
94 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err
95 if namespace is None:
96 query = sql.select(sql.func.current_schema())
97 namespace = connection.execute(query).scalar()
98 query_text = "SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';"
99 if not connection.execute(sqlalchemy.text(query_text)).scalar():
100 raise RuntimeError(
101 "The Butler PostgreSQL backend requires the btree_gist extension. "
102 "As extensions are enabled per-database, this may require an administrator to run "
103 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is "
104 " initialized."
105 )
106 self.namespace = namespace
107 self.dbname = dsn.get("dbname")
108 self._writeable = writeable
109 self._shrinker = NameShrinker(self.dialect.max_identifier_length)
111 @classmethod
112 def makeEngine(
113 cls, uri: str | sqlalchemy.engine.URL, *, writeable: bool = True
114 ) -> sqlalchemy.engine.Engine:
115 return sqlalchemy.engine.create_engine(uri, pool_size=1)
117 @classmethod
118 def fromEngine(
119 cls,
120 engine: sqlalchemy.engine.Engine,
121 *,
122 origin: int,
123 namespace: str | None = None,
124 writeable: bool = True,
125 ) -> Database:
126 return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable)
128 @contextmanager
129 def _transaction(
130 self,
131 *,
132 interrupting: bool = False,
133 savepoint: bool = False,
134 lock: Iterable[sqlalchemy.schema.Table] = (),
135 for_temp_tables: bool = False,
136 ) -> Iterator[tuple[bool, sqlalchemy.engine.Connection]]:
137 with super()._transaction(interrupting=interrupting, savepoint=savepoint, lock=lock) as (
138 is_new,
139 connection,
140 ):
141 if is_new:
142 # pgbouncer with transaction-level pooling (which we aim to
143 # support) says that SET cannot be used, except for a list of
144 # "Startup parameters" that includes "timezone" (see
145 # https://www.pgbouncer.org/features.html#fnref:0). But I
146 # don't see "timezone" in PostgreSQL's list of parameters
147 # passed when creating a new connection
148 # (https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS).
149 # Given that the pgbouncer docs say, "PgBouncer detects their
150 # changes and so it can guarantee they remain consistent for
151 # the client", I assume we can use "SET TIMESPAN" and pgbouncer
152 # will take care of clients that share connections being set
153 # consistently. And if that assumption is wrong, we should
154 # still probably be okay, since all clients should be Butler
155 # clients, and they'll all be setting the same thing.
156 #
157 # The "SET TRANSACTION READ ONLY" should also be safe, because
158 # it only ever acts on the current transaction; I think it's
159 # not included in pgbouncer's declaration that SET is
160 # incompatible with transaction-level pooling because
161 # PostgreSQL actually considers SET TRANSACTION to be a
162 # fundamentally different statement from SET (they have their
163 # own distinct doc pages, at least).
164 if not (self.isWriteable() or for_temp_tables):
165 # PostgreSQL permits writing to temporary tables inside
166 # read-only transactions, but it doesn't permit creating
167 # them.
168 with closing(connection.connection.cursor()) as cursor:
169 cursor.execute("SET TRANSACTION READ ONLY")
170 cursor.execute("SET TIME ZONE 0")
171 else:
172 with closing(connection.connection.cursor()) as cursor:
173 # Make timestamps UTC, because we didn't use TIMESTAMPZ
174 # for the column type. When we can tolerate a schema
175 # change, we should change that type and remove this
176 # line.
177 cursor.execute("SET TIME ZONE 0")
178 yield is_new, connection
180 @contextmanager
181 def temporary_table(
182 self, spec: ddl.TableSpec, name: str | None = None
183 ) -> Iterator[sqlalchemy.schema.Table]:
184 # Docstring inherited.
185 with self.transaction(for_temp_tables=True), super().temporary_table(spec, name) as table:
186 yield table
188 def _lockTables(
189 self, connection: sqlalchemy.engine.Connection, tables: Iterable[sqlalchemy.schema.Table] = ()
190 ) -> None:
191 # Docstring inherited.
192 for table in tables:
193 connection.execute(sqlalchemy.text(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE"))
195 def isWriteable(self) -> bool:
196 return self._writeable
198 def __str__(self) -> str:
199 return f"PostgreSQL@{self.dbname}:{self.namespace}"
201 def shrinkDatabaseEntityName(self, original: str) -> str:
202 return self._shrinker.shrink(original)
204 def expandDatabaseEntityName(self, shrunk: str) -> str:
205 return self._shrinker.expand(shrunk)
207 def _convertExclusionConstraintSpec(
208 self,
209 table: str,
210 spec: tuple[str | type[TimespanDatabaseRepresentation], ...],
211 metadata: sqlalchemy.MetaData,
212 ) -> sqlalchemy.schema.Constraint:
213 # Docstring inherited.
214 args: list[tuple[sqlalchemy.schema.Column, str]] = []
215 names = ["excl"]
216 for item in spec:
217 if isinstance(item, str):
218 args.append((sqlalchemy.schema.Column(item), "="))
219 names.append(item)
220 elif issubclass(item, TimespanDatabaseRepresentation):
221 assert item is self.getTimespanRepresentation()
222 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&"))
223 names.append(TimespanDatabaseRepresentation.NAME)
224 return sqlalchemy.dialects.postgresql.ExcludeConstraint(
225 *args,
226 name=self.shrinkDatabaseEntityName("_".join(names)),
227 )
229 def _make_temporary_table(
230 self,
231 connection: sqlalchemy.engine.Connection,
232 spec: ddl.TableSpec,
233 name: str | None = None,
234 **kwargs: Any,
235 ) -> sqlalchemy.schema.Table:
236 # Docstring inherited
237 # Adding ON COMMIT DROP here is really quite defensive: we already
238 # manually drop the table at the end of the temporary_table context
239 # manager, and that will usually happen first. But this will guarantee
240 # that we drop the table at the end of the transaction even if the
241 # connection lasts longer, and that's good citizenship when connections
242 # may be multiplexed by e.g. pgbouncer.
243 return super()._make_temporary_table(connection, spec, name, postgresql_on_commit="DROP", **kwargs)
245 @classmethod
246 def getTimespanRepresentation(cls) -> type[TimespanDatabaseRepresentation]:
247 # Docstring inherited.
248 return _RangeTimespanRepresentation
250 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None:
251 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.")
252 if not rows:
253 return
254 # This uses special support for UPSERT in PostgreSQL backend:
255 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert
256 query = sqlalchemy.dialects.postgresql.dml.insert(table)
257 # In the SET clause assign all columns using special `excluded`
258 # pseudo-table. If some column in the table does not appear in the
259 # INSERT list this will set it to NULL.
260 excluded = query.excluded
261 data = {
262 column.name: getattr(excluded, column.name)
263 for column in table.columns
264 if column.name not in table.primary_key
265 }
266 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data)
267 with self._transaction() as (_, connection):
268 connection.execute(query, rows)
270 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict, primary_key_only: bool = False) -> int:
271 # Docstring inherited.
272 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.")
273 if not rows:
274 return 0
275 # Like `replace`, this uses UPSERT.
276 base_insert = sqlalchemy.dialects.postgresql.dml.insert(table)
277 if primary_key_only:
278 query = base_insert.on_conflict_do_nothing(constraint=table.primary_key)
279 else:
280 query = base_insert.on_conflict_do_nothing()
281 with self._transaction() as (_, connection):
282 return connection.execute(query, rows).rowcount
284 def constant_rows(
285 self,
286 fields: NamedValueAbstractSet[ddl.FieldSpec],
287 *rows: dict,
288 name: str | None = None,
289 ) -> sqlalchemy.sql.FromClause:
290 # Docstring inherited.
291 return super().constant_rows(fields, *rows, name=name)
294class _RangeTimespanType(sqlalchemy.TypeDecorator):
295 """A single-column `Timespan` representation usable only with
296 PostgreSQL.
298 This type should be able to take advantage of PostgreSQL's built-in
299 range operators, and the indexing and EXCLUSION table constraints built
300 off of them.
301 """
303 impl = sqlalchemy.dialects.postgresql.INT8RANGE
305 cache_ok = True
307 def process_bind_param(
308 self, value: Timespan | None, dialect: sqlalchemy.engine.Dialect
309 ) -> psycopg2.extras.NumericRange | None:
310 if value is None:
311 return None
312 if not isinstance(value, Timespan):
313 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.")
314 if value.isEmpty():
315 return psycopg2.extras.NumericRange(empty=True)
316 else:
317 converter = time_utils.TimeConverter()
318 assert value._nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__."
319 assert value._nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__."
320 lower = None if value._nsec[0] == converter.min_nsec else value._nsec[0]
321 upper = None if value._nsec[1] == converter.max_nsec else value._nsec[1]
322 return psycopg2.extras.NumericRange(lower=lower, upper=upper)
324 def process_result_value(
325 self, value: psycopg2.extras.NumericRange | None, dialect: sqlalchemy.engine.Dialect
326 ) -> Timespan | None:
327 if value is None:
328 return None
329 if value.isempty:
330 return Timespan.makeEmpty()
331 converter = time_utils.TimeConverter()
332 begin_nsec = converter.min_nsec if value.lower is None else value.lower
333 end_nsec = converter.max_nsec if value.upper is None else value.upper
334 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec))
337class _RangeTimespanRepresentation(TimespanDatabaseRepresentation):
338 """An implementation of `TimespanDatabaseRepresentation` that uses
339 `_RangeTimespanType` to store a timespan in a single
340 PostgreSQL-specific field.
342 Parameters
343 ----------
344 column : `sqlalchemy.sql.ColumnElement`
345 SQLAlchemy object representing the column.
346 """
348 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str):
349 self.column = column
350 self._name = name
352 __slots__ = ("column", "_name")
354 @classmethod
355 def makeFieldSpecs(
356 cls, nullable: bool, name: str | None = None, **kwargs: Any
357 ) -> tuple[ddl.FieldSpec, ...]:
358 # Docstring inherited.
359 if name is None:
360 name = cls.NAME
361 return (
362 ddl.FieldSpec(
363 name,
364 dtype=_RangeTimespanType,
365 nullable=nullable,
366 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")),
367 **kwargs,
368 ),
369 )
371 @classmethod
372 def getFieldNames(cls, name: str | None = None) -> tuple[str, ...]:
373 # Docstring inherited.
374 if name is None:
375 name = cls.NAME
376 return (name,)
378 @classmethod
379 def update(
380 cls, extent: Timespan | None, name: str | None = None, result: dict[str, Any] | None = None
381 ) -> dict[str, Any]:
382 # Docstring inherited.
383 if name is None:
384 name = cls.NAME
385 if result is None:
386 result = {}
387 result[name] = extent
388 return result
390 @classmethod
391 def extract(cls, mapping: Mapping[str, Any], name: str | None = None) -> Timespan | None:
392 # Docstring inherited.
393 if name is None:
394 name = cls.NAME
395 return mapping[name]
397 @classmethod
398 def fromLiteral(cls, timespan: Timespan | None) -> _RangeTimespanRepresentation:
399 # Docstring inherited.
400 if timespan is None:
401 return cls(column=sqlalchemy.sql.null(), name=cls.NAME)
402 return cls(
403 column=sqlalchemy.sql.cast(
404 sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType), type_=_RangeTimespanType
405 ),
406 name=cls.NAME,
407 )
409 @classmethod
410 def from_columns(
411 cls, columns: sqlalchemy.sql.ColumnCollection, name: str | None = None
412 ) -> _RangeTimespanRepresentation:
413 # Docstring inherited.
414 if name is None:
415 name = cls.NAME
416 return cls(columns[name], name)
418 @property
419 def name(self) -> str:
420 # Docstring inherited.
421 return self._name
423 def isNull(self) -> sqlalchemy.sql.ColumnElement:
424 # Docstring inherited.
425 return self.column.is_(None)
427 def isEmpty(self) -> sqlalchemy.sql.ColumnElement:
428 # Docstring inherited
429 return sqlalchemy.sql.func.isempty(self.column)
431 def __lt__(
432 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement
433 ) -> sqlalchemy.sql.ColumnElement:
434 # Docstring inherited.
435 if isinstance(other, sqlalchemy.sql.ColumnElement):
436 return sqlalchemy.sql.and_(
437 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)),
438 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)),
439 sqlalchemy.sql.func.upper(self.column) <= other,
440 )
441 else:
442 return self.column << other.column
444 def __gt__(
445 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement
446 ) -> sqlalchemy.sql.ColumnElement:
447 # Docstring inherited.
448 if isinstance(other, sqlalchemy.sql.ColumnElement):
449 return sqlalchemy.sql.and_(
450 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)),
451 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)),
452 sqlalchemy.sql.func.lower(self.column) > other,
453 )
454 else:
455 return self.column >> other.column
457 def overlaps(
458 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement
459 ) -> sqlalchemy.sql.ColumnElement:
460 # Docstring inherited.
461 if not isinstance(other, _RangeTimespanRepresentation):
462 return self.contains(other)
463 return self.column.overlaps(other.column)
465 def contains(
466 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement
467 ) -> sqlalchemy.sql.ColumnElement:
468 # Docstring inherited
469 if isinstance(other, _RangeTimespanRepresentation):
470 return self.column.contains(other.column)
471 else:
472 return self.column.contains(other)
474 def lower(self) -> sqlalchemy.sql.ColumnElement:
475 # Docstring inherited.
476 return sqlalchemy.sql.functions.coalesce(
477 sqlalchemy.sql.func.lower(self.column), sqlalchemy.sql.literal(0)
478 )
480 def upper(self) -> sqlalchemy.sql.ColumnElement:
481 # Docstring inherited.
482 return sqlalchemy.sql.functions.coalesce(
483 sqlalchemy.sql.func.upper(self.column), sqlalchemy.sql.literal(0)
484 )
486 def flatten(self, name: str | None = None) -> tuple[sqlalchemy.sql.ColumnElement]:
487 # Docstring inherited.
488 if name is None:
489 return (self.column,)
490 else:
491 return (self.column.label(name),)