Coverage for python/lsst/daf/butler/registry/databases/postgresql.py : 27%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ["PostgresqlDatabase"]
25from contextlib import contextmanager, closing
26from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, Type, Union
28import psycopg2
29import sqlalchemy.dialects.postgresql
31from ..interfaces import Database
32from ..nameShrinker import NameShrinker
33from ...core import ddl, time_utils, Timespan, TimespanDatabaseRepresentation
36class PostgresqlDatabase(Database):
37 """An implementation of the `Database` interface for PostgreSQL.
39 Parameters
40 ----------
41 connection : `sqlalchemy.engine.Connection`
42 An existing connection created by a previous call to `connect`.
43 origin : `int`
44 An integer ID that should be used as the default for any datasets,
45 quanta, or other entities that use a (autoincrement, origin) compound
46 primary key.
47 namespace : `str`, optional
48 The namespace (schema) this database is associated with. If `None`,
49 the default schema for the connection is used (which may be `None`).
50 writeable : `bool`, optional
51 If `True`, allow write operations on the database, including
52 ``CREATE TABLE``.
54 Notes
55 -----
56 This currently requires the psycopg2 driver to be used as the backend for
57 SQLAlchemy. Running the tests for this class requires the
58 ``testing.postgresql`` be installed, which we assume indicates that a
59 PostgreSQL server is installed and can be run locally in userspace.
61 Some functionality provided by this class (and used by `Registry`) requires
62 the ``btree_gist`` PostgreSQL server extension to be installed an enabled
63 on the database being connected to; this is checked at connection time.
64 """
66 def __init__(self, *, engine: sqlalchemy.engine.Engine, origin: int,
67 namespace: Optional[str] = None, writeable: bool = True):
68 super().__init__(origin=origin, engine=engine, namespace=namespace)
69 with engine.connect() as connection:
70 dbapi = connection.connection
71 try:
72 dsn = dbapi.get_dsn_parameters()
73 except (AttributeError, KeyError) as err:
74 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err
75 if namespace is None:
76 namespace = connection.execute("SELECT current_schema();").scalar()
77 query = "SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';"
78 if not connection.execute(query).scalar():
79 raise RuntimeError(
80 "The Butler PostgreSQL backend requires the btree_gist extension. "
81 "As extensions are enabled per-database, this may require an administrator to run "
82 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is "
83 " initialized."
84 )
85 self.namespace = namespace
86 self.dbname = dsn.get("dbname")
87 self._writeable = writeable
88 self._shrinker = NameShrinker(connection.engine.dialect.max_identifier_length)
90 @classmethod
91 def makeEngine(cls, uri: str, *, writeable: bool = True) -> sqlalchemy.engine.Engine:
92 return sqlalchemy.engine.create_engine(uri)
94 @classmethod
95 def fromEngine(cls, engine: sqlalchemy.engine.Engine, *, origin: int,
96 namespace: Optional[str] = None, writeable: bool = True) -> Database:
97 return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable)
99 @contextmanager
100 def transaction(self, *, interrupting: bool = False, savepoint: bool = False,
101 lock: Iterable[sqlalchemy.schema.Table] = ()) -> Iterator[None]:
102 with super().transaction(interrupting=interrupting, savepoint=savepoint, lock=lock):
103 if not self.isWriteable():
104 with closing(self._connection.connection.cursor()) as cursor:
105 cursor.execute("SET TRANSACTION READ ONLY")
106 else:
107 with closing(self._connection.connection.cursor()) as cursor:
108 # Make timestamps UTC, because we didn't use TIMESTAMPZ for
109 # the column type. When we can tolerate a schema change,
110 # we should change that type and remove this line.
111 cursor.execute("SET TIME ZONE 0")
112 yield
114 def _lockTables(self, tables: Iterable[sqlalchemy.schema.Table] = ()) -> None:
115 # Docstring inherited.
116 for table in tables:
117 self._connection.execute(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE")
119 def isWriteable(self) -> bool:
120 return self._writeable
122 def __str__(self) -> str:
123 return f"PostgreSQL@{self.dbname}:{self.namespace}"
125 def shrinkDatabaseEntityName(self, original: str) -> str:
126 return self._shrinker.shrink(original)
128 def expandDatabaseEntityName(self, shrunk: str) -> str:
129 return self._shrinker.expand(shrunk)
131 def _convertExclusionConstraintSpec(self, table: str,
132 spec: Tuple[Union[str, Type[TimespanDatabaseRepresentation]], ...],
133 metadata: sqlalchemy.MetaData) -> sqlalchemy.schema.Constraint:
134 # Docstring inherited.
135 args = []
136 names = ["excl"]
137 for item in spec:
138 if isinstance(item, str):
139 args.append((sqlalchemy.schema.Column(item), "="))
140 names.append(item)
141 elif issubclass(item, TimespanDatabaseRepresentation):
142 assert item is self.getTimespanRepresentation()
143 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&"))
144 names.append(TimespanDatabaseRepresentation.NAME)
145 return sqlalchemy.dialects.postgresql.ExcludeConstraint(
146 *args,
147 name=self.shrinkDatabaseEntityName("_".join(names)),
148 )
150 @classmethod
151 def getTimespanRepresentation(cls) -> Type[TimespanDatabaseRepresentation]:
152 # Docstring inherited.
153 return _RangeTimespanRepresentation
155 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None:
156 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.")
157 if not rows:
158 return
159 # This uses special support for UPSERT in PostgreSQL backend:
160 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert
161 query = sqlalchemy.dialects.postgresql.dml.insert(table)
162 # In the SET clause assign all columns using special `excluded`
163 # pseudo-table. If some column in the table does not appear in the
164 # INSERT list this will set it to NULL.
165 excluded = query.excluded
166 data = {column.name: getattr(excluded, column.name)
167 for column in table.columns
168 if column.name not in table.primary_key}
169 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data)
170 self._connection.execute(query, *rows)
172 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict) -> int:
173 # Docstring inherited.
174 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.")
175 if not rows:
176 return 0
177 # Like `replace`, this uses UPSERT, but it's a bit simpler because
178 # we don't care which constraint is violated or specify which columns
179 # to update.
180 query = sqlalchemy.dialects.postgresql.dml.insert(table).on_conflict_do_nothing()
181 return self._connection.execute(query, *rows).rowcount
184class _RangeTimespanType(sqlalchemy.TypeDecorator):
185 """A single-column `Timespan` representation usable only with
186 PostgreSQL.
188 This type should be able to take advantage of PostgreSQL's built-in
189 range operators, and the indexing and EXCLUSION table constraints built
190 off of them.
191 """
193 impl = sqlalchemy.dialects.postgresql.INT8RANGE
195 def process_bind_param(self, value: Optional[Timespan],
196 dialect: sqlalchemy.engine.Dialect
197 ) -> Optional[psycopg2.extras.NumericRange]:
198 if value is None:
199 return None
200 if not isinstance(value, Timespan):
201 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.")
202 if value.isEmpty():
203 return psycopg2.extras.NumericRange(empty=True)
204 else:
205 converter = time_utils.TimeConverter()
206 assert value._nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__."
207 assert value._nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__."
208 lower = None if value._nsec[0] == converter.min_nsec else value._nsec[0]
209 upper = None if value._nsec[1] == converter.max_nsec else value._nsec[1]
210 return psycopg2.extras.NumericRange(lower=lower, upper=upper)
212 def process_result_value(self, value: Optional[psycopg2.extras.NumericRange],
213 dialect: sqlalchemy.engine.Dialect
214 ) -> Optional[Timespan]:
215 if value is None:
216 return None
217 if value.isempty:
218 return Timespan.makeEmpty()
219 converter = time_utils.TimeConverter()
220 begin_nsec = converter.min_nsec if value.lower is None else value.lower
221 end_nsec = converter.max_nsec if value.upper is None else value.upper
222 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec))
224 class comparator_factory(sqlalchemy.types.Concatenable.Comparator): # noqa: N801
225 """Comparison operators for TimespanColumnRanges.
227 Notes
228 -----
229 The existence of this nested class is a workaround for a bug
230 submitted upstream as
231 https://github.com/sqlalchemy/sqlalchemy/issues/5476 (now fixed on
232 master, but not in the releases we currently use). The code is
233 a limited copy of the operators in
234 ``sqlalchemy.dialects.postgresql.ranges.RangeOperators``, but with
235 ``is_comparison=True`` added to all calls.
236 """
238 def __ne__(self, other: Any) -> Any:
239 "Boolean expression. Returns true if two ranges are not equal"
240 if other is None:
241 return super().__ne__(other)
242 else:
243 return self.expr.op("<>", is_comparison=True)(other)
245 def contains(self, other: Any, **kw: Any) -> Any:
246 """Boolean expression. Returns true if the right hand operand,
247 which can be an element or a range, is contained within the
248 column.
249 """
250 return self.expr.op("@>", is_comparison=True)(other)
252 def contained_by(self, other: Any) -> Any:
253 """Boolean expression. Returns true if the column is contained
254 within the right hand operand.
255 """
256 return self.expr.op("<@", is_comparison=True)(other)
258 def overlaps(self, other: Any) -> Any:
259 """Boolean expression. Returns true if the column overlaps
260 (has points in common with) the right hand operand.
261 """
262 return self.expr.op("&&", is_comparison=True)(other)
264 def strictly_left_of(self, other: Any) -> Any:
265 """Boolean expression. Returns true if the column is strictly
266 left of the right hand operand.
267 """
268 return self.expr.op("<<", is_comparison=True)(other)
270 __lshift__ = strictly_left_of
272 def strictly_right_of(self, other: Any) -> Any:
273 """Boolean expression. Returns true if the column is strictly
274 right of the right hand operand.
275 """
276 return self.expr.op(">>", is_comparison=True)(other)
278 __rshift__ = strictly_right_of
281class _RangeTimespanRepresentation(TimespanDatabaseRepresentation):
282 """An implementation of `TimespanDatabaseRepresentation` that uses
283 `_RangeTimespanType` to store a timespan in a single
284 PostgreSQL-specific field.
286 Parameters
287 ----------
288 column : `sqlalchemy.sql.ColumnElement`
289 SQLAlchemy object representing the column.
290 """
291 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str):
292 self.column = column
293 self._name = name
295 __slots__ = ("column", "_name")
297 @classmethod
298 def makeFieldSpecs(cls, nullable: bool, name: Optional[str] = None, **kwargs: Any
299 ) -> Tuple[ddl.FieldSpec, ...]:
300 # Docstring inherited.
301 if name is None:
302 name = cls.NAME
303 return (
304 ddl.FieldSpec(
305 name, dtype=_RangeTimespanType, nullable=nullable,
306 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")),
307 **kwargs
308 ),
309 )
311 @classmethod
312 def getFieldNames(cls, name: Optional[str] = None) -> Tuple[str, ...]:
313 # Docstring inherited.
314 if name is None:
315 name = cls.NAME
316 return (name,)
318 @classmethod
319 def update(cls, extent: Optional[Timespan], name: Optional[str] = None,
320 result: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
321 # Docstring inherited.
322 if name is None:
323 name = cls.NAME
324 if result is None:
325 result = {}
326 result[name] = extent
327 return result
329 @classmethod
330 def extract(cls, mapping: Mapping[str, Any], name: Optional[str] = None) -> Optional[Timespan]:
331 # Docstring inherited.
332 if name is None:
333 name = cls.NAME
334 return mapping[name]
336 @classmethod
337 def fromLiteral(cls, timespan: Timespan) -> _RangeTimespanRepresentation:
338 # Docstring inherited.
339 return cls(
340 column=sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType),
341 name=cls.NAME,
342 )
344 @classmethod
345 def fromSelectable(cls, selectable: sqlalchemy.sql.FromClause, name: Optional[str] = None
346 ) -> _RangeTimespanRepresentation:
347 # Docstring inherited.
348 if name is None:
349 name = cls.NAME
350 return cls(selectable.columns[name], name)
352 @property
353 def name(self) -> str:
354 # Docstring inherited.
355 return self._name
357 def isNull(self) -> sqlalchemy.sql.ColumnElement:
358 # Docstring inherited.
359 return self.column.is_(None)
361 def isEmpty(self) -> sqlalchemy.sql.ColumnElement:
362 # Docstring inherited
363 return sqlalchemy.sql.func.isempty(self.column)
365 def __lt__(
366 self,
367 other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement]
368 ) -> sqlalchemy.sql.ColumnElement:
369 # Docstring inherited.
370 if isinstance(other, sqlalchemy.sql.ColumnElement):
371 return sqlalchemy.sql.and_(
372 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)),
373 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)),
374 sqlalchemy.sql.func.upper(self.column) <= other,
375 )
376 else:
377 return self.column << other.column
379 def __gt__(
380 self,
381 other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement]
382 ) -> sqlalchemy.sql.ColumnElement:
383 # Docstring inherited.
384 if isinstance(other, sqlalchemy.sql.ColumnElement):
385 return sqlalchemy.sql.and_(
386 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)),
387 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)),
388 sqlalchemy.sql.func.lower(self.column) > other,
389 )
390 else:
391 return self.column >> other.column
393 def overlaps(self, other: _RangeTimespanRepresentation) -> sqlalchemy.sql.ColumnElement:
394 # Docstring inherited.
395 return self.column.overlaps(other.column)
397 def contains(self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement]
398 ) -> sqlalchemy.sql.ColumnElement:
399 # Docstring inherited
400 if isinstance(other, _RangeTimespanRepresentation):
401 return self.column.contains(other.column)
402 else:
403 return self.column.contains(other)
405 def flatten(self, name: Optional[str] = None) -> Iterator[sqlalchemy.sql.ColumnElement]:
406 # Docstring inherited.
407 if name is None:
408 yield self.column
409 else:
410 yield self.column.label(name)