Coverage for python/lsst/daf/butler/registry/databases/postgresql.py: 31%

197 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-10-02 07:59 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29__all__ = ["PostgresqlDatabase"] 

30 

31from collections.abc import Iterable, Iterator, Mapping 

32from contextlib import closing, contextmanager 

33from typing import Any 

34 

35import psycopg2 

36import sqlalchemy 

37import sqlalchemy.dialects.postgresql 

38from sqlalchemy import sql 

39 

40from ...core import Timespan, TimespanDatabaseRepresentation, ddl, time_utils 

41from ...core.named import NamedValueAbstractSet 

42from ..interfaces import Database 

43from ..nameShrinker import NameShrinker 

44 

45 

46class PostgresqlDatabase(Database): 

47 """An implementation of the `Database` interface for PostgreSQL. 

48 

49 Parameters 

50 ---------- 

51 connection : `sqlalchemy.engine.Connection` 

52 An existing connection created by a previous call to `connect`. 

53 origin : `int` 

54 An integer ID that should be used as the default for any datasets, 

55 quanta, or other entities that use a (autoincrement, origin) compound 

56 primary key. 

57 namespace : `str`, optional 

58 The namespace (schema) this database is associated with. If `None`, 

59 the default schema for the connection is used (which may be `None`). 

60 writeable : `bool`, optional 

61 If `True`, allow write operations on the database, including 

62 ``CREATE TABLE``. 

63 

64 Notes 

65 ----- 

66 This currently requires the psycopg2 driver to be used as the backend for 

67 SQLAlchemy. Running the tests for this class requires the 

68 ``testing.postgresql`` be installed, which we assume indicates that a 

69 PostgreSQL server is installed and can be run locally in userspace. 

70 

71 Some functionality provided by this class (and used by `Registry`) requires 

72 the ``btree_gist`` PostgreSQL server extension to be installed an enabled 

73 on the database being connected to; this is checked at connection time. 

74 """ 

75 

76 def __init__( 

77 self, 

78 *, 

79 engine: sqlalchemy.engine.Engine, 

80 origin: int, 

81 namespace: str | None = None, 

82 writeable: bool = True, 

83 ): 

84 super().__init__(origin=origin, engine=engine, namespace=namespace) 

85 with engine.connect() as connection: 

86 # `Any` to make mypy ignore the line below, can't use type: ignore 

87 dbapi: Any = connection.connection 

88 try: 

89 dsn = dbapi.get_dsn_parameters() 

90 except (AttributeError, KeyError) as err: 

91 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err 

92 if namespace is None: 

93 query = sql.select(sql.func.current_schema()) 

94 namespace = connection.execute(query).scalar() 

95 query_text = "SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';" 

96 if not connection.execute(sqlalchemy.text(query_text)).scalar(): 

97 raise RuntimeError( 

98 "The Butler PostgreSQL backend requires the btree_gist extension. " 

99 "As extensions are enabled per-database, this may require an administrator to run " 

100 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is " 

101 " initialized." 

102 ) 

103 self.namespace = namespace 

104 self.dbname = dsn.get("dbname") 

105 self._writeable = writeable 

106 self._shrinker = NameShrinker(self.dialect.max_identifier_length) 

107 

108 @classmethod 

109 def makeEngine( 

110 cls, uri: str | sqlalchemy.engine.URL, *, writeable: bool = True 

111 ) -> sqlalchemy.engine.Engine: 

112 return sqlalchemy.engine.create_engine(uri, pool_size=1) 

113 

114 @classmethod 

115 def fromEngine( 

116 cls, 

117 engine: sqlalchemy.engine.Engine, 

118 *, 

119 origin: int, 

120 namespace: str | None = None, 

121 writeable: bool = True, 

122 ) -> Database: 

123 return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable) 

124 

125 @contextmanager 

126 def _transaction( 

127 self, 

128 *, 

129 interrupting: bool = False, 

130 savepoint: bool = False, 

131 lock: Iterable[sqlalchemy.schema.Table] = (), 

132 for_temp_tables: bool = False, 

133 ) -> Iterator[tuple[bool, sqlalchemy.engine.Connection]]: 

134 with super()._transaction(interrupting=interrupting, savepoint=savepoint, lock=lock) as ( 

135 is_new, 

136 connection, 

137 ): 

138 if is_new: 

139 # pgbouncer with transaction-level pooling (which we aim to 

140 # support) says that SET cannot be used, except for a list of 

141 # "Startup parameters" that includes "timezone" (see 

