Coverage for python/lsst/daf/butler/registry/databases/postgresql.py: 28%

197 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-08-31 04:05 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ["PostgresqlDatabase"] 

24 

25from contextlib import closing, contextmanager 

26from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, Type, Union 

27 

28import psycopg2 

29import sqlalchemy 

30import sqlalchemy.dialects.postgresql 

31 

32from ...core import Timespan, TimespanDatabaseRepresentation, ddl, time_utils 

33from ..interfaces import Database 

34from ..nameShrinker import NameShrinker 

35 

36 

37class PostgresqlDatabase(Database): 

38 """An implementation of the `Database` interface for PostgreSQL. 

39 

40 Parameters 

41 ---------- 

42 connection : `sqlalchemy.engine.Connection` 

43 An existing connection created by a previous call to `connect`. 

44 origin : `int` 

45 An integer ID that should be used as the default for any datasets, 

46 quanta, or other entities that use a (autoincrement, origin) compound 

47 primary key. 

48 namespace : `str`, optional 

49 The namespace (schema) this database is associated with. If `None`, 

50 the default schema for the connection is used (which may be `None`). 

51 writeable : `bool`, optional 

52 If `True`, allow write operations on the database, including 

53 ``CREATE TABLE``. 

54 

55 Notes 

56 ----- 

57 This currently requires the psycopg2 driver to be used as the backend for 

58 SQLAlchemy. Running the tests for this class requires the 

59 ``testing.postgresql`` be installed, which we assume indicates that a 

60 PostgreSQL server is installed and can be run locally in userspace. 

61 

62 Some functionality provided by this class (and used by `Registry`) requires 

63 the ``btree_gist`` PostgreSQL server extension to be installed an enabled 

64 on the database being connected to; this is checked at connection time. 

65 """ 

66 

67 def __init__( 

68 self, 

69 *, 

70 engine: sqlalchemy.engine.Engine, 

71 origin: int, 

72 namespace: Optional[str] = None, 

73 writeable: bool = True, 

74 ): 

75 super().__init__(origin=origin, engine=engine, namespace=namespace) 

76 with engine.connect() as connection: 

77 dbapi = connection.connection 

78 try: 

79 dsn = dbapi.get_dsn_parameters() 

80 except (AttributeError, KeyError) as err: 

81 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err 

82 if namespace is None: 

83 namespace = connection.execute("SELECT current_schema();").scalar() 

84 query = "SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';" 

85 if not connection.execute(sqlalchemy.text(query)).scalar(): 

86 raise RuntimeError( 

87 "The Butler PostgreSQL backend requires the btree_gist extension. " 

88 "As extensions are enabled per-database, this may require an administrator to run " 

89 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is " 

90 " initialized." 

91 ) 

92 self.namespace = namespace 

93 self.dbname = dsn.get("dbname") 

94 self._writeable = writeable 

95 self._shrinker = NameShrinker(connection.engine.dialect.max_identifier_length) 

96 

97 @classmethod 

98 def makeEngine(cls, uri: str, *, writeable: bool = True) -> sqlalchemy.engine.Engine: 

99 return sqlalchemy.engine.create_engine(uri) 

100 

101 @classmethod 

102 def fromEngine( 

103 cls, 

104 engine: sqlalchemy.engine.Engine, 

105 *, 

106 origin: int, 

107 namespace: Optional[str] = None, 

108 writeable: bool = True, 

109 ) -> Database: 

110 return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable) 

111 

112 @contextmanager 

113 def transaction( 

114 self, 

115 *, 

116 interrupting: bool = False, 

117 savepoint: bool = False, 

118 lock: Iterable[sqlalchemy.schema.Table] = (), 

119 ) -> Iterator[None]: 

120 with super().transaction(interrupting=interrupting, savepoint=savepoint, lock=lock): 

121 assert self._session_connection is not None, "Guaranteed to have a connection in transaction" 

122 if not self.isWriteable(): 

123 with closing(self._session_connection.connection.cursor()) as cursor: 

124 cursor.execute("SET TRANSACTION READ ONLY") 

125 else: 

126 with closing(self._session_connection.connection.cursor()) as cursor: 

127 # Make timestamps UTC, because we didn't use TIMESTAMPZ for 

