Coverage for python/lsst/daf/butler/registry/databases/postgresql.py: 30%

198 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-14 19:20 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ["PostgresqlDatabase"] 

24 

25from collections.abc import Iterable, Iterator, Mapping 

26from contextlib import closing, contextmanager 

27from typing import Any 

28 

29import psycopg2 

30import sqlalchemy 

31import sqlalchemy.dialects.postgresql 

32from sqlalchemy import sql 

33 

34from ...core import Timespan, TimespanDatabaseRepresentation, ddl, time_utils 

35from ...core.named import NamedValueAbstractSet 

36from ..interfaces import Database 

37from ..nameShrinker import NameShrinker 

38 

39 

40class PostgresqlDatabase(Database): 

41 """An implementation of the `Database` interface for PostgreSQL. 

42 

43 Parameters 

44 ---------- 

45 connection : `sqlalchemy.engine.Connection` 

46 An existing connection created by a previous call to `connect`. 

47 origin : `int` 

48 An integer ID that should be used as the default for any datasets, 

49 quanta, or other entities that use a (autoincrement, origin) compound 

50 primary key. 

51 namespace : `str`, optional 

52 The namespace (schema) this database is associated with. If `None`, 

53 the default schema for the connection is used (which may be `None`). 

54 writeable : `bool`, optional 

55 If `True`, allow write operations on the database, including 

56 ``CREATE TABLE``. 

57 

58 Notes 

59 ----- 

60 This currently requires the psycopg2 driver to be used as the backend for 

61 SQLAlchemy. Running the tests for this class requires the 

62 ``testing.postgresql`` be installed, which we assume indicates that a 

63 PostgreSQL server is installed and can be run locally in userspace. 

64 

65 Some functionality provided by this class (and used by `Registry`) requires 

66 the ``btree_gist`` PostgreSQL server extension to be installed an enabled 

67 on the database being connected to; this is checked at connection time. 

68 """ 

69 

70 def __init__( 

71 self, 

72 *, 

73 engine: sqlalchemy.engine.Engine, 

74 origin: int, 

75 namespace: str | None = None, 

76 writeable: bool = True, 

77 ): 

78 super().__init__(origin=origin, engine=engine, namespace=namespace) 

79 with engine.connect() as connection: 

80 # `Any` to make mypy ignore the line below, can't use type: ignore 

81 dbapi: Any = connection.connection 

82 try: 

83 dsn = dbapi.get_dsn_parameters() 

84 except (AttributeError, KeyError) as err: 

85 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err 

86 if namespace is None: 

87 query = sql.select(sql.func.current_schema()) 

88 namespace = connection.execute(query).scalar() 

89 query_text = "SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';" 

90 if not connection.execute(sqlalchemy.text(query_text)).scalar(): 

91 raise RuntimeError( 

92 "The Butler PostgreSQL backend requires the btree_gist extension. " 

93 "As extensions are enabled per-database, this may require an administrator to run " 

94 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is " 

95 " initialized." 

96 ) 

97 self.namespace = namespace 

98 self.dbname = dsn.get("dbname") 

99 self._writeable = writeable 

100 self._shrinker = NameShrinker(self.dialect.max_identifier_length) 

101 

102 @classmethod 

103 def makeEngine( 

104 cls, uri: str | sqlalchemy.engine.URL, *, writeable: bool = True 

105 ) -> sqlalchemy.engine.Engine: 

106 return sqlalchemy.engine.create_engine(uri, pool_size=1) 

107 

108 @classmethod 

109 def fromEngine( 

110 cls, 

111 engine: sqlalchemy.engine.Engine, 

112 *, 

113 origin: int, 

114 namespace: str | None = None, 

115 writeable: bool = True, 

116 ) -> Database: 

117 return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable) 

118 

119 @contextmanager 

120 def _transaction( 

121 self, 

122 *, 

123 interrupting: bool = False, 

124 savepoint: bool = False, 

125 lock: Iterable[sqlalchemy.schema.Table] = (), 

126 for_temp_tables: bool = False, 

127 ) -> Iterator[tuple[bool, sqlalchemy.engine.Connection]]: 

128 with super()._transaction(interrupting=interrupting, savepoint=savepoint, lock=lock) as ( 

129 is_new, 

130 connection, 

131 ): 

132 if is_new: 

133 # pgbouncer with transaction-level pooling (which we aim to 

134 # support) says that SET cannot be used, except for a list of 

135 # "Startup parameters" that includes "timezone" (see 

136 # https://www.pgbouncer.org/features.html#fnref:0). But I 

