Coverage for python/lsst/daf/butler/registry/databases/postgresql.py: 24%

197 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 02:10 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ["PostgresqlDatabase"] 

24 

25from contextlib import closing, contextmanager 

26from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, Type, Union 

27 

28import psycopg2 

29import sqlalchemy 

30import sqlalchemy.dialects.postgresql 

31from sqlalchemy import sql 

32 

33from ...core import Timespan, TimespanDatabaseRepresentation, ddl, time_utils 

34from ...core.named import NamedValueAbstractSet 

35from ..interfaces import Database 

36from ..nameShrinker import NameShrinker 

37 

38 

39class PostgresqlDatabase(Database): 

40 """An implementation of the `Database` interface for PostgreSQL. 

41 

42 Parameters 

43 ---------- 

44 connection : `sqlalchemy.engine.Connection` 

45 An existing connection created by a previous call to `connect`. 

46 origin : `int` 

47 An integer ID that should be used as the default for any datasets, 

48 quanta, or other entities that use a (autoincrement, origin) compound 

49 primary key. 

50 namespace : `str`, optional 

51 The namespace (schema) this database is associated with. If `None`, 

52 the default schema for the connection is used (which may be `None`). 

53 writeable : `bool`, optional 

54 If `True`, allow write operations on the database, including 

55 ``CREATE TABLE``. 

56 

57 Notes 

58 ----- 

59 This currently requires the psycopg2 driver to be used as the backend for 

60 SQLAlchemy. Running the tests for this class requires the 

61 ``testing.postgresql`` be installed, which we assume indicates that a 

62 PostgreSQL server is installed and can be run locally in userspace. 

63 

64 Some functionality provided by this class (and used by `Registry`) requires 

65 the ``btree_gist`` PostgreSQL server extension to be installed an enabled 

66 on the database being connected to; this is checked at connection time. 

67 """ 

68 

69 def __init__( 

70 self, 

71 *, 

72 engine: sqlalchemy.engine.Engine, 

73 origin: int, 

74 namespace: Optional[str] = None, 

75 writeable: bool = True, 

76 ): 

77 super().__init__(origin=origin, engine=engine, namespace=namespace) 

78 with engine.connect() as connection: 

79 # `Any` to make mypy ignore the line below, can't use type: ignore 

80 dbapi: Any = connection.connection 

81 try: 

82 dsn = dbapi.get_dsn_parameters() 

83 except (AttributeError, KeyError) as err: 

84 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err 

85 if namespace is None: 

86 query = sql.select(sql.func.current_schema()) 

87 namespace = connection.execute(query).scalar() 

88 query_text = "SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';" 

89 if not connection.execute(sqlalchemy.text(query_text)).scalar(): 

90 raise RuntimeError( 

91 "The Butler PostgreSQL backend requires the btree_gist extension. " 

92 "As extensions are enabled per-database, this may require an administrator to run " 

93 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is " 

94 " initialized." 

95 ) 

96 self.namespace = namespace 

97 self.dbname = dsn.get("dbname") 

98 self._writeable = writeable 

99 self._shrinker = NameShrinker(self.dialect.max_identifier_length) 

100 

101 @classmethod 

102 def makeEngine( 

103 cls, uri: str | sqlalchemy.engine.URL, *, writeable: bool = True 

104 ) -> sqlalchemy.engine.Engine: 

105 return sqlalchemy.engine.create_engine(uri, pool_size=1) 

106 

107 @classmethod 

108 def fromEngine( 

109 cls, 

110 engine: sqlalchemy.engine.Engine, 

111 *, 

112 origin: int, 

113 namespace: Optional[str] = None, 

114 writeable: bool = True, 

115 ) -> Database: 

116 return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable) 

117 

118 @contextmanager 

119 def _transaction( 

120 self, 

121 *, 

122 interrupting: bool = False, 

123 savepoint: bool = False, 

124 lock: Iterable[sqlalchemy.schema.Table] = (), 

125 for_temp_tables: bool = False, 

126 ) -> Iterator[tuple[bool, sqlalchemy.engine.Connection]]: 

127 with super()._transaction(interrupting=interrupting, savepoint=savepoint, lock=lock) as ( 

128 is_new, 

129 connection, 

130 ): 

131 if is_new: 

132 # pgbouncer with transaction-level pooling (which we aim to 

133 # support) says that SET cannot be used, except for a list of 

134 # "Startup parameters" that includes "timezone" (see 

135 # https://www.pgbouncer.org/features.html#fnref:0). But I 

136 # don't see "timezone" in PostgreSQL's list of parameters 