128 # the column type. When we can tolerate a schema change, 

129 # we should change that type and remove this line. 

130 cursor.execute("SET TIME ZONE 0") 

131 yield 

132 

133 def _lockTables( 

134 self, connection: sqlalchemy.engine.Connection, tables: Iterable[sqlalchemy.schema.Table] = () 

135 ) -> None: 

136 # Docstring inherited. 

137 for table in tables: 

138 connection.execute(sqlalchemy.text(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE")) 

139 

140 def isWriteable(self) -> bool: 

141 return self._writeable 

142 

143 def __str__(self) -> str: 

144 return f"PostgreSQL@{self.dbname}:{self.namespace}" 

145 

146 def shrinkDatabaseEntityName(self, original: str) -> str: 

147 return self._shrinker.shrink(original) 

148 

149 def expandDatabaseEntityName(self, shrunk: str) -> str: 

150 return self._shrinker.expand(shrunk) 

151 

152 def _convertExclusionConstraintSpec( 

153 self, 

154 table: str, 

155 spec: Tuple[Union[str, Type[TimespanDatabaseRepresentation]], ...], 

156 metadata: sqlalchemy.MetaData, 

157 ) -> sqlalchemy.schema.Constraint: 

158 # Docstring inherited. 

159 args = [] 

160 names = ["excl"] 

161 for item in spec: 

162 if isinstance(item, str): 

163 args.append((sqlalchemy.schema.Column(item), "=")) 

164 names.append(item) 

165 elif issubclass(item, TimespanDatabaseRepresentation): 

166 assert item is self.getTimespanRepresentation() 

167 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&")) 

168 names.append(TimespanDatabaseRepresentation.NAME) 

169 return sqlalchemy.dialects.postgresql.ExcludeConstraint( 

170 *args, 

171 name=self.shrinkDatabaseEntityName("_".join(names)), 

172 ) 

173 

174 @classmethod 

175 def getTimespanRepresentation(cls) -> Type[TimespanDatabaseRepresentation]: 

176 # Docstring inherited. 

177 return _RangeTimespanRepresentation 

178 

179 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None: 

180 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.") 

181 if not rows: 

182 return 

183 # This uses special support for UPSERT in PostgreSQL backend: 

184 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert 

185 query = sqlalchemy.dialects.postgresql.dml.insert(table) 

186 # In the SET clause assign all columns using special `excluded` 

187 # pseudo-table. If some column in the table does not appear in the 

188 # INSERT list this will set it to NULL. 

189 excluded = query.excluded 

190 data = { 

191 column.name: getattr(excluded, column.name) 

192 for column in table.columns 

193 if column.name not in table.primary_key 

194 } 

195 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data) 

196 with self._connection() as connection: 

197 connection.execute(query, rows) 

198 

199 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict, primary_key_only: bool = False) -> int: 

200 # Docstring inherited. 

201 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.") 

202 if not rows: 

203 return 0 

204 # Like `replace`, this uses UPSERT. 

205 base_insert = sqlalchemy.dialects.postgresql.dml.insert(table) 

206 if primary_key_only: 

207 query = base_insert.on_conflict_do_nothing(constraint=table.primary_key) 

208 else: 

209 query = base_insert.on_conflict_do_nothing() 

210 with self._connection() as connection: 

211 return connection.execute(query, rows).rowcount 

212 

213 

214class _RangeTimespanType(sqlalchemy.TypeDecorator): 

215 """A single-column `Timespan` representation usable only with 

216 PostgreSQL. 

217 

218 This type should be able to take advantage of PostgreSQL's built-in 

219 range operators, and the indexing and EXCLUSION table constraints built 

220 off of them. 

221 """ 

222 

223 impl = sqlalchemy.dialects.postgresql.INT8RANGE 

224 

225 cache_ok = True 

226 

227 def process_bind_param( 

228 self, value: Optional[Timespan], dialect: sqlalchemy.engine.Dialect 

229 ) -> Optional[psycopg2.extras.NumericRange]: 

230 if value is None: 

231 return None 

232 if not isinstance(value, Timespan): 

233 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.") 

234 if value.isEmpty(): 

235 return psycopg2.extras.NumericRange(empty=True) 

