Coverage for python/lsst/daf/butler/registry/databases/postgresql.py: 24%

195 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-10-26 15:13 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ["PostgresqlDatabase"] 

24 

25from contextlib import closing, contextmanager 

26from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, Type, Union 

27 

28import psycopg2 

29import sqlalchemy 

30import sqlalchemy.dialects.postgresql 

31 

32from ...core import Timespan, TimespanDatabaseRepresentation, ddl, time_utils 

33from ...core.named import NamedValueAbstractSet 

34from ..interfaces import Database 

35from ..nameShrinker import NameShrinker 

36 

37 

38class PostgresqlDatabase(Database): 

39 """An implementation of the `Database` interface for PostgreSQL. 

40 

41 Parameters 

42 ---------- 

43 connection : `sqlalchemy.engine.Connection` 

44 An existing connection created by a previous call to `connect`. 

45 origin : `int` 

46 An integer ID that should be used as the default for any datasets, 

47 quanta, or other entities that use a (autoincrement, origin) compound 

48 primary key. 

49 namespace : `str`, optional 

50 The namespace (schema) this database is associated with. If `None`, 

51 the default schema for the connection is used (which may be `None`). 

52 writeable : `bool`, optional 

53 If `True`, allow write operations on the database, including 

54 ``CREATE TABLE``. 

55 

56 Notes 

57 ----- 

58 This currently requires the psycopg2 driver to be used as the backend for 

59 SQLAlchemy. Running the tests for this class requires the 

60 ``testing.postgresql`` be installed, which we assume indicates that a 

61 PostgreSQL server is installed and can be run locally in userspace. 

62 

63 Some functionality provided by this class (and used by `Registry`) requires 

64 the ``btree_gist`` PostgreSQL server extension to be installed an enabled 

65 on the database being connected to; this is checked at connection time. 

66 """ 

67 

68 def __init__( 

69 self, 

70 *, 

71 engine: sqlalchemy.engine.Engine, 

72 origin: int, 

73 namespace: Optional[str] = None, 

74 writeable: bool = True, 

75 ): 

76 super().__init__(origin=origin, engine=engine, namespace=namespace) 

77 with engine.connect() as connection: 

78 dbapi = connection.connection 

79 try: 

80 dsn = dbapi.get_dsn_parameters() 

81 except (AttributeError, KeyError) as err: 

82 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err 

83 if namespace is None: 

84 namespace = connection.execute("SELECT current_schema();").scalar() 

85 query = "SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';" 

86 if not connection.execute(sqlalchemy.text(query)).scalar(): 

87 raise RuntimeError( 

88 "The Butler PostgreSQL backend requires the btree_gist extension. " 

89 "As extensions are enabled per-database, this may require an administrator to run " 

90 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is " 

91 " initialized." 

92 ) 

93 self.namespace = namespace 

94 self.dbname = dsn.get("dbname") 

95 self._writeable = writeable 

96 self._shrinker = NameShrinker(connection.engine.dialect.max_identifier_length) 

97 

98 @classmethod 

99 def makeEngine(cls, uri: str, *, writeable: bool = True) -> sqlalchemy.engine.Engine: 

100 return sqlalchemy.engine.create_engine(uri, pool_size=1) 

101 

102 @classmethod 

103 def fromEngine( 

104 cls, 

105 engine: sqlalchemy.engine.Engine, 

106 *, 

107 origin: int, 

108 namespace: Optional[str] = None, 

109 writeable: bool = True, 

110 ) -> Database: 

111 return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable) 

112 

113 @contextmanager 

114 def _transaction( 

115 self, 

116 *, 

117 interrupting: bool = False, 

118 savepoint: bool = False, 

119 lock: Iterable[sqlalchemy.schema.Table] = (), 

120 for_temp_tables: bool = False, 

121 ) -> Iterator[tuple[bool, sqlalchemy.engine.Connection]]: 

122 with super()._transaction(interrupting=interrupting, savepoint=savepoint, lock=lock) as ( 

123 is_new, 

124 connection, 

125 ): 

126 if is_new: 

127 # pgbouncer with transaction-level pooling (which we aim to 

128 # support) says that SET cannot be used, except for a list of 

129 # "Startup parameters" that includes "timezone" (see 

130 # https://www.pgbouncer.org/features.html#fnref:0). But I 

131 # don't see "timezone" in PostgreSQL's list of parameters 

132 # passed when creating a new connection 

133 # (https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS). 

134 # Given that the pgbouncer docs say, "PgBouncer detects their 

