Coverage for python/lsst/daf/butler/registry/databases/postgresql.py: 24%

197 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-17 09:32 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ["PostgresqlDatabase"] 

24 

25from contextlib import closing, contextmanager 

26from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, Type, Union 

27 

28import psycopg2 

29import sqlalchemy 

30import sqlalchemy.dialects.postgresql 

31from sqlalchemy import sql 

32 

33from ...core import Timespan, TimespanDatabaseRepresentation, ddl, time_utils 

34from ...core.named import NamedValueAbstractSet 

35from ..interfaces import Database 

36from ..nameShrinker import NameShrinker 

37 

38 

39class PostgresqlDatabase(Database): 

40 """An implementation of the `Database` interface for PostgreSQL. 

41 

42 Parameters 

43 ---------- 

44 connection : `sqlalchemy.engine.Connection` 

45 An existing connection created by a previous call to `connect`. 

46 origin : `int` 

47 An integer ID that should be used as the default for any datasets, 

48 quanta, or other entities that use a (autoincrement, origin) compound 

49 primary key. 

50 namespace : `str`, optional 

51 The namespace (schema) this database is associated with. If `None`, 

52 the default schema for the connection is used (which may be `None`). 

53 writeable : `bool`, optional 

54 If `True`, allow write operations on the database, including 

55 ``CREATE TABLE``. 

56 

57 Notes 

58 ----- 

59 This currently requires the psycopg2 driver to be used as the backend for 

60 SQLAlchemy. Running the tests for this class requires the 

61 ``testing.postgresql`` be installed, which we assume indicates that a 

62 PostgreSQL server is installed and can be run locally in userspace. 

63 

64 Some functionality provided by this class (and used by `Registry`) requires 

65 the ``btree_gist`` PostgreSQL server extension to be installed an enabled 

66 on the database being connected to; this is checked at connection time. 

67 """ 

68 

69 def __init__( 

70 self, 

71 *, 

72 engine: sqlalchemy.engine.Engine, 

73 origin: int, 

74 namespace: Optional[str] = None, 

75 writeable: bool = True, 

76 ): 

77 super().__init__(origin=origin, engine=engine, namespace=namespace) 

78 with engine.connect() as connection: 

79 # `Any` to make mypy ignore the line below, can't use type: ignore 

80 dbapi: Any = connection.connection 

81 try: 

82 dsn = dbapi.get_dsn_parameters() 

83 except (AttributeError, KeyError) as err: 

84 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err 

85 if namespace is None: 

86 query = sql.select(sql.func.current_schema()) 

87 namespace = connection.execute(query).scalar() 

88 query_text = "SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';" 

89 if not connection.execute(sqlalchemy.text(query_text)).scalar(): 

90 raise RuntimeError( 

91 "The Butler PostgreSQL backend requires the btree_gist extension. " 

92 "As extensions are enabled per-database, this may require an administrator to run " 

93 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is " 

94 " initialized." 

95 ) 

96 self.namespace = namespace 

97 self.dbname = dsn.get("dbname") 

98 self._writeable = writeable 

99 self._shrinker = NameShrinker(self.dialect.max_identifier_length) 

100 

101 @classmethod 

102 def makeEngine(cls, uri: str, *, writeable: bool = True) -> sqlalchemy.engine.Engine: 

103 return sqlalchemy.engine.create_engine(uri, pool_size=1) 

104 

105 @classmethod 

106 def fromEngine( 

107 cls, 

108 engine: sqlalchemy.engine.Engine, 

109 *, 

110 origin: int, 

111 namespace: Optional[str] = None, 

112 writeable: bool = True, 

113 ) -> Database: 

114 return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable) 

115 

116 @contextmanager 

117 def _transaction( 

118 self, 

119 *, 

120 interrupting: bool = False, 

121 savepoint: bool = False, 

122 lock: Iterable[sqlalchemy.schema.Table] = (), 

123 for_temp_tables: bool = False, 

124 ) -> Iterator[tuple[bool, sqlalchemy.engine.Connection]]: 

125 with super()._transaction(interrupting=interrupting, savepoint=savepoint, lock=lock) as ( 

126 is_new, 

127 connection, 

128 ): 

129 if is_new: 

130 # pgbouncer with transaction-level pooling (which we aim to 

131 # support) says that SET cannot be used, except for a list of 

132 # "Startup parameters" that includes "timezone" (see 

133 # https://www.pgbouncer.org/features.html#fnref:0). But I 

134 # don't see "timezone" in PostgreSQL's list of parameters 