236 else: 

237 converter = time_utils.TimeConverter() 

238 assert value._nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__." 

239 assert value._nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__." 

240 lower = None if value._nsec[0] == converter.min_nsec else value._nsec[0] 

241 upper = None if value._nsec[1] == converter.max_nsec else value._nsec[1] 

242 return psycopg2.extras.NumericRange(lower=lower, upper=upper) 

243 

244 def process_result_value( 

245 self, value: Optional[psycopg2.extras.NumericRange], dialect: sqlalchemy.engine.Dialect 

246 ) -> Optional[Timespan]: 

247 if value is None: 

248 return None 

249 if value.isempty: 

250 return Timespan.makeEmpty() 

251 converter = time_utils.TimeConverter() 

252 begin_nsec = converter.min_nsec if value.lower is None else value.lower 

253 end_nsec = converter.max_nsec if value.upper is None else value.upper 

254 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec)) 

255 

256 class comparator_factory(sqlalchemy.types.Concatenable.Comparator): # noqa: N801 

257 """Comparison operators for TimespanColumnRanges. 

258 

259 Notes 

260 ----- 

261 The existence of this nested class is a workaround for a bug 

262 submitted upstream as 

263 https://github.com/sqlalchemy/sqlalchemy/issues/5476 (now fixed on 

264 main, but not in the releases we currently use). The code is 

265 a limited copy of the operators in 

266 ``sqlalchemy.dialects.postgresql.ranges.RangeOperators``, but with 

267 ``is_comparison=True`` added to all calls. 

268 """ 

269 

270 def __ne__(self, other: Any) -> Any: 

271 "Boolean expression. Returns true if two ranges are not equal" 

272 if other is None: 

273 return super().__ne__(other) 

274 else: 

275 return self.expr.op("<>", is_comparison=True)(other) 

276 

277 def contains(self, other: Any, **kw: Any) -> Any: 

278 """Boolean expression. Returns true if the right hand operand, 

279 which can be an element or a range, is contained within the 

280 column. 

281 """ 

282 return self.expr.op("@>", is_comparison=True)(other) 

283 

284 def contained_by(self, other: Any) -> Any: 

285 """Boolean expression. Returns true if the column is contained 

286 within the right hand operand. 

287 """ 

288 return self.expr.op("<@", is_comparison=True)(other) 

289 

290 def overlaps(self, other: Any) -> Any: 

291 """Boolean expression. Returns true if the column overlaps 

292 (has points in common with) the right hand operand. 

293 """ 

294 return self.expr.op("&&", is_comparison=True)(other) 

295 

296 def strictly_left_of(self, other: Any) -> Any: 

297 """Boolean expression. Returns true if the column is strictly 

298 left of the right hand operand. 

299 """ 

300 return self.expr.op("<<", is_comparison=True)(other) 

301 

302 __lshift__ = strictly_left_of 

303 

304 def strictly_right_of(self, other: Any) -> Any: 

305 """Boolean expression. Returns true if the column is strictly 

306 right of the right hand operand. 

307 """ 

308 return self.expr.op(">>", is_comparison=True)(other) 

309 

310 __rshift__ = strictly_right_of 

311 

312 

313class _RangeTimespanRepresentation(TimespanDatabaseRepresentation): 

314 """An implementation of `TimespanDatabaseRepresentation` that uses 

315 `_RangeTimespanType` to store a timespan in a single 

316 PostgreSQL-specific field. 

317 

318 Parameters 

319 ---------- 

320 column : `sqlalchemy.sql.ColumnElement` 

321 SQLAlchemy object representing the column. 

322 """ 

323 

324 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str): 

325 self.column = column 

326 self._name = name 

327 

328 __slots__ = ("column", "_name") 

329 

330 @classmethod 

331 def makeFieldSpecs( 

332 cls, nullable: bool, name: Optional[str] = None, **kwargs: Any 

333 ) -> Tuple[ddl.FieldSpec, ...]: 

334 # Docstring inherited. 

335 if name is None: 

336 name = cls.NAME 

337 return ( 

338 ddl.FieldSpec( 

339 name, 

340 dtype=_RangeTimespanType, 

341 nullable=nullable, 

342 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")), 

343 **kwargs, 

344 ), 

345 ) 