135 # changes and so it can guarantee they remain consistent for 

136 # the client", I assume we can use "SET TIMESPAN" and pgbouncer 

137 # will take care of clients that share connections being set 

138 # consistently. And if that assumption is wrong, we should 

139 # still probably be okay, since all clients should be Butler 

140 # clients, and they'll all be setting the same thing. 

141 # 

142 # The "SET TRANSACTION READ ONLY" should also be safe, because 

143 # it only ever acts on the current transaction; I think it's 

144 # not included in pgbouncer's declaration that SET is 

145 # incompatible with transaction-level pooling because 

146 # PostgreSQL actually considers SET TRANSACTION to be a 

147 # fundamentally different statement from SET (they have their 

148 # own distinct doc pages, at least). 

149 if not (self.isWriteable() or for_temp_tables): 

150 # PostgreSQL permits writing to temporary tables inside 

151 # read-only transactions, but it doesn't permit creating 

152 # them. 

153 with closing(connection.connection.cursor()) as cursor: 

154 cursor.execute("SET TRANSACTION READ ONLY") 

155 cursor.execute("SET TIME ZONE 0") 

156 else: 

157 with closing(connection.connection.cursor()) as cursor: 

158 # Make timestamps UTC, because we didn't use TIMESTAMPZ 

159 # for the column type. When we can tolerate a schema 

160 # change, we should change that type and remove this 

161 # line. 

162 cursor.execute("SET TIME ZONE 0") 

163 yield is_new, connection 

164 

165 @contextmanager 

166 def temporary_table( 

167 self, spec: ddl.TableSpec, name: Optional[str] = None 

168 ) -> Iterator[sqlalchemy.schema.Table]: 

169 # Docstring inherited. 

170 with self.transaction(for_temp_tables=True): 

171 with super().temporary_table(spec, name) as table: 

172 yield table 

173 

174 def _lockTables( 

175 self, connection: sqlalchemy.engine.Connection, tables: Iterable[sqlalchemy.schema.Table] = () 

176 ) -> None: 

177 # Docstring inherited. 

178 for table in tables: 

179 connection.execute(sqlalchemy.text(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE")) 

180 

181 def isWriteable(self) -> bool: 

182 return self._writeable 

183 

184 def __str__(self) -> str: 

185 return f"PostgreSQL@{self.dbname}:{self.namespace}" 

186 

187 def shrinkDatabaseEntityName(self, original: str) -> str: 

188 return self._shrinker.shrink(original) 

189 

190 def expandDatabaseEntityName(self, shrunk: str) -> str: 

191 return self._shrinker.expand(shrunk) 

192 

193 def _convertExclusionConstraintSpec( 

194 self, 

195 table: str, 

196 spec: Tuple[Union[str, Type[TimespanDatabaseRepresentation]], ...], 

197 metadata: sqlalchemy.MetaData, 

198 ) -> sqlalchemy.schema.Constraint: 

199 # Docstring inherited. 

200 args = [] 

201 names = ["excl"] 

202 for item in spec: 

203 if isinstance(item, str): 

204 args.append((sqlalchemy.schema.Column(item), "=")) 

205 names.append(item) 

206 elif issubclass(item, TimespanDatabaseRepresentation): 

207 assert item is self.getTimespanRepresentation() 

208 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&")) 

209 names.append(TimespanDatabaseRepresentation.NAME) 

210 return sqlalchemy.dialects.postgresql.ExcludeConstraint( 

211 *args, 

212 name=self.shrinkDatabaseEntityName("_".join(names)), 

213 ) 

214 

215 def _make_temporary_table( 

216 self, 

217 connection: sqlalchemy.engine.Connection, 

218 spec: ddl.TableSpec, 

219 name: Optional[str] = None, 

220 **kwargs: Any, 

221 ) -> sqlalchemy.schema.Table: 

222 # Docstring inherited 

223 # Adding ON COMMIT DROP here is really quite defensive: we already 

224 # manually drop the table at the end of the temporary_table context 

225 # manager, and that will usually happen first. But this will guarantee 

226 # that we drop the table at the end of the transaction even if the 

227 # connection lasts longer, and that's good citizenship when connections 

228 # may be multiplexed by e.g. pgbouncer. 

229 return super()._make_temporary_table(connection, spec, name, postgresql_on_commit="DROP", **kwargs) 

230 

231 @classmethod 

232 def getTimespanRepresentation(cls) -> Type[TimespanDatabaseRepresentation]: 

233 # Docstring inherited. 

234 return _RangeTimespanRepresentation 

235 

236 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None: 

237 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.") 

238 if not rows: 

239 return 

240 # This uses special support for UPSERT in PostgreSQL backend: 

241 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert 

242 query = sqlalchemy.dialects.postgresql.dml.insert(table) 

243 # In the SET clause assign all columns using special `excluded` 

244 # pseudo-table. If some column in the table does not appear in the 

245 # INSERT list this will set it to NULL. 

246 excluded = query.excluded 

247 data = { 

248 column.name: getattr(excluded, column.name) 

249 for column in table.columns 

250 if column.name not in table.primary_key 

251 } 

252 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data) 

253 with self._transaction() as (_, connection): 

254 connection.execute(query, rows) 

255 

256 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict, primary_key_only: bool = False) -> int: 

