Coverage for python/lsst/daf/butler/registry/databases/postgresql.py: 33%
222 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-26 02:47 -0700
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-26 02:47 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29from sqlalchemy.sql.expression import ColumnElement as ColumnElement
31from ... import ddl, time_utils
33__all__ = ["PostgresqlDatabase"]
35import re
36from collections.abc import Callable, Iterable, Iterator, Mapping
37from contextlib import closing, contextmanager
38from typing import Any
40import psycopg2
41import sqlalchemy
42import sqlalchemy.dialects.postgresql
43from sqlalchemy import sql
45from ..._named import NamedValueAbstractSet
46from ..._timespan import Timespan
47from ...name_shrinker import NameShrinker
48from ...timespan_database_representation import TimespanDatabaseRepresentation
49from ..interfaces import Database
51_SERVER_VERSION_REGEX = re.compile(r"(?P<major>\d+)\.(?P<minor>\d+)")
54class PostgresqlDatabase(Database):
55 """An implementation of the `Database` interface for PostgreSQL.
57 Parameters
58 ----------
59 engine : `sqlalchemy.engine.Engine`
60 Engine to use for this connection.
61 origin : `int`
62 An integer ID that should be used as the default for any datasets,
63 quanta, or other entities that use a (autoincrement, origin) compound
64 primary key.
65 namespace : `str`, optional
66 The namespace (schema) this database is associated with. If `None`,
67 the default schema for the connection is used (which may be `None`).
68 writeable : `bool`, optional
69 If `True`, allow write operations on the database, including
70 ``CREATE TABLE``.
72 Notes
73 -----
74 This currently requires the psycopg2 driver to be used as the backend for
75 SQLAlchemy. Running the tests for this class requires the
76 ``testing.postgresql`` be installed, which we assume indicates that a
77 PostgreSQL server is installed and can be run locally in userspace.
79 Some functionality provided by this class (and used by `Registry`) requires
80 the ``btree_gist`` PostgreSQL server extension to be installed an enabled
81 on the database being connected to; this is checked at connection time.
82 """
84 def __init__(
85 self,
86 *,
87 engine: sqlalchemy.engine.Engine,
88 origin: int,
89 namespace: str | None = None,
90 writeable: bool = True,
91 ):
92 with engine.connect() as connection:
93 # `Any` to make mypy ignore the line below, can't use type: ignore
94 dbapi: Any = connection.connection
95 try:
96 dsn = dbapi.get_dsn_parameters()
97 except (AttributeError, KeyError) as err:
98 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err
99 if namespace is None:
100 query = sql.select(sql.func.current_schema())
101 namespace = connection.execute(query).scalar()
102 query_text = "SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';"
103 if not connection.execute(sqlalchemy.text(query_text)).scalar():
104 raise RuntimeError(
105 "The Butler PostgreSQL backend requires the btree_gist extension. "
106 "As extensions are enabled per-database, this may require an administrator to run "
107 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is "
108 " initialized."
109 )
110 raw_pg_version = connection.execute(sqlalchemy.text("SHOW server_version")).scalar()
111 if raw_pg_version is not None and (m := _SERVER_VERSION_REGEX.search(raw_pg_version)):
112 pg_version = (int(m.group("major")), int(m.group("minor")))
113 else:
114 raise RuntimeError("Failed to get PostgreSQL server version.")
115 self._init(
116 engine=engine,
117 origin=origin,
118 namespace=namespace,
119 writeable=writeable,
120 dbname=dsn.get("dbname"),
121 metadata=None,
122 pg_version=pg_version,
123 )
125 def _init(
126 self,
127 *,
128 engine: sqlalchemy.engine.Engine,
129 origin: int,
130 namespace: str | None = None,
131 writeable: bool = True,
132 dbname: str,
133 metadata: sqlalchemy.schema.MetaData | None,
134 pg_version: tuple[int, int],
135 ) -> None:
136 # Initialization logic shared between ``__init__`` and ``clone``.
137 super().__init__(origin=origin, engine=engine, namespace=namespace, metadata=metadata)
138 self._writeable = writeable
139 self.dbname = dbname
140 self._pg_version = pg_version
141 self._shrinker = NameShrinker(self.dialect.max_identifier_length)
143 def clone(self) -> PostgresqlDatabase:
144 clone = self.__new__(type(self))
145 clone._init(
146 origin=self.origin,
147 engine=self._engine,
148 namespace=self.namespace,
149 writeable=self._writeable,
150 dbname=self.dbname,
151 metadata=self._metadata,
152 pg_version=self._pg_version,
153 )
154 return clone
156 @classmethod
157 def makeEngine(
158 cls, uri: str | sqlalchemy.engine.URL, *, writeable: bool = True
159 ) -> sqlalchemy.engine.Engine:
160 return sqlalchemy.engine.create_engine(
161 uri,
162 # Prevent stale database connections from throwing exeptions, at
163 # the expense of a round trip to the database server each time we
164 # check out a session. Many services using the Butler operate in
165 # networks where connections are dropped when idle for some time.
166 pool_pre_ping=True,
167 # This engine and database connection pool can be shared between
168 # multiple Butler instances created via Butler.clone() or
169 # LabeledButlerFactory, and typically these will be used from
170 # multiple threads simultaneously. So we need to configure
171 # SQLAlchemy to pool connections for multi-threaded usage.
172 #
173 # This is not the maximum number of active connections --
174 # SQLAlchemy allows some additional overflow configured via the
175 # max_overflow parameter. pool_size is only the maximum number
176 # saved in the pool during periods of lower concurrency.
177 #
178 # This specific value for pool size was chosen somewhat arbitrarily
179 # -- there has not been any formal testing done to profile database
180 # concurrency. The value chosen may be somewhat lower than is
181 # optimal for service use cases. Some considerations:
182 #
183 # 1. Connections are only created as they are needed, so in typical
184 # single-threaded Butler use only one connection will ever be
185 # created. Services with low peak concurrency may never create
186 # this many connections.
187 # 2. Most services using the Butler (including Butler
188 # server) are using FastAPI, which uses a thread pool of 40 by
189 # default. So when running at max concurrency we may have:
190 # * 10 connections checked out from the pool
191 # * 10 "overflow" connections re-created each time they are
192 # used.
193 # * 20 threads queued up, waiting for a connection, and
194 # potentially timing out if the other threads don't release
195 # their connections in a timely manner.
196 # 3. The main Butler databases at SLAC are run behind pgbouncer,
197 # so we can support a larger number of simultaneous connections
198 # than if we were connecting directly to Postgres.
199 #
200 # See
201 # https://docs.sqlalchemy.org/en/20/core/pooling.html#sqlalchemy.pool.QueuePool.__init__
202 # for more information on the behavior of this parameter.
203 pool_size=10,
204 # In combination with pool_pre_ping, prevent SQLAlchemy from
205 # unnecessarily reviving pooled connections that have gone stale.
206 # Setting this to true makes it always re-use the most recent
207 # known-good connection when possible, instead of cycling to other
208 # connections in the pool that we may no longer need.
209 pool_use_lifo=True,
210 )
212 @classmethod
213 def fromEngine(
214 cls,
215 engine: sqlalchemy.engine.Engine,
216 *,
217 origin: int,
218 namespace: str | None = None,
219 writeable: bool = True,
220 ) -> Database:
221 return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable)
223 @contextmanager
224 def _transaction(
225 self,
226 *,
227 interrupting: bool = False,
228 savepoint: bool = False,
229 lock: Iterable[sqlalchemy.schema.Table] = (),
230 for_temp_tables: bool = False,
231 ) -> Iterator[tuple[bool, sqlalchemy.engine.Connection]]:
232 with super()._transaction(interrupting=interrupting, savepoint=savepoint, lock=lock) as (
233 is_new,
234 connection,
235 ):
236 if is_new:
237 # pgbouncer with transaction-level pooling (which we aim to
238 # support) says that SET cannot be used, except for a list of
239 # "Startup parameters" that includes "timezone" (see
240 # https://www.pgbouncer.org/features.html#fnref:0). But I
241 # don't see "timezone" in PostgreSQL's list of parameters
242 # passed when creating a new connection
243 # (https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS).
244 # Given that the pgbouncer docs say, "PgBouncer detects their
245 # changes and so it can guarantee they remain consistent for
246 # the client", I assume we can use "SET TIMESPAN" and pgbouncer
247 # will take care of clients that share connections being set
248 # consistently. And if that assumption is wrong, we should
249 # still probably be okay, since all clients should be Butler
250 # clients, and they'll all be setting the same thing.
251 #
252 # The "SET TRANSACTION READ ONLY" should also be safe, because
253 # it only ever acts on the current transaction; I think it's
254 # not included in pgbouncer's declaration that SET is
255 # incompatible with transaction-level pooling because
256 # PostgreSQL actually considers SET TRANSACTION to be a
257 # fundamentally different statement from SET (they have their
258 # own distinct doc pages, at least).
259 if not (self.isWriteable() or for_temp_tables):
260 # PostgreSQL permits writing to temporary tables inside
261 # read-only transactions, but it doesn't permit creating
262 # them.
263 with closing(connection.connection.cursor()) as cursor:
264 cursor.execute("SET TRANSACTION READ ONLY")
265 cursor.execute("SET TIME ZONE 0")
266 else:
267 with closing(connection.connection.cursor()) as cursor:
268 # Make timestamps UTC, because we didn't use TIMESTAMPZ
269 # for the column type. When we can tolerate a schema
270 # change, we should change that type and remove this
271 # line.
272 cursor.execute("SET TIME ZONE 0")
273 yield is_new, connection
275 @contextmanager
276 def temporary_table(
277 self, spec: ddl.TableSpec, name: str | None = None
278 ) -> Iterator[sqlalchemy.schema.Table]:
279 # Docstring inherited.
280 with self.transaction(for_temp_tables=True), super().temporary_table(spec, name) as table:
281 yield table
283 def _lockTables(
284 self, connection: sqlalchemy.engine.Connection, tables: Iterable[sqlalchemy.schema.Table] = ()
285 ) -> None:
286 # Docstring inherited.
287 for table in tables:
288 connection.execute(sqlalchemy.text(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE"))
290 def isWriteable(self) -> bool:
291 return self._writeable
293 def __str__(self) -> str:
294 return f"PostgreSQL@{self.dbname}:{self.namespace}"
296 def shrinkDatabaseEntityName(self, original: str) -> str:
297 return self._shrinker.shrink(original)
299 def expandDatabaseEntityName(self, shrunk: str) -> str:
300 return self._shrinker.expand(shrunk)
302 def _convertExclusionConstraintSpec(
303 self,
304 table: str,
305 spec: tuple[str | type[TimespanDatabaseRepresentation], ...],
306 metadata: sqlalchemy.MetaData,
307 ) -> sqlalchemy.schema.Constraint:
308 # Docstring inherited.
309 args: list[tuple[sqlalchemy.schema.Column, str]] = []
310 names = ["excl"]
311 for item in spec:
312 if isinstance(item, str):
313 args.append((sqlalchemy.schema.Column(item), "="))
314 names.append(item)
315 elif issubclass(item, TimespanDatabaseRepresentation):
316 assert item is self.getTimespanRepresentation()
317 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&"))
318 names.append(TimespanDatabaseRepresentation.NAME)
319 return sqlalchemy.dialects.postgresql.ExcludeConstraint(
320 *args,
321 name=self.shrinkDatabaseEntityName("_".join(names)),
322 )
324 def _make_temporary_table(
325 self,
326 connection: sqlalchemy.engine.Connection,
327 spec: ddl.TableSpec,
328 name: str | None = None,
329 **kwargs: Any,
330 ) -> sqlalchemy.schema.Table:
331 # Docstring inherited
332 # Adding ON COMMIT DROP here is really quite defensive: we already
333 # manually drop the table at the end of the temporary_table context
334 # manager, and that will usually happen first. But this will guarantee
335 # that we drop the table at the end of the transaction even if the
336 # connection lasts longer, and that's good citizenship when connections
337 # may be multiplexed by e.g. pgbouncer.
338 return super()._make_temporary_table(connection, spec, name, postgresql_on_commit="DROP", **kwargs)
340 @classmethod
341 def getTimespanRepresentation(cls) -> type[TimespanDatabaseRepresentation]:
342 # Docstring inherited.
343 return _RangeTimespanRepresentation
345 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None:
346 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.")
347 if not rows:
348 return
349 # This uses special support for UPSERT in PostgreSQL backend:
350 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert
351 query = sqlalchemy.dialects.postgresql.dml.insert(table)
352 # In the SET clause assign all columns using special `excluded`
353 # pseudo-table. If some column in the table does not appear in the
354 # INSERT list this will set it to NULL.
355 excluded = query.excluded
356 data = {
357 column.name: getattr(excluded, column.name)
358 for column in table.columns
359 if column.name not in table.primary_key
360 }
361 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data)
362 with self._transaction() as (_, connection):
363 connection.execute(query, rows)
365 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict, primary_key_only: bool = False) -> int:
366 # Docstring inherited.
367 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.")
368 if not rows:
369 return 0
370 # Like `replace`, this uses UPSERT.
371 base_insert = sqlalchemy.dialects.postgresql.dml.insert(table)
372 if primary_key_only:
373 query = base_insert.on_conflict_do_nothing(constraint=table.primary_key)
374 else:
375 query = base_insert.on_conflict_do_nothing()
376 with self._transaction() as (_, connection):
377 return connection.execute(query, rows).rowcount
379 def constant_rows(
380 self,
381 fields: NamedValueAbstractSet[ddl.FieldSpec],
382 *rows: dict,
383 name: str | None = None,
384 ) -> sqlalchemy.sql.FromClause:
385 # Docstring inherited.
386 return super().constant_rows(fields, *rows, name=name)
388 @property
389 def has_distinct_on(self) -> bool:
390 # Docstring inherited.
391 return True
393 @property
394 def has_any_aggregate(self) -> bool:
395 # Docstring inherited.
396 return self._pg_version >= (16, 0)
398 def apply_any_aggregate(self, column: sqlalchemy.ColumnElement[Any]) -> sqlalchemy.ColumnElement[Any]:
399 # Docstring inherited.x
400 return sqlalchemy.func.any_value(column)
403class _RangeTimespanType(sqlalchemy.TypeDecorator):
404 """A single-column `Timespan` representation usable only with
405 PostgreSQL.
407 This type should be able to take advantage of PostgreSQL's built-in
408 range operators, and the indexing and EXCLUSION table constraints built
409 off of them.
410 """
412 impl = sqlalchemy.dialects.postgresql.INT8RANGE
414 cache_ok = True
416 def process_bind_param(
417 self, value: Timespan | None, dialect: sqlalchemy.engine.Dialect
418 ) -> psycopg2.extras.NumericRange | None:
419 if value is None:
420 return None
421 if not isinstance(value, Timespan):
422 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.")
423 if value.isEmpty():
424 return psycopg2.extras.NumericRange(empty=True)
425 else:
426 converter = time_utils.TimeConverter()
427 assert value.nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__."
428 assert value.nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__."
429 lower = None if value.nsec[0] == converter.min_nsec else value.nsec[0]
430 upper = None if value.nsec[1] == converter.max_nsec else value.nsec[1]
431 return psycopg2.extras.NumericRange(lower=lower, upper=upper)
433 def process_result_value(
434 self, value: psycopg2.extras.NumericRange | None, dialect: sqlalchemy.engine.Dialect
435 ) -> Timespan | None:
436 if value is None:
437 return None
438 if value.isempty:
439 return Timespan.makeEmpty()
440 converter = time_utils.TimeConverter()
441 begin_nsec = converter.min_nsec if value.lower is None else value.lower
442 end_nsec = converter.max_nsec if value.upper is None else value.upper
443 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec))
446class _RangeTimespanRepresentation(TimespanDatabaseRepresentation):
447 """An implementation of `TimespanDatabaseRepresentation` that uses
448 `_RangeTimespanType` to store a timespan in a single
449 PostgreSQL-specific field.
451 Parameters
452 ----------
453 column : `sqlalchemy.sql.ColumnElement`
454 SQLAlchemy object representing the column.
455 """
457 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str):
458 self.column = column
459 self._name = name
461 __slots__ = ("column", "_name")
463 @classmethod
464 def makeFieldSpecs(
465 cls, nullable: bool, name: str | None = None, **kwargs: Any
466 ) -> tuple[ddl.FieldSpec, ...]:
467 # Docstring inherited.
468 if name is None:
469 name = cls.NAME
470 return (
471 ddl.FieldSpec(
472 name,
473 dtype=_RangeTimespanType,
474 nullable=nullable,
475 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")),
476 **kwargs,
477 ),
478 )
480 @classmethod
481 def getFieldNames(cls, name: str | None = None) -> tuple[str, ...]:
482 # Docstring inherited.
483 if name is None:
484 name = cls.NAME
485 return (name,)
487 @classmethod
488 def update(
489 cls, extent: Timespan | None, name: str | None = None, result: dict[str, Any] | None = None
490 ) -> dict[str, Any]:
491 # Docstring inherited.
492 if name is None:
493 name = cls.NAME
494 if result is None:
495 result = {}
496 result[name] = extent
497 return result
499 @classmethod
500 def extract(cls, mapping: Mapping[str, Any], name: str | None = None) -> Timespan | None:
501 # Docstring inherited.
502 if name is None:
503 name = cls.NAME
504 return mapping[name]
506 @classmethod
507 def fromLiteral(cls, timespan: Timespan | None) -> _RangeTimespanRepresentation:
508 # Docstring inherited.
509 if timespan is None:
510 return cls(column=sqlalchemy.sql.null(), name=cls.NAME)
511 return cls(
512 column=sqlalchemy.sql.cast(
513 sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType), type_=_RangeTimespanType
514 ),
515 name=cls.NAME,
516 )
518 @classmethod
519 def from_columns(
520 cls, columns: sqlalchemy.sql.ColumnCollection, name: str | None = None
521 ) -> _RangeTimespanRepresentation:
522 # Docstring inherited.
523 if name is None:
524 name = cls.NAME
525 return cls(columns[name], name)
527 @property
528 def name(self) -> str:
529 # Docstring inherited.
530 return self._name
532 def isNull(self) -> sqlalchemy.sql.ColumnElement:
533 # Docstring inherited.
534 return self.column.is_(None)
536 def isEmpty(self) -> sqlalchemy.sql.ColumnElement:
537 # Docstring inherited
538 return sqlalchemy.sql.func.isempty(self.column)
540 def __lt__(
541 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement
542 ) -> sqlalchemy.sql.ColumnElement:
543 # Docstring inherited.
544 if isinstance(other, sqlalchemy.sql.ColumnElement):
545 return sqlalchemy.sql.and_(
546 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)),
547 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)),
548 sqlalchemy.sql.func.upper(self.column) <= other,
549 )
550 else:
551 return self.column << other.column
553 def __gt__(
554 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement
555 ) -> sqlalchemy.sql.ColumnElement:
556 # Docstring inherited.
557 if isinstance(other, sqlalchemy.sql.ColumnElement):
558 return sqlalchemy.sql.and_(
559 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)),
560 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)),
561 sqlalchemy.sql.func.lower(self.column) > other,
562 )
563 else:
564 return self.column >> other.column
566 def overlaps(
567 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement
568 ) -> sqlalchemy.sql.ColumnElement:
569 # Docstring inherited.
570 if not isinstance(other, _RangeTimespanRepresentation):
571 return self.contains(other)
572 return self.column.overlaps(other.column)
574 def contains(
575 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement
576 ) -> sqlalchemy.sql.ColumnElement:
577 # Docstring inherited
578 if isinstance(other, _RangeTimespanRepresentation):
579 return self.column.contains(other.column)
580 else:
581 return self.column.contains(other)
583 def lower(self) -> sqlalchemy.sql.ColumnElement:
584 # Docstring inherited.
585 return sqlalchemy.sql.functions.coalesce(
586 sqlalchemy.sql.func.lower(self.column), sqlalchemy.sql.literal(0)
587 )
589 def upper(self) -> sqlalchemy.sql.ColumnElement:
590 # Docstring inherited.
591 return sqlalchemy.sql.functions.coalesce(
592 sqlalchemy.sql.func.upper(self.column), sqlalchemy.sql.literal(0)
593 )
595 def flatten(self, name: str | None = None) -> tuple[sqlalchemy.sql.ColumnElement]:
596 # Docstring inherited.
597 if name is None:
598 return (self.column,)
599 else:
600 return (self.column.label(name),)
602 def apply_any_aggregate(
603 self, func: Callable[[ColumnElement[Any]], ColumnElement[Any]]
604 ) -> TimespanDatabaseRepresentation:
605 # Docstring inherited.
606 return _RangeTimespanRepresentation(func(self.column), self.name)