Coverage for python/lsst/daf/butler/registry/databases/postgresql.py: 24%
187 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-12 09:00 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-12 09:00 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ["PostgresqlDatabase"]
25from contextlib import closing, contextmanager
26from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, Type, Union
28import psycopg2
29import sqlalchemy
30import sqlalchemy.dialects.postgresql
32from ...core import Timespan, TimespanDatabaseRepresentation, ddl, time_utils
33from ...core.named import NamedValueAbstractSet
34from ..interfaces import Database
35from ..nameShrinker import NameShrinker
38class PostgresqlDatabase(Database):
39 """An implementation of the `Database` interface for PostgreSQL.
41 Parameters
42 ----------
43 connection : `sqlalchemy.engine.Connection`
44 An existing connection created by a previous call to `connect`.
45 origin : `int`
46 An integer ID that should be used as the default for any datasets,
47 quanta, or other entities that use a (autoincrement, origin) compound
48 primary key.
49 namespace : `str`, optional
50 The namespace (schema) this database is associated with. If `None`,
51 the default schema for the connection is used (which may be `None`).
52 writeable : `bool`, optional
53 If `True`, allow write operations on the database, including
54 ``CREATE TABLE``.
56 Notes
57 -----
58 This currently requires the psycopg2 driver to be used as the backend for
59 SQLAlchemy. Running the tests for this class requires the
60 ``testing.postgresql`` be installed, which we assume indicates that a
61 PostgreSQL server is installed and can be run locally in userspace.
63 Some functionality provided by this class (and used by `Registry`) requires
64 the ``btree_gist`` PostgreSQL server extension to be installed an enabled
65 on the database being connected to; this is checked at connection time.
66 """
68 def __init__(
69 self,
70 *,
71 engine: sqlalchemy.engine.Engine,
72 origin: int,
73 namespace: Optional[str] = None,
74 writeable: bool = True,
75 ):
76 super().__init__(origin=origin, engine=engine, namespace=namespace)
77 with engine.connect() as connection:
78 dbapi = connection.connection
79 try:
80 dsn = dbapi.get_dsn_parameters()
81 except (AttributeError, KeyError) as err:
82 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err
83 if namespace is None:
84 namespace = connection.execute("SELECT current_schema();").scalar()
85 query = "SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';"
86 if not connection.execute(sqlalchemy.text(query)).scalar():
87 raise RuntimeError(
88 "The Butler PostgreSQL backend requires the btree_gist extension. "
89 "As extensions are enabled per-database, this may require an administrator to run "
90 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is "
91 " initialized."
92 )
93 self.namespace = namespace
94 self.dbname = dsn.get("dbname")
95 self._writeable = writeable
96 self._shrinker = NameShrinker(connection.engine.dialect.max_identifier_length)
98 @classmethod
99 def makeEngine(cls, uri: str, *, writeable: bool = True) -> sqlalchemy.engine.Engine:
100 return sqlalchemy.engine.create_engine(uri)
102 @classmethod
103 def fromEngine(
104 cls,
105 engine: sqlalchemy.engine.Engine,
106 *,
107 origin: int,
108 namespace: Optional[str] = None,
109 writeable: bool = True,
110 ) -> Database:
111 return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable)
113 @contextmanager
114 def transaction(
115 self,
116 *,
117 interrupting: bool = False,
118 savepoint: bool = False,
119 lock: Iterable[sqlalchemy.schema.Table] = (),
120 ) -> Iterator[None]:
121 with super().transaction(interrupting=interrupting, savepoint=savepoint, lock=lock):
122 assert self._session_connection is not None, "Guaranteed to have a connection in transaction"
123 if not self.isWriteable():
124 with closing(self._session_connection.connection.cursor()) as cursor:
125 cursor.execute("SET TRANSACTION READ ONLY")
126 else:
127 with closing(self._session_connection.connection.cursor()) as cursor:
128 # Make timestamps UTC, because we didn't use TIMESTAMPZ for
129 # the column type. When we can tolerate a schema change,
130 # we should change that type and remove this line.
131 cursor.execute("SET TIME ZONE 0")
132 yield
134 def _lockTables(
135 self, connection: sqlalchemy.engine.Connection, tables: Iterable[sqlalchemy.schema.Table] = ()
136 ) -> None:
137 # Docstring inherited.
138 for table in tables:
139 connection.execute(sqlalchemy.text(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE"))
141 def isWriteable(self) -> bool:
142 return self._writeable
144 def __str__(self) -> str:
145 return f"PostgreSQL@{self.dbname}:{self.namespace}"
147 def shrinkDatabaseEntityName(self, original: str) -> str:
148 return self._shrinker.shrink(original)
150 def expandDatabaseEntityName(self, shrunk: str) -> str:
151 return self._shrinker.expand(shrunk)
153 def _convertExclusionConstraintSpec(
154 self,
155 table: str,
156 spec: Tuple[Union[str, Type[TimespanDatabaseRepresentation]], ...],
157 metadata: sqlalchemy.MetaData,
158 ) -> sqlalchemy.schema.Constraint:
159 # Docstring inherited.
160 args = []
161 names = ["excl"]
162 for item in spec:
163 if isinstance(item, str):
164 args.append((sqlalchemy.schema.Column(item), "="))
165 names.append(item)
166 elif issubclass(item, TimespanDatabaseRepresentation):
167 assert item is self.getTimespanRepresentation()
168 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&"))
169 names.append(TimespanDatabaseRepresentation.NAME)
170 return sqlalchemy.dialects.postgresql.ExcludeConstraint(
171 *args,
172 name=self.shrinkDatabaseEntityName("_".join(names)),
173 )
175 @classmethod
176 def getTimespanRepresentation(cls) -> Type[TimespanDatabaseRepresentation]:
177 # Docstring inherited.
178 return _RangeTimespanRepresentation
180 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None:
181 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.")
182 if not rows:
183 return
184 # This uses special support for UPSERT in PostgreSQL backend:
185 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert
186 query = sqlalchemy.dialects.postgresql.dml.insert(table)
187 # In the SET clause assign all columns using special `excluded`
188 # pseudo-table. If some column in the table does not appear in the
189 # INSERT list this will set it to NULL.
190 excluded = query.excluded
191 data = {
192 column.name: getattr(excluded, column.name)
193 for column in table.columns
194 if column.name not in table.primary_key
195 }
196 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data)
197 with self._connection() as connection:
198 connection.execute(query, rows)
200 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict, primary_key_only: bool = False) -> int:
201 # Docstring inherited.
202 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.")
203 if not rows:
204 return 0
205 # Like `replace`, this uses UPSERT.
206 base_insert = sqlalchemy.dialects.postgresql.dml.insert(table)
207 if primary_key_only:
208 query = base_insert.on_conflict_do_nothing(constraint=table.primary_key)
209 else:
210 query = base_insert.on_conflict_do_nothing()
211 with self._connection() as connection:
212 return connection.execute(query, rows).rowcount
214 def constant_rows(
215 self,
216 fields: NamedValueAbstractSet[ddl.FieldSpec],
217 *rows: dict,
218 name: Optional[str] = None,
219 ) -> sqlalchemy.sql.FromClause:
220 # Docstring inherited.
221 return super().constant_rows(fields, *rows, name=name)
224class _RangeTimespanType(sqlalchemy.TypeDecorator):
225 """A single-column `Timespan` representation usable only with
226 PostgreSQL.
228 This type should be able to take advantage of PostgreSQL's built-in
229 range operators, and the indexing and EXCLUSION table constraints built
230 off of them.
231 """
233 impl = sqlalchemy.dialects.postgresql.INT8RANGE
235 cache_ok = True
237 def process_bind_param(
238 self, value: Optional[Timespan], dialect: sqlalchemy.engine.Dialect
239 ) -> Optional[psycopg2.extras.NumericRange]:
240 if value is None:
241 return None
242 if not isinstance(value, Timespan):
243 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.")
244 if value.isEmpty():
245 return psycopg2.extras.NumericRange(empty=True)
246 else:
247 converter = time_utils.TimeConverter()
248 assert value._nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__."
249 assert value._nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__."
250 lower = None if value._nsec[0] == converter.min_nsec else value._nsec[0]
251 upper = None if value._nsec[1] == converter.max_nsec else value._nsec[1]
252 return psycopg2.extras.NumericRange(lower=lower, upper=upper)
254 def process_result_value(
255 self, value: Optional[psycopg2.extras.NumericRange], dialect: sqlalchemy.engine.Dialect
256 ) -> Optional[Timespan]:
257 if value is None:
258 return None
259 if value.isempty:
260 return Timespan.makeEmpty()
261 converter = time_utils.TimeConverter()
262 begin_nsec = converter.min_nsec if value.lower is None else value.lower
263 end_nsec = converter.max_nsec if value.upper is None else value.upper
264 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec))
267class _RangeTimespanRepresentation(TimespanDatabaseRepresentation):
268 """An implementation of `TimespanDatabaseRepresentation` that uses
269 `_RangeTimespanType` to store a timespan in a single
270 PostgreSQL-specific field.
272 Parameters
273 ----------
274 column : `sqlalchemy.sql.ColumnElement`
275 SQLAlchemy object representing the column.
276 """
278 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str):
279 self.column = column
280 self._name = name
282 __slots__ = ("column", "_name")
284 @classmethod
285 def makeFieldSpecs(
286 cls, nullable: bool, name: Optional[str] = None, **kwargs: Any
287 ) -> Tuple[ddl.FieldSpec, ...]:
288 # Docstring inherited.
289 if name is None:
290 name = cls.NAME
291 return (
292 ddl.FieldSpec(
293 name,
294 dtype=_RangeTimespanType,
295 nullable=nullable,
296 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")),
297 **kwargs,
298 ),
299 )
301 @classmethod
302 def getFieldNames(cls, name: Optional[str] = None) -> Tuple[str, ...]:
303 # Docstring inherited.
304 if name is None:
305 name = cls.NAME
306 return (name,)
308 @classmethod
309 def update(
310 cls, extent: Optional[Timespan], name: Optional[str] = None, result: Optional[Dict[str, Any]] = None
311 ) -> Dict[str, Any]:
312 # Docstring inherited.
313 if name is None:
314 name = cls.NAME
315 if result is None:
316 result = {}
317 result[name] = extent
318 return result
320 @classmethod
321 def extract(cls, mapping: Mapping[str, Any], name: Optional[str] = None) -> Optional[Timespan]:
322 # Docstring inherited.
323 if name is None:
324 name = cls.NAME
325 return mapping[name]
327 @classmethod
328 def fromLiteral(cls, timespan: Optional[Timespan]) -> _RangeTimespanRepresentation:
329 # Docstring inherited.
330 if timespan is None:
331 return cls(column=sqlalchemy.sql.null(), name=cls.NAME)
332 return cls(
333 column=sqlalchemy.sql.cast(
334 sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType), type_=_RangeTimespanType
335 ),
336 name=cls.NAME,
337 )
339 @classmethod
340 def from_columns(
341 cls, columns: sqlalchemy.sql.ColumnCollection, name: Optional[str] = None
342 ) -> _RangeTimespanRepresentation:
343 # Docstring inherited.
344 if name is None:
345 name = cls.NAME
346 return cls(columns[name], name)
348 @property
349 def name(self) -> str:
350 # Docstring inherited.
351 return self._name
353 def isNull(self) -> sqlalchemy.sql.ColumnElement:
354 # Docstring inherited.
355 return self.column.is_(None)
357 def isEmpty(self) -> sqlalchemy.sql.ColumnElement:
358 # Docstring inherited
359 return sqlalchemy.sql.func.isempty(self.column)
361 def __lt__(
362 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement]
363 ) -> sqlalchemy.sql.ColumnElement:
364 # Docstring inherited.
365 if isinstance(other, sqlalchemy.sql.ColumnElement):
366 return sqlalchemy.sql.and_(
367 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)),
368 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)),
369 sqlalchemy.sql.func.upper(self.column) <= other,
370 )
371 else:
372 return self.column << other.column
374 def __gt__(
375 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement]
376 ) -> sqlalchemy.sql.ColumnElement:
377 # Docstring inherited.
378 if isinstance(other, sqlalchemy.sql.ColumnElement):
379 return sqlalchemy.sql.and_(
380 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)),
381 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)),
382 sqlalchemy.sql.func.lower(self.column) > other,
383 )
384 else:
385 return self.column >> other.column
387 def overlaps(
388 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement
389 ) -> sqlalchemy.sql.ColumnElement:
390 # Docstring inherited.
391 if not isinstance(other, _RangeTimespanRepresentation):
392 return self.contains(other)
393 return self.column.overlaps(other.column)
395 def contains(
396 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement]
397 ) -> sqlalchemy.sql.ColumnElement:
398 # Docstring inherited
399 if isinstance(other, _RangeTimespanRepresentation):
400 return self.column.contains(other.column)
401 else:
402 return self.column.contains(other)
404 def lower(self) -> sqlalchemy.sql.ColumnElement:
405 # Docstring inherited.
406 return sqlalchemy.sql.functions.coalesce(
407 sqlalchemy.sql.func.lower(self.column), sqlalchemy.sql.literal(0)
408 )
410 def upper(self) -> sqlalchemy.sql.ColumnElement:
411 # Docstring inherited.
412 return sqlalchemy.sql.functions.coalesce(
413 sqlalchemy.sql.func.upper(self.column), sqlalchemy.sql.literal(0)
414 )
416 def flatten(self, name: Optional[str] = None) -> tuple[sqlalchemy.sql.ColumnElement]:
417 # Docstring inherited.
418 if name is None:
419 return (self.column,)
420 else:
421 return (self.column.label(name),)