Coverage for python/lsst/daf/butler/registry/databases/postgresql.py: 26%

205 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-02 18:18 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ["PostgresqlDatabase"] 

24 

25from contextlib import closing, contextmanager 

26from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, Type, Union 

27 

28import psycopg2 

29import sqlalchemy 

30import sqlalchemy.dialects.postgresql 

31 

32from ...core import Timespan, TimespanDatabaseRepresentation, ddl, time_utils 

33from ..interfaces import Database 

34from ..nameShrinker import NameShrinker 

35 

36 

37class PostgresqlDatabase(Database): 

38 """An implementation of the `Database` interface for PostgreSQL. 

39 

40 Parameters 

41 ---------- 

42 connection : `sqlalchemy.engine.Connection` 

43 An existing connection created by a previous call to `connect`. 

44 origin : `int` 

45 An integer ID that should be used as the default for any datasets, 

46 quanta, or other entities that use a (autoincrement, origin) compound 

47 primary key. 

48 namespace : `str`, optional 

49 The namespace (schema) this database is associated with. If `None`, 

50 the default schema for the connection is used (which may be `None`). 

51 writeable : `bool`, optional 

52 If `True`, allow write operations on the database, including 

53 ``CREATE TABLE``. 

54 

55 Notes 

56 ----- 

57 This currently requires the psycopg2 driver to be used as the backend for 

58 SQLAlchemy. Running the tests for this class requires the 

59 ``testing.postgresql`` be installed, which we assume indicates that a 

60 PostgreSQL server is installed and can be run locally in userspace. 

61 

62 Some functionality provided by this class (and used by `Registry`) requires 

63 the ``btree_gist`` PostgreSQL server extension to be installed an enabled 

64 on the database being connected to; this is checked at connection time. 

65 """ 

66 

67 def __init__( 

68 self, 

69 *, 

70 engine: sqlalchemy.engine.Engine, 

71 origin: int, 

72 namespace: Optional[str] = None, 

73 writeable: bool = True, 

74 ): 

75 super().__init__(origin=origin, engine=engine, namespace=namespace) 

76 with engine.connect() as connection: 

77 dbapi = connection.connection 

78 try: 

79 dsn = dbapi.get_dsn_parameters() 

80 except (AttributeError, KeyError) as err: 

81 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err 

82 if namespace is None: 

83 namespace = connection.execute("SELECT current_schema();").scalar() 

84 query = "SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';" 

85 if not connection.execute(sqlalchemy.text(query)).scalar(): 

86 raise RuntimeError( 

87 "The Butler PostgreSQL backend requires the btree_gist extension. " 

88 "As extensions are enabled per-database, this may require an administrator to run " 

89 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is " 

90 " initialized." 

91 ) 

92 self.namespace = namespace 

93 self.dbname = dsn.get("dbname") 

94 self._writeable = writeable 

95 self._shrinker = NameShrinker(connection.engine.dialect.max_identifier_length) 

96 

97 @classmethod 

98 def makeEngine(cls, uri: str, *, writeable: bool = True) -> sqlalchemy.engine.Engine: 

99 return sqlalchemy.engine.create_engine(uri, pool_size=1) 

100 

101 @classmethod 

102 def fromEngine( 

103 cls, 

104 engine: sqlalchemy.engine.Engine, 

105 *, 

106 origin: int, 

107 namespace: Optional[str] = None, 

108 writeable: bool = True, 

109 ) -> Database: 

110 return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable) 

111 

112 @contextmanager 

113 def _transaction( 

114 self, 

115 *, 

116 interrupting: bool = False, 

117 savepoint: bool = False, 

118 lock: Iterable[sqlalchemy.schema.Table] = (), 

119 for_temp_tables: bool = False, 

120 ) -> Iterator[tuple[bool, sqlalchemy.engine.Connection]]: 

121 with super()._transaction(interrupting=interrupting, savepoint=savepoint, lock=lock) as ( 

122 is_new, 

123 connection, 

124 ): 