257 # Docstring inherited. 

258 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.") 

259 if not rows: 

260 return 0 

261 # Like `replace`, this uses UPSERT. 

262 base_insert = sqlalchemy.dialects.postgresql.dml.insert(table) 

263 if primary_key_only: 

264 query = base_insert.on_conflict_do_nothing(constraint=table.primary_key) 

265 else: 

266 query = base_insert.on_conflict_do_nothing() 

267 with self._transaction() as (_, connection): 

268 return connection.execute(query, rows).rowcount 

269 

270 def constant_rows( 

271 self, 

272 fields: NamedValueAbstractSet[ddl.FieldSpec], 

273 *rows: dict, 

274 name: Optional[str] = None, 

275 ) -> sqlalchemy.sql.FromClause: 

276 # Docstring inherited. 

277 return super().constant_rows(fields, *rows, name=name) 

278 

279 

280class _RangeTimespanType(sqlalchemy.TypeDecorator): 

281 """A single-column `Timespan` representation usable only with 

282 PostgreSQL. 

283 

284 This type should be able to take advantage of PostgreSQL's built-in 

285 range operators, and the indexing and EXCLUSION table constraints built 

286 off of them. 

287 """ 

288 

289 impl = sqlalchemy.dialects.postgresql.INT8RANGE 

290 

291 cache_ok = True 

292 

293 def process_bind_param( 

294 self, value: Optional[Timespan], dialect: sqlalchemy.engine.Dialect 

295 ) -> Optional[psycopg2.extras.NumericRange]: 

296 if value is None: 

297 return None 

298 if not isinstance(value, Timespan): 

299 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.") 

300 if value.isEmpty(): 

301 return psycopg2.extras.NumericRange(empty=True) 

302 else: 

303 converter = time_utils.TimeConverter() 

304 assert value._nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__." 

305 assert value._nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__." 

306 lower = None if value._nsec[0] == converter.min_nsec else value._nsec[0] 

307 upper = None if value._nsec[1] == converter.max_nsec else value._nsec[1] 

308 return psycopg2.extras.NumericRange(lower=lower, upper=upper) 

309 

310 def process_result_value( 

311 self, value: Optional[psycopg2.extras.NumericRange], dialect: sqlalchemy.engine.Dialect 

312 ) -> Optional[Timespan]: 

313 if value is None: 

314 return None 

315 if value.isempty: 

316 return Timespan.makeEmpty() 

317 converter = time_utils.TimeConverter() 

318 begin_nsec = converter.min_nsec if value.lower is None else value.lower 

319 end_nsec = converter.max_nsec if value.upper is None else value.upper 

320 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec)) 

321 

322 

323class _RangeTimespanRepresentation(TimespanDatabaseRepresentation): 

324 """An implementation of `TimespanDatabaseRepresentation` that uses 

325 `_RangeTimespanType` to store a timespan in a single 

326 PostgreSQL-specific field. 

327 

328 Parameters 

329 ---------- 

330 column : `sqlalchemy.sql.ColumnElement` 

331 SQLAlchemy object representing the column. 

332 """ 

333 

334 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str): 

335 self.column = column 

336 self._name = name 

337 

338 __slots__ = ("column", "_name") 

339 

340 @classmethod 

341 def makeFieldSpecs( 

342 cls, nullable: bool, name: Optional[str] = None, **kwargs: Any 

343 ) -> Tuple[ddl.FieldSpec, ...]: 

344 # Docstring inherited. 

345 if name is None: 

346 name = cls.NAME 

347 return ( 

348 ddl.FieldSpec( 

349 name, 

350 dtype=_RangeTimespanType, 

351 nullable=nullable, 

352 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")), 

353 **kwargs, 

354 ), 

355 ) 

