Coverage for python/lsst/daf/butler/registry/databases/postgresql.py : 29%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ["PostgresqlDatabase"]
25from contextlib import contextmanager, closing
26from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, Type, Union
28import psycopg2
29import sqlalchemy.dialects.postgresql
31from ..interfaces import Database
32from ..nameShrinker import NameShrinker
33from ...core import DatabaseTimespanRepresentation, ddl, Timespan, time_utils
36class PostgresqlDatabase(Database):
37 """An implementation of the `Database` interface for PostgreSQL.
39 Parameters
40 ----------
41 connection : `sqlalchemy.engine.Connection`
42 An existing connection created by a previous call to `connect`.
43 origin : `int`
44 An integer ID that should be used as the default for any datasets,
45 quanta, or other entities that use a (autoincrement, origin) compound
46 primary key.
47 namespace : `str`, optional
48 The namespace (schema) this database is associated with. If `None`,
49 the default schema for the connection is used (which may be `None`).
50 writeable : `bool`, optional
51 If `True`, allow write operations on the database, including
52 ``CREATE TABLE``.
54 Notes
55 -----
56 This currently requires the psycopg2 driver to be used as the backend for
57 SQLAlchemy. Running the tests for this class requires the
58 ``testing.postgresql`` be installed, which we assume indicates that a
59 PostgreSQL server is installed and can be run locally in userspace.
61 Some functionality provided by this class (and used by `Registry`) requires
62 the ``btree_gist`` PostgreSQL server extension to be installed an enabled
63 on the database being connected to; this is checked at connection time.
64 """
66 def __init__(self, *, connection: sqlalchemy.engine.Connection, origin: int,
67 namespace: Optional[str] = None, writeable: bool = True):
68 super().__init__(origin=origin, connection=connection, namespace=namespace)
69 dbapi = connection.connection
70 try:
71 dsn = dbapi.get_dsn_parameters()
72 except (AttributeError, KeyError) as err:
73 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err
74 if namespace is None:
75 namespace = connection.execute("SELECT current_schema();").scalar()
76 if not connection.execute("SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';").scalar():
77 raise RuntimeError(
78 "The Butler PostgreSQL backend requires the btree_gist extension. "
79 "As extensions are enabled per-database, this may require an administrator to run "
80 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is initialized."
81 )
82 self.namespace = namespace
83 self.dbname = dsn.get("dbname")
84 self._writeable = writeable
85 self._shrinker = NameShrinker(connection.engine.dialect.max_identifier_length)
87 @classmethod
88 def connect(cls, uri: str, *, writeable: bool = True) -> sqlalchemy.engine.Connection:
89 return sqlalchemy.engine.create_engine(uri, poolclass=sqlalchemy.pool.NullPool).connect()
91 @classmethod
92 def fromConnection(cls, connection: sqlalchemy.engine.Connection, *, origin: int,
93 namespace: Optional[str] = None, writeable: bool = True) -> Database:
94 return cls(connection=connection, origin=origin, namespace=namespace, writeable=writeable)
96 @contextmanager
97 def transaction(self, *, interrupting: bool = False, savepoint: bool = False,
98 lock: Iterable[sqlalchemy.schema.Table] = ()) -> Iterator[None]:
99 with super().transaction(interrupting=interrupting, savepoint=savepoint, lock=lock):
100 if not self.isWriteable():
101 with closing(self._connection.connection.cursor()) as cursor:
102 cursor.execute("SET TRANSACTION READ ONLY")
103 yield
105 def _lockTables(self, tables: Iterable[sqlalchemy.schema.Table] = ()) -> None:
106 # Docstring inherited.
107 for table in tables:
108 self._connection.execute(f"LOCK TABLE {table.key} IN SHARE MODE")
110 def isWriteable(self) -> bool:
111 return self._writeable
113 def __str__(self) -> str:
114 return f"PostgreSQL@{self.dbname}:{self.namespace}"
116 def shrinkDatabaseEntityName(self, original: str) -> str:
117 return self._shrinker.shrink(original)
119 def expandDatabaseEntityName(self, shrunk: str) -> str:
120 return self._shrinker.expand(shrunk)
122 def _convertExclusionConstraintSpec(self, table: str,
123 spec: Tuple[Union[str, Type[DatabaseTimespanRepresentation]], ...],
124 metadata: sqlalchemy.MetaData) -> sqlalchemy.schema.Constraint:
125 # Docstring inherited.
126 args = []
127 names = ["excl"]
128 for item in spec:
129 if isinstance(item, str):
130 args.append((sqlalchemy.schema.Column(item), "="))
131 names.append(item)
132 elif issubclass(item, DatabaseTimespanRepresentation):
133 assert item is self.getTimespanRepresentation()
134 args.append((sqlalchemy.schema.Column(DatabaseTimespanRepresentation.NAME), "&&"))
135 names.append(DatabaseTimespanRepresentation.NAME)
136 return sqlalchemy.dialects.postgresql.ExcludeConstraint(
137 *args,
138 name=self.shrinkDatabaseEntityName("_".join(names)),
139 )
141 @classmethod
142 def getTimespanRepresentation(cls) -> Type[DatabaseTimespanRepresentation]:
143 # Docstring inherited.
144 return _RangeTimespanRepresentation
146 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None:
147 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.")
148 if not rows:
149 return
150 # This uses special support for UPSERT in PostgreSQL backend:
151 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert
152 query = sqlalchemy.dialects.postgresql.dml.insert(table)
153 # In the SET clause assign all columns using special `excluded`
154 # pseudo-table. If some column in the table does not appear in the
155 # INSERT list this will set it to NULL.
156 excluded = query.excluded
157 data = {column.name: getattr(excluded, column.name)
158 for column in table.columns
159 if column.name not in table.primary_key}
160 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data)
161 self._connection.execute(query, *rows)
163 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict) -> int:
164 # Docstring inherited.
165 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.")
166 if not rows:
167 return 0
168 # Like `replace`, this uses UPSERT, but it's a bit simpler because
169 # we don't care which constraint is violated or specify which columns
170 # to update.
171 query = sqlalchemy.dialects.postgresql.dml.insert(table).on_conflict_do_nothing()
172 return self._connection.execute(query, *rows).rowcount
175class _RangeTimespanType(sqlalchemy.TypeDecorator):
176 """A single-column `Timespan` representation usable only with
177 PostgreSQL.
179 This type should be able to take advantage of PostgreSQL's built-in
180 range operators, and the indexing and EXCLUSION table constraints built
181 off of them.
182 """
184 impl = sqlalchemy.dialects.postgresql.INT8RANGE
186 def process_bind_param(self, value: Optional[Timespan],
187 dialect: sqlalchemy.engine.Dialect
188 ) -> Optional[psycopg2.extras.NumericRange]:
189 if value is None:
190 return None
191 if not isinstance(value, Timespan):
192 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.")
193 lower = None if value.begin is None else time_utils.astropy_to_nsec(value.begin)
194 upper = None if value.end is None else time_utils.astropy_to_nsec(value.end)
195 return psycopg2.extras.NumericRange(lower=lower, upper=upper)
197 def process_result_value(self, value: Optional[psycopg2.extras.NumericRange],
198 dialect: sqlalchemy.engine.Dialect
199 ) -> Optional[Timespan]:
200 if value is None or value.isempty:
201 return None
202 begin = None if value.lower is None else time_utils.nsec_to_astropy(value.lower)
203 end = None if value.upper is None else time_utils.nsec_to_astropy(value.upper)
204 return Timespan(begin=begin, end=end)
206 class comparator_factory(sqlalchemy.types.Concatenable.Comparator): # noqa: N801
207 """Comparison operators for TimespanColumnRanges.
209 Notes
210 -----
211 The existence of this nested class is a workaround for a bug
212 submitted upstream as
213 https://github.com/sqlalchemy/sqlalchemy/issues/5476. The code is
214 a limited copy of the operators in
215 ``sqlalchemy.dialects.postgresql.ranges.RangeOperators``, but with
216 ``is_comparison=True`` added to all calls.
217 """
219 def __ne__(self, other: Any) -> Any:
220 "Boolean expression. Returns true if two ranges are not equal"
221 if other is None:
222 return super().__ne__(other)
223 else:
224 return self.expr.op("<>", is_comparison=True)(other)
226 def contains(self, other: Any, **kw: Any) -> Any:
227 """Boolean expression. Returns true if the right hand operand,
228 which can be an element or a range, is contained within the
229 column.
230 """
231 return self.expr.op("@>", is_comparison=True)(other)
233 def contained_by(self, other: Any) -> Any:
234 """Boolean expression. Returns true if the column is contained
235 within the right hand operand.
236 """
237 return self.expr.op("<@", is_comparison=True)(other)
239 def overlaps(self, other: Any) -> Any:
240 """Boolean expression. Returns true if the column overlaps
241 (has points in common with) the right hand operand.
242 """
243 return self.expr.op("&&", is_comparison=True)(other)
246class _RangeTimespanRepresentation(DatabaseTimespanRepresentation):
247 """An implementation of `DatabaseTimespanRepresentation` that uses
248 `_RangeTimespanType` to store a timespan in a single
249 PostgreSQL-specific field.
251 Parameters
252 ----------
253 column : `sqlalchemy.sql.ColumnElement`
254 SQLAlchemy object representing the column.
255 """
256 def __init__(self, column: sqlalchemy.sql.ColumnElement):
257 self.column = column
259 __slots__ = ("column",)
261 @classmethod
262 def makeFieldSpecs(cls, nullable: bool, **kwargs: Any) -> Tuple[ddl.FieldSpec, ...]:
263 # Docstring inherited.
264 return (
265 ddl.FieldSpec(
266 cls.NAME, dtype=_RangeTimespanType, nullable=nullable,
267 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")),
268 **kwargs
269 ),
270 )
272 @classmethod
273 def getFieldNames(cls) -> Tuple[str, ...]:
274 # Docstring inherited.
275 return (cls.NAME,)
277 @classmethod
278 def update(cls, timespan: Optional[Timespan], *,
279 result: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
280 # Docstring inherited.
281 if result is None:
282 result = {}
283 result[cls.NAME] = timespan
284 return result
286 @classmethod
287 def extract(cls, mapping: Mapping[str, Any]) -> Optional[Timespan]:
288 # Docstring inherited.
289 return mapping[cls.NAME]
291 @classmethod
292 def hasExclusionConstraint(cls) -> bool:
293 # Docstring inherited.
294 return True
296 @classmethod
297 def fromSelectable(cls, selectable: sqlalchemy.sql.FromClause) -> _RangeTimespanRepresentation:
298 # Docstring inherited.
299 return cls(selectable.columns[cls.NAME])
301 def isNull(self) -> sqlalchemy.sql.ColumnElement:
302 # Docstring inherited.
303 return self.column.is_(None)
305 def overlaps(self, other: Union[Timespan, _RangeTimespanRepresentation]) -> sqlalchemy.sql.ColumnElement:
306 # Docstring inherited.
307 if isinstance(other, Timespan):
308 return self.column.overlaps(other)
309 else:
310 return self.column.overlaps(other.column)