142 # https://www.pgbouncer.org/features.html#fnref:0). But I 

143 # don't see "timezone" in PostgreSQL's list of parameters 

144 # passed when creating a new connection 

145 # (https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS). 

146 # Given that the pgbouncer docs say, "PgBouncer detects their 

147 # changes and so it can guarantee they remain consistent for 

148 # the client", I assume we can use "SET TIMESPAN" and pgbouncer 

149 # will take care of clients that share connections being set 

150 # consistently. And if that assumption is wrong, we should 

151 # still probably be okay, since all clients should be Butler 

152 # clients, and they'll all be setting the same thing. 

153 # 

154 # The "SET TRANSACTION READ ONLY" should also be safe, because 

155 # it only ever acts on the current transaction; I think it's 

156 # not included in pgbouncer's declaration that SET is 

157 # incompatible with transaction-level pooling because 

158 # PostgreSQL actually considers SET TRANSACTION to be a 

159 # fundamentally different statement from SET (they have their 

160 # own distinct doc pages, at least). 

161 if not (self.isWriteable() or for_temp_tables): 

162 # PostgreSQL permits writing to temporary tables inside 

163 # read-only transactions, but it doesn't permit creating 

164 # them. 

165 with closing(connection.connection.cursor()) as cursor: 

166 cursor.execute("SET TRANSACTION READ ONLY") 

167 cursor.execute("SET TIME ZONE 0") 

168 else: 

169 with closing(connection.connection.cursor()) as cursor: 

170 # Make timestamps UTC, because we didn't use TIMESTAMPZ 

171 # for the column type. When we can tolerate a schema 

172 # change, we should change that type and remove this 

173 # line. 

174 cursor.execute("SET TIME ZONE 0") 

175 yield is_new, connection 

176 

177 @contextmanager 

178 def temporary_table( 

179 self, spec: ddl.TableSpec, name: str | None = None 

180 ) -> Iterator[sqlalchemy.schema.Table]: 

181 # Docstring inherited. 

182 with self.transaction(for_temp_tables=True), super().temporary_table(spec, name) as table: 

183 yield table 

184 

185 def _lockTables( 

186 self, connection: sqlalchemy.engine.Connection, tables: Iterable[sqlalchemy.schema.Table] = () 

187 ) -> None: 

188 # Docstring inherited. 

189 for table in tables: 

190 connection.execute(sqlalchemy.text(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE")) 

191 

192 def isWriteable(self) -> bool: 

193 return self._writeable 

194 

195 def __str__(self) -> str: 

196 return f"PostgreSQL@{self.dbname}:{self.namespace}" 

197 

198 def shrinkDatabaseEntityName(self, original: str) -> str: 

199 return self._shrinker.shrink(original) 

200 

201 def expandDatabaseEntityName(self, shrunk: str) -> str: 

202 return self._shrinker.expand(shrunk) 

203 

204 def _convertExclusionConstraintSpec( 

205 self, 

206 table: str, 

207 spec: tuple[str | type[TimespanDatabaseRepresentation], ...], 

208 metadata: sqlalchemy.MetaData, 

209 ) -> sqlalchemy.schema.Constraint: 

210 # Docstring inherited. 

211 args: list[tuple[sqlalchemy.schema.Column, str]] = [] 

212 names = ["excl"] 

213 for item in spec: 

214 if isinstance(item, str): 

215 args.append((sqlalchemy.schema.Column(item), "=")) 

216 names.append(item) 

217 elif issubclass(item, TimespanDatabaseRepresentation): 

218 assert item is self.getTimespanRepresentation() 

219 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&")) 

220 names.append(TimespanDatabaseRepresentation.NAME) 

221 return sqlalchemy.dialects.postgresql.ExcludeConstraint( 

222 *args, 

223 name=self.shrinkDatabaseEntityName("_".join(names)), 

224 ) 

225 

226 def _make_temporary_table( 

227 self, 

228 connection: sqlalchemy.engine.Connection, 

229 spec: ddl.TableSpec, 

230 name: str | None = None, 

231 **kwargs: Any, 

232 ) -> sqlalchemy.schema.Table: 

233 # Docstring inherited 

234 # Adding ON COMMIT DROP here is really quite defensive: we already 

235 # manually drop the table at the end of the temporary_table context 

236 # manager, and that will usually happen first. But this will guarantee 

237 # that we drop the table at the end of the transaction even if the 

238 # connection lasts longer, and that's good citizenship when connections 

239 # may be multiplexed by e.g. pgbouncer. 

240 return super()._make_temporary_table(connection, spec, name, postgresql_on_commit="DROP", **kwargs) 

241 

242 @classmethod 

243 def getTimespanRepresentation(cls) -> type[TimespanDatabaseRepresentation]: 

244 # Docstring inherited. 

245 return _RangeTimespanRepresentation 

246 

247 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None: 

248 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.") 

249 if not rows: 

250 return 

251 # This uses special support for UPSERT in PostgreSQL backend: 

252 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert 

253 query = sqlalchemy.dialects.postgresql.dml.insert(table) 

254 # In the SET clause assign all columns using special `excluded` 

255 # pseudo-table. If some column in the table does not appear in the 

256 # INSERT list this will set it to NULL. 

257 excluded = query.excluded 

258 data = { 

259 column.name: getattr(excluded, column.name) 

260 for column in table.columns 

261 if column.name not in table.primary_key 

262 } 

263 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data) 

264 with self._transaction() as (_, connection): 

265 connection.execute(query, rows) 

266 

267 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict, primary_key_only: bool = False) -> int: 