356 

357 @classmethod 

358 def getFieldNames(cls, name: Optional[str] = None) -> Tuple[str, ...]: 

359 # Docstring inherited. 

360 if name is None: 

361 name = cls.NAME 

362 return (name,) 

363 

364 @classmethod 

365 def update( 

366 cls, extent: Optional[Timespan], name: Optional[str] = None, result: Optional[Dict[str, Any]] = None 

367 ) -> Dict[str, Any]: 

368 # Docstring inherited. 

369 if name is None: 

370 name = cls.NAME 

371 if result is None: 

372 result = {} 

373 result[name] = extent 

374 return result 

375 

376 @classmethod 

377 def extract(cls, mapping: Mapping[str, Any], name: Optional[str] = None) -> Optional[Timespan]: 

378 # Docstring inherited. 

379 if name is None: 

380 name = cls.NAME 

381 return mapping[name] 

382 

383 @classmethod 

384 def fromLiteral(cls, timespan: Optional[Timespan]) -> _RangeTimespanRepresentation: 

385 # Docstring inherited. 

386 if timespan is None: 

387 return cls(column=sqlalchemy.sql.null(), name=cls.NAME) 

388 return cls( 

389 column=sqlalchemy.sql.cast( 

390 sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType), type_=_RangeTimespanType 

391 ), 

392 name=cls.NAME, 

393 ) 

394 

395 @classmethod 

396 def from_columns( 

397 cls, columns: sqlalchemy.sql.ColumnCollection, name: Optional[str] = None 

398 ) -> _RangeTimespanRepresentation: 

399 # Docstring inherited. 

400 if name is None: 

401 name = cls.NAME 

402 return cls(columns[name], name) 

403 

404 @property 

405 def name(self) -> str: 

406 # Docstring inherited. 

407 return self._name 

408 

409 def isNull(self) -> sqlalchemy.sql.ColumnElement: 

410 # Docstring inherited. 

411 return self.column.is_(None) 

412 

413 def isEmpty(self) -> sqlalchemy.sql.ColumnElement: 

414 # Docstring inherited 

415 return sqlalchemy.sql.func.isempty(self.column) 

416 

417 def __lt__( 

418 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

419 ) -> sqlalchemy.sql.ColumnElement: 

420 # Docstring inherited. 

421 if isinstance(other, sqlalchemy.sql.ColumnElement): 

422 return sqlalchemy.sql.and_( 

423 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)), 

424 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

425 sqlalchemy.sql.func.upper(self.column) <= other, 

426 ) 

427 else: 

428 return self.column << other.column 

429 

430 def __gt__( 

431 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

432 ) -> sqlalchemy.sql.ColumnElement: 

433 # Docstring inherited. 

434 if isinstance(other, sqlalchemy.sql.ColumnElement): 

435 return sqlalchemy.sql.and_( 

436 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)), 

437 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

438 sqlalchemy.sql.func.lower(self.column) > other, 

439 ) 

440 else: 

441 return self.column >> other.column 

442 

443 def overlaps( 

444 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement 

445 ) -> sqlalchemy.sql.ColumnElement: 

446 # Docstring inherited. 

447 if not isinstance(other, _RangeTimespanRepresentation): 

448 return self.contains(other) 

449 return self.column.overlaps(other.column) 

450 

451 def contains( 

452 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

453 ) -> sqlalchemy.sql.ColumnElement: 

454 # Docstring inherited 

455 if isinstance(other, _RangeTimespanRepresentation): 

456 return self.column.contains(other.column) 

457 else: 

458 return self.column.contains(other) 

459 

460 def lower(self) -> sqlalchemy.sql.ColumnElement: 

461 # Docstring inherited. 

462 return sqlalchemy.sql.functions.coalesce( 

463 sqlalchemy.sql.func.lower(self.column), sqlalchemy.sql.literal(0) 

464 ) 

465 

466 def upper(self) -> sqlalchemy.sql.ColumnElement: 

467 # Docstring inherited. 

468 return sqlalchemy.sql.functions.coalesce( 

469 sqlalchemy.sql.func.upper(self.column), sqlalchemy.sql.literal(0) 

470 ) 

471 

472 def flatten(self, name: Optional[str] = None) -> tuple[sqlalchemy.sql.ColumnElement]: 

473 # Docstring inherited. 

474 if name is None: 

475 return (self.column,) 

476 else: 

477 return (self.column.label(name),)