Coverage for python/lsst/daf/butler/registry/databases/postgresql.py: 30%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ["PostgresqlDatabase"]
25from contextlib import closing, contextmanager
26from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, Type, Union
28import psycopg2
29import sqlalchemy
30import sqlalchemy.dialects.postgresql
32from ...core import Timespan, TimespanDatabaseRepresentation, ddl, time_utils
33from ..interfaces import Database
34from ..nameShrinker import NameShrinker
37class PostgresqlDatabase(Database):
38 """An implementation of the `Database` interface for PostgreSQL.
40 Parameters
41 ----------
42 connection : `sqlalchemy.engine.Connection`
43 An existing connection created by a previous call to `connect`.
44 origin : `int`
45 An integer ID that should be used as the default for any datasets,
46 quanta, or other entities that use a (autoincrement, origin) compound
47 primary key.
48 namespace : `str`, optional
49 The namespace (schema) this database is associated with. If `None`,
50 the default schema for the connection is used (which may be `None`).
51 writeable : `bool`, optional
52 If `True`, allow write operations on the database, including
53 ``CREATE TABLE``.
55 Notes
56 -----
57 This currently requires the psycopg2 driver to be used as the backend for
58 SQLAlchemy. Running the tests for this class requires the
59 ``testing.postgresql`` be installed, which we assume indicates that a
60 PostgreSQL server is installed and can be run locally in userspace.
62 Some functionality provided by this class (and used by `Registry`) requires
63 the ``btree_gist`` PostgreSQL server extension to be installed an enabled
64 on the database being connected to; this is checked at connection time.
65 """
67 def __init__(
68 self,
69 *,
70 engine: sqlalchemy.engine.Engine,
71 origin: int,
72 namespace: Optional[str] = None,
73 writeable: bool = True,
74 ):
75 super().__init__(origin=origin, engine=engine, namespace=namespace)
76 with engine.connect() as connection:
77 dbapi = connection.connection
78 try:
79 dsn = dbapi.get_dsn_parameters()
80 except (AttributeError, KeyError) as err:
81 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err
82 if namespace is None:
83 namespace = connection.execute("SELECT current_schema();").scalar()
84 query = "SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';"
85 if not connection.execute(sqlalchemy.text(query)).scalar():
86 raise RuntimeError(
87 "The Butler PostgreSQL backend requires the btree_gist extension. "
88 "As extensions are enabled per-database, this may require an administrator to run "
89 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is "
90 " initialized."
91 )
92 self.namespace = namespace
93 self.dbname = dsn.get("dbname")
94 self._writeable = writeable
95 self._shrinker = NameShrinker(connection.engine.dialect.max_identifier_length)
97 @classmethod
98 def makeEngine(cls, uri: str, *, writeable: bool = True) -> sqlalchemy.engine.Engine:
99 return sqlalchemy.engine.create_engine(uri)
101 @classmethod
102 def fromEngine(
103 cls,
104 engine: sqlalchemy.engine.Engine,
105 *,
106 origin: int,
107 namespace: Optional[str] = None,
108 writeable: bool = True,
109 ) -> Database:
110 return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable)
112 @contextmanager
113 def transaction(
114 self,
115 *,
116 interrupting: bool = False,
117 savepoint: bool = False,
118 lock: Iterable[sqlalchemy.schema.Table] = (),
119 ) -> Iterator[None]:
120 with super().transaction(interrupting=interrupting, savepoint=savepoint, lock=lock):
121 assert self._session_connection is not None, "Guaranteed to have a connection in transaction"
122 if not self.isWriteable():
123 with closing(self._session_connection.connection.cursor()) as cursor:
124 cursor.execute("SET TRANSACTION READ ONLY")
125 else:
126 with closing(self._session_connection.connection.cursor()) as cursor:
127 # Make timestamps UTC, because we didn't use TIMESTAMPZ for
128 # the column type. When we can tolerate a schema change,
129 # we should change that type and remove this line.
130 cursor.execute("SET TIME ZONE 0")
131 yield
133 def _lockTables(
134 self, connection: sqlalchemy.engine.Connection, tables: Iterable[sqlalchemy.schema.Table] = ()
135 ) -> None:
136 # Docstring inherited.
137 for table in tables:
138 connection.execute(sqlalchemy.text(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE"))
140 def isWriteable(self) -> bool:
141 return self._writeable
143 def __str__(self) -> str:
144 return f"PostgreSQL@{self.dbname}:{self.namespace}"
146 def shrinkDatabaseEntityName(self, original: str) -> str:
147 return self._shrinker.shrink(original)
149 def expandDatabaseEntityName(self, shrunk: str) -> str:
150 return self._shrinker.expand(shrunk)
152 def _convertExclusionConstraintSpec(
153 self,
154 table: str,
155 spec: Tuple[Union[str, Type[TimespanDatabaseRepresentation]], ...],
156 metadata: sqlalchemy.MetaData,
157 ) -> sqlalchemy.schema.Constraint:
158 # Docstring inherited.
159 args = []
160 names = ["excl"]
161 for item in spec:
162 if isinstance(item, str):
163 args.append((sqlalchemy.schema.Column(item), "="))
164 names.append(item)
165 elif issubclass(item, TimespanDatabaseRepresentation):
166 assert item is self.getTimespanRepresentation()
167 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&"))
168 names.append(TimespanDatabaseRepresentation.NAME)
169 return sqlalchemy.dialects.postgresql.ExcludeConstraint(
170 *args,
171 name=self.shrinkDatabaseEntityName("_".join(names)),
172 )
174 @classmethod
175 def getTimespanRepresentation(cls) -> Type[TimespanDatabaseRepresentation]:
176 # Docstring inherited.
177 return _RangeTimespanRepresentation
179 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None:
180 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.")
181 if not rows:
182 return
183 # This uses special support for UPSERT in PostgreSQL backend:
184 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert
185 query = sqlalchemy.dialects.postgresql.dml.insert(table)
186 # In the SET clause assign all columns using special `excluded`
187 # pseudo-table. If some column in the table does not appear in the
188 # INSERT list this will set it to NULL.
189 excluded = query.excluded
190 data = {
191 column.name: getattr(excluded, column.name)
192 for column in table.columns
193 if column.name not in table.primary_key
194 }
195 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data)
196 with self._connection() as connection:
197 connection.execute(query, rows)
199 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict) -> int:
200 # Docstring inherited.
201 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.")
202 if not rows:
203 return 0
204 # Like `replace`, this uses UPSERT, but it's a bit simpler because
205 # we don't care which constraint is violated or specify which columns
206 # to update.
207 query = sqlalchemy.dialects.postgresql.dml.insert(table).on_conflict_do_nothing()
208 with self._connection() as connection:
209 return connection.execute(query, rows).rowcount
212class _RangeTimespanType(sqlalchemy.TypeDecorator):
213 """A single-column `Timespan` representation usable only with
214 PostgreSQL.
216 This type should be able to take advantage of PostgreSQL's built-in
217 range operators, and the indexing and EXCLUSION table constraints built
218 off of them.
219 """
221 impl = sqlalchemy.dialects.postgresql.INT8RANGE
223 cache_ok = True
225 def process_bind_param(
226 self, value: Optional[Timespan], dialect: sqlalchemy.engine.Dialect
227 ) -> Optional[psycopg2.extras.NumericRange]:
228 if value is None:
229 return None
230 if not isinstance(value, Timespan):
231 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.")
232 if value.isEmpty():
233 return psycopg2.extras.NumericRange(empty=True)
234 else:
235 converter = time_utils.TimeConverter()
236 assert value._nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__."
237 assert value._nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__."
238 lower = None if value._nsec[0] == converter.min_nsec else value._nsec[0]
239 upper = None if value._nsec[1] == converter.max_nsec else value._nsec[1]
240 return psycopg2.extras.NumericRange(lower=lower, upper=upper)
242 def process_result_value(
243 self, value: Optional[psycopg2.extras.NumericRange], dialect: sqlalchemy.engine.Dialect
244 ) -> Optional[Timespan]:
245 if value is None:
246 return None
247 if value.isempty:
248 return Timespan.makeEmpty()
249 converter = time_utils.TimeConverter()
250 begin_nsec = converter.min_nsec if value.lower is None else value.lower
251 end_nsec = converter.max_nsec if value.upper is None else value.upper
252 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec))
254 class comparator_factory(sqlalchemy.types.Concatenable.Comparator): # noqa: N801
255 """Comparison operators for TimespanColumnRanges.
257 Notes
258 -----
259 The existence of this nested class is a workaround for a bug
260 submitted upstream as
261 https://github.com/sqlalchemy/sqlalchemy/issues/5476 (now fixed on
262 main, but not in the releases we currently use). The code is
263 a limited copy of the operators in
264 ``sqlalchemy.dialects.postgresql.ranges.RangeOperators``, but with
265 ``is_comparison=True`` added to all calls.
266 """
268 def __ne__(self, other: Any) -> Any:
269 "Boolean expression. Returns true if two ranges are not equal"
270 if other is None:
271 return super().__ne__(other)
272 else:
273 return self.expr.op("<>", is_comparison=True)(other)
275 def contains(self, other: Any, **kw: Any) -> Any:
276 """Boolean expression. Returns true if the right hand operand,
277 which can be an element or a range, is contained within the
278 column.
279 """
280 return self.expr.op("@>", is_comparison=True)(other)
282 def contained_by(self, other: Any) -> Any:
283 """Boolean expression. Returns true if the column is contained
284 within the right hand operand.
285 """
286 return self.expr.op("<@", is_comparison=True)(other)
288 def overlaps(self, other: Any) -> Any:
289 """Boolean expression. Returns true if the column overlaps
290 (has points in common with) the right hand operand.
291 """
292 return self.expr.op("&&", is_comparison=True)(other)
294 def strictly_left_of(self, other: Any) -> Any:
295 """Boolean expression. Returns true if the column is strictly
296 left of the right hand operand.
297 """
298 return self.expr.op("<<", is_comparison=True)(other)
300 __lshift__ = strictly_left_of
302 def strictly_right_of(self, other: Any) -> Any:
303 """Boolean expression. Returns true if the column is strictly
304 right of the right hand operand.
305 """
306 return self.expr.op(">>", is_comparison=True)(other)
308 __rshift__ = strictly_right_of
311class _RangeTimespanRepresentation(TimespanDatabaseRepresentation):
312 """An implementation of `TimespanDatabaseRepresentation` that uses
313 `_RangeTimespanType` to store a timespan in a single
314 PostgreSQL-specific field.
316 Parameters
317 ----------
318 column : `sqlalchemy.sql.ColumnElement`
319 SQLAlchemy object representing the column.
320 """
322 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str):
323 self.column = column
324 self._name = name
326 __slots__ = ("column", "_name")
328 @classmethod
329 def makeFieldSpecs(
330 cls, nullable: bool, name: Optional[str] = None, **kwargs: Any
331 ) -> Tuple[ddl.FieldSpec, ...]:
332 # Docstring inherited.
333 if name is None:
334 name = cls.NAME
335 return (
336 ddl.FieldSpec(
337 name,
338 dtype=_RangeTimespanType,
339 nullable=nullable,
340 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")),
341 **kwargs,
342 ),
343 )
345 @classmethod
346 def getFieldNames(cls, name: Optional[str] = None) -> Tuple[str, ...]:
347 # Docstring inherited.
348 if name is None:
349 name = cls.NAME
350 return (name,)
352 @classmethod
353 def update(
354 cls, extent: Optional[Timespan], name: Optional[str] = None, result: Optional[Dict[str, Any]] = None
355 ) -> Dict[str, Any]:
356 # Docstring inherited.
357 if name is None:
358 name = cls.NAME
359 if result is None:
360 result = {}
361 result[name] = extent
362 return result
364 @classmethod
365 def extract(cls, mapping: Mapping[str, Any], name: Optional[str] = None) -> Optional[Timespan]:
366 # Docstring inherited.
367 if name is None:
368 name = cls.NAME
369 return mapping[name]
371 @classmethod
372 def fromLiteral(cls, timespan: Timespan) -> _RangeTimespanRepresentation:
373 # Docstring inherited.
374 return cls(
375 column=sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType),
376 name=cls.NAME,
377 )
379 @classmethod
380 def fromSelectable(
381 cls, selectable: sqlalchemy.sql.FromClause, name: Optional[str] = None
382 ) -> _RangeTimespanRepresentation:
383 # Docstring inherited.
384 if name is None:
385 name = cls.NAME
386 return cls(selectable.columns[name], name)
388 @property
389 def name(self) -> str:
390 # Docstring inherited.
391 return self._name
393 def isNull(self) -> sqlalchemy.sql.ColumnElement:
394 # Docstring inherited.
395 return self.column.is_(None)
397 def isEmpty(self) -> sqlalchemy.sql.ColumnElement:
398 # Docstring inherited
399 return sqlalchemy.sql.func.isempty(self.column)
401 def __lt__(
402 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement]
403 ) -> sqlalchemy.sql.ColumnElement:
404 # Docstring inherited.
405 if isinstance(other, sqlalchemy.sql.ColumnElement):
406 return sqlalchemy.sql.and_(
407 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)),
408 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)),
409 sqlalchemy.sql.func.upper(self.column) <= other,
410 )
411 else:
412 return self.column << other.column
414 def __gt__(
415 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement]
416 ) -> sqlalchemy.sql.ColumnElement:
417 # Docstring inherited.
418 if isinstance(other, sqlalchemy.sql.ColumnElement):
419 return sqlalchemy.sql.and_(
420 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)),
421 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)),
422 sqlalchemy.sql.func.lower(self.column) > other,
423 )
424 else:
425 return self.column >> other.column
427 def overlaps(self, other: _RangeTimespanRepresentation) -> sqlalchemy.sql.ColumnElement:
428 # Docstring inherited.
429 return self.column.overlaps(other.column)
431 def contains(
432 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement]
433 ) -> sqlalchemy.sql.ColumnElement:
434 # Docstring inherited
435 if isinstance(other, _RangeTimespanRepresentation):
436 return self.column.contains(other.column)
437 else:
438 return self.column.contains(other)
440 def lower(self) -> sqlalchemy.sql.ColumnElement:
441 # Docstring inherited.
442 return sqlalchemy.sql.functions.coalesce(
443 sqlalchemy.sql.func.lower(self.column), sqlalchemy.sql.literal(0)
444 )
446 def upper(self) -> sqlalchemy.sql.ColumnElement:
447 # Docstring inherited.
448 return sqlalchemy.sql.functions.coalesce(
449 sqlalchemy.sql.func.upper(self.column), sqlalchemy.sql.literal(0)
450 )
452 def flatten(self, name: Optional[str] = None) -> Iterator[sqlalchemy.sql.ColumnElement]:
453 # Docstring inherited.
454 if name is None:
455 yield self.column
456 else:
457 yield self.column.label(name)