Coverage for python/lsst/daf/butler/registry/databases/postgresql.py: 30%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ["PostgresqlDatabase"]
25from contextlib import contextmanager, closing
26from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, Type, Union
28import psycopg2
29import sqlalchemy
30import sqlalchemy.dialects.postgresql
32from ..interfaces import Database
33from ..nameShrinker import NameShrinker
34from ...core import ddl, time_utils, Timespan, TimespanDatabaseRepresentation
37class PostgresqlDatabase(Database):
38 """An implementation of the `Database` interface for PostgreSQL.
40 Parameters
41 ----------
42 connection : `sqlalchemy.engine.Connection`
43 An existing connection created by a previous call to `connect`.
44 origin : `int`
45 An integer ID that should be used as the default for any datasets,
46 quanta, or other entities that use a (autoincrement, origin) compound
47 primary key.
48 namespace : `str`, optional
49 The namespace (schema) this database is associated with. If `None`,
50 the default schema for the connection is used (which may be `None`).
51 writeable : `bool`, optional
52 If `True`, allow write operations on the database, including
53 ``CREATE TABLE``.
55 Notes
56 -----
57 This currently requires the psycopg2 driver to be used as the backend for
58 SQLAlchemy. Running the tests for this class requires the
59 ``testing.postgresql`` be installed, which we assume indicates that a
60 PostgreSQL server is installed and can be run locally in userspace.
62 Some functionality provided by this class (and used by `Registry`) requires
63 the ``btree_gist`` PostgreSQL server extension to be installed an enabled
64 on the database being connected to; this is checked at connection time.
65 """
67 def __init__(self, *, engine: sqlalchemy.engine.Engine, origin: int,
68 namespace: Optional[str] = None, writeable: bool = True):
69 super().__init__(origin=origin, engine=engine, namespace=namespace)
70 with engine.connect() as connection:
71 dbapi = connection.connection
72 try:
73 dsn = dbapi.get_dsn_parameters()
74 except (AttributeError, KeyError) as err:
75 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err
76 if namespace is None:
77 namespace = connection.execute("SELECT current_schema();").scalar()
78 query = "SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';"
79 if not connection.execute(sqlalchemy.text(query)).scalar():
80 raise RuntimeError(
81 "The Butler PostgreSQL backend requires the btree_gist extension. "
82 "As extensions are enabled per-database, this may require an administrator to run "
83 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is "
84 " initialized."
85 )
86 self.namespace = namespace
87 self.dbname = dsn.get("dbname")
88 self._writeable = writeable
89 self._shrinker = NameShrinker(connection.engine.dialect.max_identifier_length)
91 @classmethod
92 def makeEngine(cls, uri: str, *, writeable: bool = True) -> sqlalchemy.engine.Engine:
93 return sqlalchemy.engine.create_engine(uri)
95 @classmethod
96 def fromEngine(cls, engine: sqlalchemy.engine.Engine, *, origin: int,
97 namespace: Optional[str] = None, writeable: bool = True) -> Database:
98 return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable)
100 @contextmanager
101 def transaction(self, *, interrupting: bool = False, savepoint: bool = False,
102 lock: Iterable[sqlalchemy.schema.Table] = ()) -> Iterator[None]:
103 with super().transaction(interrupting=interrupting, savepoint=savepoint, lock=lock):
104 assert self._session_connection is not None, "Guaranteed to have a connection in transaction"
105 if not self.isWriteable():
106 with closing(self._session_connection.connection.cursor()) as cursor:
107 cursor.execute("SET TRANSACTION READ ONLY")
108 else:
109 with closing(self._session_connection.connection.cursor()) as cursor:
110 # Make timestamps UTC, because we didn't use TIMESTAMPZ for
111 # the column type. When we can tolerate a schema change,
112 # we should change that type and remove this line.
113 cursor.execute("SET TIME ZONE 0")
114 yield
116 def _lockTables(self, connection: sqlalchemy.engine.Connection,
117 tables: Iterable[sqlalchemy.schema.Table] = ()) -> None:
118 # Docstring inherited.
119 for table in tables:
120 connection.execute(sqlalchemy.text(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE"))
122 def isWriteable(self) -> bool:
123 return self._writeable
125 def __str__(self) -> str:
126 return f"PostgreSQL@{self.dbname}:{self.namespace}"
128 def shrinkDatabaseEntityName(self, original: str) -> str:
129 return self._shrinker.shrink(original)
131 def expandDatabaseEntityName(self, shrunk: str) -> str:
132 return self._shrinker.expand(shrunk)
134 def _convertExclusionConstraintSpec(self, table: str,
135 spec: Tuple[Union[str, Type[TimespanDatabaseRepresentation]], ...],
136 metadata: sqlalchemy.MetaData) -> sqlalchemy.schema.Constraint:
137 # Docstring inherited.
138 args = []
139 names = ["excl"]
140 for item in spec:
141 if isinstance(item, str):
142 args.append((sqlalchemy.schema.Column(item), "="))
143 names.append(item)
144 elif issubclass(item, TimespanDatabaseRepresentation):
145 assert item is self.getTimespanRepresentation()
146 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&"))
147 names.append(TimespanDatabaseRepresentation.NAME)
148 return sqlalchemy.dialects.postgresql.ExcludeConstraint(
149 *args,
150 name=self.shrinkDatabaseEntityName("_".join(names)),
151 )
153 @classmethod
154 def getTimespanRepresentation(cls) -> Type[TimespanDatabaseRepresentation]:
155 # Docstring inherited.
156 return _RangeTimespanRepresentation
158 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None:
159 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.")
160 if not rows:
161 return
162 # This uses special support for UPSERT in PostgreSQL backend:
163 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert
164 query = sqlalchemy.dialects.postgresql.dml.insert(table)
165 # In the SET clause assign all columns using special `excluded`
166 # pseudo-table. If some column in the table does not appear in the
167 # INSERT list this will set it to NULL.
168 excluded = query.excluded
169 data = {column.name: getattr(excluded, column.name)
170 for column in table.columns
171 if column.name not in table.primary_key}
172 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data)
173 with self._connection() as connection:
174 connection.execute(query, rows)
176 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict) -> int:
177 # Docstring inherited.
178 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.")
179 if not rows:
180 return 0
181 # Like `replace`, this uses UPSERT, but it's a bit simpler because
182 # we don't care which constraint is violated or specify which columns
183 # to update.
184 query = sqlalchemy.dialects.postgresql.dml.insert(table).on_conflict_do_nothing()
185 with self._connection() as connection:
186 return connection.execute(query, rows).rowcount
189class _RangeTimespanType(sqlalchemy.TypeDecorator):
190 """A single-column `Timespan` representation usable only with
191 PostgreSQL.
193 This type should be able to take advantage of PostgreSQL's built-in
194 range operators, and the indexing and EXCLUSION table constraints built
195 off of them.
196 """
198 impl = sqlalchemy.dialects.postgresql.INT8RANGE
200 cache_ok = True
202 def process_bind_param(self, value: Optional[Timespan],
203 dialect: sqlalchemy.engine.Dialect
204 ) -> Optional[psycopg2.extras.NumericRange]:
205 if value is None:
206 return None
207 if not isinstance(value, Timespan):
208 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.")
209 if value.isEmpty():
210 return psycopg2.extras.NumericRange(empty=True)
211 else:
212 converter = time_utils.TimeConverter()
213 assert value._nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__."
214 assert value._nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__."
215 lower = None if value._nsec[0] == converter.min_nsec else value._nsec[0]
216 upper = None if value._nsec[1] == converter.max_nsec else value._nsec[1]
217 return psycopg2.extras.NumericRange(lower=lower, upper=upper)
219 def process_result_value(self, value: Optional[psycopg2.extras.NumericRange],
220 dialect: sqlalchemy.engine.Dialect
221 ) -> Optional[Timespan]:
222 if value is None:
223 return None
224 if value.isempty:
225 return Timespan.makeEmpty()
226 converter = time_utils.TimeConverter()
227 begin_nsec = converter.min_nsec if value.lower is None else value.lower
228 end_nsec = converter.max_nsec if value.upper is None else value.upper
229 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec))
231 class comparator_factory(sqlalchemy.types.Concatenable.Comparator): # noqa: N801
232 """Comparison operators for TimespanColumnRanges.
234 Notes
235 -----
236 The existence of this nested class is a workaround for a bug
237 submitted upstream as
238 https://github.com/sqlalchemy/sqlalchemy/issues/5476 (now fixed on
239 main, but not in the releases we currently use). The code is
240 a limited copy of the operators in
241 ``sqlalchemy.dialects.postgresql.ranges.RangeOperators``, but with
242 ``is_comparison=True`` added to all calls.
243 """
245 def __ne__(self, other: Any) -> Any:
246 "Boolean expression. Returns true if two ranges are not equal"
247 if other is None:
248 return super().__ne__(other)
249 else:
250 return self.expr.op("<>", is_comparison=True)(other)
252 def contains(self, other: Any, **kw: Any) -> Any:
253 """Boolean expression. Returns true if the right hand operand,
254 which can be an element or a range, is contained within the
255 column.
256 """
257 return self.expr.op("@>", is_comparison=True)(other)
259 def contained_by(self, other: Any) -> Any:
260 """Boolean expression. Returns true if the column is contained
261 within the right hand operand.
262 """
263 return self.expr.op("<@", is_comparison=True)(other)
265 def overlaps(self, other: Any) -> Any:
266 """Boolean expression. Returns true if the column overlaps
267 (has points in common with) the right hand operand.
268 """
269 return self.expr.op("&&", is_comparison=True)(other)
271 def strictly_left_of(self, other: Any) -> Any:
272 """Boolean expression. Returns true if the column is strictly
273 left of the right hand operand.
274 """
275 return self.expr.op("<<", is_comparison=True)(other)
277 __lshift__ = strictly_left_of
279 def strictly_right_of(self, other: Any) -> Any:
280 """Boolean expression. Returns true if the column is strictly
281 right of the right hand operand.
282 """
283 return self.expr.op(">>", is_comparison=True)(other)
285 __rshift__ = strictly_right_of
288class _RangeTimespanRepresentation(TimespanDatabaseRepresentation):
289 """An implementation of `TimespanDatabaseRepresentation` that uses
290 `_RangeTimespanType` to store a timespan in a single
291 PostgreSQL-specific field.
293 Parameters
294 ----------
295 column : `sqlalchemy.sql.ColumnElement`
296 SQLAlchemy object representing the column.
297 """
298 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str):
299 self.column = column
300 self._name = name
302 __slots__ = ("column", "_name")
304 @classmethod
305 def makeFieldSpecs(cls, nullable: bool, name: Optional[str] = None, **kwargs: Any
306 ) -> Tuple[ddl.FieldSpec, ...]:
307 # Docstring inherited.
308 if name is None:
309 name = cls.NAME
310 return (
311 ddl.FieldSpec(
312 name, dtype=_RangeTimespanType, nullable=nullable,
313 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")),
314 **kwargs
315 ),
316 )
318 @classmethod
319 def getFieldNames(cls, name: Optional[str] = None) -> Tuple[str, ...]:
320 # Docstring inherited.
321 if name is None:
322 name = cls.NAME
323 return (name,)
325 @classmethod
326 def update(cls, extent: Optional[Timespan], name: Optional[str] = None,
327 result: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
328 # Docstring inherited.
329 if name is None:
330 name = cls.NAME
331 if result is None:
332 result = {}
333 result[name] = extent
334 return result
336 @classmethod
337 def extract(cls, mapping: Mapping[str, Any], name: Optional[str] = None) -> Optional[Timespan]:
338 # Docstring inherited.
339 if name is None:
340 name = cls.NAME
341 return mapping[name]
343 @classmethod
344 def fromLiteral(cls, timespan: Timespan) -> _RangeTimespanRepresentation:
345 # Docstring inherited.
346 return cls(
347 column=sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType),
348 name=cls.NAME,
349 )
351 @classmethod
352 def fromSelectable(cls, selectable: sqlalchemy.sql.FromClause, name: Optional[str] = None
353 ) -> _RangeTimespanRepresentation:
354 # Docstring inherited.
355 if name is None:
356 name = cls.NAME
357 return cls(selectable.columns[name], name)
359 @property
360 def name(self) -> str:
361 # Docstring inherited.
362 return self._name
364 def isNull(self) -> sqlalchemy.sql.ColumnElement:
365 # Docstring inherited.
366 return self.column.is_(None)
368 def isEmpty(self) -> sqlalchemy.sql.ColumnElement:
369 # Docstring inherited
370 return sqlalchemy.sql.func.isempty(self.column)
372 def __lt__(
373 self,
374 other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement]
375 ) -> sqlalchemy.sql.ColumnElement:
376 # Docstring inherited.
377 if isinstance(other, sqlalchemy.sql.ColumnElement):
378 return sqlalchemy.sql.and_(
379 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)),
380 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)),
381 sqlalchemy.sql.func.upper(self.column) <= other,
382 )
383 else:
384 return self.column << other.column
386 def __gt__(
387 self,
388 other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement]
389 ) -> sqlalchemy.sql.ColumnElement:
390 # Docstring inherited.
391 if isinstance(other, sqlalchemy.sql.ColumnElement):
392 return sqlalchemy.sql.and_(
393 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)),
394 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)),
395 sqlalchemy.sql.func.lower(self.column) > other,
396 )
397 else:
398 return self.column >> other.column
400 def overlaps(self, other: _RangeTimespanRepresentation) -> sqlalchemy.sql.ColumnElement:
401 # Docstring inherited.
402 return self.column.overlaps(other.column)
404 def contains(self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement]
405 ) -> sqlalchemy.sql.ColumnElement:
406 # Docstring inherited
407 if isinstance(other, _RangeTimespanRepresentation):
408 return self.column.contains(other.column)
409 else:
410 return self.column.contains(other)
412 def flatten(self, name: Optional[str] = None) -> Iterator[sqlalchemy.sql.ColumnElement]:
413 # Docstring inherited.
414 if name is None:
415 yield self.column
416 else:
417 yield self.column.label(name)