Coverage for python/lsst/daf/butler/registry/databases/postgresql.py: 26%
205 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-24 23:49 -0700
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-24 23:49 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ["PostgresqlDatabase"]
25from contextlib import closing, contextmanager
26from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, Type, Union
28import psycopg2
29import sqlalchemy
30import sqlalchemy.dialects.postgresql
32from ...core import Timespan, TimespanDatabaseRepresentation, ddl, time_utils
33from ..interfaces import Database
34from ..nameShrinker import NameShrinker
37class PostgresqlDatabase(Database):
38 """An implementation of the `Database` interface for PostgreSQL.
40 Parameters
41 ----------
42 connection : `sqlalchemy.engine.Connection`
43 An existing connection created by a previous call to `connect`.
44 origin : `int`
45 An integer ID that should be used as the default for any datasets,
46 quanta, or other entities that use a (autoincrement, origin) compound
47 primary key.
48 namespace : `str`, optional
49 The namespace (schema) this database is associated with. If `None`,
50 the default schema for the connection is used (which may be `None`).
51 writeable : `bool`, optional
52 If `True`, allow write operations on the database, including
53 ``CREATE TABLE``.
55 Notes
56 -----
57 This currently requires the psycopg2 driver to be used as the backend for
58 SQLAlchemy. Running the tests for this class requires the
59 ``testing.postgresql`` be installed, which we assume indicates that a
60 PostgreSQL server is installed and can be run locally in userspace.
62 Some functionality provided by this class (and used by `Registry`) requires
63 the ``btree_gist`` PostgreSQL server extension to be installed an enabled
64 on the database being connected to; this is checked at connection time.
65 """
67 def __init__(
68 self,
69 *,
70 engine: sqlalchemy.engine.Engine,
71 origin: int,
72 namespace: Optional[str] = None,
73 writeable: bool = True,
74 ):
75 super().__init__(origin=origin, engine=engine, namespace=namespace)
76 with engine.connect() as connection:
77 dbapi = connection.connection
78 try:
79 dsn = dbapi.get_dsn_parameters()
80 except (AttributeError, KeyError) as err:
81 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err
82 if namespace is None:
83 namespace = connection.execute("SELECT current_schema();").scalar()
84 query = "SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';"
85 if not connection.execute(sqlalchemy.text(query)).scalar():
86 raise RuntimeError(
87 "The Butler PostgreSQL backend requires the btree_gist extension. "
88 "As extensions are enabled per-database, this may require an administrator to run "
89 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is "
90 " initialized."
91 )
92 self.namespace = namespace
93 self.dbname = dsn.get("dbname")
94 self._writeable = writeable
95 self._shrinker = NameShrinker(connection.engine.dialect.max_identifier_length)
97 @classmethod
98 def makeEngine(cls, uri: str, *, writeable: bool = True) -> sqlalchemy.engine.Engine:
99 return sqlalchemy.engine.create_engine(uri, pool_size=1)
101 @classmethod
102 def fromEngine(
103 cls,
104 engine: sqlalchemy.engine.Engine,
105 *,
106 origin: int,
107 namespace: Optional[str] = None,
108 writeable: bool = True,
109 ) -> Database:
110 return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable)
112 @contextmanager
113 def _transaction(
114 self,
115 *,
116 interrupting: bool = False,
117 savepoint: bool = False,
118 lock: Iterable[sqlalchemy.schema.Table] = (),
119 for_temp_tables: bool = False,
120 ) -> Iterator[tuple[bool, sqlalchemy.engine.Connection]]:
121 with super()._transaction(interrupting=interrupting, savepoint=savepoint, lock=lock) as (
122 is_new,
123 connection,
124 ):
125 if is_new:
126 # pgbouncer with transaction-level pooling (which we aim to
127 # support) says that SET cannot be used, except for a list of
128 # "Startup parameters" that includes "timezone" (see
129 # https://www.pgbouncer.org/features.html#fnref:0). But I
130 # don't see "timezone" in PostgreSQL's list of parameters
131 # passed when creating a new connection
132 # (https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS).
133 # Given that the pgbouncer docs say, "PgBouncer detects their
134 # changes and so it can guarantee they remain consistent for
135 # the client", I assume we can use "SET TIMESPAN" and pgbouncer
136 # will take care of clients that share connections being set
137 # consistently. And if that assumption is wrong, we should
138 # still probably be okay, since all clients should be Butler
139 # clients, and they'll all be setting the same thing.
140 #
141 # The "SET TRANSACTION READ ONLY" should also be safe, because
142 # it only ever acts on the current transaction; I think it's
143 # not included in pgbouncer's declaration that SET is
144 # incompatible with transaction-level pooling because
145 # PostgreSQL actually considers SET TRANSACTION to be a
146 # fundamentally different statement from SET (they have their
147 # own distinct doc pages, at least).
148 if not (self.isWriteable() or for_temp_tables):
149 # PostgreSQL permits writing to temporary tables inside
150 # read-only transactions, but it doesn't permit creating
151 # them.
152 with closing(connection.connection.cursor()) as cursor:
153 cursor.execute("SET TRANSACTION READ ONLY")
154 cursor.execute("SET TIME ZONE 0")
155 else:
156 with closing(connection.connection.cursor()) as cursor:
157 # Make timestamps UTC, because we didn't use TIMESTAMPZ
158 # for the column type. When we can tolerate a schema
159 # change, we should change that type and remove this
160 # line.
161 cursor.execute("SET TIME ZONE 0")
162 yield is_new, connection
164 @contextmanager
165 def temporary_table(
166 self, spec: ddl.TableSpec, name: Optional[str] = None
167 ) -> Iterator[sqlalchemy.schema.Table]:
168 # Docstring inherited.
169 with self.transaction(for_temp_tables=True):
170 with super().temporary_table(spec, name) as table:
171 yield table
173 def _lockTables(
174 self, connection: sqlalchemy.engine.Connection, tables: Iterable[sqlalchemy.schema.Table] = ()
175 ) -> None:
176 # Docstring inherited.
177 for table in tables:
178 connection.execute(sqlalchemy.text(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE"))
180 def isWriteable(self) -> bool:
181 return self._writeable
183 def __str__(self) -> str:
184 return f"PostgreSQL@{self.dbname}:{self.namespace}"
186 def shrinkDatabaseEntityName(self, original: str) -> str:
187 return self._shrinker.shrink(original)
189 def expandDatabaseEntityName(self, shrunk: str) -> str:
190 return self._shrinker.expand(shrunk)
192 def _convertExclusionConstraintSpec(
193 self,
194 table: str,
195 spec: Tuple[Union[str, Type[TimespanDatabaseRepresentation]], ...],
196 metadata: sqlalchemy.MetaData,
197 ) -> sqlalchemy.schema.Constraint:
198 # Docstring inherited.
199 args = []
200 names = ["excl"]
201 for item in spec:
202 if isinstance(item, str):
203 args.append((sqlalchemy.schema.Column(item), "="))
204 names.append(item)
205 elif issubclass(item, TimespanDatabaseRepresentation):
206 assert item is self.getTimespanRepresentation()
207 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&"))
208 names.append(TimespanDatabaseRepresentation.NAME)
209 return sqlalchemy.dialects.postgresql.ExcludeConstraint(
210 *args,
211 name=self.shrinkDatabaseEntityName("_".join(names)),
212 )
214 def _make_temporary_table(
215 self,
216 connection: sqlalchemy.engine.Connection,
217 spec: ddl.TableSpec,
218 name: Optional[str] = None,
219 **kwargs: Any,
220 ) -> sqlalchemy.schema.Table:
221 # Docstring inherited
222 # Adding ON COMMIT DROP here is really quite defensive: we already
223 # manually drop the table at the end of the temporary_table context
224 # manager, and that will usually happen first. But this will guarantee
225 # that we drop the table at the end of the transaction even if the
226 # connection lasts longer, and that's good citizenship when connections
227 # may be multiplexed by e.g. pgbouncer.
228 return super()._make_temporary_table(connection, spec, name, postgresql_on_commit="DROP", **kwargs)
230 @classmethod
231 def getTimespanRepresentation(cls) -> Type[TimespanDatabaseRepresentation]:
232 # Docstring inherited.
233 return _RangeTimespanRepresentation
235 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None:
236 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.")
237 if not rows:
238 return
239 # This uses special support for UPSERT in PostgreSQL backend:
240 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert
241 query = sqlalchemy.dialects.postgresql.dml.insert(table)
242 # In the SET clause assign all columns using special `excluded`
243 # pseudo-table. If some column in the table does not appear in the
244 # INSERT list this will set it to NULL.
245 excluded = query.excluded
246 data = {
247 column.name: getattr(excluded, column.name)
248 for column in table.columns
249 if column.name not in table.primary_key
250 }
251 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data)
252 with self._transaction() as (_, connection):
253 connection.execute(query, rows)
255 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict, primary_key_only: bool = False) -> int:
256 # Docstring inherited.
257 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.")
258 if not rows:
259 return 0
260 # Like `replace`, this uses UPSERT.
261 base_insert = sqlalchemy.dialects.postgresql.dml.insert(table)
262 if primary_key_only:
263 query = base_insert.on_conflict_do_nothing(constraint=table.primary_key)
264 else:
265 query = base_insert.on_conflict_do_nothing()
266 with self._transaction() as (_, connection):
267 return connection.execute(query, rows).rowcount
270class _RangeTimespanType(sqlalchemy.TypeDecorator):
271 """A single-column `Timespan` representation usable only with
272 PostgreSQL.
274 This type should be able to take advantage of PostgreSQL's built-in
275 range operators, and the indexing and EXCLUSION table constraints built
276 off of them.
277 """
279 impl = sqlalchemy.dialects.postgresql.INT8RANGE
281 cache_ok = True
283 def process_bind_param(
284 self, value: Optional[Timespan], dialect: sqlalchemy.engine.Dialect
285 ) -> Optional[psycopg2.extras.NumericRange]:
286 if value is None:
287 return None
288 if not isinstance(value, Timespan):
289 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.")
290 if value.isEmpty():
291 return psycopg2.extras.NumericRange(empty=True)
292 else:
293 converter = time_utils.TimeConverter()
294 assert value._nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__."
295 assert value._nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__."
296 lower = None if value._nsec[0] == converter.min_nsec else value._nsec[0]
297 upper = None if value._nsec[1] == converter.max_nsec else value._nsec[1]
298 return psycopg2.extras.NumericRange(lower=lower, upper=upper)
300 def process_result_value(
301 self, value: Optional[psycopg2.extras.NumericRange], dialect: sqlalchemy.engine.Dialect
302 ) -> Optional[Timespan]:
303 if value is None:
304 return None
305 if value.isempty:
306 return Timespan.makeEmpty()
307 converter = time_utils.TimeConverter()
308 begin_nsec = converter.min_nsec if value.lower is None else value.lower
309 end_nsec = converter.max_nsec if value.upper is None else value.upper
310 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec))
312 class comparator_factory(sqlalchemy.types.Concatenable.Comparator): # noqa: N801
313 """Comparison operators for TimespanColumnRanges.
315 Notes
316 -----
317 The existence of this nested class is a workaround for a bug
318 submitted upstream as
319 https://github.com/sqlalchemy/sqlalchemy/issues/5476 (now fixed on
320 main, but not in the releases we currently use). The code is
321 a limited copy of the operators in
322 ``sqlalchemy.dialects.postgresql.ranges.RangeOperators``, but with
323 ``is_comparison=True`` added to all calls.
324 """
326 def __ne__(self, other: Any) -> Any:
327 "Boolean expression. Returns true if two ranges are not equal"
328 if other is None:
329 return super().__ne__(other)
330 else:
331 return self.expr.op("<>", is_comparison=True)(other)
333 def contains(self, other: Any, **kw: Any) -> Any:
334 """Boolean expression. Returns true if the right hand operand,
335 which can be an element or a range, is contained within the
336 column.
337 """
338 return self.expr.op("@>", is_comparison=True)(other)
340 def contained_by(self, other: Any) -> Any:
341 """Boolean expression. Returns true if the column is contained
342 within the right hand operand.
343 """
344 return self.expr.op("<@", is_comparison=True)(other)
346 def overlaps(self, other: Any) -> Any:
347 """Boolean expression. Returns true if the column overlaps
348 (has points in common with) the right hand operand.
349 """
350 return self.expr.op("&&", is_comparison=True)(other)
352 def strictly_left_of(self, other: Any) -> Any:
353 """Boolean expression. Returns true if the column is strictly
354 left of the right hand operand.
355 """
356 return self.expr.op("<<", is_comparison=True)(other)
358 __lshift__ = strictly_left_of
360 def strictly_right_of(self, other: Any) -> Any:
361 """Boolean expression. Returns true if the column is strictly
362 right of the right hand operand.
363 """
364 return self.expr.op(">>", is_comparison=True)(other)
366 __rshift__ = strictly_right_of
369class _RangeTimespanRepresentation(TimespanDatabaseRepresentation):
370 """An implementation of `TimespanDatabaseRepresentation` that uses
371 `_RangeTimespanType` to store a timespan in a single
372 PostgreSQL-specific field.
374 Parameters
375 ----------
376 column : `sqlalchemy.sql.ColumnElement`
377 SQLAlchemy object representing the column.
378 """
380 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str):
381 self.column = column
382 self._name = name
384 __slots__ = ("column", "_name")
386 @classmethod
387 def makeFieldSpecs(
388 cls, nullable: bool, name: Optional[str] = None, **kwargs: Any
389 ) -> Tuple[ddl.FieldSpec, ...]:
390 # Docstring inherited.
391 if name is None:
392 name = cls.NAME
393 return (
394 ddl.FieldSpec(
395 name,
396 dtype=_RangeTimespanType,
397 nullable=nullable,
398 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")),
399 **kwargs,
400 ),
401 )
403 @classmethod
404 def getFieldNames(cls, name: Optional[str] = None) -> Tuple[str, ...]:
405 # Docstring inherited.
406 if name is None:
407 name = cls.NAME
408 return (name,)
410 @classmethod
411 def update(
412 cls, extent: Optional[Timespan], name: Optional[str] = None, result: Optional[Dict[str, Any]] = None
413 ) -> Dict[str, Any]:
414 # Docstring inherited.
415 if name is None:
416 name = cls.NAME
417 if result is None:
418 result = {}
419 result[name] = extent
420 return result
422 @classmethod
423 def extract(cls, mapping: Mapping[str, Any], name: Optional[str] = None) -> Optional[Timespan]:
424 # Docstring inherited.
425 if name is None:
426 name = cls.NAME
427 return mapping[name]
429 @classmethod
430 def fromLiteral(cls, timespan: Timespan) -> _RangeTimespanRepresentation:
431 # Docstring inherited.
432 return cls(
433 column=sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType),
434 name=cls.NAME,
435 )
437 @classmethod
438 def fromSelectable(
439 cls, selectable: sqlalchemy.sql.FromClause, name: Optional[str] = None
440 ) -> _RangeTimespanRepresentation:
441 # Docstring inherited.
442 if name is None:
443 name = cls.NAME
444 return cls(selectable.columns[name], name)
446 @property
447 def name(self) -> str:
448 # Docstring inherited.
449 return self._name
451 def isNull(self) -> sqlalchemy.sql.ColumnElement:
452 # Docstring inherited.
453 return self.column.is_(None)
455 def isEmpty(self) -> sqlalchemy.sql.ColumnElement:
456 # Docstring inherited
457 return sqlalchemy.sql.func.isempty(self.column)
459 def __lt__(
460 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement]
461 ) -> sqlalchemy.sql.ColumnElement:
462 # Docstring inherited.
463 if isinstance(other, sqlalchemy.sql.ColumnElement):
464 return sqlalchemy.sql.and_(
465 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)),
466 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)),
467 sqlalchemy.sql.func.upper(self.column) <= other,
468 )
469 else:
470 return self.column << other.column
472 def __gt__(
473 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement]
474 ) -> sqlalchemy.sql.ColumnElement:
475 # Docstring inherited.
476 if isinstance(other, sqlalchemy.sql.ColumnElement):
477 return sqlalchemy.sql.and_(
478 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)),
479 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)),
480 sqlalchemy.sql.func.lower(self.column) > other,
481 )
482 else:
483 return self.column >> other.column
485 def overlaps(self, other: _RangeTimespanRepresentation) -> sqlalchemy.sql.ColumnElement:
486 # Docstring inherited.
487 return self.column.overlaps(other.column)
489 def contains(
490 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement]
491 ) -> sqlalchemy.sql.ColumnElement:
492 # Docstring inherited
493 if isinstance(other, _RangeTimespanRepresentation):
494 return self.column.contains(other.column)
495 else:
496 return self.column.contains(other)
498 def lower(self) -> sqlalchemy.sql.ColumnElement:
499 # Docstring inherited.
500 return sqlalchemy.sql.functions.coalesce(
501 sqlalchemy.sql.func.lower(self.column), sqlalchemy.sql.literal(0)
502 )
504 def upper(self) -> sqlalchemy.sql.ColumnElement:
505 # Docstring inherited.
506 return sqlalchemy.sql.functions.coalesce(
507 sqlalchemy.sql.func.upper(self.column), sqlalchemy.sql.literal(0)
508 )
510 def flatten(self, name: Optional[str] = None) -> Iterator[sqlalchemy.sql.ColumnElement]:
511 # Docstring inherited.
512 if name is None:
513 yield self.column
514 else:
515 yield self.column.label(name)