137 # passed when creating a new connection 

138 # (https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS). 

139 # Given that the pgbouncer docs say, "PgBouncer detects their 

140 # changes and so it can guarantee they remain consistent for 

141 # the client", I assume we can use "SET TIMESPAN" and pgbouncer 

142 # will take care of clients that share connections being set 

143 # consistently. And if that assumption is wrong, we should 

144 # still probably be okay, since all clients should be Butler 

145 # clients, and they'll all be setting the same thing. 

146 # 

147 # The "SET TRANSACTION READ ONLY" should also be safe, because 

148 # it only ever acts on the current transaction; I think it's 

149 # not included in pgbouncer's declaration that SET is 

150 # incompatible with transaction-level pooling because 

151 # PostgreSQL actually considers SET TRANSACTION to be a 

152 # fundamentally different statement from SET (they have their 

153 # own distinct doc pages, at least). 

154 if not (self.isWriteable() or for_temp_tables): 

155 # PostgreSQL permits writing to temporary tables inside 

156 # read-only transactions, but it doesn't permit creating 

157 # them. 

158 with closing(connection.connection.cursor()) as cursor: 

159 cursor.execute("SET TRANSACTION READ ONLY") 

160 cursor.execute("SET TIME ZONE 0") 

161 else: 

162 with closing(connection.connection.cursor()) as cursor: 

163 # Make timestamps UTC, because we didn't use TIMESTAMPZ 

164 # for the column type. When we can tolerate a schema 

165 # change, we should change that type and remove this 

166 # line. 

167 cursor.execute("SET TIME ZONE 0") 

168 yield is_new, connection 

169 

170 @contextmanager 

171 def temporary_table( 

172 self, spec: ddl.TableSpec, name: Optional[str] = None 

173 ) -> Iterator[sqlalchemy.schema.Table]: 

174 # Docstring inherited. 

175 with self.transaction(for_temp_tables=True): 

176 with super().temporary_table(spec, name) as table: 

177 yield table 

178 

179 def _lockTables( 

180 self, connection: sqlalchemy.engine.Connection, tables: Iterable[sqlalchemy.schema.Table] = () 

181 ) -> None: 

182 # Docstring inherited. 

183 for table in tables: 

184 connection.execute(sqlalchemy.text(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE")) 

185 

186 def isWriteable(self) -> bool: 

187 return self._writeable 

188 

189 def __str__(self) -> str: 

190 return f"PostgreSQL@{self.dbname}:{self.namespace}" 

191 

192 def shrinkDatabaseEntityName(self, original: str) -> str: 

193 return self._shrinker.shrink(original) 

194 

195 def expandDatabaseEntityName(self, shrunk: str) -> str: 

196 return self._shrinker.expand(shrunk) 

197 

198 def _convertExclusionConstraintSpec( 

199 self, 

200 table: str, 

201 spec: Tuple[Union[str, Type[TimespanDatabaseRepresentation]], ...], 

202 metadata: sqlalchemy.MetaData, 

203 ) -> sqlalchemy.schema.Constraint: 

204 # Docstring inherited. 

205 args: list[tuple[sqlalchemy.schema.Column, str]] = [] 

206 names = ["excl"] 

207 for item in spec: 

208 if isinstance(item, str): 

209 args.append((sqlalchemy.schema.Column(item), "=")) 

210 names.append(item) 

211 elif issubclass(item, TimespanDatabaseRepresentation): 

212 assert item is self.getTimespanRepresentation() 

213 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&")) 

214 names.append(TimespanDatabaseRepresentation.NAME) 

215 return sqlalchemy.dialects.postgresql.ExcludeConstraint( 

216 *args, 

217 name=self.shrinkDatabaseEntityName("_".join(names)), 

218 ) 

219 

220 def _make_temporary_table( 

221 self, 

222 connection: sqlalchemy.engine.Connection, 

223 spec: ddl.TableSpec, 

224 name: Optional[str] = None, 

225 **kwargs: Any, 

226 ) -> sqlalchemy.schema.Table: 

227 # Docstring inherited 

228 # Adding ON COMMIT DROP here is really quite defensive: we already 

229 # manually drop the table at the end of the temporary_table context 

230 # manager, and that will usually happen first. But this will guarantee 

231 # that we drop the table at the end of the transaction even if the 

232 # connection lasts longer, and that's good citizenship when connections 

233 # may be multiplexed by e.g. pgbouncer. 

234 return super()._make_temporary_table(connection, spec, name, postgresql_on_commit="DROP", **kwargs) 

235 

236 @classmethod 

237 def getTimespanRepresentation(cls) -> Type[TimespanDatabaseRepresentation]: 

238 # Docstring inherited. 

239 return _RangeTimespanRepresentation 

240 

241 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None: 

242 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.") 

243 if not rows: 

244 return 

245 # This uses special support for UPSERT in PostgreSQL backend: 

246 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert 

247 query = sqlalchemy.dialects.postgresql.dml.insert(table) 

248 # In the SET clause assign all columns using special `excluded` 

249 # pseudo-table. If some column in the table does not appear in the 

250 # INSERT list this will set it to NULL. 

251 excluded = query.excluded 

252 data = { 

253 column.name: getattr(excluded, column.name) 

254 for column in table.columns 

255 if column.name not in table.primary_key 

256 } 

257 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data) 

258 with self._transaction() as (_, connection): 

259 connection.execute(query, rows) 

260 

261 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict, primary_key_only: bool = False) -> int: 