125 if is_new: 

126 # pgbouncer with transaction-level pooling (which we aim to 

127 # support) says that SET cannot be used, except for a list of 

128 # "Startup parameters" that includes "timezone" (see 

129 # https://www.pgbouncer.org/features.html#fnref:0). But I 

130 # don't see "timezone" in PostgreSQL's list of parameters 

131 # passed when creating a new connection 

132 # (https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS). 

133 # Given that the pgbouncer docs say, "PgBouncer detects their 

134 # changes and so it can guarantee they remain consistent for 

135 # the client", I assume we can use "SET TIMESPAN" and pgbouncer 

136 # will take care of clients that share connections being set 

137 # consistently. And if that assumption is wrong, we should 

138 # still probably be okay, since all clients should be Butler 

139 # clients, and they'll all be setting the same thing. 

140 # 

141 # The "SET TRANSACTION READ ONLY" should also be safe, because 

142 # it only ever acts on the current transaction; I think it's 

143 # not included in pgbouncer's declaration that SET is 

144 # incompatible with transaction-level pooling because 

145 # PostgreSQL actually considers SET TRANSACTION to be a 

146 # fundamentally different statement from SET (they have their 

147 # own distinct doc pages, at least). 

148 if not (self.isWriteable() or for_temp_tables): 

149 # PostgreSQL permits writing to temporary tables inside 

150 # read-only transactions, but it doesn't permit creating 

151 # them. 

152 with closing(connection.connection.cursor()) as cursor: 

153 cursor.execute("SET TRANSACTION READ ONLY") 

154 cursor.execute("SET TIME ZONE 0") 

155 else: 

156 with closing(connection.connection.cursor()) as cursor: 

157 # Make timestamps UTC, because we didn't use TIMESTAMPZ 

158 # for the column type. When we can tolerate a schema 

159 # change, we should change that type and remove this 

160 # line. 

161 cursor.execute("SET TIME ZONE 0") 

162 yield is_new, connection 

163 

164 @contextmanager 

165 def temporary_table( 

166 self, spec: ddl.TableSpec, name: Optional[str] = None 

167 ) -> Iterator[sqlalchemy.schema.Table]: 

168 # Docstring inherited. 

169 with self.transaction(for_temp_tables=True): 

170 with super().temporary_table(spec, name) as table: 

171 yield table 

172 

173 def _lockTables( 

174 self, connection: sqlalchemy.engine.Connection, tables: Iterable[sqlalchemy.schema.Table] = () 

175 ) -> None: 

176 # Docstring inherited. 

177 for table in tables: 

