Coverage for python/lsst/daf/butler/registry/databases/postgresql.py: 30%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

194 statements  

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ["PostgresqlDatabase"] 

24 

25from contextlib import closing, contextmanager 

26from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, Type, Union 

27 

28import psycopg2 

29import sqlalchemy 

30import sqlalchemy.dialects.postgresql 

31 

32from ...core import Timespan, TimespanDatabaseRepresentation, ddl, time_utils 

33from ..interfaces import Database 

34from ..nameShrinker import NameShrinker 

35 

36 

37class PostgresqlDatabase(Database): 

38 """An implementation of the `Database` interface for PostgreSQL. 

39 

40 Parameters 

41 ---------- 

42 connection : `sqlalchemy.engine.Connection` 

43 An existing connection created by a previous call to `connect`. 

44 origin : `int` 

45 An integer ID that should be used as the default for any datasets, 

46 quanta, or other entities that use a (autoincrement, origin) compound 

47 primary key. 

48 namespace : `str`, optional 

49 The namespace (schema) this database is associated with. If `None`, 

50 the default schema for the connection is used (which may be `None`). 

51 writeable : `bool`, optional 

52 If `True`, allow write operations on the database, including 

53 ``CREATE TABLE``. 

54 

55 Notes 

56 ----- 

57 This currently requires the psycopg2 driver to be used as the backend for 

58 SQLAlchemy. Running the tests for this class requires the 

59 ``testing.postgresql`` be installed, which we assume indicates that a 

60 PostgreSQL server is installed and can be run locally in userspace. 

61 

62 Some functionality provided by this class (and used by `Registry`) requires 

63 the ``btree_gist`` PostgreSQL server extension to be installed an enabled 

64 on the database being connected to; this is checked at connection time. 

65 """ 

66 

67 def __init__( 

68 self, 

69 *, 

70 engine: sqlalchemy.engine.Engine, 

71 origin: int, 

72 namespace: Optional[str] = None, 

73 writeable: bool = True, 

74 ): 

75 super().__init__(origin=origin, engine=engine, namespace=namespace) 

76 with engine.connect() as connection: 

77 dbapi = connection.connection 

78 try: 

79 dsn = dbapi.get_dsn_parameters() 

80 except (AttributeError, KeyError) as err: 

81 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err 

82 if namespace is None: 

83 namespace = connection.execute("SELECT current_schema();").scalar() 

84 query = "SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';" 

85 if not connection.execute(sqlalchemy.text(query)).scalar(): 

86 raise RuntimeError( 

87 "The Butler PostgreSQL backend requires the btree_gist extension. " 

88 "As extensions are enabled per-database, this may require an administrator to run " 

89 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is " 

90 " initialized." 

91 ) 

92 self.namespace = namespace 

93 self.dbname = dsn.get("dbname") 

94 self._writeable = writeable 

95 self._shrinker = NameShrinker(connection.engine.dialect.max_identifier_length) 

96 

97 @classmethod 

98 def makeEngine(cls, uri: str, *, writeable: bool = True) -> sqlalchemy.engine.Engine: 

99 return sqlalchemy.engine.create_engine(uri) 

100 

101 @classmethod 

102 def fromEngine( 

103 cls, 

104 engine: sqlalchemy.engine.Engine, 

105 *, 

106 origin: int, 

107 namespace: Optional[str] = None, 

108 writeable: bool = True, 

109 ) -> Database: 

110 return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable) 

111 

112 @contextmanager 

113 def transaction( 

114 self, 

115 *, 

116 interrupting: bool = False, 

117 savepoint: bool = False, 

118 lock: Iterable[sqlalchemy.schema.Table] = (), 

119 ) -> Iterator[None]: 

120 with super().transaction(interrupting=interrupting, savepoint=savepoint, lock=lock): 

121 assert self._session_connection is not None, "Guaranteed to have a connection in transaction" 

122 if not self.isWriteable(): 

123 with closing(self._session_connection.connection.cursor()) as cursor: 

124 cursor.execute("SET TRANSACTION READ ONLY") 

125 else: 

126 with closing(self._session_connection.connection.cursor()) as cursor: 