135 # passed when creating a new connection 

136 # (https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS). 

137 # Given that the pgbouncer docs say, "PgBouncer detects their 

138 # changes and so it can guarantee they remain consistent for 

139 # the client", I assume we can use "SET TIMESPAN" and pgbouncer 

140 # will take care of clients that share connections being set 

141 # consistently. And if that assumption is wrong, we should 

142 # still probably be okay, since all clients should be Butler 

143 # clients, and they'll all be setting the same thing. 

144 # 

145 # The "SET TRANSACTION READ ONLY" should also be safe, because 

146 # it only ever acts on the current transaction; I think it's 

147 # not included in pgbouncer's declaration that SET is 

148 # incompatible with transaction-level pooling because 

149 # PostgreSQL actually considers SET TRANSACTION to be a 

150 # fundamentally different statement from SET (they have their 

151 # own distinct doc pages, at least). 

152 if not (self.isWriteable() or for_temp_tables): 

153 # PostgreSQL permits writing to temporary tables inside 

154 # read-only transactions, but it doesn't permit creating 

155 # them. 

156 with closing(connection.connection.cursor()) as cursor: 

157 cursor.execute("SET TRANSACTION READ ONLY") 

158 cursor.execute("SET TIME ZONE 0") 

159 else: 

160 with closing(connection.connection.cursor()) as cursor: 

161 # Make timestamps UTC, because we didn't use TIMESTAMPZ 

162 # for the column type. When we can tolerate a schema 

163 # change, we should change that type and remove this 

164 # line. 

165 cursor.execute("SET TIME ZONE 0") 

166 yield is_new, connection 

167 

168 @contextmanager 

169 def temporary_table( 

170 self, spec: ddl.TableSpec, name: Optional[str] = None 

171 ) -> Iterator[sqlalchemy.schema.Table]: 

172 # Docstring inherited. 

173 with self.transaction(for_temp_tables=True): 

174 with super().temporary_table(spec, name) as table: 

175 yield table 

176 

177 def _lockTables( 

178 self, connection: sqlalchemy.engine.Connection, tables: Iterable[sqlalchemy.schema.Table] = () 

179 ) -> None: 

180 # Docstring inherited. 

181 for table in tables: 

182 connection.execute(sqlalchemy.text(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE")) 

183 

184 def isWriteable(self) -> bool: 

185 return self._writeable 

186 

187 def __str__(self) -> str: 

188 return f"PostgreSQL@{self.dbname}:{self.namespace}" 

189 

190 def shrinkDatabaseEntityName(self, original: str) -> str: 

191 return self._shrinker.shrink(original) 

192 

193 def expandDatabaseEntityName(self, shrunk: str) -> str: 

194 return self._shrinker.expand(shrunk) 

195 

196 def _convertExclusionConstraintSpec( 

197 self, 

198 table: str, 

199 spec: Tuple[Union[str, Type[TimespanDatabaseRepresentation]], ...], 

200 metadata: sqlalchemy.MetaData, 

201 ) -> sqlalchemy.schema.Constraint: 

202 # Docstring inherited. 

203 args: list[tuple[sqlalchemy.schema.Column, str]] = [] 

204 names = ["excl"] 

205 for item in spec: 

206 if isinstance(item, str): 

207 args.append((sqlalchemy.schema.Column(item), "=")) 

208 names.append(item) 

209 elif issubclass(item, TimespanDatabaseRepresentation): 

210 assert item is self.getTimespanRepresentation() 

211 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&")) 

212 names.append(TimespanDatabaseRepresentation.NAME) 

213 return sqlalchemy.dialects.postgresql.ExcludeConstraint( 

214 *args, 

215 name=self.shrinkDatabaseEntityName("_".join(names)), 

216 ) 

217 

218 def _make_temporary_table( 

219 self, 

220 connection: sqlalchemy.engine.Connection, 

221 spec: ddl.TableSpec, 

222 name: Optional[str] = None, 

223 **kwargs: Any, 

224 ) -> sqlalchemy.schema.Table: 

225 # Docstring inherited 

226 # Adding ON COMMIT DROP here is really quite defensive: we already 

227 # manually drop the table at the end of the temporary_table context 

228 # manager, and that will usually happen first. But this will guarantee 

229 # that we drop the table at the end of the transaction even if the 

230 # connection lasts longer, and that's good citizenship when connections 

231 # may be multiplexed by e.g. pgbouncer. 

232 return super()._make_temporary_table(connection, spec, name, postgresql_on_commit="DROP", **kwargs) 

233 

234 @classmethod 

235 def getTimespanRepresentation(cls) -> Type[TimespanDatabaseRepresentation]: 

236 # Docstring inherited. 

237 return _RangeTimespanRepresentation 

238 

239 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None: 

240 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.") 

241 if not rows: 

242 return 

243 # This uses special support for UPSERT in PostgreSQL backend: 

244 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert 

245 query = sqlalchemy.dialects.postgresql.dml.insert(table) 

246 # In the SET clause assign all columns using special `excluded` 

247 # pseudo-table. If some column in the table does not appear in the 

248 # INSERT list this will set it to NULL. 

249 excluded = query.excluded 

250 data = { 

251 column.name: getattr(excluded, column.name) 

252 for column in table.columns 

253 if column.name not in table.primary_key 

254 } 

255 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data) 

256 with self._transaction() as (_, connection): 

257 connection.execute(query, rows) 

258 

259 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict, primary_key_only: bool = False) -> int: 