268 # Docstring inherited. 

269 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.") 

270 if not rows: 

271 return 0 

272 # Like `replace`, this uses UPSERT. 

273 base_insert = sqlalchemy.dialects.postgresql.dml.insert(table) 

274 if primary_key_only: 

275 query = base_insert.on_conflict_do_nothing(constraint=table.primary_key) 

276 else: 

277 query = base_insert.on_conflict_do_nothing() 

278 with self._transaction() as (_, connection): 

279 return connection.execute(query, rows).rowcount 

280 

281 def constant_rows( 

282 self, 

283 fields: NamedValueAbstractSet[ddl.FieldSpec], 

284 *rows: dict, 

285 name: str | None = None, 

286 ) -> sqlalchemy.sql.FromClause: 

287 # Docstring inherited. 

288 return super().constant_rows(fields, *rows, name=name) 

289 

290 

291class _RangeTimespanType(sqlalchemy.TypeDecorator): 

292 """A single-column `Timespan` representation usable only with 

293 PostgreSQL. 

294 

295 This type should be able to take advantage of PostgreSQL's built-in 

296 range operators, and the indexing and EXCLUSION table constraints built 

297 off of them. 

298 """ 

299 

300 impl = sqlalchemy.dialects.postgresql.INT8RANGE 

301 

302 cache_ok = True 

303 

304 def process_bind_param( 

305 self, value: Timespan | None, dialect: sqlalchemy.engine.Dialect 

306 ) -> psycopg2.extras.NumericRange | None: 

307 if value is None: 

308 return None 

309 if not isinstance(value, Timespan): 

310 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.") 

311 if value.isEmpty(): 

312 return psycopg2.extras.NumericRange(empty=True) 

313 else: 

314 converter = time_utils.TimeConverter() 

315 assert value._nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__." 

316 assert value._nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__." 

317 lower = None if value._nsec[0] == converter.min_nsec else value._nsec[0] 

318 upper = None if value._nsec[1] == converter.max_nsec else value._nsec[1] 

319 return psycopg2.extras.NumericRange(lower=lower, upper=upper) 

320 

321 def process_result_value( 

322 self, value: psycopg2.extras.NumericRange | None, dialect: sqlalchemy.engine.Dialect 

323 ) -> Timespan | None: 

324 if value is None: 

325 return None 

326 if value.isempty: 

327 return Timespan.makeEmpty() 

328 converter = time_utils.TimeConverter() 

329 begin_nsec = converter.min_nsec if value.lower is None else value.lower 

330 end_nsec = converter.max_nsec if value.upper is None else value.upper 

331 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec)) 

332 

333 

334class _RangeTimespanRepresentation(TimespanDatabaseRepresentation): 

335 """An implementation of `TimespanDatabaseRepresentation` that uses 

336 `_RangeTimespanType` to store a timespan in a single 

337 PostgreSQL-specific field. 

338 

339 Parameters 

340 ---------- 

341 column : `sqlalchemy.sql.ColumnElement` 

342 SQLAlchemy object representing the column. 

343 """ 

344 

345 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str): 

346 self.column = column 

347 self._name = name 

348 

349 __slots__ = ("column", "_name") 

350 

351 @classmethod 

352 def makeFieldSpecs( 

353 cls, nullable: bool, name: str | None = None, **kwargs: Any 

354 ) -> tuple[ddl.FieldSpec, ...]: 

355 # Docstring inherited. 

356 if name is None: 

357 name = cls.NAME 

