Coverage for python/lsst/daf/butler/registry/databases/postgresql.py : 27%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ["PostgresqlDatabase"]
25from contextlib import contextmanager, closing
26from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, Type, Union
28import psycopg2
29import sqlalchemy.dialects.postgresql
31from ..interfaces import Database
32from ..nameShrinker import NameShrinker
33from ...core import ddl, time_utils, Timespan, TimespanDatabaseRepresentation
36class PostgresqlDatabase(Database):
37 """An implementation of the `Database` interface for PostgreSQL.
39 Parameters
40 ----------
41 connection : `sqlalchemy.engine.Connection`
42 An existing connection created by a previous call to `connect`.
43 origin : `int`
44 An integer ID that should be used as the default for any datasets,
45 quanta, or other entities that use a (autoincrement, origin) compound
46 primary key.
47 namespace : `str`, optional
48 The namespace (schema) this database is associated with. If `None`,
49 the default schema for the connection is used (which may be `None`).
50 writeable : `bool`, optional
51 If `True`, allow write operations on the database, including
52 ``CREATE TABLE``.
54 Notes
55 -----
56 This currently requires the psycopg2 driver to be used as the backend for
57 SQLAlchemy. Running the tests for this class requires the
58 ``testing.postgresql`` be installed, which we assume indicates that a
59 PostgreSQL server is installed and can be run locally in userspace.
61 Some functionality provided by this class (and used by `Registry`) requires
62 the ``btree_gist`` PostgreSQL server extension to be installed an enabled
63 on the database being connected to; this is checked at connection time.
64 """
66 def __init__(self, *, engine: sqlalchemy.engine.Engine, origin: int,
67 namespace: Optional[str] = None, writeable: bool = True):
68 super().__init__(origin=origin, engine=engine, namespace=namespace)
69 with engine.connect() as connection:
70 dbapi = connection.connection
71 try:
72 dsn = dbapi.get_dsn_parameters()
73 except (AttributeError, KeyError) as err:
74 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err
75 if namespace is None:
76 namespace = connection.execute("SELECT current_schema();").scalar()
77 query = "SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';"
78 if not connection.execute(query).scalar():
79 raise RuntimeError(
80 "The Butler PostgreSQL backend requires the btree_gist extension. "
81 "As extensions are enabled per-database, this may require an administrator to run "
82 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is "
83 " initialized."
84 )
85 self.namespace = namespace
86 self.dbname = dsn.get("dbname")
87 self._writeable = writeable
88 self._shrinker = NameShrinker(connection.engine.dialect.max_identifier_length)
90 @classmethod
91 def makeEngine(cls, uri: str, *, writeable: bool = True) -> sqlalchemy.engine.Engine:
92 return sqlalchemy.engine.create_engine(uri)
94 @classmethod
95 def fromEngine(cls, engine: sqlalchemy.engine.Engine, *, origin: int,
96 namespace: Optional[str] = None, writeable: bool = True) -> Database:
97 return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable)
99 @contextmanager
100 def transaction(self, *, interrupting: bool = False, savepoint: bool = False,
101 lock: Iterable[sqlalchemy.schema.Table] = ()) -> Iterator[None]:
102 with super().transaction(interrupting=interrupting, savepoint=savepoint, lock=lock):
103 if not self.isWriteable():
104 with closing(self._connection.connection.cursor()) as cursor:
105 cursor.execute("SET TRANSACTION READ ONLY")
106 else:
107 with closing(self._connection.connection.cursor()) as cursor:
108 # Make timestamps UTC, because we didn't use TIMESTAMPZ for
109 # the column type. When we can tolerate a schema change,
110 # we should change that type and remove this line.
111 cursor.execute("SET TIME ZONE 0")
112 yield
114 def _lockTables(self, tables: Iterable[sqlalchemy.schema.Table] = ()) -> None:
115 # Docstring inherited.
116 for table in tables:
117 self._connection.execute(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE")
119 def isWriteable(self) -> bool:
120 return self._writeable
122 def __str__(self) -> str:
123 return f"PostgreSQL@{self.dbname}:{self.namespace}"
125 def shrinkDatabaseEntityName(self, original: str) -> str:
126 return self._shrinker.shrink(original)
128 def expandDatabaseEntityName(self, shrunk: str) -> str:
129 return self._shrinker.expand(shrunk)
131 def _convertExclusionConstraintSpec(self, table: str,
132 spec: Tuple[Union[str, Type[TimespanDatabaseRepresentation]], ...],
133 metadata: sqlalchemy.MetaData) -> sqlalchemy.schema.Constraint:
134 # Docstring inherited.
135 args = []
136 names = ["excl"]
137 for item in spec:
138 if isinstance(item, str):
139 args.append((sqlalchemy.schema.Column(item), "="))
140 names.append(item)
141 elif issubclass(item, TimespanDatabaseRepresentation):
142 assert item is self.getTimespanRepresentation()
143 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&"))
144 names.append(TimespanDatabaseRepresentation.NAME)
145 return sqlalchemy.dialects.postgresql.ExcludeConstraint(
146 *args,
147 name=self.shrinkDatabaseEntityName("_".join(names)),
148 )
150 @classmethod
151 def getTimespanRepresentation(cls) -> Type[TimespanDatabaseRepresentation]:
152 # Docstring inherited.
153 return _RangeTimespanRepresentation
155 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None:
156 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.")
157 if not rows:
158 return
159 # This uses special support for UPSERT in PostgreSQL backend:
160 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert
161 query = sqlalchemy.dialects.postgresql.dml.insert(table)
162 # In the SET clause assign all columns using special `excluded`
163 # pseudo-table. If some column in the table does not appear in the
164 # INSERT list this will set it to NULL.
165 excluded = query.excluded
166 data = {column.name: getattr(excluded, column.name)
167 for column in table.columns
168 if column.name not in table.primary_key}
169 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data)
170 self._connection.execute(query, *rows)
172 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict) -> int:
173 # Docstring inherited.
174 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.")
175 if not rows:
176 return 0
177 # Like `replace`, this uses UPSERT, but it's a bit simpler because
178 # we don't care which constraint is violated or specify which columns
179 # to update.
180 query = sqlalchemy.dialects.postgresql.dml.insert(table).on_conflict_do_nothing()
181 return self._connection.execute(query, *rows).rowcount
184class _RangeTimespanType(sqlalchemy.TypeDecorator):
185 """A single-column `Timespan` representation usable only with
186 PostgreSQL.
188 This type should be able to take advantage of PostgreSQL's built-in
189 range operators, and the indexing and EXCLUSION table constraints built
190 off of them.
191 """
193 impl = sqlalchemy.dialects.postgresql.INT8RANGE
195 cache_ok = True
197 def process_bind_param(self, value: Optional[Timespan],
198 dialect: sqlalchemy.engine.Dialect
199 ) -> Optional[psycopg2.extras.NumericRange]:
200 if value is None:
201 return None
202 if not isinstance(value, Timespan):
203 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.")
204 if value.isEmpty():
205 return psycopg2.extras.NumericRange(empty=True)
206 else:
207 converter = time_utils.TimeConverter()
208 assert value._nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__."
209 assert value._nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__."
210 lower = None if value._nsec[0] == converter.min_nsec else value._nsec[0]
211 upper = None if value._nsec[1] == converter.max_nsec else value._nsec[1]
212 return psycopg2.extras.NumericRange(lower=lower, upper=upper)
214 def process_result_value(self, value: Optional[psycopg2.extras.NumericRange],
215 dialect: sqlalchemy.engine.Dialect
216 ) -> Optional[Timespan]:
217 if value is None:
218 return None
219 if value.isempty:
220 return Timespan.makeEmpty()
221 converter = time_utils.TimeConverter()
222 begin_nsec = converter.min_nsec if value.lower is None else value.lower
223 end_nsec = converter.max_nsec if value.upper is None else value.upper
224 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec))
226 class comparator_factory(sqlalchemy.types.Concatenable.Comparator): # noqa: N801
227 """Comparison operators for TimespanColumnRanges.
229 Notes
230 -----
231 The existence of this nested class is a workaround for a bug
232 submitted upstream as
233 https://github.com/sqlalchemy/sqlalchemy/issues/5476 (now fixed on
234 master, but not in the releases we currently use). The code is
235 a limited copy of the operators in
236 ``sqlalchemy.dialects.postgresql.ranges.RangeOperators``, but with
237 ``is_comparison=True`` added to all calls.
238 """
240 def __ne__(self, other: Any) -> Any:
241 "Boolean expression. Returns true if two ranges are not equal"
242 if other is None:
243 return super().__ne__(other)
244 else:
245 return self.expr.op("<>", is_comparison=True)(other)
247 def contains(self, other: Any, **kw: Any) -> Any:
248 """Boolean expression. Returns true if the right hand operand,
249 which can be an element or a range, is contained within the
250 column.
251 """
252 return self.expr.op("@>", is_comparison=True)(other)
254 def contained_by(self, other: Any) -> Any:
255 """Boolean expression. Returns true if the column is contained
256 within the right hand operand.
257 """
258 return self.expr.op("<@", is_comparison=True)(other)
260 def overlaps(self, other: Any) -> Any:
261 """Boolean expression. Returns true if the column overlaps
262 (has points in common with) the right hand operand.
263 """
264 return self.expr.op("&&", is_comparison=True)(other)
266 def strictly_left_of(self, other: Any) -> Any:
267 """Boolean expression. Returns true if the column is strictly
268 left of the right hand operand.
269 """
270 return self.expr.op("<<", is_comparison=True)(other)
272 __lshift__ = strictly_left_of
274 def strictly_right_of(self, other: Any) -> Any:
275 """Boolean expression. Returns true if the column is strictly
276 right of the right hand operand.
277 """
278 return self.expr.op(">>", is_comparison=True)(other)
280 __rshift__ = strictly_right_of
283class _RangeTimespanRepresentation(TimespanDatabaseRepresentation):
284 """An implementation of `TimespanDatabaseRepresentation` that uses
285 `_RangeTimespanType` to store a timespan in a single
286 PostgreSQL-specific field.
288 Parameters
289 ----------
290 column : `sqlalchemy.sql.ColumnElement`
291 SQLAlchemy object representing the column.
292 """
293 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str):
294 self.column = column
295 self._name = name
297 __slots__ = ("column", "_name")
299 @classmethod
300 def makeFieldSpecs(cls, nullable: bool, name: Optional[str] = None, **kwargs: Any
301 ) -> Tuple[ddl.FieldSpec, ...]:
302 # Docstring inherited.
303 if name is None:
304 name = cls.NAME
305 return (
306 ddl.FieldSpec(
307 name, dtype=_RangeTimespanType, nullable=nullable,
308 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")),
309 **kwargs
310 ),
311 )
313 @classmethod
314 def getFieldNames(cls, name: Optional[str] = None) -> Tuple[str, ...]:
315 # Docstring inherited.
316 if name is None:
317 name = cls.NAME
318 return (name,)
320 @classmethod
321 def update(cls, extent: Optional[Timespan], name: Optional[str] = None,
322 result: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
323 # Docstring inherited.
324 if name is None:
325 name = cls.NAME
326 if result is None:
327 result = {}
328 result[name] = extent
329 return result
331 @classmethod
332 def extract(cls, mapping: Mapping[str, Any], name: Optional[str] = None) -> Optional[Timespan]:
333 # Docstring inherited.
334 if name is None:
335 name = cls.NAME
336 return mapping[name]
338 @classmethod
339 def fromLiteral(cls, timespan: Timespan) -> _RangeTimespanRepresentation:
340 # Docstring inherited.
341 return cls(
342 column=sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType),
343 name=cls.NAME,
344 )
346 @classmethod
347 def fromSelectable(cls, selectable: sqlalchemy.sql.FromClause, name: Optional[str] = None
348 ) -> _RangeTimespanRepresentation:
349 # Docstring inherited.
350 if name is None:
351 name = cls.NAME
352 return cls(selectable.columns[name], name)
354 @property
355 def name(self) -> str:
356 # Docstring inherited.
357 return self._name
359 def isNull(self) -> sqlalchemy.sql.ColumnElement:
360 # Docstring inherited.
361 return self.column.is_(None)
363 def isEmpty(self) -> sqlalchemy.sql.ColumnElement:
364 # Docstring inherited
365 return sqlalchemy.sql.func.isempty(self.column)
367 def __lt__(
368 self,
369 other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement]
370 ) -> sqlalchemy.sql.ColumnElement:
371 # Docstring inherited.
372 if isinstance(other, sqlalchemy.sql.ColumnElement):
373 return sqlalchemy.sql.and_(
374 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)),
375 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)),
376 sqlalchemy.sql.func.upper(self.column) <= other,
377 )
378 else:
379 return self.column << other.column
381 def __gt__(
382 self,
383 other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement]
384 ) -> sqlalchemy.sql.ColumnElement:
385 # Docstring inherited.
386 if isinstance(other, sqlalchemy.sql.ColumnElement):
387 return sqlalchemy.sql.and_(
388 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)),
389 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)),
390 sqlalchemy.sql.func.lower(self.column) > other,
391 )
392 else:
393 return self.column >> other.column
395 def overlaps(self, other: _RangeTimespanRepresentation) -> sqlalchemy.sql.ColumnElement:
396 # Docstring inherited.
397 return self.column.overlaps(other.column)
399 def contains(self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement]
400 ) -> sqlalchemy.sql.ColumnElement:
401 # Docstring inherited
402 if isinstance(other, _RangeTimespanRepresentation):
403 return self.column.contains(other.column)
404 else:
405 return self.column.contains(other)
407 def flatten(self, name: Optional[str] = None) -> Iterator[sqlalchemy.sql.ColumnElement]:
408 # Docstring inherited.
409 if name is None:
410 yield self.column
411 else:
412 yield self.column.label(name)