178 connection.execute(sqlalchemy.text(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE")) 

179 

180 def isWriteable(self) -> bool: 

181 return self._writeable 

182 

183 def __str__(self) -> str: 

184 return f"PostgreSQL@{self.dbname}:{self.namespace}" 

185 

186 def shrinkDatabaseEntityName(self, original: str) -> str: 

187 return self._shrinker.shrink(original) 

188 

189 def expandDatabaseEntityName(self, shrunk: str) -> str: 

190 return self._shrinker.expand(shrunk) 

191 

192 def _convertExclusionConstraintSpec( 

193 self, 

194 table: str, 

195 spec: Tuple[Union[str, Type[TimespanDatabaseRepresentation]], ...], 

196 metadata: sqlalchemy.MetaData, 

197 ) -> sqlalchemy.schema.Constraint: 

198 # Docstring inherited. 

199 args = [] 

200 names = ["excl"] 

201 for item in spec: 

202 if isinstance(item, str): 

203 args.append((sqlalchemy.schema.Column(item), "=")) 

204 names.append(item) 

205 elif issubclass(item, TimespanDatabaseRepresentation): 

206 assert item is self.getTimespanRepresentation() 

207 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&")) 

208 names.append(TimespanDatabaseRepresentation.NAME) 

209 return sqlalchemy.dialects.postgresql.ExcludeConstraint( 

210 *args, 

211 name=self.shrinkDatabaseEntityName("_".join(names)), 

212 ) 

213 

214 def _make_temporary_table( 

215 self, 

216 connection: sqlalchemy.engine.Connection, 

217 spec: ddl.TableSpec, 

218 name: Optional[str] = None, 

219 **kwargs: Any, 

220 ) -> sqlalchemy.schema.Table: 

221 # Docstring inherited 

222 # Adding ON COMMIT DROP here is really quite defensive: we already 

223 # manually drop the table at the end of the temporary_table context 

224 # manager, and that will usually happen first. But this will guarantee 

225 # that we drop the table at the end of the transaction even if the 

226 # connection lasts longer, and that's good citizenship when connections 

227 # may be multiplexed by e.g. pgbouncer. 

228 return super()._make_temporary_table(connection, spec, name, postgresql_on_commit="DROP", **kwargs) 

229 

230 @classmethod 

231 def getTimespanRepresentation(cls) -> Type[TimespanDatabaseRepresentation]: 

232 # Docstring inherited. 

233 return _RangeTimespanRepresentation 

234 

235 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None: 

236 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.") 

237 if not rows: 

238 return 

239 # This uses special support for UPSERT in PostgreSQL backend: 

240 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert 

241 query = sqlalchemy.dialects.postgresql.dml.insert(table) 

242 # In the SET clause assign all columns using special `excluded` 

243 # pseudo-table. If some column in the table does not appear in the 

244 # INSERT list this will set it to NULL. 

245 excluded = query.excluded 

246 data = { 

247 column.name: getattr(excluded, column.name) 

248 for column in table.columns 

249 if column.name not in table.primary_key 

250 } 

251 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data) 

252 with self._transaction() as (_, connection): 

253 connection.execute(query, rows) 

254 

255 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict, primary_key_only: bool = False) -> int: 

256 # Docstring inherited. 

257 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.") 

258 if not rows: 

259 return 0 

260 # Like `replace`, this uses UPSERT. 

261 base_insert = sqlalchemy.dialects.postgresql.dml.insert(table) 

262 if primary_key_only: 

263 query = base_insert.on_conflict_do_nothing(constraint=table.primary_key) 

264 else: 

265 query = base_insert.on_conflict_do_nothing() 

266 with self._transaction() as (_, connection): 

267 return connection.execute(query, rows).rowcount 

268 

269 

270class _RangeTimespanType(sqlalchemy.TypeDecorator): 

271 """A single-column `Timespan` representation usable only with 

272 PostgreSQL. 

273 

274 This type should be able to take advantage of PostgreSQL's built-in 

275 range operators, and the indexing and EXCLUSION table constraints built 

276 off of them. 

277 """ 

278 

279 impl = sqlalchemy.dialects.postgresql.INT8RANGE 

280 

281 cache_ok = True 

282 

283 def process_bind_param( 

284 self, value: Optional[Timespan], dialect: sqlalchemy.engine.Dialect 

285 ) -> Optional[psycopg2.extras.NumericRange]: 

286 if value is None: 

287 return None 

288 if not isinstance(value, Timespan): 

289 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.") 

290 if value.isEmpty(): 

291 return psycopg2.extras.NumericRange(empty=True) 

292 else: 

293 converter = time_utils.TimeConverter() 

294 assert value._nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__." 

295 assert value._nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__." 

296 lower = None if value._nsec[0] == converter.min_nsec else value._nsec[0] 

297 upper = None if value._nsec[1] == converter.max_nsec else value._nsec[1] 

298 return psycopg2.extras.NumericRange(lower=lower, upper=upper) 

299 

300 def process_result_value( 

301 self, value: Optional[psycopg2.extras.NumericRange], dialect: sqlalchemy.engine.Dialect 

302 ) -> Optional[Timespan]: 