137 # don't see "timezone" in PostgreSQL's list of parameters 

138 # passed when creating a new connection 

139 # (https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS). 

140 # Given that the pgbouncer docs say, "PgBouncer detects their 

141 # changes and so it can guarantee they remain consistent for 

142 # the client", I assume we can use "SET TIMESPAN" and pgbouncer 

143 # will take care of clients that share connections being set 

144 # consistently. And if that assumption is wrong, we should 

145 # still probably be okay, since all clients should be Butler 

146 # clients, and they'll all be setting the same thing. 

147 # 

148 # The "SET TRANSACTION READ ONLY" should also be safe, because 

149 # it only ever acts on the current transaction; I think it's 

150 # not included in pgbouncer's declaration that SET is 

151 # incompatible with transaction-level pooling because 

152 # PostgreSQL actually considers SET TRANSACTION to be a 

153 # fundamentally different statement from SET (they have their 

154 # own distinct doc pages, at least). 

155 if not (self.isWriteable() or for_temp_tables): 

156 # PostgreSQL permits writing to temporary tables inside 

157 # read-only transactions, but it doesn't permit creating 

158 # them. 

159 with closing(connection.connection.cursor()) as cursor: 

160 cursor.execute("SET TRANSACTION READ ONLY") 

161 cursor.execute("SET TIME ZONE 0") 

162 else: 

163 with closing(connection.connection.cursor()) as cursor: 

164 # Make timestamps UTC, because we didn't use TIMESTAMPZ 

165 # for the column type. When we can tolerate a schema 

166 # change, we should change that type and remove this 

167 # line. 

168 cursor.execute("SET TIME ZONE 0") 

169 yield is_new, connection 

170 

171 @contextmanager 

172 def temporary_table( 

173 self, spec: ddl.TableSpec, name: str | None = None 

174 ) -> Iterator[sqlalchemy.schema.Table]: 

175 # Docstring inherited. 

176 with self.transaction(for_temp_tables=True): 

177 with super().temporary_table(spec, name) as table: 

178 yield table 

179 

180 def _lockTables( 

181 self, connection: sqlalchemy.engine.Connection, tables: Iterable[sqlalchemy.schema.Table] = () 

182 ) -> None: 

183 # Docstring inherited. 

184 for table in tables: 

185 connection.execute(sqlalchemy.text(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE")) 

186 

187 def isWriteable(self) -> bool: 

188 return self._writeable 

189 

190 def __str__(self) -> str: 

191 return f"PostgreSQL@{self.dbname}:{self.namespace}" 

192 

193 def shrinkDatabaseEntityName(self, original: str) -> str: 

194 return self._shrinker.shrink(original) 

195 

196 def expandDatabaseEntityName(self, shrunk: str) -> str: 

197 return self._shrinker.expand(shrunk) 

198 

199 def _convertExclusionConstraintSpec( 

200 self, 

201 table: str, 

202 spec: tuple[str | type[TimespanDatabaseRepresentation], ...], 

203 metadata: sqlalchemy.MetaData, 

204 ) -> sqlalchemy.schema.Constraint: 

205 # Docstring inherited. 

206 args: list[tuple[sqlalchemy.schema.Column, str]] = [] 

207 names = ["excl"] 

208 for item in spec: 

209 if isinstance(item, str): 

210 args.append((sqlalchemy.schema.Column(item), "=")) 

211 names.append(item) 

212 elif issubclass(item, TimespanDatabaseRepresentation): 

213 assert item is self.getTimespanRepresentation() 

214 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&")) 

215 names.append(TimespanDatabaseRepresentation.NAME) 

216 return sqlalchemy.dialects.postgresql.ExcludeConstraint( 

217 *args, 

218 name=self.shrinkDatabaseEntityName("_".join(names)), 

219 ) 

220 

221 def _make_temporary_table( 

222 self, 

223 connection: sqlalchemy.engine.Connection, 

224 spec: ddl.TableSpec, 

225 name: str | None = None, 

226 **kwargs: Any, 

227 ) -> sqlalchemy.schema.Table: 

228 # Docstring inherited 

229 # Adding ON COMMIT DROP here is really quite defensive: we already 

230 # manually drop the table at the end of the temporary_table context 

231 # manager, and that will usually happen first. But this will guarantee 

232 # that we drop the table at the end of the transaction even if the 

233 # connection lasts longer, and that's good citizenship when connections 

234 # may be multiplexed by e.g. pgbouncer. 

235 return super()._make_temporary_table(connection, spec, name, postgresql_on_commit="DROP", **kwargs) 

236 

237 @classmethod 

238 def getTimespanRepresentation(cls) -> type[TimespanDatabaseRepresentation]: 

239 # Docstring inherited. 

240 return _RangeTimespanRepresentation 

241 

242 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None: 

243 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.") 

244 if not rows: 

245 return 

246 # This uses special support for UPSERT in PostgreSQL backend: 

247 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert 

248 query = sqlalchemy.dialects.postgresql.dml.insert(table) 

249 # In the SET clause assign all columns using special `excluded` 

250 # pseudo-table. If some column in the table does not appear in the 

251 # INSERT list this will set it to NULL. 

252 excluded = query.excluded 

253 data = { 

254 column.name: getattr(excluded, column.name) 

255 for column in table.columns 

256 if column.name not in table.primary_key 

257 } 

258 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data) 

259 with self._transaction() as (_, connection): 

260 connection.execute(query, rows) 

261 

262 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict, primary_key_only: bool = False) -> int: 