127 # Make timestamps UTC, because we didn't use TIMESTAMPZ for 

128 # the column type. When we can tolerate a schema change, 

129 # we should change that type and remove this line. 

130 cursor.execute("SET TIME ZONE 0") 

131 yield 

132 

133 def _lockTables( 

134 self, connection: sqlalchemy.engine.Connection, tables: Iterable[sqlalchemy.schema.Table] = () 

135 ) -> None: 

136 # Docstring inherited. 

137 for table in tables: 

138 connection.execute(sqlalchemy.text(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE")) 

139 

140 def isWriteable(self) -> bool: 

141 return self._writeable 

142 

143 def __str__(self) -> str: 

144 return f"PostgreSQL@{self.dbname}:{self.namespace}" 

145 

146 def shrinkDatabaseEntityName(self, original: str) -> str: 

147 return self._shrinker.shrink(original) 

148 

149 def expandDatabaseEntityName(self, shrunk: str) -> str: 

150 return self._shrinker.expand(shrunk) 

151 

152 def _convertExclusionConstraintSpec( 

153 self, 

154 table: str, 

155 spec: Tuple[Union[str, Type[TimespanDatabaseRepresentation]], ...], 

156 metadata: sqlalchemy.MetaData, 

157 ) -> sqlalchemy.schema.Constraint: 

158 # Docstring inherited. 

159 args = [] 

160 names = ["excl"] 

161 for item in spec: 

162 if isinstance(item, str): 

163 args.append((sqlalchemy.schema.Column(item), "=")) 

164 names.append(item) 

165 elif issubclass(item, TimespanDatabaseRepresentation): 

166 assert item is self.getTimespanRepresentation() 

167 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&")) 

168 names.append(TimespanDatabaseRepresentation.NAME) 

169 return sqlalchemy.dialects.postgresql.ExcludeConstraint( 

170 *args, 

171 name=self.shrinkDatabaseEntityName("_".join(names)), 

172 ) 

173 

174 @classmethod 

175 def getTimespanRepresentation(cls) -> Type[TimespanDatabaseRepresentation]: 

176 # Docstring inherited. 

177 return _RangeTimespanRepresentation 

178 

179 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None: 

180 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.") 

181 if not rows: 

182 return 

183 # This uses special support for UPSERT in PostgreSQL backend: 

184 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert 

185 query = sqlalchemy.dialects.postgresql.dml.insert(table) 

186 # In the SET clause assign all columns using special `excluded` 

187 # pseudo-table. If some column in the table does not appear in the 

188 # INSERT list this will set it to NULL. 

189 excluded = query.excluded 

190 data = { 

191 column.name: getattr(excluded, column.name) 

192 for column in table.columns 

193 if column.name not in table.primary_key 

194 } 

195 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data) 

196 with self._connection() as connection: 

197 connection.execute(query, rows) 

198 

199 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict) -> int: 

200 # Docstring inherited. 

201 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.") 

202 if not rows: 

203 return 0 

204 # Like `replace`, this uses UPSERT, but it's a bit simpler because 

205 # we don't care which constraint is violated or specify which columns 

206 # to update. 

207 query = sqlalchemy.dialects.postgresql.dml.insert(table).on_conflict_do_nothing() 

208 with self._connection() as connection: 

209 return connection.execute(query, rows).rowcount 

210 

211 

212class _RangeTimespanType(sqlalchemy.TypeDecorator): 

213 """A single-column `Timespan` representation usable only with 

214 PostgreSQL. 

215 

216 This type should be able to take advantage of PostgreSQL's built-in 

217 range operators, and the indexing and EXCLUSION table constraints built 

218 off of them. 

219 """ 

220 

221 impl = sqlalchemy.dialects.postgresql.INT8RANGE 

222 

223 cache_ok = True 

224 

225 def process_bind_param( 

226 self, value: Optional[Timespan], dialect: sqlalchemy.engine.Dialect 

227 ) -> Optional[psycopg2.extras.NumericRange]: 

228 if value is None: 

229 return None 

230 if not isinstance(value, Timespan): 

231 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.") 

232 if value.isEmpty(): 

233 return psycopg2.extras.NumericRange(empty=True) 

234 else: 