262 # Docstring inherited. 

263 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.") 

264 if not rows: 

265 return 0 

266 # Like `replace`, this uses UPSERT. 

267 base_insert = sqlalchemy.dialects.postgresql.dml.insert(table) 

268 if primary_key_only: 

269 query = base_insert.on_conflict_do_nothing(constraint=table.primary_key) 

270 else: 

271 query = base_insert.on_conflict_do_nothing() 

272 with self._transaction() as (_, connection): 

273 return connection.execute(query, rows).rowcount 

274 

275 def constant_rows( 

276 self, 

277 fields: NamedValueAbstractSet[ddl.FieldSpec], 

278 *rows: dict, 

279 name: Optional[str] = None, 

280 ) -> sqlalchemy.sql.FromClause: 

281 # Docstring inherited. 

282 return super().constant_rows(fields, *rows, name=name) 

283 

284 

285class _RangeTimespanType(sqlalchemy.TypeDecorator): 

286 """A single-column `Timespan` representation usable only with 

287 PostgreSQL. 

288 

289 This type should be able to take advantage of PostgreSQL's built-in 

290 range operators, and the indexing and EXCLUSION table constraints built 

291 off of them. 

292 """ 

293 

294 impl = sqlalchemy.dialects.postgresql.INT8RANGE 

295 

296 cache_ok = True 

297 

298 def process_bind_param( 

299 self, value: Optional[Timespan], dialect: sqlalchemy.engine.Dialect 

300 ) -> Optional[psycopg2.extras.NumericRange]: 

301 if value is None: 

302 return None 

303 if not isinstance(value, Timespan): 

304 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.") 

305 if value.isEmpty(): 

306 return psycopg2.extras.NumericRange(empty=True) 

307 else: 

308 converter = time_utils.TimeConverter() 

309 assert value._nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__." 

310 assert value._nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__." 

311 lower = None if value._nsec[0] == converter.min_nsec else value._nsec[0] 

312 upper = None if value._nsec[1] == converter.max_nsec else value._nsec[1] 

313 return psycopg2.extras.NumericRange(lower=lower, upper=upper) 

314 

315 def process_result_value( 

316 self, value: Optional[psycopg2.extras.NumericRange], dialect: sqlalchemy.engine.Dialect 

317 ) -> Optional[Timespan]: 

318 if value is None: 

319 return None 

320 if value.isempty: 

321 return Timespan.makeEmpty() 

322 converter = time_utils.TimeConverter() 

323 begin_nsec = converter.min_nsec if value.lower is None else value.lower 

324 end_nsec = converter.max_nsec if value.upper is None else value.upper 

325 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec)) 

326 

327 

328class _RangeTimespanRepresentation(TimespanDatabaseRepresentation): 

329 """An implementation of `TimespanDatabaseRepresentation` that uses 

330 `_RangeTimespanType` to store a timespan in a single 

331 PostgreSQL-specific field. 

332 

333 Parameters 

334 ---------- 

335 column : `sqlalchemy.sql.ColumnElement` 

336 SQLAlchemy object representing the column. 

337 """ 

338 

339 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str): 

340 self.column = column 

341 self._name = name 

342 

343 __slots__ = ("column", "_name") 

344 

345 @classmethod 

346 def makeFieldSpecs( 

347 cls, nullable: bool, name: Optional[str] = None, **kwargs: Any 

348 ) -> Tuple[ddl.FieldSpec, ...]: 

349 # Docstring inherited. 

350 if name is None: 

351 name = cls.NAME 