263 # Docstring inherited. 

264 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.") 

265 if not rows: 

266 return 0 

267 # Like `replace`, this uses UPSERT. 

268 base_insert = sqlalchemy.dialects.postgresql.dml.insert(table) 

269 if primary_key_only: 

270 query = base_insert.on_conflict_do_nothing(constraint=table.primary_key) 

271 else: 

272 query = base_insert.on_conflict_do_nothing() 

273 with self._transaction() as (_, connection): 

274 return connection.execute(query, rows).rowcount 

275 

276 def constant_rows( 

277 self, 

278 fields: NamedValueAbstractSet[ddl.FieldSpec], 

279 *rows: dict, 

280 name: str | None = None, 

281 ) -> sqlalchemy.sql.FromClause: 

282 # Docstring inherited. 

283 return super().constant_rows(fields, *rows, name=name) 

284 

285 

286class _RangeTimespanType(sqlalchemy.TypeDecorator): 

287 """A single-column `Timespan` representation usable only with 

288 PostgreSQL. 

289 

290 This type should be able to take advantage of PostgreSQL's built-in 

291 range operators, and the indexing and EXCLUSION table constraints built 

292 off of them. 

293 """ 

294 

295 impl = sqlalchemy.dialects.postgresql.INT8RANGE 

296 

297 cache_ok = True 

298 

299 def process_bind_param( 

300 self, value: Timespan | None, dialect: sqlalchemy.engine.Dialect 

301 ) -> psycopg2.extras.NumericRange | None: 

302 if value is None: 

303 return None 

304 if not isinstance(value, Timespan): 

305 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.") 

306 if value.isEmpty(): 

307 return psycopg2.extras.NumericRange(empty=True) 

308 else: 

309 converter = time_utils.TimeConverter() 

310 assert value._nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__." 

311 assert value._nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__." 

312 lower = None if value._nsec[0] == converter.min_nsec else value._nsec[0] 

313 upper = None if value._nsec[1] == converter.max_nsec else value._nsec[1] 

314 return psycopg2.extras.NumericRange(lower=lower, upper=upper) 

315 

316 def process_result_value( 

317 self, value: psycopg2.extras.NumericRange | None, dialect: sqlalchemy.engine.Dialect 

318 ) -> Timespan | None: 

319 if value is None: 

320 return None 

321 if value.isempty: 

322 return Timespan.makeEmpty() 

323 converter = time_utils.TimeConverter() 

324 begin_nsec = converter.min_nsec if value.lower is None else value.lower 

325 end_nsec = converter.max_nsec if value.upper is None else value.upper 

326 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec)) 

327 

328 

329class _RangeTimespanRepresentation(TimespanDatabaseRepresentation): 

330 """An implementation of `TimespanDatabaseRepresentation` that uses 

331 `_RangeTimespanType` to store a timespan in a single 

332 PostgreSQL-specific field. 

333 

334 Parameters 

335 ---------- 

336 column : `sqlalchemy.sql.ColumnElement` 

337 SQLAlchemy object representing the column. 

338 """ 

339 

340 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str): 

341 self.column = column 

342 self._name = name 

343 

344 __slots__ = ("column", "_name") 

345 

346 @classmethod 

347 def makeFieldSpecs( 

348 cls, nullable: bool, name: str | None = None, **kwargs: Any 

349 ) -> tuple[ddl.FieldSpec, ...]: 

350 # Docstring inherited. 

351 if name is None: 

352 name = cls.NAME 

