Coverage for python/lsst/daf/butler/registry/databases/postgresql.py: 31%
198 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-05 11:05 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-05 11:05 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29from ... import ddl, time_utils
31__all__ = ["PostgresqlDatabase"]
33from collections.abc import Iterable, Iterator, Mapping
34from contextlib import closing, contextmanager
35from typing import Any
37import psycopg2
38import sqlalchemy
39import sqlalchemy.dialects.postgresql
40from sqlalchemy import sql
42from ..._named import NamedValueAbstractSet
43from ..._timespan import Timespan, TimespanDatabaseRepresentation
44from ..interfaces import Database
45from ..nameShrinker import NameShrinker
48class PostgresqlDatabase(Database):
49 """An implementation of the `Database` interface for PostgreSQL.
51 Parameters
52 ----------
53 connection : `sqlalchemy.engine.Connection`
54 An existing connection created by a previous call to `connect`.
55 origin : `int`
56 An integer ID that should be used as the default for any datasets,
57 quanta, or other entities that use a (autoincrement, origin) compound
58 primary key.
59 namespace : `str`, optional
60 The namespace (schema) this database is associated with. If `None`,
61 the default schema for the connection is used (which may be `None`).
62 writeable : `bool`, optional
63 If `True`, allow write operations on the database, including
64 ``CREATE TABLE``.
66 Notes
67 -----
68 This currently requires the psycopg2 driver to be used as the backend for
69 SQLAlchemy. Running the tests for this class requires the
70 ``testing.postgresql`` be installed, which we assume indicates that a
71 PostgreSQL server is installed and can be run locally in userspace.
73 Some functionality provided by this class (and used by `Registry`) requires
74 the ``btree_gist`` PostgreSQL server extension to be installed an enabled
75 on the database being connected to; this is checked at connection time.
76 """
78 def __init__(
79 self,
80 *,
81 engine: sqlalchemy.engine.Engine,
82 origin: int,
83 namespace: str | None = None,
84 writeable: bool = True,
85 ):
86 super().__init__(origin=origin, engine=engine, namespace=namespace)
87 with engine.connect() as connection:
88 # `Any` to make mypy ignore the line below, can't use type: ignore
89 dbapi: Any = connection.connection
90 try:
91 dsn = dbapi.get_dsn_parameters()
92 except (AttributeError, KeyError) as err:
93 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err
94 if namespace is None:
95 query = sql.select(sql.func.current_schema())
96 namespace = connection.execute(query).scalar()
97 query_text = "SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';"
98 if not connection.execute(sqlalchemy.text(query_text)).scalar():
99 raise RuntimeError(
100 "The Butler PostgreSQL backend requires the btree_gist extension. "
101 "As extensions are enabled per-database, this may require an administrator to run "
102 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is "
103 " initialized."
104 )
105 self.namespace = namespace
106 self.dbname = dsn.get("dbname")
107 self._writeable = writeable
108 self._shrinker = NameShrinker(self.dialect.max_identifier_length)
110 @classmethod
111 def makeEngine(
112 cls, uri: str | sqlalchemy.engine.URL, *, writeable: bool = True
113 ) -> sqlalchemy.engine.Engine:
114 return sqlalchemy.engine.create_engine(uri, pool_size=1)
116 @classmethod
117 def fromEngine(
118 cls,
119 engine: sqlalchemy.engine.Engine,
120 *,
121 origin: int,
122 namespace: str | None = None,
123 writeable: bool = True,
124 ) -> Database:
125 return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable)
127 @contextmanager
128 def _transaction(
129 self,
130 *,
131 interrupting: bool = False,
132 savepoint: bool = False,
133 lock: Iterable[sqlalchemy.schema.Table] = (),
134 for_temp_tables: bool = False,
135 ) -> Iterator[tuple[bool, sqlalchemy.engine.Connection]]:
136 with super()._transaction(interrupting=interrupting, savepoint=savepoint, lock=lock) as (
137 is_new,
138 connection,
139 ):
140 if is_new:
141 # pgbouncer with transaction-level pooling (which we aim to
142 # support) says that SET cannot be used, except for a list of
143 # "Startup parameters" that includes "timezone" (see
144 # https://www.pgbouncer.org/features.html#fnref:0). But I
145 # don't see "timezone" in PostgreSQL's list of parameters
146 # passed when creating a new connection
147 # (https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS).
148 # Given that the pgbouncer docs say, "PgBouncer detects their
149 # changes and so it can guarantee they remain consistent for
150 # the client", I assume we can use "SET TIMESPAN" and pgbouncer
151 # will take care of clients that share connections being set
152 # consistently. And if that assumption is wrong, we should
153 # still probably be okay, since all clients should be Butler
154 # clients, and they'll all be setting the same thing.
155 #
156 # The "SET TRANSACTION READ ONLY" should also be safe, because
157 # it only ever acts on the current transaction; I think it's
158 # not included in pgbouncer's declaration that SET is
159 # incompatible with transaction-level pooling because
160 # PostgreSQL actually considers SET TRANSACTION to be a
161 # fundamentally different statement from SET (they have their
162 # own distinct doc pages, at least).
163 if not (self.isWriteable() or for_temp_tables):
164 # PostgreSQL permits writing to temporary tables inside
165 # read-only transactions, but it doesn't permit creating
166 # them.
167 with closing(connection.connection.cursor()) as cursor:
168 cursor.execute("SET TRANSACTION READ ONLY")
169 cursor.execute("SET TIME ZONE 0")
170 else:
171 with closing(connection.connection.cursor()) as cursor:
172 # Make timestamps UTC, because we didn't use TIMESTAMPZ
173 # for the column type. When we can tolerate a schema
174 # change, we should change that type and remove this
175 # line.
176 cursor.execute("SET TIME ZONE 0")
177 yield is_new, connection
179 @contextmanager
180 def temporary_table(
181 self, spec: ddl.TableSpec, name: str | None = None
182 ) -> Iterator[sqlalchemy.schema.Table]:
183 # Docstring inherited.
184 with self.transaction(for_temp_tables=True), super().temporary_table(spec, name) as table:
185 yield table
187 def _lockTables(
188 self, connection: sqlalchemy.engine.Connection, tables: Iterable[sqlalchemy.schema.Table] = ()
189 ) -> None:
190 # Docstring inherited.
191 for table in tables:
192 connection.execute(sqlalchemy.text(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE"))
194 def isWriteable(self) -> bool:
195 return self._writeable
197 def __str__(self) -> str:
198 return f"PostgreSQL@{self.dbname}:{self.namespace}"
200 def shrinkDatabaseEntityName(self, original: str) -> str:
201 return self._shrinker.shrink(original)
203 def expandDatabaseEntityName(self, shrunk: str) -> str:
204 return self._shrinker.expand(shrunk)
206 def _convertExclusionConstraintSpec(
207 self,
208 table: str,
209 spec: tuple[str | type[TimespanDatabaseRepresentation], ...],
210 metadata: sqlalchemy.MetaData,
211 ) -> sqlalchemy.schema.Constraint:
212 # Docstring inherited.
213 args: list[tuple[sqlalchemy.schema.Column, str]] = []
214 names = ["excl"]
215 for item in spec:
216 if isinstance(item, str):
217 args.append((sqlalchemy.schema.Column(item), "="))
218 names.append(item)
219 elif issubclass(item, TimespanDatabaseRepresentation):
220 assert item is self.getTimespanRepresentation()
221 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&"))
222 names.append(TimespanDatabaseRepresentation.NAME)
223 return sqlalchemy.dialects.postgresql.ExcludeConstraint(
224 *args,
225 name=self.shrinkDatabaseEntityName("_".join(names)),
226 )
228 def _make_temporary_table(
229 self,
230 connection: sqlalchemy.engine.Connection,
231 spec: ddl.TableSpec,
232 name: str | None = None,
233 **kwargs: Any,
234 ) -> sqlalchemy.schema.Table:
235 # Docstring inherited
236 # Adding ON COMMIT DROP here is really quite defensive: we already
237 # manually drop the table at the end of the temporary_table context
238 # manager, and that will usually happen first. But this will guarantee
239 # that we drop the table at the end of the transaction even if the
240 # connection lasts longer, and that's good citizenship when connections
241 # may be multiplexed by e.g. pgbouncer.
242 return super()._make_temporary_table(connection, spec, name, postgresql_on_commit="DROP", **kwargs)
244 @classmethod
245 def getTimespanRepresentation(cls) -> type[TimespanDatabaseRepresentation]:
246 # Docstring inherited.
247 return _RangeTimespanRepresentation
249 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None:
250 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.")
251 if not rows:
252 return
253 # This uses special support for UPSERT in PostgreSQL backend:
254 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert
255 query = sqlalchemy.dialects.postgresql.dml.insert(table)
256 # In the SET clause assign all columns using special `excluded`
257 # pseudo-table. If some column in the table does not appear in the
258 # INSERT list this will set it to NULL.
259 excluded = query.excluded
260 data = {
261 column.name: getattr(excluded, column.name)
262 for column in table.columns
263 if column.name not in table.primary_key
264 }
265 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data)
266 with self._transaction() as (_, connection):
267 connection.execute(query, rows)
269 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict, primary_key_only: bool = False) -> int:
270 # Docstring inherited.
271 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.")
272 if not rows:
273 return 0
274 # Like `replace`, this uses UPSERT.
275 base_insert = sqlalchemy.dialects.postgresql.dml.insert(table)
276 if primary_key_only:
277 query = base_insert.on_conflict_do_nothing(constraint=table.primary_key)
278 else:
279 query = base_insert.on_conflict_do_nothing()
280 with self._transaction() as (_, connection):
281 return connection.execute(query, rows).rowcount
283 def constant_rows(
284 self,
285 fields: NamedValueAbstractSet[ddl.FieldSpec],
286 *rows: dict,
287 name: str | None = None,
288 ) -> sqlalchemy.sql.FromClause:
289 # Docstring inherited.
290 return super().constant_rows(fields, *rows, name=name)
293class _RangeTimespanType(sqlalchemy.TypeDecorator):
294 """A single-column `Timespan` representation usable only with
295 PostgreSQL.
297 This type should be able to take advantage of PostgreSQL's built-in
298 range operators, and the indexing and EXCLUSION table constraints built
299 off of them.
300 """
302 impl = sqlalchemy.dialects.postgresql.INT8RANGE
304 cache_ok = True
306 def process_bind_param(
307 self, value: Timespan | None, dialect: sqlalchemy.engine.Dialect
308 ) -> psycopg2.extras.NumericRange | None:
309 if value is None:
310 return None
311 if not isinstance(value, Timespan):
312 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.")
313 if value.isEmpty():
314 return psycopg2.extras.NumericRange(empty=True)
315 else:
316 converter = time_utils.TimeConverter()
317 assert value._nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__."
318 assert value._nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__."
319 lower = None if value._nsec[0] == converter.min_nsec else value._nsec[0]
320 upper = None if value._nsec[1] == converter.max_nsec else value._nsec[1]
321 return psycopg2.extras.NumericRange(lower=lower, upper=upper)
323 def process_result_value(
324 self, value: psycopg2.extras.NumericRange | None, dialect: sqlalchemy.engine.Dialect
325 ) -> Timespan | None:
326 if value is None:
327 return None
328 if value.isempty:
329 return Timespan.makeEmpty()
330 converter = time_utils.TimeConverter()
331 begin_nsec = converter.min_nsec if value.lower is None else value.lower
332 end_nsec = converter.max_nsec if value.upper is None else value.upper
333 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec))
336class _RangeTimespanRepresentation(TimespanDatabaseRepresentation):
337 """An implementation of `TimespanDatabaseRepresentation` that uses
338 `_RangeTimespanType` to store a timespan in a single
339 PostgreSQL-specific field.
341 Parameters
342 ----------
343 column : `sqlalchemy.sql.ColumnElement`
344 SQLAlchemy object representing the column.
345 """
347 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str):
348 self.column = column
349 self._name = name
351 __slots__ = ("column", "_name")
353 @classmethod
354 def makeFieldSpecs(
355 cls, nullable: bool, name: str | None = None, **kwargs: Any
356 ) -> tuple[ddl.FieldSpec, ...]:
357 # Docstring inherited.
358 if name is None:
359 name = cls.NAME
360 return (
361 ddl.FieldSpec(
362 name,
363 dtype=_RangeTimespanType,
364 nullable=nullable,
365 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")),
366 **kwargs,
367 ),
368 )
370 @classmethod
371 def getFieldNames(cls, name: str | None = None) -> tuple[str, ...]:
372 # Docstring inherited.
373 if name is None:
374 name = cls.NAME
375 return (name,)
377 @classmethod
378 def update(
379 cls, extent: Timespan | None, name: str | None = None, result: dict[str, Any] | None = None
380 ) -> dict[str, Any]:
381 # Docstring inherited.
382 if name is None:
383 name = cls.NAME
384 if result is None:
385 result = {}
386 result[name] = extent
387 return result
389 @classmethod
390 def extract(cls, mapping: Mapping[str, Any], name: str | None = None) -> Timespan | None:
391 # Docstring inherited.
392 if name is None:
393 name = cls.NAME
394 return mapping[name]
396 @classmethod
397 def fromLiteral(cls, timespan: Timespan | None) -> _RangeTimespanRepresentation:
398 # Docstring inherited.
399 if timespan is None:
400 return cls(column=sqlalchemy.sql.null(), name=cls.NAME)
401 return cls(
402 column=sqlalchemy.sql.cast(
403 sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType), type_=_RangeTimespanType
404 ),
405 name=cls.NAME,
406 )
408 @classmethod
409 def from_columns(
410 cls, columns: sqlalchemy.sql.ColumnCollection, name: str | None = None
411 ) -> _RangeTimespanRepresentation:
412 # Docstring inherited.
413 if name is None:
414 name = cls.NAME
415 return cls(columns[name], name)
417 @property
418 def name(self) -> str:
419 # Docstring inherited.
420 return self._name
422 def isNull(self) -> sqlalchemy.sql.ColumnElement:
423 # Docstring inherited.
424 return self.column.is_(None)
426 def isEmpty(self) -> sqlalchemy.sql.ColumnElement:
427 # Docstring inherited
428 return sqlalchemy.sql.func.isempty(self.column)
430 def __lt__(
431 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement
432 ) -> sqlalchemy.sql.ColumnElement:
433 # Docstring inherited.
434 if isinstance(other, sqlalchemy.sql.ColumnElement):
435 return sqlalchemy.sql.and_(
436 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)),
437 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)),
438 sqlalchemy.sql.func.upper(self.column) <= other,
439 )
440 else:
441 return self.column << other.column
443 def __gt__(
444 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement
445 ) -> sqlalchemy.sql.ColumnElement:
446 # Docstring inherited.
447 if isinstance(other, sqlalchemy.sql.ColumnElement):
448 return sqlalchemy.sql.and_(
449 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)),
450 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)),
451 sqlalchemy.sql.func.lower(self.column) > other,
452 )
453 else:
454 return self.column >> other.column
456 def overlaps(
457 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement
458 ) -> sqlalchemy.sql.ColumnElement:
459 # Docstring inherited.
460 if not isinstance(other, _RangeTimespanRepresentation):
461 return self.contains(other)
462 return self.column.overlaps(other.column)
464 def contains(
465 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement
466 ) -> sqlalchemy.sql.ColumnElement:
467 # Docstring inherited
468 if isinstance(other, _RangeTimespanRepresentation):
469 return self.column.contains(other.column)
470 else:
471 return self.column.contains(other)
473 def lower(self) -> sqlalchemy.sql.ColumnElement:
474 # Docstring inherited.
475 return sqlalchemy.sql.functions.coalesce(
476 sqlalchemy.sql.func.lower(self.column), sqlalchemy.sql.literal(0)
477 )
479 def upper(self) -> sqlalchemy.sql.ColumnElement:
480 # Docstring inherited.
481 return sqlalchemy.sql.functions.coalesce(
482 sqlalchemy.sql.func.upper(self.column), sqlalchemy.sql.literal(0)
483 )
485 def flatten(self, name: str | None = None) -> tuple[sqlalchemy.sql.ColumnElement]:
486 # Docstring inherited.
487 if name is None:
488 return (self.column,)
489 else:
490 return (self.column.label(name),)