260 # Docstring inherited. 

261 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.") 

262 if not rows: 

263 return 0 

264 # Like `replace`, this uses UPSERT. 

265 base_insert = sqlalchemy.dialects.postgresql.dml.insert(table) 

266 if primary_key_only: 

267 query = base_insert.on_conflict_do_nothing(constraint=table.primary_key) 

268 else: 

269 query = base_insert.on_conflict_do_nothing() 

270 with self._transaction() as (_, connection): 

271 return connection.execute(query, rows).rowcount 

272 

273 def constant_rows( 

274 self, 

275 fields: NamedValueAbstractSet[ddl.FieldSpec], 

276 *rows: dict, 

277 name: Optional[str] = None, 

278 ) -> sqlalchemy.sql.FromClause: 

279 # Docstring inherited. 

280 return super().constant_rows(fields, *rows, name=name) 

281 

282 

283class _RangeTimespanType(sqlalchemy.TypeDecorator): 

284 """A single-column `Timespan` representation usable only with 

285 PostgreSQL. 

286 

287 This type should be able to take advantage of PostgreSQL's built-in 

288 range operators, and the indexing and EXCLUSION table constraints built 

289 off of them. 

290 """ 

291 

292 impl = sqlalchemy.dialects.postgresql.INT8RANGE 

293 

294 cache_ok = True 

295 

296 def process_bind_param( 

297 self, value: Optional[Timespan], dialect: sqlalchemy.engine.Dialect 

298 ) -> Optional[psycopg2.extras.NumericRange]: 

299 if value is None: 

300 return None 

301 if not isinstance(value, Timespan): 

302 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.") 

303 if value.isEmpty(): 

304 return psycopg2.extras.NumericRange(empty=True) 

305 else: 

306 converter = time_utils.TimeConverter() 

307 assert value._nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__." 

308 assert value._nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__." 

309 lower = None if value._nsec[0] == converter.min_nsec else value._nsec[0] 

310 upper = None if value._nsec[1] == converter.max_nsec else value._nsec[1] 

311 return psycopg2.extras.NumericRange(lower=lower, upper=upper) 

312 

313 def process_result_value( 

314 self, value: Optional[psycopg2.extras.NumericRange], dialect: sqlalchemy.engine.Dialect 

315 ) -> Optional[Timespan]: 

316 if value is None: 

317 return None 

318 if value.isempty: 

319 return Timespan.makeEmpty() 

320 converter = time_utils.TimeConverter() 

321 begin_nsec = converter.min_nsec if value.lower is None else value.lower 

322 end_nsec = converter.max_nsec if value.upper is None else value.upper 

323 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec)) 

324 

325 

326class _RangeTimespanRepresentation(TimespanDatabaseRepresentation): 

327 """An implementation of `TimespanDatabaseRepresentation` that uses 

328 `_RangeTimespanType` to store a timespan in a single 

329 PostgreSQL-specific field. 

330 

331 Parameters 

332 ---------- 

333 column : `sqlalchemy.sql.ColumnElement` 

334 SQLAlchemy object representing the column. 

335 """ 

336 

337 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str): 

338 self.column = column 

339 self._name = name 

340 

341 __slots__ = ("column", "_name") 

342 

343 @classmethod 

344 def makeFieldSpecs( 

345 cls, nullable: bool, name: Optional[str] = None, **kwargs: Any 

346 ) -> Tuple[ddl.FieldSpec, ...]: 

347 # Docstring inherited. 

348 if name is None: 

349 name = cls.NAME 