235 converter = time_utils.TimeConverter() 

236 assert value._nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__." 

237 assert value._nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__." 

238 lower = None if value._nsec[0] == converter.min_nsec else value._nsec[0] 

239 upper = None if value._nsec[1] == converter.max_nsec else value._nsec[1] 

240 return psycopg2.extras.NumericRange(lower=lower, upper=upper) 

241 

242 def process_result_value( 

243 self, value: Optional[psycopg2.extras.NumericRange], dialect: sqlalchemy.engine.Dialect 

244 ) -> Optional[Timespan]: 

245 if value is None: 

246 return None 

247 if value.isempty: 

248 return Timespan.makeEmpty() 

249 converter = time_utils.TimeConverter() 

250 begin_nsec = converter.min_nsec if value.lower is None else value.lower 

251 end_nsec = converter.max_nsec if value.upper is None else value.upper 

252 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec)) 

253 

254 class comparator_factory(sqlalchemy.types.Concatenable.Comparator): # noqa: N801 

255 """Comparison operators for TimespanColumnRanges. 

256 

257 Notes 

258 ----- 

259 The existence of this nested class is a workaround for a bug 

260 submitted upstream as 

261 https://github.com/sqlalchemy/sqlalchemy/issues/5476 (now fixed on 

262 main, but not in the releases we currently use). The code is 

263 a limited copy of the operators in 

264 ``sqlalchemy.dialects.postgresql.ranges.RangeOperators``, but with 

265 ``is_comparison=True`` added to all calls. 

266 """ 

267 

268 def __ne__(self, other: Any) -> Any: 

269 "Boolean expression. Returns true if two ranges are not equal" 

270 if other is None: 

271 return super().__ne__(other) 

272 else: 

273 return self.expr.op("<>", is_comparison=True)(other) 

274 

275 def contains(self, other: Any, **kw: Any) -> Any: 

276 """Boolean expression. Returns true if the right hand operand, 

277 which can be an element or a range, is contained within the 

278 column. 

279 """ 

280 return self.expr.op("@>", is_comparison=True)(other) 

281 

282 def contained_by(self, other: Any) -> Any: 

283 """Boolean expression. Returns true if the column is contained 

284 within the right hand operand. 

285 """ 

286 return self.expr.op("<@", is_comparison=True)(other) 

287 

288 def overlaps(self, other: Any) -> Any: 

289 """Boolean expression. Returns true if the column overlaps 

290 (has points in common with) the right hand operand. 

291 """ 

292 return self.expr.op("&&", is_comparison=True)(other) 

293 

294 def strictly_left_of(self, other: Any) -> Any: 

295 """Boolean expression. Returns true if the column is strictly 

296 left of the right hand operand. 

297 """ 

298 return self.expr.op("<<", is_comparison=True)(other) 

299 

300 __lshift__ = strictly_left_of 

301 

302 def strictly_right_of(self, other: Any) -> Any: 

303 """Boolean expression. Returns true if the column is strictly 

304 right of the right hand operand. 

305 """ 

306 return self.expr.op(">>", is_comparison=True)(other) 

307 

308 __rshift__ = strictly_right_of 

309 

310 

311class _RangeTimespanRepresentation(TimespanDatabaseRepresentation): 

312 """An implementation of `TimespanDatabaseRepresentation` that uses 

313 `_RangeTimespanType` to store a timespan in a single 

314 PostgreSQL-specific field. 

315 

316 Parameters 

317 ---------- 

318 column : `sqlalchemy.sql.ColumnElement` 

319 SQLAlchemy object representing the column. 

320 """ 

321 

322 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str): 

323 self.column = column 

324 self._name = name 

325 

326 __slots__ = ("column", "_name") 

327 

328 @classmethod 

329 def makeFieldSpecs( 

330 cls, nullable: bool, name: Optional[str] = None, **kwargs: Any 

331 ) -> Tuple[ddl.FieldSpec, ...]: 

332 # Docstring inherited. 

333 if name is None: 

334 name = cls.NAME 

335 return ( 

336 ddl.FieldSpec( 

337 name, 

338 dtype=_RangeTimespanType, 

339 nullable=nullable, 

340 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")), 

341 **kwargs, 

342 ), 

343 ) 