346 

347 @classmethod 

348 def getFieldNames(cls, name: Optional[str] = None) -> Tuple[str, ...]: 

349 # Docstring inherited. 

350 if name is None: 

351 name = cls.NAME 

352 return (name,) 

353 

354 @classmethod 

355 def update( 

356 cls, extent: Optional[Timespan], name: Optional[str] = None, result: Optional[Dict[str, Any]] = None 

357 ) -> Dict[str, Any]: 

358 # Docstring inherited. 

359 if name is None: 

360 name = cls.NAME 

361 if result is None: 

362 result = {} 

363 result[name] = extent 

364 return result 

365 

366 @classmethod 

367 def extract(cls, mapping: Mapping[str, Any], name: Optional[str] = None) -> Optional[Timespan]: 

368 # Docstring inherited. 

369 if name is None: 

370 name = cls.NAME 

371 return mapping[name] 

372 

373 @classmethod 

374 def fromLiteral(cls, timespan: Timespan) -> _RangeTimespanRepresentation: 

375 # Docstring inherited. 

376 return cls( 

377 column=sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType), 

378 name=cls.NAME, 

379 ) 

380 

381 @classmethod 

382 def fromSelectable( 

383 cls, selectable: sqlalchemy.sql.FromClause, name: Optional[str] = None 

384 ) -> _RangeTimespanRepresentation: 

385 # Docstring inherited. 

386 if name is None: 

387 name = cls.NAME 

388 return cls(selectable.columns[name], name) 

389 

390 @property 

391 def name(self) -> str: 

392 # Docstring inherited. 

393 return self._name 

394 

395 def isNull(self) -> sqlalchemy.sql.ColumnElement: 

396 # Docstring inherited. 

397 return self.column.is_(None) 

398 

399 def isEmpty(self) -> sqlalchemy.sql.ColumnElement: 

400 # Docstring inherited 

401 return sqlalchemy.sql.func.isempty(self.column) 

402 

403 def __lt__( 

404 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

405 ) -> sqlalchemy.sql.ColumnElement: 

406 # Docstring inherited. 

407 if isinstance(other, sqlalchemy.sql.ColumnElement): 

408 return sqlalchemy.sql.and_( 

409 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)), 

410 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

411 sqlalchemy.sql.func.upper(self.column) <= other, 

412 ) 

413 else: 

414 return self.column << other.column 

415 

416 def __gt__( 

417 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

418 ) -> sqlalchemy.sql.ColumnElement: 

419 # Docstring inherited. 

420 if isinstance(other, sqlalchemy.sql.ColumnElement): 

421 return sqlalchemy.sql.and_( 

422 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)), 

423 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

424 sqlalchemy.sql.func.lower(self.column) > other, 

425 ) 

426 else: 

427 return self.column >> other.column 

428 

429 def overlaps(self, other: _RangeTimespanRepresentation) -> sqlalchemy.sql.ColumnElement: 

430 # Docstring inherited. 

431 return self.column.overlaps(other.column) 

432 

433 def contains( 

434 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

435 ) -> sqlalchemy.sql.ColumnElement: 

436 # Docstring inherited 

437 if isinstance(other, _RangeTimespanRepresentation): 

438 return self.column.contains(other.column) 

439 else: 

440 return self.column.contains(other) 

441 

442 def lower(self) -> sqlalchemy.sql.ColumnElement: 

443 # Docstring inherited. 

444 return sqlalchemy.sql.functions.coalesce( 

445 sqlalchemy.sql.func.lower(self.column), sqlalchemy.sql.literal(0) 

446 ) 

447 

448 def upper(self) -> sqlalchemy.sql.ColumnElement: 

449 # Docstring inherited. 

450 return sqlalchemy.sql.functions.coalesce( 

451 sqlalchemy.sql.func.upper(self.column), sqlalchemy.sql.literal(0) 

452 ) 

453 

454 def flatten(self, name: Optional[str] = None) -> Iterator[sqlalchemy.sql.ColumnElement]: 

455 # Docstring inherited. 

456 if name is None: 

457 yield self.column 

458 else: 

459 yield self.column.label(name)