350 return ( 

351 ddl.FieldSpec( 

352 name, 

353 dtype=_RangeTimespanType, 

354 nullable=nullable, 

355 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")), 

356 **kwargs, 

357 ), 

358 ) 

359 

360 @classmethod 

361 def getFieldNames(cls, name: Optional[str] = None) -> Tuple[str, ...]: 

362 # Docstring inherited. 

363 if name is None: 

364 name = cls.NAME 

365 return (name,) 

366 

367 @classmethod 

368 def update( 

369 cls, extent: Optional[Timespan], name: Optional[str] = None, result: Optional[Dict[str, Any]] = None 

370 ) -> Dict[str, Any]: 

371 # Docstring inherited. 

372 if name is None: 

373 name = cls.NAME 

374 if result is None: 

375 result = {} 

376 result[name] = extent 

377 return result 

378 

379 @classmethod 

380 def extract(cls, mapping: Mapping[str, Any], name: Optional[str] = None) -> Optional[Timespan]: 

381 # Docstring inherited. 

382 if name is None: 

383 name = cls.NAME 

384 return mapping[name] 

385 

386 @classmethod 

387 def fromLiteral(cls, timespan: Optional[Timespan]) -> _RangeTimespanRepresentation: 

388 # Docstring inherited. 

389 if timespan is None: 

390 return cls(column=sqlalchemy.sql.null(), name=cls.NAME) 

391 return cls( 

392 column=sqlalchemy.sql.cast( 

393 sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType), type_=_RangeTimespanType 

394 ), 

395 name=cls.NAME, 

396 ) 

397 

398 @classmethod 

399 def from_columns( 

400 cls, columns: sqlalchemy.sql.ColumnCollection, name: Optional[str] = None 

401 ) -> _RangeTimespanRepresentation: 

402 # Docstring inherited. 

403 if name is None: 

404 name = cls.NAME 

405 return cls(columns[name], name) 

406 

407 @property 

408 def name(self) -> str: 

409 # Docstring inherited. 

410 return self._name 

411 

412 def isNull(self) -> sqlalchemy.sql.ColumnElement: 

413 # Docstring inherited. 

414 return self.column.is_(None) 

415 

416 def isEmpty(self) -> sqlalchemy.sql.ColumnElement: 

417 # Docstring inherited 

418 return sqlalchemy.sql.func.isempty(self.column) 

419 

420 def __lt__( 

421 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

422 ) -> sqlalchemy.sql.ColumnElement: 

423 # Docstring inherited. 

424 if isinstance(other, sqlalchemy.sql.ColumnElement): 

425 return sqlalchemy.sql.and_( 

426 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)), 

427 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

428 sqlalchemy.sql.func.upper(self.column) <= other, 

429 ) 

430 else: 

431 return self.column << other.column 

432 

433 def __gt__( 

434 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

435 ) -> sqlalchemy.sql.ColumnElement: 

436 # Docstring inherited. 

437 if isinstance(other, sqlalchemy.sql.ColumnElement): 

438 return sqlalchemy.sql.and_( 

439 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)), 

440 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

441 sqlalchemy.sql.func.lower(self.column) > other, 

442 ) 

443 else: 

444 return self.column >> other.column 

445 

446 def overlaps( 

447 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement 

448 ) -> sqlalchemy.sql.ColumnElement: 

449 # Docstring inherited. 

450 if not isinstance(other, _RangeTimespanRepresentation): 

451 return self.contains(other) 

452 return self.column.overlaps(other.column) 

453 

454 def contains( 

455 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

456 ) -> sqlalchemy.sql.ColumnElement: 

457 # Docstring inherited 

458 if isinstance(other, _RangeTimespanRepresentation): 

459 return self.column.contains(other.column) 

460 else: 

461 return self.column.contains(other) 

462 

463 def lower(self) -> sqlalchemy.sql.ColumnElement: 

464 # Docstring inherited. 

465 return sqlalchemy.sql.functions.coalesce( 

466 sqlalchemy.sql.func.lower(self.column), sqlalchemy.sql.literal(0) 

467 ) 

468 

469 def upper(self) -> sqlalchemy.sql.ColumnElement: 

470 # Docstring inherited. 

471 return sqlalchemy.sql.functions.coalesce( 

472 sqlalchemy.sql.func.upper(self.column), sqlalchemy.sql.literal(0) 

473 ) 

474 

475 def flatten(self, name: Optional[str] = None) -> tuple[sqlalchemy.sql.ColumnElement]: 

476 # Docstring inherited. 

477 if name is None: 

478 return (self.column,) 

479 else: 

480 return (self.column.label(name),)