Coverage for python/lsst/daf/butler/registry/databases/postgresql.py: 31%
197 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-10-02 07:59 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-10-02 07:59 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = ["PostgresqlDatabase"]
31from collections.abc import Iterable, Iterator, Mapping
32from contextlib import closing, contextmanager
33from typing import Any
35import psycopg2
36import sqlalchemy
37import sqlalchemy.dialects.postgresql
38from sqlalchemy import sql
40from ...core import Timespan, TimespanDatabaseRepresentation, ddl, time_utils
41from ...core.named import NamedValueAbstractSet
42from ..interfaces import Database
43from ..nameShrinker import NameShrinker
46class PostgresqlDatabase(Database):
47 """An implementation of the `Database` interface for PostgreSQL.
49 Parameters
50 ----------
51 connection : `sqlalchemy.engine.Connection`
52 An existing connection created by a previous call to `connect`.
53 origin : `int`
54 An integer ID that should be used as the default for any datasets,
55 quanta, or other entities that use a (autoincrement, origin) compound
56 primary key.
57 namespace : `str`, optional
58 The namespace (schema) this database is associated with. If `None`,
59 the default schema for the connection is used (which may be `None`).
60 writeable : `bool`, optional
61 If `True`, allow write operations on the database, including
62 ``CREATE TABLE``.
64 Notes
65 -----
66 This currently requires the psycopg2 driver to be used as the backend for
67 SQLAlchemy. Running the tests for this class requires the
68 ``testing.postgresql`` be installed, which we assume indicates that a
69 PostgreSQL server is installed and can be run locally in userspace.
71 Some functionality provided by this class (and used by `Registry`) requires
72 the ``btree_gist`` PostgreSQL server extension to be installed an enabled
73 on the database being connected to; this is checked at connection time.
74 """
76 def __init__(
77 self,
78 *,
79 engine: sqlalchemy.engine.Engine,
80 origin: int,
81 namespace: str | None = None,
82 writeable: bool = True,
83 ):
84 super().__init__(origin=origin, engine=engine, namespace=namespace)
85 with engine.connect() as connection:
86 # `Any` to make mypy ignore the line below, can't use type: ignore
87 dbapi: Any = connection.connection
88 try:
89 dsn = dbapi.get_dsn_parameters()
90 except (AttributeError, KeyError) as err:
91 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err
92 if namespace is None:
93 query = sql.select(sql.func.current_schema())
94 namespace = connection.execute(query).scalar()
95 query_text = "SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';"
96 if not connection.execute(sqlalchemy.text(query_text)).scalar():
97 raise RuntimeError(
98 "The Butler PostgreSQL backend requires the btree_gist extension. "
99 "As extensions are enabled per-database, this may require an administrator to run "
100 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is "
101 " initialized."
102 )
103 self.namespace = namespace
104 self.dbname = dsn.get("dbname")
105 self._writeable = writeable
106 self._shrinker = NameShrinker(self.dialect.max_identifier_length)
108 @classmethod
109 def makeEngine(
110 cls, uri: str | sqlalchemy.engine.URL, *, writeable: bool = True
111 ) -> sqlalchemy.engine.Engine:
112 return sqlalchemy.engine.create_engine(uri, pool_size=1)
114 @classmethod
115 def fromEngine(
116 cls,
117 engine: sqlalchemy.engine.Engine,
118 *,
119 origin: int,
120 namespace: str | None = None,
121 writeable: bool = True,
122 ) -> Database:
123 return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable)
125 @contextmanager
126 def _transaction(
127 self,
128 *,
129 interrupting: bool = False,
130 savepoint: bool = False,
131 lock: Iterable[sqlalchemy.schema.Table] = (),
132 for_temp_tables: bool = False,
133 ) -> Iterator[tuple[bool, sqlalchemy.engine.Connection]]:
134 with super()._transaction(interrupting=interrupting, savepoint=savepoint, lock=lock) as (
135 is_new,
136 connection,
137 ):
138 if is_new:
139 # pgbouncer with transaction-level pooling (which we aim to
140 # support) says that SET cannot be used, except for a list of
141 # "Startup parameters" that includes "timezone" (see
142 # https://www.pgbouncer.org/features.html#fnref:0). But I
143 # don't see "timezone" in PostgreSQL's list of parameters
144 # passed when creating a new connection
145 # (https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS).
146 # Given that the pgbouncer docs say, "PgBouncer detects their
147 # changes and so it can guarantee they remain consistent for
148 # the client", I assume we can use "SET TIMESPAN" and pgbouncer
149 # will take care of clients that share connections being set
150 # consistently. And if that assumption is wrong, we should
151 # still probably be okay, since all clients should be Butler
152 # clients, and they'll all be setting the same thing.
153 #
154 # The "SET TRANSACTION READ ONLY" should also be safe, because
155 # it only ever acts on the current transaction; I think it's
156 # not included in pgbouncer's declaration that SET is
157 # incompatible with transaction-level pooling because
158 # PostgreSQL actually considers SET TRANSACTION to be a
159 # fundamentally different statement from SET (they have their
160 # own distinct doc pages, at least).
161 if not (self.isWriteable() or for_temp_tables):
162 # PostgreSQL permits writing to temporary tables inside
163 # read-only transactions, but it doesn't permit creating
164 # them.
165 with closing(connection.connection.cursor()) as cursor:
166 cursor.execute("SET TRANSACTION READ ONLY")
167 cursor.execute("SET TIME ZONE 0")
168 else:
169 with closing(connection.connection.cursor()) as cursor:
170 # Make timestamps UTC, because we didn't use TIMESTAMPZ
171 # for the column type. When we can tolerate a schema
172 # change, we should change that type and remove this
173 # line.
174 cursor.execute("SET TIME ZONE 0")
175 yield is_new, connection
177 @contextmanager
178 def temporary_table(
179 self, spec: ddl.TableSpec, name: str | None = None
180 ) -> Iterator[sqlalchemy.schema.Table]:
181 # Docstring inherited.
182 with self.transaction(for_temp_tables=True), super().temporary_table(spec, name) as table:
183 yield table
185 def _lockTables(
186 self, connection: sqlalchemy.engine.Connection, tables: Iterable[sqlalchemy.schema.Table] = ()
187 ) -> None:
188 # Docstring inherited.
189 for table in tables:
190 connection.execute(sqlalchemy.text(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE"))
192 def isWriteable(self) -> bool:
193 return self._writeable
195 def __str__(self) -> str:
196 return f"PostgreSQL@{self.dbname}:{self.namespace}"
198 def shrinkDatabaseEntityName(self, original: str) -> str:
199 return self._shrinker.shrink(original)
201 def expandDatabaseEntityName(self, shrunk: str) -> str:
202 return self._shrinker.expand(shrunk)
204 def _convertExclusionConstraintSpec(
205 self,
206 table: str,
207 spec: tuple[str | type[TimespanDatabaseRepresentation], ...],
208 metadata: sqlalchemy.MetaData,
209 ) -> sqlalchemy.schema.Constraint:
210 # Docstring inherited.
211 args: list[tuple[sqlalchemy.schema.Column, str]] = []
212 names = ["excl"]
213 for item in spec:
214 if isinstance(item, str):
215 args.append((sqlalchemy.schema.Column(item), "="))
216 names.append(item)
217 elif issubclass(item, TimespanDatabaseRepresentation):
218 assert item is self.getTimespanRepresentation()
219 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&"))
220 names.append(TimespanDatabaseRepresentation.NAME)
221 return sqlalchemy.dialects.postgresql.ExcludeConstraint(
222 *args,
223 name=self.shrinkDatabaseEntityName("_".join(names)),
224 )
226 def _make_temporary_table(
227 self,
228 connection: sqlalchemy.engine.Connection,
229 spec: ddl.TableSpec,
230 name: str | None = None,
231 **kwargs: Any,
232 ) -> sqlalchemy.schema.Table:
233 # Docstring inherited
234 # Adding ON COMMIT DROP here is really quite defensive: we already
235 # manually drop the table at the end of the temporary_table context
236 # manager, and that will usually happen first. But this will guarantee
237 # that we drop the table at the end of the transaction even if the
238 # connection lasts longer, and that's good citizenship when connections
239 # may be multiplexed by e.g. pgbouncer.
240 return super()._make_temporary_table(connection, spec, name, postgresql_on_commit="DROP", **kwargs)
242 @classmethod
243 def getTimespanRepresentation(cls) -> type[TimespanDatabaseRepresentation]:
244 # Docstring inherited.
245 return _RangeTimespanRepresentation
247 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None:
248 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.")
249 if not rows:
250 return
251 # This uses special support for UPSERT in PostgreSQL backend:
252 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert
253 query = sqlalchemy.dialects.postgresql.dml.insert(table)
254 # In the SET clause assign all columns using special `excluded`
255 # pseudo-table. If some column in the table does not appear in the
256 # INSERT list this will set it to NULL.
257 excluded = query.excluded
258 data = {
259 column.name: getattr(excluded, column.name)
260 for column in table.columns
261 if column.name not in table.primary_key
262 }
263 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data)
264 with self._transaction() as (_, connection):
265 connection.execute(query, rows)
267 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict, primary_key_only: bool = False) -> int:
268 # Docstring inherited.
269 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.")
270 if not rows:
271 return 0
272 # Like `replace`, this uses UPSERT.
273 base_insert = sqlalchemy.dialects.postgresql.dml.insert(table)
274 if primary_key_only:
275 query = base_insert.on_conflict_do_nothing(constraint=table.primary_key)
276 else:
277 query = base_insert.on_conflict_do_nothing()
278 with self._transaction() as (_, connection):
279 return connection.execute(query, rows).rowcount
281 def constant_rows(
282 self,
283 fields: NamedValueAbstractSet[ddl.FieldSpec],
284 *rows: dict,
285 name: str | None = None,
286 ) -> sqlalchemy.sql.FromClause:
287 # Docstring inherited.
288 return super().constant_rows(fields, *rows, name=name)
291class _RangeTimespanType(sqlalchemy.TypeDecorator):
292 """A single-column `Timespan` representation usable only with
293 PostgreSQL.
295 This type should be able to take advantage of PostgreSQL's built-in
296 range operators, and the indexing and EXCLUSION table constraints built
297 off of them.
298 """
300 impl = sqlalchemy.dialects.postgresql.INT8RANGE
302 cache_ok = True
304 def process_bind_param(
305 self, value: Timespan | None, dialect: sqlalchemy.engine.Dialect
306 ) -> psycopg2.extras.NumericRange | None:
307 if value is None:
308 return None
309 if not isinstance(value, Timespan):
310 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.")
311 if value.isEmpty():
312 return psycopg2.extras.NumericRange(empty=True)
313 else:
314 converter = time_utils.TimeConverter()
315 assert value._nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__."
316 assert value._nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__."
317 lower = None if value._nsec[0] == converter.min_nsec else value._nsec[0]
318 upper = None if value._nsec[1] == converter.max_nsec else value._nsec[1]
319 return psycopg2.extras.NumericRange(lower=lower, upper=upper)
321 def process_result_value(
322 self, value: psycopg2.extras.NumericRange | None, dialect: sqlalchemy.engine.Dialect
323 ) -> Timespan | None:
324 if value is None:
325 return None
326 if value.isempty:
327 return Timespan.makeEmpty()
328 converter = time_utils.TimeConverter()
329 begin_nsec = converter.min_nsec if value.lower is None else value.lower
330 end_nsec = converter.max_nsec if value.upper is None else value.upper
331 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec))
334class _RangeTimespanRepresentation(TimespanDatabaseRepresentation):
335 """An implementation of `TimespanDatabaseRepresentation` that uses
336 `_RangeTimespanType` to store a timespan in a single
337 PostgreSQL-specific field.
339 Parameters
340 ----------
341 column : `sqlalchemy.sql.ColumnElement`
342 SQLAlchemy object representing the column.
343 """
345 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str):
346 self.column = column
347 self._name = name
349 __slots__ = ("column", "_name")
351 @classmethod
352 def makeFieldSpecs(
353 cls, nullable: bool, name: str | None = None, **kwargs: Any
354 ) -> tuple[ddl.FieldSpec, ...]:
355 # Docstring inherited.
356 if name is None:
357 name = cls.NAME
358 return (
359 ddl.FieldSpec(
360 name,
361 dtype=_RangeTimespanType,
362 nullable=nullable,
363 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")),
364 **kwargs,
365 ),
366 )
368 @classmethod
369 def getFieldNames(cls, name: str | None = None) -> tuple[str, ...]:
370 # Docstring inherited.
371 if name is None:
372 name = cls.NAME
373 return (name,)
375 @classmethod
376 def update(
377 cls, extent: Timespan | None, name: str | None = None, result: dict[str, Any] | None = None
378 ) -> dict[str, Any]:
379 # Docstring inherited.
380 if name is None:
381 name = cls.NAME
382 if result is None:
383 result = {}
384 result[name] = extent
385 return result
387 @classmethod
388 def extract(cls, mapping: Mapping[str, Any], name: str | None = None) -> Timespan | None:
389 # Docstring inherited.
390 if name is None:
391 name = cls.NAME
392 return mapping[name]
394 @classmethod
395 def fromLiteral(cls, timespan: Timespan | None) -> _RangeTimespanRepresentation:
396 # Docstring inherited.
397 if timespan is None:
398 return cls(column=sqlalchemy.sql.null(), name=cls.NAME)
399 return cls(
400 column=sqlalchemy.sql.cast(
401 sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType), type_=_RangeTimespanType
402 ),
403 name=cls.NAME,
404 )
406 @classmethod
407 def from_columns(
408 cls, columns: sqlalchemy.sql.ColumnCollection, name: str | None = None
409 ) -> _RangeTimespanRepresentation:
410 # Docstring inherited.
411 if name is None:
412 name = cls.NAME
413 return cls(columns[name], name)
415 @property
416 def name(self) -> str:
417 # Docstring inherited.
418 return self._name
420 def isNull(self) -> sqlalchemy.sql.ColumnElement:
421 # Docstring inherited.
422 return self.column.is_(None)
424 def isEmpty(self) -> sqlalchemy.sql.ColumnElement:
425 # Docstring inherited
426 return sqlalchemy.sql.func.isempty(self.column)
428 def __lt__(
429 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement
430 ) -> sqlalchemy.sql.ColumnElement:
431 # Docstring inherited.
432 if isinstance(other, sqlalchemy.sql.ColumnElement):
433 return sqlalchemy.sql.and_(
434 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)),
435 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)),
436 sqlalchemy.sql.func.upper(self.column) <= other,
437 )
438 else:
439 return self.column << other.column
441 def __gt__(
442 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement
443 ) -> sqlalchemy.sql.ColumnElement:
444 # Docstring inherited.
445 if isinstance(other, sqlalchemy.sql.ColumnElement):
446 return sqlalchemy.sql.and_(
447 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)),
448 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)),
449 sqlalchemy.sql.func.lower(self.column) > other,
450 )
451 else:
452 return self.column >> other.column
454 def overlaps(
455 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement
456 ) -> sqlalchemy.sql.ColumnElement:
457 # Docstring inherited.
458 if not isinstance(other, _RangeTimespanRepresentation):
459 return self.contains(other)
460 return self.column.overlaps(other.column)
462 def contains(
463 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement
464 ) -> sqlalchemy.sql.ColumnElement:
465 # Docstring inherited
466 if isinstance(other, _RangeTimespanRepresentation):
467 return self.column.contains(other.column)
468 else:
469 return self.column.contains(other)
471 def lower(self) -> sqlalchemy.sql.ColumnElement:
472 # Docstring inherited.
473 return sqlalchemy.sql.functions.coalesce(
474 sqlalchemy.sql.func.lower(self.column), sqlalchemy.sql.literal(0)
475 )
477 def upper(self) -> sqlalchemy.sql.ColumnElement:
478 # Docstring inherited.
479 return sqlalchemy.sql.functions.coalesce(
480 sqlalchemy.sql.func.upper(self.column), sqlalchemy.sql.literal(0)
481 )
483 def flatten(self, name: str | None = None) -> tuple[sqlalchemy.sql.ColumnElement]:
484 # Docstring inherited.
485 if name is None:
486 return (self.column,)
487 else:
488 return (self.column.label(name),)