358 return ( 

359 ddl.FieldSpec( 

360 name, 

361 dtype=_RangeTimespanType, 

362 nullable=nullable, 

363 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")), 

364 **kwargs, 

365 ), 

366 ) 

367 

368 @classmethod 

369 def getFieldNames(cls, name: str | None = None) -> tuple[str, ...]: 

370 # Docstring inherited. 

371 if name is None: 

372 name = cls.NAME 

373 return (name,) 

374 

375 @classmethod 

376 def update( 

377 cls, extent: Timespan | None, name: str | None = None, result: dict[str, Any] | None = None 

378 ) -> dict[str, Any]: 

379 # Docstring inherited. 

380 if name is None: 

381 name = cls.NAME 

382 if result is None: 

383 result = {} 

384 result[name] = extent 

385 return result 

386 

387 @classmethod 

388 def extract(cls, mapping: Mapping[str, Any], name: str | None = None) -> Timespan | None: 

389 # Docstring inherited. 

390 if name is None: 

391 name = cls.NAME 

392 return mapping[name] 

393 

394 @classmethod 

395 def fromLiteral(cls, timespan: Timespan | None) -> _RangeTimespanRepresentation: 

396 # Docstring inherited. 

397 if timespan is None: 

398 return cls(column=sqlalchemy.sql.null(), name=cls.NAME) 

399 return cls( 

400 column=sqlalchemy.sql.cast( 

401 sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType), type_=_RangeTimespanType 

402 ), 

403 name=cls.NAME, 

404 ) 

405 

406 @classmethod 

407 def from_columns( 

408 cls, columns: sqlalchemy.sql.ColumnCollection, name: str | None = None 

409 ) -> _RangeTimespanRepresentation: 

410 # Docstring inherited. 

411 if name is None: 

412 name = cls.NAME 

413 return cls(columns[name], name) 

414 

415 @property 

416 def name(self) -> str: 

417 # Docstring inherited. 

418 return self._name 

419 

420 def isNull(self) -> sqlalchemy.sql.ColumnElement: 

421 # Docstring inherited. 

422 return self.column.is_(None) 

423 

424 def isEmpty(self) -> sqlalchemy.sql.ColumnElement: 

425 # Docstring inherited 

426 return sqlalchemy.sql.func.isempty(self.column) 

427 

428 def __lt__( 

429 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement 

430 ) -> sqlalchemy.sql.ColumnElement: 

431 # Docstring inherited. 

432 if isinstance(other, sqlalchemy.sql.ColumnElement): 

433 return sqlalchemy.sql.and_( 

434 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)), 

435 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

436 sqlalchemy.sql.func.upper(self.column) <= other, 

437 ) 

438 else: 

439 return self.column << other.column 

440 

441 def __gt__( 

442 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement 

443 ) -> sqlalchemy.sql.ColumnElement: 

444 # Docstring inherited. 

445 if isinstance(other, sqlalchemy.sql.ColumnElement): 

446 return sqlalchemy.sql.and_( 

447 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)), 

448 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

449 sqlalchemy.sql.func.lower(self.column) > other, 

450 ) 

451 else: 

452 return self.column >> other.column 

453 

454 def overlaps( 

455 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement 

456 ) -> sqlalchemy.sql.ColumnElement: 

457 # Docstring inherited. 

458 if not isinstance(other, _RangeTimespanRepresentation): 

459 return self.contains(other) 

460 return self.column.overlaps(other.column) 

461 

462 def contains( 

463 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement 

464 ) -> sqlalchemy.sql.ColumnElement: 

465 # Docstring inherited 

466 if isinstance(other, _RangeTimespanRepresentation): 

467 return self.column.contains(other.column) 

468 else: 

469 return self.column.contains(other) 

470 

471 def lower(self) -> sqlalchemy.sql.ColumnElement: 

472 # Docstring inherited. 

473 return sqlalchemy.sql.functions.coalesce( 

474 sqlalchemy.sql.func.lower(self.column), sqlalchemy.sql.literal(0) 

475 ) 

476 

477 def upper(self) -> sqlalchemy.sql.ColumnElement: 

478 # Docstring inherited. 

479 return sqlalchemy.sql.functions.coalesce( 

480 sqlalchemy.sql.func.upper(self.column), sqlalchemy.sql.literal(0) 

481 ) 

482 

483 def flatten(self, name: str | None = None) -> tuple[sqlalchemy.sql.ColumnElement]: 

484 # Docstring inherited. 

485 if name is None: 

486 return (self.column,) 

487 else: 

488 return (self.column.label(name),)