353 return ( 

354 ddl.FieldSpec( 

355 name, 

356 dtype=_RangeTimespanType, 

357 nullable=nullable, 

358 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")), 

359 **kwargs, 

360 ), 

361 ) 

362 

363 @classmethod 

364 def getFieldNames(cls, name: str | None = None) -> tuple[str, ...]: 

365 # Docstring inherited. 

366 if name is None: 

367 name = cls.NAME 

368 return (name,) 

369 

370 @classmethod 

371 def update( 

372 cls, extent: Timespan | None, name: str | None = None, result: dict[str, Any] | None = None 

373 ) -> dict[str, Any]: 

374 # Docstring inherited. 

375 if name is None: 

376 name = cls.NAME 

377 if result is None: 

378 result = {} 

379 result[name] = extent 

380 return result 

381 

382 @classmethod 

383 def extract(cls, mapping: Mapping[str, Any], name: str | None = None) -> Timespan | None: 

384 # Docstring inherited. 

385 if name is None: 

386 name = cls.NAME 

387 return mapping[name] 

388 

389 @classmethod 

390 def fromLiteral(cls, timespan: Timespan | None) -> _RangeTimespanRepresentation: 

391 # Docstring inherited. 

392 if timespan is None: 

393 return cls(column=sqlalchemy.sql.null(), name=cls.NAME) 

394 return cls( 

395 column=sqlalchemy.sql.cast( 

396 sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType), type_=_RangeTimespanType 

397 ), 

398 name=cls.NAME, 

399 ) 

400 

401 @classmethod 

402 def from_columns( 

403 cls, columns: sqlalchemy.sql.ColumnCollection, name: str | None = None 

404 ) -> _RangeTimespanRepresentation: 

405 # Docstring inherited. 

406 if name is None: 

407 name = cls.NAME 

408 return cls(columns[name], name) 

409 

410 @property 

411 def name(self) -> str: 

412 # Docstring inherited. 

413 return self._name 

414 

415 def isNull(self) -> sqlalchemy.sql.ColumnElement: 

416 # Docstring inherited. 

417 return self.column.is_(None) 

418 

419 def isEmpty(self) -> sqlalchemy.sql.ColumnElement: 

420 # Docstring inherited 

421 return sqlalchemy.sql.func.isempty(self.column) 

422 

423 def __lt__( 

424 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement 

425 ) -> sqlalchemy.sql.ColumnElement: 

426 # Docstring inherited. 

427 if isinstance(other, sqlalchemy.sql.ColumnElement): 

428 return sqlalchemy.sql.and_( 

429 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)), 

430 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

431 sqlalchemy.sql.func.upper(self.column) <= other, 

432 ) 

433 else: 

434 return self.column << other.column 

435 

436 def __gt__( 

437 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement 

438 ) -> sqlalchemy.sql.ColumnElement: 

439 # Docstring inherited. 

440 if isinstance(other, sqlalchemy.sql.ColumnElement): 

441 return sqlalchemy.sql.and_( 

442 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)), 

443 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

444 sqlalchemy.sql.func.lower(self.column) > other, 

445 ) 

446 else: 

447 return self.column >> other.column 

448 

449 def overlaps( 

450 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement 

451 ) -> sqlalchemy.sql.ColumnElement: 

452 # Docstring inherited. 

453 if not isinstance(other, _RangeTimespanRepresentation): 

454 return self.contains(other) 

455 return self.column.overlaps(other.column) 

456 

457 def contains( 

458 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement 

459 ) -> sqlalchemy.sql.ColumnElement: 

460 # Docstring inherited 

461 if isinstance(other, _RangeTimespanRepresentation): 

462 return self.column.contains(other.column) 

463 else: 

464 return self.column.contains(other) 

465 

466 def lower(self) -> sqlalchemy.sql.ColumnElement: 

467 # Docstring inherited. 

468 return sqlalchemy.sql.functions.coalesce( 

469 sqlalchemy.sql.func.lower(self.column), sqlalchemy.sql.literal(0) 

470 ) 

471 

472 def upper(self) -> sqlalchemy.sql.ColumnElement: 

473 # Docstring inherited. 

474 return sqlalchemy.sql.functions.coalesce( 

475 sqlalchemy.sql.func.upper(self.column), sqlalchemy.sql.literal(0) 

476 ) 

477 

478 def flatten(self, name: str | None = None) -> tuple[sqlalchemy.sql.ColumnElement]: 

479 # Docstring inherited. 

480 if name is None: 

481 return (self.column,) 

482 else: 

483 return (self.column.label(name),)