344 

345 @classmethod 

346 def getFieldNames(cls, name: Optional[str] = None) -> Tuple[str, ...]: 

347 # Docstring inherited. 

348 if name is None: 

349 name = cls.NAME 

350 return (name,) 

351 

352 @classmethod 

353 def update( 

354 cls, extent: Optional[Timespan], name: Optional[str] = None, result: Optional[Dict[str, Any]] = None 

355 ) -> Dict[str, Any]: 

356 # Docstring inherited. 

357 if name is None: 

358 name = cls.NAME 

359 if result is None: 

360 result = {} 

361 result[name] = extent 

362 return result 

363 

364 @classmethod 

365 def extract(cls, mapping: Mapping[str, Any], name: Optional[str] = None) -> Optional[Timespan]: 

366 # Docstring inherited. 

367 if name is None: 

368 name = cls.NAME 

369 return mapping[name] 

370 

371 @classmethod 

372 def fromLiteral(cls, timespan: Timespan) -> _RangeTimespanRepresentation: 

373 # Docstring inherited. 

374 return cls( 

375 column=sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType), 

376 name=cls.NAME, 

377 ) 

378 

379 @classmethod 

380 def fromSelectable( 

381 cls, selectable: sqlalchemy.sql.FromClause, name: Optional[str] = None 

382 ) -> _RangeTimespanRepresentation: 

383 # Docstring inherited. 

384 if name is None: 

385 name = cls.NAME 

386 return cls(selectable.columns[name], name) 

387 

388 @property 

389 def name(self) -> str: 

390 # Docstring inherited. 

391 return self._name 

392 

393 def isNull(self) -> sqlalchemy.sql.ColumnElement: 

394 # Docstring inherited. 

395 return self.column.is_(None) 

396 

397 def isEmpty(self) -> sqlalchemy.sql.ColumnElement: 

398 # Docstring inherited 

399 return sqlalchemy.sql.func.isempty(self.column) 

400 

401 def __lt__( 

402 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

403 ) -> sqlalchemy.sql.ColumnElement: 

404 # Docstring inherited. 

405 if isinstance(other, sqlalchemy.sql.ColumnElement): 

406 return sqlalchemy.sql.and_( 

407 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)), 

408 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

409 sqlalchemy.sql.func.upper(self.column) <= other, 

410 ) 

411 else: 

412 return self.column << other.column 

413 

414 def __gt__( 

415 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

416 ) -> sqlalchemy.sql.ColumnElement: 

417 # Docstring inherited. 

418 if isinstance(other, sqlalchemy.sql.ColumnElement): 

419 return sqlalchemy.sql.and_( 

420 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)), 

421 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

422 sqlalchemy.sql.func.lower(self.column) > other, 

423 ) 

424 else: 

425 return self.column >> other.column 

426 

427 def overlaps(self, other: _RangeTimespanRepresentation) -> sqlalchemy.sql.ColumnElement: 

428 # Docstring inherited. 

429 return self.column.overlaps(other.column) 

430 

431 def contains( 

432 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

433 ) -> sqlalchemy.sql.ColumnElement: 

434 # Docstring inherited 

435 if isinstance(other, _RangeTimespanRepresentation): 

436 return self.column.contains(other.column) 

437 else: 

438 return self.column.contains(other) 

439 

440 def lower(self) -> sqlalchemy.sql.ColumnElement: 

441 # Docstring inherited. 

442 return sqlalchemy.sql.functions.coalesce( 

443 sqlalchemy.sql.func.lower(self.column), sqlalchemy.sql.literal(0) 

444 ) 

445 

446 def upper(self) -> sqlalchemy.sql.ColumnElement: 

447 # Docstring inherited. 

448 return sqlalchemy.sql.functions.coalesce( 

449 sqlalchemy.sql.func.upper(self.column), sqlalchemy.sql.literal(0) 

450 ) 

451 

452 def flatten(self, name: Optional[str] = None) -> Iterator[sqlalchemy.sql.ColumnElement]: 

453 # Docstring inherited. 

454 if name is None: 

455 yield self.column 

456 else: 

457 yield self.column.label(name)