Coverage for python/lsst/daf/butler/registry/databases/postgresql.py: 24%
198 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-23 09:29 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-23 09:29 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ["PostgresqlDatabase"]
25from collections.abc import Iterable, Iterator, Mapping
26from contextlib import closing, contextmanager
27from typing import Any
29import psycopg2
30import sqlalchemy
31import sqlalchemy.dialects.postgresql
32from sqlalchemy import sql
34from ...core import Timespan, TimespanDatabaseRepresentation, ddl, time_utils
35from ...core.named import NamedValueAbstractSet
36from ..interfaces import Database
37from ..nameShrinker import NameShrinker
40class PostgresqlDatabase(Database):
41 """An implementation of the `Database` interface for PostgreSQL.
43 Parameters
44 ----------
45 connection : `sqlalchemy.engine.Connection`
46 An existing connection created by a previous call to `connect`.
47 origin : `int`
48 An integer ID that should be used as the default for any datasets,
49 quanta, or other entities that use a (autoincrement, origin) compound
50 primary key.
51 namespace : `str`, optional
52 The namespace (schema) this database is associated with. If `None`,
53 the default schema for the connection is used (which may be `None`).
54 writeable : `bool`, optional
55 If `True`, allow write operations on the database, including
56 ``CREATE TABLE``.
58 Notes
59 -----
60 This currently requires the psycopg2 driver to be used as the backend for
61 SQLAlchemy. Running the tests for this class requires the
62 ``testing.postgresql`` be installed, which we assume indicates that a
63 PostgreSQL server is installed and can be run locally in userspace.
65 Some functionality provided by this class (and used by `Registry`) requires
66 the ``btree_gist`` PostgreSQL server extension to be installed an enabled
67 on the database being connected to; this is checked at connection time.
68 """
70 def __init__(
71 self,
72 *,
73 engine: sqlalchemy.engine.Engine,
74 origin: int,
75 namespace: str | None = None,
76 writeable: bool = True,
77 ):
78 super().__init__(origin=origin, engine=engine, namespace=namespace)
79 with engine.connect() as connection:
80 # `Any` to make mypy ignore the line below, can't use type: ignore
81 dbapi: Any = connection.connection
82 try:
83 dsn = dbapi.get_dsn_parameters()
84 except (AttributeError, KeyError) as err:
85 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err
86 if namespace is None:
87 query = sql.select(sql.func.current_schema())
88 namespace = connection.execute(query).scalar()
89 query_text = "SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';"
90 if not connection.execute(sqlalchemy.text(query_text)).scalar():
91 raise RuntimeError(
92 "The Butler PostgreSQL backend requires the btree_gist extension. "
93 "As extensions are enabled per-database, this may require an administrator to run "
94 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is "
95 " initialized."
96 )
97 self.namespace = namespace
98 self.dbname = dsn.get("dbname")
99 self._writeable = writeable
100 self._shrinker = NameShrinker(self.dialect.max_identifier_length)
102 @classmethod
103 def makeEngine(
104 cls, uri: str | sqlalchemy.engine.URL, *, writeable: bool = True
105 ) -> sqlalchemy.engine.Engine:
106 return sqlalchemy.engine.create_engine(uri, pool_size=1)
108 @classmethod
109 def fromEngine(
110 cls,
111 engine: sqlalchemy.engine.Engine,
112 *,
113 origin: int,
114 namespace: str | None = None,
115 writeable: bool = True,
116 ) -> Database:
117 return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable)
119 @contextmanager
120 def _transaction(
121 self,
122 *,
123 interrupting: bool = False,
124 savepoint: bool = False,
125 lock: Iterable[sqlalchemy.schema.Table] = (),
126 for_temp_tables: bool = False,
127 ) -> Iterator[tuple[bool, sqlalchemy.engine.Connection]]:
128 with super()._transaction(interrupting=interrupting, savepoint=savepoint, lock=lock) as (
129 is_new,
130 connection,
131 ):
132 if is_new:
133 # pgbouncer with transaction-level pooling (which we aim to
134 # support) says that SET cannot be used, except for a list of
135 # "Startup parameters" that includes "timezone" (see
136 # https://www.pgbouncer.org/features.html#fnref:0). But I
137 # don't see "timezone" in PostgreSQL's list of parameters
138 # passed when creating a new connection
139 # (https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS).
140 # Given that the pgbouncer docs say, "PgBouncer detects their
141 # changes and so it can guarantee they remain consistent for
142 # the client", I assume we can use "SET TIMESPAN" and pgbouncer
143 # will take care of clients that share connections being set
144 # consistently. And if that assumption is wrong, we should
145 # still probably be okay, since all clients should be Butler
146 # clients, and they'll all be setting the same thing.
147 #
148 # The "SET TRANSACTION READ ONLY" should also be safe, because
149 # it only ever acts on the current transaction; I think it's
150 # not included in pgbouncer's declaration that SET is
151 # incompatible with transaction-level pooling because
152 # PostgreSQL actually considers SET TRANSACTION to be a
153 # fundamentally different statement from SET (they have their
154 # own distinct doc pages, at least).
155 if not (self.isWriteable() or for_temp_tables):
156 # PostgreSQL permits writing to temporary tables inside
157 # read-only transactions, but it doesn't permit creating
158 # them.
159 with closing(connection.connection.cursor()) as cursor:
160 cursor.execute("SET TRANSACTION READ ONLY")
161 cursor.execute("SET TIME ZONE 0")
162 else:
163 with closing(connection.connection.cursor()) as cursor:
164 # Make timestamps UTC, because we didn't use TIMESTAMPZ
165 # for the column type. When we can tolerate a schema
166 # change, we should change that type and remove this
167 # line.
168 cursor.execute("SET TIME ZONE 0")
169 yield is_new, connection
171 @contextmanager
172 def temporary_table(
173 self, spec: ddl.TableSpec, name: str | None = None
174 ) -> Iterator[sqlalchemy.schema.Table]:
175 # Docstring inherited.
176 with self.transaction(for_temp_tables=True):
177 with super().temporary_table(spec, name) as table:
178 yield table
180 def _lockTables(
181 self, connection: sqlalchemy.engine.Connection, tables: Iterable[sqlalchemy.schema.Table] = ()
182 ) -> None:
183 # Docstring inherited.
184 for table in tables:
185 connection.execute(sqlalchemy.text(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE"))
187 def isWriteable(self) -> bool:
188 return self._writeable
190 def __str__(self) -> str:
191 return f"PostgreSQL@{self.dbname}:{self.namespace}"
193 def shrinkDatabaseEntityName(self, original: str) -> str:
194 return self._shrinker.shrink(original)
196 def expandDatabaseEntityName(self, shrunk: str) -> str:
197 return self._shrinker.expand(shrunk)
199 def _convertExclusionConstraintSpec(
200 self,
201 table: str,
202 spec: tuple[str | type[TimespanDatabaseRepresentation], ...],
203 metadata: sqlalchemy.MetaData,
204 ) -> sqlalchemy.schema.Constraint:
205 # Docstring inherited.
206 args: list[tuple[sqlalchemy.schema.Column, str]] = []
207 names = ["excl"]
208 for item in spec:
209 if isinstance(item, str):
210 args.append((sqlalchemy.schema.Column(item), "="))
211 names.append(item)
212 elif issubclass(item, TimespanDatabaseRepresentation):
213 assert item is self.getTimespanRepresentation()
214 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&"))
215 names.append(TimespanDatabaseRepresentation.NAME)
216 return sqlalchemy.dialects.postgresql.ExcludeConstraint(
217 *args,
218 name=self.shrinkDatabaseEntityName("_".join(names)),
219 )
221 def _make_temporary_table(
222 self,
223 connection: sqlalchemy.engine.Connection,
224 spec: ddl.TableSpec,
225 name: str | None = None,
226 **kwargs: Any,
227 ) -> sqlalchemy.schema.Table:
228 # Docstring inherited
229 # Adding ON COMMIT DROP here is really quite defensive: we already
230 # manually drop the table at the end of the temporary_table context
231 # manager, and that will usually happen first. But this will guarantee
232 # that we drop the table at the end of the transaction even if the
233 # connection lasts longer, and that's good citizenship when connections
234 # may be multiplexed by e.g. pgbouncer.
235 return super()._make_temporary_table(connection, spec, name, postgresql_on_commit="DROP", **kwargs)
237 @classmethod
238 def getTimespanRepresentation(cls) -> type[TimespanDatabaseRepresentation]:
239 # Docstring inherited.
240 return _RangeTimespanRepresentation
242 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None:
243 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.")
244 if not rows:
245 return
246 # This uses special support for UPSERT in PostgreSQL backend:
247 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert
248 query = sqlalchemy.dialects.postgresql.dml.insert(table)
249 # In the SET clause assign all columns using special `excluded`
250 # pseudo-table. If some column in the table does not appear in the
251 # INSERT list this will set it to NULL.
252 excluded = query.excluded
253 data = {
254 column.name: getattr(excluded, column.name)
255 for column in table.columns
256 if column.name not in table.primary_key
257 }
258 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data)
259 with self._transaction() as (_, connection):
260 connection.execute(query, rows)
262 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict, primary_key_only: bool = False) -> int:
263 # Docstring inherited.
264 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.")
265 if not rows:
266 return 0
267 # Like `replace`, this uses UPSERT.
268 base_insert = sqlalchemy.dialects.postgresql.dml.insert(table)
269 if primary_key_only:
270 query = base_insert.on_conflict_do_nothing(constraint=table.primary_key)
271 else:
272 query = base_insert.on_conflict_do_nothing()
273 with self._transaction() as (_, connection):
274 return connection.execute(query, rows).rowcount
276 def constant_rows(
277 self,
278 fields: NamedValueAbstractSet[ddl.FieldSpec],
279 *rows: dict,
280 name: str | None = None,
281 ) -> sqlalchemy.sql.FromClause:
282 # Docstring inherited.
283 return super().constant_rows(fields, *rows, name=name)
286class _RangeTimespanType(sqlalchemy.TypeDecorator):
287 """A single-column `Timespan` representation usable only with
288 PostgreSQL.
290 This type should be able to take advantage of PostgreSQL's built-in
291 range operators, and the indexing and EXCLUSION table constraints built
292 off of them.
293 """
295 impl = sqlalchemy.dialects.postgresql.INT8RANGE
297 cache_ok = True
299 def process_bind_param(
300 self, value: Timespan | None, dialect: sqlalchemy.engine.Dialect
301 ) -> psycopg2.extras.NumericRange | None:
302 if value is None:
303 return None
304 if not isinstance(value, Timespan):
305 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.")
306 if value.isEmpty():
307 return psycopg2.extras.NumericRange(empty=True)
308 else:
309 converter = time_utils.TimeConverter()
310 assert value._nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__."
311 assert value._nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__."
312 lower = None if value._nsec[0] == converter.min_nsec else value._nsec[0]
313 upper = None if value._nsec[1] == converter.max_nsec else value._nsec[1]
314 return psycopg2.extras.NumericRange(lower=lower, upper=upper)
316 def process_result_value(
317 self, value: psycopg2.extras.NumericRange | None, dialect: sqlalchemy.engine.Dialect
318 ) -> Timespan | None:
319 if value is None:
320 return None
321 if value.isempty:
322 return Timespan.makeEmpty()
323 converter = time_utils.TimeConverter()
324 begin_nsec = converter.min_nsec if value.lower is None else value.lower
325 end_nsec = converter.max_nsec if value.upper is None else value.upper
326 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec))
329class _RangeTimespanRepresentation(TimespanDatabaseRepresentation):
330 """An implementation of `TimespanDatabaseRepresentation` that uses
331 `_RangeTimespanType` to store a timespan in a single
332 PostgreSQL-specific field.
334 Parameters
335 ----------
336 column : `sqlalchemy.sql.ColumnElement`
337 SQLAlchemy object representing the column.
338 """
340 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str):
341 self.column = column
342 self._name = name
344 __slots__ = ("column", "_name")
346 @classmethod
347 def makeFieldSpecs(
348 cls, nullable: bool, name: str | None = None, **kwargs: Any
349 ) -> tuple[ddl.FieldSpec, ...]:
350 # Docstring inherited.
351 if name is None:
352 name = cls.NAME
353 return (
354 ddl.FieldSpec(
355 name,
356 dtype=_RangeTimespanType,
357 nullable=nullable,
358 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")),
359 **kwargs,
360 ),
361 )
363 @classmethod
364 def getFieldNames(cls, name: str | None = None) -> tuple[str, ...]:
365 # Docstring inherited.
366 if name is None:
367 name = cls.NAME
368 return (name,)
370 @classmethod
371 def update(
372 cls, extent: Timespan | None, name: str | None = None, result: dict[str, Any] | None = None
373 ) -> dict[str, Any]:
374 # Docstring inherited.
375 if name is None:
376 name = cls.NAME
377 if result is None:
378 result = {}
379 result[name] = extent
380 return result
382 @classmethod
383 def extract(cls, mapping: Mapping[str, Any], name: str | None = None) -> Timespan | None:
384 # Docstring inherited.
385 if name is None:
386 name = cls.NAME
387 return mapping[name]
389 @classmethod
390 def fromLiteral(cls, timespan: Timespan | None) -> _RangeTimespanRepresentation:
391 # Docstring inherited.
392 if timespan is None:
393 return cls(column=sqlalchemy.sql.null(), name=cls.NAME)
394 return cls(
395 column=sqlalchemy.sql.cast(
396 sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType), type_=_RangeTimespanType
397 ),
398 name=cls.NAME,
399 )
401 @classmethod
402 def from_columns(
403 cls, columns: sqlalchemy.sql.ColumnCollection, name: str | None = None
404 ) -> _RangeTimespanRepresentation:
405 # Docstring inherited.
406 if name is None:
407 name = cls.NAME
408 return cls(columns[name], name)
410 @property
411 def name(self) -> str:
412 # Docstring inherited.
413 return self._name
415 def isNull(self) -> sqlalchemy.sql.ColumnElement:
416 # Docstring inherited.
417 return self.column.is_(None)
419 def isEmpty(self) -> sqlalchemy.sql.ColumnElement:
420 # Docstring inherited
421 return sqlalchemy.sql.func.isempty(self.column)
423 def __lt__(
424 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement
425 ) -> sqlalchemy.sql.ColumnElement:
426 # Docstring inherited.
427 if isinstance(other, sqlalchemy.sql.ColumnElement):
428 return sqlalchemy.sql.and_(
429 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)),
430 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)),
431 sqlalchemy.sql.func.upper(self.column) <= other,
432 )
433 else:
434 return self.column << other.column
436 def __gt__(
437 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement
438 ) -> sqlalchemy.sql.ColumnElement:
439 # Docstring inherited.
440 if isinstance(other, sqlalchemy.sql.ColumnElement):
441 return sqlalchemy.sql.and_(
442 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)),
443 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)),
444 sqlalchemy.sql.func.lower(self.column) > other,
445 )
446 else:
447 return self.column >> other.column
449 def overlaps(
450 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement
451 ) -> sqlalchemy.sql.ColumnElement:
452 # Docstring inherited.
453 if not isinstance(other, _RangeTimespanRepresentation):
454 return self.contains(other)
455 return self.column.overlaps(other.column)
457 def contains(
458 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement
459 ) -> sqlalchemy.sql.ColumnElement:
460 # Docstring inherited
461 if isinstance(other, _RangeTimespanRepresentation):
462 return self.column.contains(other.column)
463 else:
464 return self.column.contains(other)
466 def lower(self) -> sqlalchemy.sql.ColumnElement:
467 # Docstring inherited.
468 return sqlalchemy.sql.functions.coalesce(
469 sqlalchemy.sql.func.lower(self.column), sqlalchemy.sql.literal(0)
470 )
472 def upper(self) -> sqlalchemy.sql.ColumnElement:
473 # Docstring inherited.
474 return sqlalchemy.sql.functions.coalesce(
475 sqlalchemy.sql.func.upper(self.column), sqlalchemy.sql.literal(0)
476 )
478 def flatten(self, name: str | None = None) -> tuple[sqlalchemy.sql.ColumnElement]:
479 # Docstring inherited.
480 if name is None:
481 return (self.column,)
482 else:
483 return (self.column.label(name),)