303 if value is None: 

304 return None 

305 if value.isempty: 

306 return Timespan.makeEmpty() 

307 converter = time_utils.TimeConverter() 

308 begin_nsec = converter.min_nsec if value.lower is None else value.lower 

309 end_nsec = converter.max_nsec if value.upper is None else value.upper 

310 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec)) 

311 

312 class comparator_factory(sqlalchemy.types.Concatenable.Comparator): # noqa: N801 

313 """Comparison operators for TimespanColumnRanges. 

314 

315 Notes 

316 ----- 

317 The existence of this nested class is a workaround for a bug 

318 submitted upstream as 

319 https://github.com/sqlalchemy/sqlalchemy/issues/5476 (now fixed on 

320 main, but not in the releases we currently use). The code is 

321 a limited copy of the operators in 

322 ``sqlalchemy.dialects.postgresql.ranges.RangeOperators``, but with 

323 ``is_comparison=True`` added to all calls. 

324 """ 

325 

326 def __ne__(self, other: Any) -> Any: 

327 "Boolean expression. Returns true if two ranges are not equal" 

328 if other is None: 

329 return super().__ne__(other) 

330 else: 

331 return self.expr.op("<>", is_comparison=True)(other) 

332 

333 def contains(self, other: Any, **kw: Any) -> Any: 

334 """Boolean expression. Returns true if the right hand operand, 

335 which can be an element or a range, is contained within the 

336 column. 

337 """ 

338 return self.expr.op("@>", is_comparison=True)(other) 

339 

340 def contained_by(self, other: Any) -> Any: 

341 """Boolean expression. Returns true if the column is contained 

342 within the right hand operand. 

343 """ 

344 return self.expr.op("<@", is_comparison=True)(other) 

345 

346 def overlaps(self, other: Any) -> Any: 

347 """Boolean expression. Returns true if the column overlaps 

348 (has points in common with) the right hand operand. 

349 """ 

350 return self.expr.op("&&", is_comparison=True)(other) 

351 

352 def strictly_left_of(self, other: Any) -> Any: 

353 """Boolean expression. Returns true if the column is strictly 

354 left of the right hand operand. 

355 """ 

356 return self.expr.op("<<", is_comparison=True)(other) 

357 

358 __lshift__ = strictly_left_of 

359 

360 def strictly_right_of(self, other: Any) -> Any: 

361 """Boolean expression. Returns true if the column is strictly 

362 right of the right hand operand. 

363 """ 

364 return self.expr.op(">>", is_comparison=True)(other) 

365 

366 __rshift__ = strictly_right_of 

367 

368 

369class _RangeTimespanRepresentation(TimespanDatabaseRepresentation): 

370 """An implementation of `TimespanDatabaseRepresentation` that uses 

371 `_RangeTimespanType` to store a timespan in a single 

372 PostgreSQL-specific field. 

373 

374 Parameters 

375 ---------- 

376 column : `sqlalchemy.sql.ColumnElement` 

377 SQLAlchemy object representing the column. 

378 """ 

379 

380 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str): 

381 self.column = column 

382 self._name = name 

383 

384 __slots__ = ("column", "_name") 

385 

386 @classmethod 

387 def makeFieldSpecs( 

388 cls, nullable: bool, name: Optional[str] = None, **kwargs: Any 

389 ) -> Tuple[ddl.FieldSpec, ...]: 

390 # Docstring inherited. 

391 if name is None: 

392 name = cls.NAME 

393 return ( 

394 ddl.FieldSpec( 

395 name, 

396 dtype=_RangeTimespanType, 

397 nullable=nullable, 

398 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")), 

399 **kwargs, 

400 ), 

401 ) 

402 

403 @classmethod 

404 def getFieldNames(cls, name: Optional[str] = None) -> Tuple[str, ...]: 

405 # Docstring inherited. 

406 if name is None: 

407 name = cls.NAME 

