Coverage for python/lsst/daf/butler/registry/databases/postgresql.py: 24%
195 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-26 15:13 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-26 15:13 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ["PostgresqlDatabase"]
25from contextlib import closing, contextmanager
26from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, Type, Union
28import psycopg2
29import sqlalchemy
30import sqlalchemy.dialects.postgresql
32from ...core import Timespan, TimespanDatabaseRepresentation, ddl, time_utils
33from ...core.named import NamedValueAbstractSet
34from ..interfaces import Database
35from ..nameShrinker import NameShrinker
38class PostgresqlDatabase(Database):
39 """An implementation of the `Database` interface for PostgreSQL.
41 Parameters
42 ----------
43 connection : `sqlalchemy.engine.Connection`
44 An existing connection created by a previous call to `connect`.
45 origin : `int`
46 An integer ID that should be used as the default for any datasets,
47 quanta, or other entities that use a (autoincrement, origin) compound
48 primary key.
49 namespace : `str`, optional
50 The namespace (schema) this database is associated with. If `None`,
51 the default schema for the connection is used (which may be `None`).
52 writeable : `bool`, optional
53 If `True`, allow write operations on the database, including
54 ``CREATE TABLE``.
56 Notes
57 -----
58 This currently requires the psycopg2 driver to be used as the backend for
59 SQLAlchemy. Running the tests for this class requires the
60 ``testing.postgresql`` be installed, which we assume indicates that a
61 PostgreSQL server is installed and can be run locally in userspace.
63 Some functionality provided by this class (and used by `Registry`) requires
64 the ``btree_gist`` PostgreSQL server extension to be installed an enabled
65 on the database being connected to; this is checked at connection time.
66 """
68 def __init__(
69 self,
70 *,
71 engine: sqlalchemy.engine.Engine,
72 origin: int,
73 namespace: Optional[str] = None,
74 writeable: bool = True,
75 ):
76 super().__init__(origin=origin, engine=engine, namespace=namespace)
77 with engine.connect() as connection:
78 dbapi = connection.connection
79 try:
80 dsn = dbapi.get_dsn_parameters()
81 except (AttributeError, KeyError) as err:
82 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err
83 if namespace is None:
84 namespace = connection.execute("SELECT current_schema();").scalar()
85 query = "SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';"
86 if not connection.execute(sqlalchemy.text(query)).scalar():
87 raise RuntimeError(
88 "The Butler PostgreSQL backend requires the btree_gist extension. "
89 "As extensions are enabled per-database, this may require an administrator to run "
90 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is "
91 " initialized."
92 )
93 self.namespace = namespace
94 self.dbname = dsn.get("dbname")
95 self._writeable = writeable
96 self._shrinker = NameShrinker(connection.engine.dialect.max_identifier_length)
98 @classmethod
99 def makeEngine(cls, uri: str, *, writeable: bool = True) -> sqlalchemy.engine.Engine:
100 return sqlalchemy.engine.create_engine(uri, pool_size=1)
102 @classmethod
103 def fromEngine(
104 cls,
105 engine: sqlalchemy.engine.Engine,
106 *,
107 origin: int,
108 namespace: Optional[str] = None,
109 writeable: bool = True,
110 ) -> Database:
111 return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable)
113 @contextmanager
114 def _transaction(
115 self,
116 *,
117 interrupting: bool = False,
118 savepoint: bool = False,
119 lock: Iterable[sqlalchemy.schema.Table] = (),
120 for_temp_tables: bool = False,
121 ) -> Iterator[tuple[bool, sqlalchemy.engine.Connection]]:
122 with super()._transaction(interrupting=interrupting, savepoint=savepoint, lock=lock) as (
123 is_new,
124 connection,
125 ):
126 if is_new:
127 # pgbouncer with transaction-level pooling (which we aim to
128 # support) says that SET cannot be used, except for a list of
129 # "Startup parameters" that includes "timezone" (see
130 # https://www.pgbouncer.org/features.html#fnref:0). But I
131 # don't see "timezone" in PostgreSQL's list of parameters
132 # passed when creating a new connection
133 # (https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS).
134 # Given that the pgbouncer docs say, "PgBouncer detects their
135 # changes and so it can guarantee they remain consistent for
136 # the client", I assume we can use "SET TIMESPAN" and pgbouncer
137 # will take care of clients that share connections being set
138 # consistently. And if that assumption is wrong, we should
139 # still probably be okay, since all clients should be Butler
140 # clients, and they'll all be setting the same thing.
141 #
142 # The "SET TRANSACTION READ ONLY" should also be safe, because
143 # it only ever acts on the current transaction; I think it's
144 # not included in pgbouncer's declaration that SET is
145 # incompatible with transaction-level pooling because
146 # PostgreSQL actually considers SET TRANSACTION to be a
147 # fundamentally different statement from SET (they have their
148 # own distinct doc pages, at least).
149 if not (self.isWriteable() or for_temp_tables):
150 # PostgreSQL permits writing to temporary tables inside
151 # read-only transactions, but it doesn't permit creating
152 # them.
153 with closing(connection.connection.cursor()) as cursor:
154 cursor.execute("SET TRANSACTION READ ONLY")
155 cursor.execute("SET TIME ZONE 0")
156 else:
157 with closing(connection.connection.cursor()) as cursor:
158 # Make timestamps UTC, because we didn't use TIMESTAMPZ
159 # for the column type. When we can tolerate a schema
160 # change, we should change that type and remove this
161 # line.
162 cursor.execute("SET TIME ZONE 0")
163 yield is_new, connection
165 @contextmanager
166 def temporary_table(
167 self, spec: ddl.TableSpec, name: Optional[str] = None
168 ) -> Iterator[sqlalchemy.schema.Table]:
169 # Docstring inherited.
170 with self.transaction(for_temp_tables=True):
171 with super().temporary_table(spec, name) as table:
172 yield table
174 def _lockTables(
175 self, connection: sqlalchemy.engine.Connection, tables: Iterable[sqlalchemy.schema.Table] = ()
176 ) -> None:
177 # Docstring inherited.
178 for table in tables:
179 connection.execute(sqlalchemy.text(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE"))
181 def isWriteable(self) -> bool:
182 return self._writeable
184 def __str__(self) -> str:
185 return f"PostgreSQL@{self.dbname}:{self.namespace}"
187 def shrinkDatabaseEntityName(self, original: str) -> str:
188 return self._shrinker.shrink(original)
190 def expandDatabaseEntityName(self, shrunk: str) -> str:
191 return self._shrinker.expand(shrunk)
193 def _convertExclusionConstraintSpec(
194 self,
195 table: str,
196 spec: Tuple[Union[str, Type[TimespanDatabaseRepresentation]], ...],
197 metadata: sqlalchemy.MetaData,
198 ) -> sqlalchemy.schema.Constraint:
199 # Docstring inherited.
200 args = []
201 names = ["excl"]
202 for item in spec:
203 if isinstance(item, str):
204 args.append((sqlalchemy.schema.Column(item), "="))
205 names.append(item)
206 elif issubclass(item, TimespanDatabaseRepresentation):
207 assert item is self.getTimespanRepresentation()
208 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&"))
209 names.append(TimespanDatabaseRepresentation.NAME)
210 return sqlalchemy.dialects.postgresql.ExcludeConstraint(
211 *args,
212 name=self.shrinkDatabaseEntityName("_".join(names)),
213 )
215 def _make_temporary_table(
216 self,
217 connection: sqlalchemy.engine.Connection,
218 spec: ddl.TableSpec,
219 name: Optional[str] = None,
220 **kwargs: Any,
221 ) -> sqlalchemy.schema.Table:
222 # Docstring inherited
223 # Adding ON COMMIT DROP here is really quite defensive: we already
224 # manually drop the table at the end of the temporary_table context
225 # manager, and that will usually happen first. But this will guarantee
226 # that we drop the table at the end of the transaction even if the
227 # connection lasts longer, and that's good citizenship when connections
228 # may be multiplexed by e.g. pgbouncer.
229 return super()._make_temporary_table(connection, spec, name, postgresql_on_commit="DROP", **kwargs)
231 @classmethod
232 def getTimespanRepresentation(cls) -> Type[TimespanDatabaseRepresentation]:
233 # Docstring inherited.
234 return _RangeTimespanRepresentation
236 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None:
237 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.")
238 if not rows:
239 return
240 # This uses special support for UPSERT in PostgreSQL backend:
241 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert
242 query = sqlalchemy.dialects.postgresql.dml.insert(table)
243 # In the SET clause assign all columns using special `excluded`
244 # pseudo-table. If some column in the table does not appear in the
245 # INSERT list this will set it to NULL.
246 excluded = query.excluded
247 data = {
248 column.name: getattr(excluded, column.name)
249 for column in table.columns
250 if column.name not in table.primary_key
251 }
252 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data)
253 with self._transaction() as (_, connection):
254 connection.execute(query, rows)
256 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict, primary_key_only: bool = False) -> int:
257 # Docstring inherited.
258 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.")
259 if not rows:
260 return 0
261 # Like `replace`, this uses UPSERT.
262 base_insert = sqlalchemy.dialects.postgresql.dml.insert(table)
263 if primary_key_only:
264 query = base_insert.on_conflict_do_nothing(constraint=table.primary_key)
265 else:
266 query = base_insert.on_conflict_do_nothing()
267 with self._transaction() as (_, connection):
268 return connection.execute(query, rows).rowcount
270 def constant_rows(
271 self,
272 fields: NamedValueAbstractSet[ddl.FieldSpec],
273 *rows: dict,
274 name: Optional[str] = None,
275 ) -> sqlalchemy.sql.FromClause:
276 # Docstring inherited.
277 return super().constant_rows(fields, *rows, name=name)
280class _RangeTimespanType(sqlalchemy.TypeDecorator):
281 """A single-column `Timespan` representation usable only with
282 PostgreSQL.
284 This type should be able to take advantage of PostgreSQL's built-in
285 range operators, and the indexing and EXCLUSION table constraints built
286 off of them.
287 """
289 impl = sqlalchemy.dialects.postgresql.INT8RANGE
291 cache_ok = True
293 def process_bind_param(
294 self, value: Optional[Timespan], dialect: sqlalchemy.engine.Dialect
295 ) -> Optional[psycopg2.extras.NumericRange]:
296 if value is None:
297 return None
298 if not isinstance(value, Timespan):
299 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.")
300 if value.isEmpty():
301 return psycopg2.extras.NumericRange(empty=True)
302 else:
303 converter = time_utils.TimeConverter()
304 assert value._nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__."
305 assert value._nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__."
306 lower = None if value._nsec[0] == converter.min_nsec else value._nsec[0]
307 upper = None if value._nsec[1] == converter.max_nsec else value._nsec[1]
308 return psycopg2.extras.NumericRange(lower=lower, upper=upper)
310 def process_result_value(
311 self, value: Optional[psycopg2.extras.NumericRange], dialect: sqlalchemy.engine.Dialect
312 ) -> Optional[Timespan]:
313 if value is None:
314 return None
315 if value.isempty:
316 return Timespan.makeEmpty()
317 converter = time_utils.TimeConverter()
318 begin_nsec = converter.min_nsec if value.lower is None else value.lower
319 end_nsec = converter.max_nsec if value.upper is None else value.upper
320 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec))
323class _RangeTimespanRepresentation(TimespanDatabaseRepresentation):
324 """An implementation of `TimespanDatabaseRepresentation` that uses
325 `_RangeTimespanType` to store a timespan in a single
326 PostgreSQL-specific field.
328 Parameters
329 ----------
330 column : `sqlalchemy.sql.ColumnElement`
331 SQLAlchemy object representing the column.
332 """
334 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str):
335 self.column = column
336 self._name = name
338 __slots__ = ("column", "_name")
340 @classmethod
341 def makeFieldSpecs(
342 cls, nullable: bool, name: Optional[str] = None, **kwargs: Any
343 ) -> Tuple[ddl.FieldSpec, ...]:
344 # Docstring inherited.
345 if name is None:
346 name = cls.NAME
347 return (
348 ddl.FieldSpec(
349 name,
350 dtype=_RangeTimespanType,
351 nullable=nullable,
352 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")),
353 **kwargs,
354 ),
355 )
357 @classmethod
358 def getFieldNames(cls, name: Optional[str] = None) -> Tuple[str, ...]:
359 # Docstring inherited.
360 if name is None:
361 name = cls.NAME
362 return (name,)
364 @classmethod
365 def update(
366 cls, extent: Optional[Timespan], name: Optional[str] = None, result: Optional[Dict[str, Any]] = None
367 ) -> Dict[str, Any]:
368 # Docstring inherited.
369 if name is None:
370 name = cls.NAME
371 if result is None:
372 result = {}
373 result[name] = extent
374 return result
376 @classmethod
377 def extract(cls, mapping: Mapping[str, Any], name: Optional[str] = None) -> Optional[Timespan]:
378 # Docstring inherited.
379 if name is None:
380 name = cls.NAME
381 return mapping[name]
383 @classmethod
384 def fromLiteral(cls, timespan: Optional[Timespan]) -> _RangeTimespanRepresentation:
385 # Docstring inherited.
386 if timespan is None:
387 return cls(column=sqlalchemy.sql.null(), name=cls.NAME)
388 return cls(
389 column=sqlalchemy.sql.cast(
390 sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType), type_=_RangeTimespanType
391 ),
392 name=cls.NAME,
393 )
395 @classmethod
396 def from_columns(
397 cls, columns: sqlalchemy.sql.ColumnCollection, name: Optional[str] = None
398 ) -> _RangeTimespanRepresentation:
399 # Docstring inherited.
400 if name is None:
401 name = cls.NAME
402 return cls(columns[name], name)
404 @property
405 def name(self) -> str:
406 # Docstring inherited.
407 return self._name
409 def isNull(self) -> sqlalchemy.sql.ColumnElement:
410 # Docstring inherited.
411 return self.column.is_(None)
413 def isEmpty(self) -> sqlalchemy.sql.ColumnElement:
414 # Docstring inherited
415 return sqlalchemy.sql.func.isempty(self.column)
417 def __lt__(
418 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement]
419 ) -> sqlalchemy.sql.ColumnElement:
420 # Docstring inherited.
421 if isinstance(other, sqlalchemy.sql.ColumnElement):
422 return sqlalchemy.sql.and_(
423 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)),
424 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)),
425 sqlalchemy.sql.func.upper(self.column) <= other,
426 )
427 else:
428 return self.column << other.column
430 def __gt__(
431 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement]
432 ) -> sqlalchemy.sql.ColumnElement:
433 # Docstring inherited.
434 if isinstance(other, sqlalchemy.sql.ColumnElement):
435 return sqlalchemy.sql.and_(
436 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)),
437 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)),
438 sqlalchemy.sql.func.lower(self.column) > other,
439 )
440 else:
441 return self.column >> other.column
443 def overlaps(
444 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement
445 ) -> sqlalchemy.sql.ColumnElement:
446 # Docstring inherited.
447 if not isinstance(other, _RangeTimespanRepresentation):
448 return self.contains(other)
449 return self.column.overlaps(other.column)
451 def contains(
452 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement]
453 ) -> sqlalchemy.sql.ColumnElement:
454 # Docstring inherited
455 if isinstance(other, _RangeTimespanRepresentation):
456 return self.column.contains(other.column)
457 else:
458 return self.column.contains(other)
460 def lower(self) -> sqlalchemy.sql.ColumnElement:
461 # Docstring inherited.
462 return sqlalchemy.sql.functions.coalesce(
463 sqlalchemy.sql.func.lower(self.column), sqlalchemy.sql.literal(0)
464 )
466 def upper(self) -> sqlalchemy.sql.ColumnElement:
467 # Docstring inherited.
468 return sqlalchemy.sql.functions.coalesce(
469 sqlalchemy.sql.func.upper(self.column), sqlalchemy.sql.literal(0)
470 )
472 def flatten(self, name: Optional[str] = None) -> tuple[sqlalchemy.sql.ColumnElement]:
473 # Docstring inherited.
474 if name is None:
475 return (self.column,)
476 else:
477 return (self.column.label(name),)