352 return ( 

353 ddl.FieldSpec( 

354 name, 

355 dtype=_RangeTimespanType, 

356 nullable=nullable, 

357 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")), 

358 **kwargs, 

359 ), 

360 ) 

361 

362 @classmethod 

363 def getFieldNames(cls, name: Optional[str] = None) -> Tuple[str, ...]: 

364 # Docstring inherited. 

365 if name is None: 

366 name = cls.NAME 

367 return (name,) 

368 

369 @classmethod 

370 def update( 

371 cls, extent: Optional[Timespan], name: Optional[str] = None, result: Optional[Dict[str, Any]] = None 

372 ) -> Dict[str, Any]: 

373 # Docstring inherited. 

374 if name is None: 

375 name = cls.NAME 

376 if result is None: 

377 result = {} 

378 result[name] = extent 

379 return result 

380 

381 @classmethod 

382 def extract(cls, mapping: Mapping[str, Any], name: Optional[str] = None) -> Optional[Timespan]: 

383 # Docstring inherited. 

384 if name is None: 

385 name = cls.NAME 

386 return mapping[name] 

387 

388 @classmethod 

389 def fromLiteral(cls, timespan: Optional[Timespan]) -> _RangeTimespanRepresentation: 

390 # Docstring inherited. 

391 if timespan is None: 

392 return cls(column=sqlalchemy.sql.null(), name=cls.NAME) 

393 return cls( 

394 column=sqlalchemy.sql.cast( 

395 sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType), type_=_RangeTimespanType 

396 ), 

397 name=cls.NAME, 

398 ) 

399 

400 @classmethod 

401 def from_columns( 

402 cls, columns: sqlalchemy.sql.ColumnCollection, name: Optional[str] = None 

403 ) -> _RangeTimespanRepresentation: 

404 # Docstring inherited. 

405 if name is None: 

406 name = cls.NAME 

407 return cls(columns[name], name) 

408 

409 @property 

410 def name(self) -> str: 

411 # Docstring inherited. 

412 return self._name 

413 

414 def isNull(self) -> sqlalchemy.sql.ColumnElement: 

415 # Docstring inherited. 

416 return self.column.is_(None) 

417 

418 def isEmpty(self) -> sqlalchemy.sql.ColumnElement: 

419 # Docstring inherited 

420 return sqlalchemy.sql.func.isempty(self.column) 

421 

422 def __lt__( 

423 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

424 ) -> sqlalchemy.sql.ColumnElement: 

425 # Docstring inherited. 

426 if isinstance(other, sqlalchemy.sql.ColumnElement): 

427 return sqlalchemy.sql.and_( 

428 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)), 

429 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

430 sqlalchemy.sql.func.upper(self.column) <= other, 

431 ) 

432 else: 

433 return self.column << other.column 

434 

435 def __gt__( 

436 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

437 ) -> sqlalchemy.sql.ColumnElement: 

438 # Docstring inherited. 

439 if isinstance(other, sqlalchemy.sql.ColumnElement): 

440 return sqlalchemy.sql.and_( 

441 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)), 

442 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

443 sqlalchemy.sql.func.lower(self.column) > other, 

444 ) 

445 else: 

446 return self.column >> other.column 

447 

448 def overlaps( 

449 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement 

450 ) -> sqlalchemy.sql.ColumnElement: 

451 # Docstring inherited. 

452 if not isinstance(other, _RangeTimespanRepresentation): 

453 return self.contains(other) 

454 return self.column.overlaps(other.column) 

455 

456 def contains( 

457 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

458 ) -> sqlalchemy.sql.ColumnElement: 

459 # Docstring inherited 

460 if isinstance(other, _RangeTimespanRepresentation): 

461 return self.column.contains(other.column) 

462 else: 

463 return self.column.contains(other) 

464 

465 def lower(self) -> sqlalchemy.sql.ColumnElement: 

466 # Docstring inherited. 

467 return sqlalchemy.sql.functions.coalesce( 

468 sqlalchemy.sql.func.lower(self.column), sqlalchemy.sql.literal(0) 

469 ) 

470 

471 def upper(self) -> sqlalchemy.sql.ColumnElement: 

472 # Docstring inherited. 

473 return sqlalchemy.sql.functions.coalesce( 

474 sqlalchemy.sql.func.upper(self.column), sqlalchemy.sql.literal(0) 

475 ) 

476 

477 def flatten(self, name: Optional[str] = None) -> tuple[sqlalchemy.sql.ColumnElement]: 

478 # Docstring inherited. 

479 if name is None: 

480 return (self.column,) 

481 else: 

482 return (self.column.label(name),)