408 return (name,) 

409 

410 @classmethod 

411 def update( 

412 cls, extent: Optional[Timespan], name: Optional[str] = None, result: Optional[Dict[str, Any]] = None 

413 ) -> Dict[str, Any]: 

414 # Docstring inherited. 

415 if name is None: 

416 name = cls.NAME 

417 if result is None: 

418 result = {} 

419 result[name] = extent 

420 return result 

421 

422 @classmethod 

423 def extract(cls, mapping: Mapping[str, Any], name: Optional[str] = None) -> Optional[Timespan]: 

424 # Docstring inherited. 

425 if name is None: 

426 name = cls.NAME 

427 return mapping[name] 

428 

429 @classmethod 

430 def fromLiteral(cls, timespan: Timespan) -> _RangeTimespanRepresentation: 

431 # Docstring inherited. 

432 return cls( 

433 column=sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType), 

434 name=cls.NAME, 

435 ) 

436 

437 @classmethod 

438 def fromSelectable( 

439 cls, selectable: sqlalchemy.sql.FromClause, name: Optional[str] = None 

440 ) -> _RangeTimespanRepresentation: 

441 # Docstring inherited. 

442 if name is None: 

443 name = cls.NAME 

444 return cls(selectable.columns[name], name) 

445 

446 @property 

447 def name(self) -> str: 

448 # Docstring inherited. 

449 return self._name 

450 

451 def isNull(self) -> sqlalchemy.sql.ColumnElement: 

452 # Docstring inherited. 

453 return self.column.is_(None) 

454 

455 def isEmpty(self) -> sqlalchemy.sql.ColumnElement: 

456 # Docstring inherited 

457 return sqlalchemy.sql.func.isempty(self.column) 

458 

459 def __lt__( 

460 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

461 ) -> sqlalchemy.sql.ColumnElement: 

462 # Docstring inherited. 

463 if isinstance(other, sqlalchemy.sql.ColumnElement): 

464 return sqlalchemy.sql.and_( 

465 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)), 

466 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

467 sqlalchemy.sql.func.upper(self.column) <= other, 

468 ) 

469 else: 

470 return self.column << other.column 

471 

472 def __gt__( 

473 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

474 ) -> sqlalchemy.sql.ColumnElement: 

475 # Docstring inherited. 

476 if isinstance(other, sqlalchemy.sql.ColumnElement): 

477 return sqlalchemy.sql.and_( 

478 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)), 

479 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

480 sqlalchemy.sql.func.lower(self.column) > other, 

481 ) 

482 else: 

483 return self.column >> other.column 

484 

485 def overlaps(self, other: _RangeTimespanRepresentation) -> sqlalchemy.sql.ColumnElement: 

486 # Docstring inherited. 

487 return self.column.overlaps(other.column) 

488 

489 def contains( 

490 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

491 ) -> sqlalchemy.sql.ColumnElement: 

492 # Docstring inherited 

493 if isinstance(other, _RangeTimespanRepresentation): 

494 return self.column.contains(other.column) 

495 else: 

496 return self.column.contains(other) 

497 

498 def lower(self) -> sqlalchemy.sql.ColumnElement: 

499 # Docstring inherited. 

500 return sqlalchemy.sql.functions.coalesce( 

501 sqlalchemy.sql.func.lower(self.column), sqlalchemy.sql.literal(0) 

502 ) 

503 

504 def upper(self) -> sqlalchemy.sql.ColumnElement: 

505 # Docstring inherited. 

506 return sqlalchemy.sql.functions.coalesce( 

507 sqlalchemy.sql.func.upper(self.column), sqlalchemy.sql.literal(0) 

508 ) 

509 

510 def flatten(self, name: Optional[str] = None) -> Iterator[sqlalchemy.sql.ColumnElement]: 

511 # Docstring inherited. 

512 if name is None: 

513 yield self.column 

514 else: 

515 yield self.column.label(name)