Coverage for python/lsst/daf/butler/registry/databases/postgresql.py : 28%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ["PostgresqlDatabase"]
25from contextlib import contextmanager, closing
26from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, Type, Union
28import psycopg2
29import sqlalchemy.dialects.postgresql
31from ..interfaces import Database
32from ..nameShrinker import NameShrinker
33from ...core import TimespanDatabaseRepresentation, ddl, Timespan, time_utils
36class PostgresqlDatabase(Database):
37 """An implementation of the `Database` interface for PostgreSQL.
39 Parameters
40 ----------
41 connection : `sqlalchemy.engine.Connection`
42 An existing connection created by a previous call to `connect`.
43 origin : `int`
44 An integer ID that should be used as the default for any datasets,
45 quanta, or other entities that use a (autoincrement, origin) compound
46 primary key.
47 namespace : `str`, optional
48 The namespace (schema) this database is associated with. If `None`,
49 the default schema for the connection is used (which may be `None`).
50 writeable : `bool`, optional
51 If `True`, allow write operations on the database, including
52 ``CREATE TABLE``.
54 Notes
55 -----
56 This currently requires the psycopg2 driver to be used as the backend for
57 SQLAlchemy. Running the tests for this class requires the
58 ``testing.postgresql`` be installed, which we assume indicates that a
59 PostgreSQL server is installed and can be run locally in userspace.
61 Some functionality provided by this class (and used by `Registry`) requires
62 the ``btree_gist`` PostgreSQL server extension to be installed an enabled
63 on the database being connected to; this is checked at connection time.
64 """
66 def __init__(self, *, connection: sqlalchemy.engine.Connection, origin: int,
67 namespace: Optional[str] = None, writeable: bool = True):
68 super().__init__(origin=origin, connection=connection, namespace=namespace)
69 dbapi = connection.connection
70 try:
71 dsn = dbapi.get_dsn_parameters()
72 except (AttributeError, KeyError) as err:
73 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err
74 if namespace is None:
75 namespace = connection.execute("SELECT current_schema();").scalar()
76 if not connection.execute("SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';").scalar():
77 raise RuntimeError(
78 "The Butler PostgreSQL backend requires the btree_gist extension. "
79 "As extensions are enabled per-database, this may require an administrator to run "
80 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is initialized."
81 )
82 self.namespace = namespace
83 self.dbname = dsn.get("dbname")
84 self._writeable = writeable
85 self._shrinker = NameShrinker(connection.engine.dialect.max_identifier_length)
87 @classmethod
88 def connect(cls, uri: str, *, writeable: bool = True) -> sqlalchemy.engine.Connection:
89 return sqlalchemy.engine.create_engine(uri, poolclass=sqlalchemy.pool.NullPool).connect()
91 @classmethod
92 def fromConnection(cls, connection: sqlalchemy.engine.Connection, *, origin: int,
93 namespace: Optional[str] = None, writeable: bool = True) -> Database:
94 return cls(connection=connection, origin=origin, namespace=namespace, writeable=writeable)
96 @contextmanager
97 def transaction(self, *, interrupting: bool = False, savepoint: bool = False,
98 lock: Iterable[sqlalchemy.schema.Table] = ()) -> Iterator[None]:
99 with super().transaction(interrupting=interrupting, savepoint=savepoint, lock=lock):
100 if not self.isWriteable():
101 with closing(self._connection.connection.cursor()) as cursor:
102 cursor.execute("SET TRANSACTION READ ONLY")
103 else:
104 with closing(self._connection.connection.cursor()) as cursor:
105 # Make timestamps UTC, because we didn't use TIMESTAMPZ for
106 # the column type. When we can tolerate a schema change,
107 # we should change that type and remove this line.
108 cursor.execute("SET TIME ZONE 0")
109 yield
111 def _lockTables(self, tables: Iterable[sqlalchemy.schema.Table] = ()) -> None:
112 # Docstring inherited.
113 for table in tables:
114 self._connection.execute(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE")
116 def isWriteable(self) -> bool:
117 return self._writeable
119 def __str__(self) -> str:
120 return f"PostgreSQL@{self.dbname}:{self.namespace}"
122 def shrinkDatabaseEntityName(self, original: str) -> str:
123 return self._shrinker.shrink(original)
125 def expandDatabaseEntityName(self, shrunk: str) -> str:
126 return self._shrinker.expand(shrunk)
128 def _convertExclusionConstraintSpec(self, table: str,
129 spec: Tuple[Union[str, Type[TimespanDatabaseRepresentation]], ...],
130 metadata: sqlalchemy.MetaData) -> sqlalchemy.schema.Constraint:
131 # Docstring inherited.
132 args = []
133 names = ["excl"]
134 for item in spec:
135 if isinstance(item, str):
136 args.append((sqlalchemy.schema.Column(item), "="))
137 names.append(item)
138 elif issubclass(item, TimespanDatabaseRepresentation):
139 assert item is self.getTimespanRepresentation()
140 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&"))
141 names.append(TimespanDatabaseRepresentation.NAME)
142 return sqlalchemy.dialects.postgresql.ExcludeConstraint(
143 *args,
144 name=self.shrinkDatabaseEntityName("_".join(names)),
145 )
147 @classmethod
148 def getTimespanRepresentation(cls) -> Type[TimespanDatabaseRepresentation]:
149 # Docstring inherited.
150 return _RangeTimespanRepresentation
152 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None:
153 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.")
154 if not rows:
155 return
156 # This uses special support for UPSERT in PostgreSQL backend:
157 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert
158 query = sqlalchemy.dialects.postgresql.dml.insert(table)
159 # In the SET clause assign all columns using special `excluded`
160 # pseudo-table. If some column in the table does not appear in the
161 # INSERT list this will set it to NULL.
162 excluded = query.excluded
163 data = {column.name: getattr(excluded, column.name)
164 for column in table.columns
165 if column.name not in table.primary_key}
166 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data)
167 self._connection.execute(query, *rows)
169 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict) -> int:
170 # Docstring inherited.
171 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.")
172 if not rows:
173 return 0
174 # Like `replace`, this uses UPSERT, but it's a bit simpler because
175 # we don't care which constraint is violated or specify which columns
176 # to update.
177 query = sqlalchemy.dialects.postgresql.dml.insert(table).on_conflict_do_nothing()
178 return self._connection.execute(query, *rows).rowcount
181class _RangeTimespanType(sqlalchemy.TypeDecorator):
182 """A single-column `Timespan` representation usable only with
183 PostgreSQL.
185 This type should be able to take advantage of PostgreSQL's built-in
186 range operators, and the indexing and EXCLUSION table constraints built
187 off of them.
188 """
190 impl = sqlalchemy.dialects.postgresql.INT8RANGE
192 def process_bind_param(self, value: Optional[Timespan],
193 dialect: sqlalchemy.engine.Dialect
194 ) -> Optional[psycopg2.extras.NumericRange]:
195 if value is None:
196 return None
197 if not isinstance(value, Timespan):
198 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.")
199 lower = None if value.begin is None else time_utils.astropy_to_nsec(value.begin)
200 upper = None if value.end is None else time_utils.astropy_to_nsec(value.end)
201 return psycopg2.extras.NumericRange(lower=lower, upper=upper)
203 def process_result_value(self, value: Optional[psycopg2.extras.NumericRange],
204 dialect: sqlalchemy.engine.Dialect
205 ) -> Optional[Timespan]:
206 if value is None or value.isempty:
207 return None
208 begin = None if value.lower is None else time_utils.nsec_to_astropy(value.lower)
209 end = None if value.upper is None else time_utils.nsec_to_astropy(value.upper)
210 return Timespan(begin=begin, end=end)
212 class comparator_factory(sqlalchemy.types.Concatenable.Comparator): # noqa: N801
213 """Comparison operators for TimespanColumnRanges.
215 Notes
216 -----
217 The existence of this nested class is a workaround for a bug
218 submitted upstream as
219 https://github.com/sqlalchemy/sqlalchemy/issues/5476. The code is
220 a limited copy of the operators in
221 ``sqlalchemy.dialects.postgresql.ranges.RangeOperators``, but with
222 ``is_comparison=True`` added to all calls.
223 """
225 def __ne__(self, other: Any) -> Any:
226 "Boolean expression. Returns true if two ranges are not equal"
227 if other is None:
228 return super().__ne__(other)
229 else:
230 return self.expr.op("<>", is_comparison=True)(other)
232 def contains(self, other: Any, **kw: Any) -> Any:
233 """Boolean expression. Returns true if the right hand operand,
234 which can be an element or a range, is contained within the
235 column.
236 """
237 return self.expr.op("@>", is_comparison=True)(other)
239 def contained_by(self, other: Any) -> Any:
240 """Boolean expression. Returns true if the column is contained
241 within the right hand operand.
242 """
243 return self.expr.op("<@", is_comparison=True)(other)
245 def overlaps(self, other: Any) -> Any:
246 """Boolean expression. Returns true if the column overlaps
247 (has points in common with) the right hand operand.
248 """
249 return self.expr.op("&&", is_comparison=True)(other)
252class _RangeTimespanRepresentation(TimespanDatabaseRepresentation):
253 """An implementation of `TimespanDatabaseRepresentation` that uses
254 `_RangeTimespanType` to store a timespan in a single
255 PostgreSQL-specific field.
257 Parameters
258 ----------
259 column : `sqlalchemy.sql.ColumnElement`
260 SQLAlchemy object representing the column.
261 """
262 def __init__(self, column: sqlalchemy.sql.ColumnElement):
263 self.column = column
265 __slots__ = ("column",)
267 @classmethod
268 def makeFieldSpecs(cls, nullable: bool, **kwargs: Any) -> Tuple[ddl.FieldSpec, ...]:
269 # Docstring inherited.
270 return (
271 ddl.FieldSpec(
272 cls.NAME, dtype=_RangeTimespanType, nullable=nullable,
273 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")),
274 **kwargs
275 ),
276 )
278 @classmethod
279 def getFieldNames(cls) -> Tuple[str, ...]:
280 # Docstring inherited.
281 return (cls.NAME,)
283 @classmethod
284 def update(cls, timespan: Optional[Timespan], *,
285 result: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
286 # Docstring inherited.
287 if result is None:
288 result = {}
289 result[cls.NAME] = timespan
290 return result
292 @classmethod
293 def extract(cls, mapping: Mapping[str, Any]) -> Optional[Timespan]:
294 # Docstring inherited.
295 return mapping[cls.NAME]
297 @classmethod
298 def hasExclusionConstraint(cls) -> bool:
299 # Docstring inherited.
300 return True
302 @classmethod
303 def fromSelectable(cls, selectable: sqlalchemy.sql.FromClause) -> _RangeTimespanRepresentation:
304 # Docstring inherited.
305 return cls(selectable.columns[cls.NAME])
307 def isNull(self) -> sqlalchemy.sql.ColumnElement:
308 # Docstring inherited.
309 return self.column.is_(None)
311 def overlaps(self, other: Union[Timespan, _RangeTimespanRepresentation]) -> sqlalchemy.sql.ColumnElement:
312 # Docstring inherited.
313 if isinstance(other, Timespan):
314 return self.column.overlaps(other)
315 else:
316 return self.column.overlaps(other.column)