Coverage for python/lsst/daf/butler/registry/databases/postgresql.py: 27%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

190 statements  

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ["PostgresqlDatabase"] 

24 

25from contextlib import contextmanager, closing 

26from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, Type, Union 

27 

28import psycopg2 

29import sqlalchemy 

30import sqlalchemy.dialects.postgresql 

31 

32from ..interfaces import Database 

33from ..nameShrinker import NameShrinker 

34from ...core import ddl, time_utils, Timespan, TimespanDatabaseRepresentation 

35 

36 

37class PostgresqlDatabase(Database): 

38 """An implementation of the `Database` interface for PostgreSQL. 

39 

40 Parameters 

41 ---------- 

42 connection : `sqlalchemy.engine.Connection` 

43 An existing connection created by a previous call to `connect`. 

44 origin : `int` 

45 An integer ID that should be used as the default for any datasets, 

46 quanta, or other entities that use a (autoincrement, origin) compound 

47 primary key. 

48 namespace : `str`, optional 

49 The namespace (schema) this database is associated with. If `None`, 

50 the default schema for the connection is used (which may be `None`). 

51 writeable : `bool`, optional 

52 If `True`, allow write operations on the database, including 

53 ``CREATE TABLE``. 

54 

55 Notes 

56 ----- 

57 This currently requires the psycopg2 driver to be used as the backend for 

58 SQLAlchemy. Running the tests for this class requires the 

59 ``testing.postgresql`` be installed, which we assume indicates that a 

60 PostgreSQL server is installed and can be run locally in userspace. 

61 

62 Some functionality provided by this class (and used by `Registry`) requires 

63 the ``btree_gist`` PostgreSQL server extension to be installed an enabled 

64 on the database being connected to; this is checked at connection time. 

65 """ 

66 

67 def __init__(self, *, engine: sqlalchemy.engine.Engine, origin: int, 

68 namespace: Optional[str] = None, writeable: bool = True): 

69 super().__init__(origin=origin, engine=engine, namespace=namespace) 

70 with engine.connect() as connection: 

71 dbapi = connection.connection 

72 try: 

73 dsn = dbapi.get_dsn_parameters() 

74 except (AttributeError, KeyError) as err: 

75 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err 

76 if namespace is None: 

77 namespace = connection.execute("SELECT current_schema();").scalar() 

78 query = "SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';" 

79 if not connection.execute(sqlalchemy.text(query)).scalar(): 

80 raise RuntimeError( 

81 "The Butler PostgreSQL backend requires the btree_gist extension. " 

82 "As extensions are enabled per-database, this may require an administrator to run " 

83 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is " 

84 " initialized." 

85 ) 

86 self.namespace = namespace 

87 self.dbname = dsn.get("dbname") 

88 self._writeable = writeable 

89 self._shrinker = NameShrinker(connection.engine.dialect.max_identifier_length) 

90 

91 @classmethod 

92 def makeEngine(cls, uri: str, *, writeable: bool = True) -> sqlalchemy.engine.Engine: 

93 return sqlalchemy.engine.create_engine(uri) 

94 

95 @classmethod 

96 def fromEngine(cls, engine: sqlalchemy.engine.Engine, *, origin: int, 

97 namespace: Optional[str] = None, writeable: bool = True) -> Database: 

98 return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable) 

99 

100 @contextmanager 

101 def transaction(self, *, interrupting: bool = False, savepoint: bool = False, 

102 lock: Iterable[sqlalchemy.schema.Table] = ()) -> Iterator[None]: 

103 with super().transaction(interrupting=interrupting, savepoint=savepoint, lock=lock): 

104 assert self._session_connection is not None, "Guaranteed to have a connection in transaction" 

105 if not self.isWriteable(): 

106 with closing(self._session_connection.connection.cursor()) as cursor: 

107 cursor.execute("SET TRANSACTION READ ONLY") 

108 else: 

109 with closing(self._session_connection.connection.cursor()) as cursor: 

110 # Make timestamps UTC, because we didn't use TIMESTAMPZ for 

111 # the column type. When we can tolerate a schema change, 

112 # we should change that type and remove this line. 

113 cursor.execute("SET TIME ZONE 0") 

114 yield 

115 

116 def _lockTables(self, connection: sqlalchemy.engine.Connection, 

117 tables: Iterable[sqlalchemy.schema.Table] = ()) -> None: 

118 # Docstring inherited. 

119 for table in tables: 

120 connection.execute(sqlalchemy.text(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE")) 

121 

122 def isWriteable(self) -> bool: 

123 return self._writeable 

124 

125 def __str__(self) -> str: 

126 return f"PostgreSQL@{self.dbname}:{self.namespace}" 

127 

128 def shrinkDatabaseEntityName(self, original: str) -> str: 

129 return self._shrinker.shrink(original) 

130 

131 def expandDatabaseEntityName(self, shrunk: str) -> str: 

132 return self._shrinker.expand(shrunk) 

133 

134 def _convertExclusionConstraintSpec(self, table: str, 

135 spec: Tuple[Union[str, Type[TimespanDatabaseRepresentation]], ...], 

136 metadata: sqlalchemy.MetaData) -> sqlalchemy.schema.Constraint: 

137 # Docstring inherited. 

138 args = [] 

139 names = ["excl"] 

140 for item in spec: 

141 if isinstance(item, str): 

142 args.append((sqlalchemy.schema.Column(item), "=")) 

143 names.append(item) 

144 elif issubclass(item, TimespanDatabaseRepresentation): 

145 assert item is self.getTimespanRepresentation() 

146 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&")) 

147 names.append(TimespanDatabaseRepresentation.NAME) 

148 return sqlalchemy.dialects.postgresql.ExcludeConstraint( 

149 *args, 

150 name=self.shrinkDatabaseEntityName("_".join(names)), 

151 ) 

152 

153 @classmethod 

154 def getTimespanRepresentation(cls) -> Type[TimespanDatabaseRepresentation]: 

155 # Docstring inherited. 

156 return _RangeTimespanRepresentation 

157 

158 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None: 

159 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.") 

160 if not rows: 

161 return 

162 # This uses special support for UPSERT in PostgreSQL backend: 

163 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert 

164 query = sqlalchemy.dialects.postgresql.dml.insert(table) 

165 # In the SET clause assign all columns using special `excluded` 

166 # pseudo-table. If some column in the table does not appear in the 

167 # INSERT list this will set it to NULL. 

168 excluded = query.excluded 

169 data = {column.name: getattr(excluded, column.name) 

170 for column in table.columns 

171 if column.name not in table.primary_key} 

172 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data) 

173 with self._connection() as connection: 

174 connection.execute(query, rows) 

175 

176 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict) -> int: 

177 # Docstring inherited. 

178 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.") 

179 if not rows: 

180 return 0 

181 # Like `replace`, this uses UPSERT, but it's a bit simpler because 

182 # we don't care which constraint is violated or specify which columns 

183 # to update. 

184 query = sqlalchemy.dialects.postgresql.dml.insert(table).on_conflict_do_nothing() 

185 with self._connection() as connection: 

186 return connection.execute(query, rows).rowcount 

187 

188 

189class _RangeTimespanType(sqlalchemy.TypeDecorator): 

190 """A single-column `Timespan` representation usable only with 

191 PostgreSQL. 

192 

193 This type should be able to take advantage of PostgreSQL's built-in 

194 range operators, and the indexing and EXCLUSION table constraints built 

195 off of them. 

196 """ 

197 

198 impl = sqlalchemy.dialects.postgresql.INT8RANGE 

199 

200 cache_ok = True 

201 

202 def process_bind_param(self, value: Optional[Timespan], 

203 dialect: sqlalchemy.engine.Dialect 

204 ) -> Optional[psycopg2.extras.NumericRange]: 

205 if value is None: 

206 return None 

207 if not isinstance(value, Timespan): 

208 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.") 

209 if value.isEmpty(): 

210 return psycopg2.extras.NumericRange(empty=True) 

211 else: 

212 converter = time_utils.TimeConverter() 

213 assert value._nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__." 

214 assert value._nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__." 

215 lower = None if value._nsec[0] == converter.min_nsec else value._nsec[0] 

216 upper = None if value._nsec[1] == converter.max_nsec else value._nsec[1] 

217 return psycopg2.extras.NumericRange(lower=lower, upper=upper) 

218 

219 def process_result_value(self, value: Optional[psycopg2.extras.NumericRange], 

220 dialect: sqlalchemy.engine.Dialect 

221 ) -> Optional[Timespan]: 

222 if value is None: 

223 return None 

224 if value.isempty: 

225 return Timespan.makeEmpty() 

226 converter = time_utils.TimeConverter() 

227 begin_nsec = converter.min_nsec if value.lower is None else value.lower 

228 end_nsec = converter.max_nsec if value.upper is None else value.upper 

229 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec)) 

230 

231 class comparator_factory(sqlalchemy.types.Concatenable.Comparator): # noqa: N801 

232 """Comparison operators for TimespanColumnRanges. 

233 

234 Notes 

235 ----- 

236 The existence of this nested class is a workaround for a bug 

237 submitted upstream as 

238 https://github.com/sqlalchemy/sqlalchemy/issues/5476 (now fixed on 

239 master, but not in the releases we currently use). The code is 

240 a limited copy of the operators in 

241 ``sqlalchemy.dialects.postgresql.ranges.RangeOperators``, but with 

242 ``is_comparison=True`` added to all calls. 

243 """ 

244 

245 def __ne__(self, other: Any) -> Any: 

246 "Boolean expression. Returns true if two ranges are not equal" 

247 if other is None: 

248 return super().__ne__(other) 

249 else: 

250 return self.expr.op("<>", is_comparison=True)(other) 

251 

252 def contains(self, other: Any, **kw: Any) -> Any: 

253 """Boolean expression. Returns true if the right hand operand, 

254 which can be an element or a range, is contained within the 

255 column. 

256 """ 

257 return self.expr.op("@>", is_comparison=True)(other) 

258 

259 def contained_by(self, other: Any) -> Any: 

260 """Boolean expression. Returns true if the column is contained 

261 within the right hand operand. 

262 """ 

263 return self.expr.op("<@", is_comparison=True)(other) 

264 

265 def overlaps(self, other: Any) -> Any: 

266 """Boolean expression. Returns true if the column overlaps 

267 (has points in common with) the right hand operand. 

268 """ 

269 return self.expr.op("&&", is_comparison=True)(other) 

270 

271 def strictly_left_of(self, other: Any) -> Any: 

272 """Boolean expression. Returns true if the column is strictly 

273 left of the right hand operand. 

274 """ 

275 return self.expr.op("<<", is_comparison=True)(other) 

276 

277 __lshift__ = strictly_left_of 

278 

279 def strictly_right_of(self, other: Any) -> Any: 

280 """Boolean expression. Returns true if the column is strictly 

281 right of the right hand operand. 

282 """ 

283 return self.expr.op(">>", is_comparison=True)(other) 

284 

285 __rshift__ = strictly_right_of 

286 

287 

288class _RangeTimespanRepresentation(TimespanDatabaseRepresentation): 

289 """An implementation of `TimespanDatabaseRepresentation` that uses 

290 `_RangeTimespanType` to store a timespan in a single 

291 PostgreSQL-specific field. 

292 

293 Parameters 

294 ---------- 

295 column : `sqlalchemy.sql.ColumnElement` 

296 SQLAlchemy object representing the column. 

297 """ 

298 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str): 

299 self.column = column 

300 self._name = name 

301 

302 __slots__ = ("column", "_name") 

303 

304 @classmethod 

305 def makeFieldSpecs(cls, nullable: bool, name: Optional[str] = None, **kwargs: Any 

306 ) -> Tuple[ddl.FieldSpec, ...]: 

307 # Docstring inherited. 

308 if name is None: 

309 name = cls.NAME 

310 return ( 

311 ddl.FieldSpec( 

312 name, dtype=_RangeTimespanType, nullable=nullable, 

313 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")), 

314 **kwargs 

315 ), 

316 ) 

317 

318 @classmethod 

319 def getFieldNames(cls, name: Optional[str] = None) -> Tuple[str, ...]: 

320 # Docstring inherited. 

321 if name is None: 

322 name = cls.NAME 

323 return (name,) 

324 

325 @classmethod 

326 def update(cls, extent: Optional[Timespan], name: Optional[str] = None, 

327 result: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: 

328 # Docstring inherited. 

329 if name is None: 

330 name = cls.NAME 

331 if result is None: 

332 result = {} 

333 result[name] = extent 

334 return result 

335 

336 @classmethod 

337 def extract(cls, mapping: Mapping[str, Any], name: Optional[str] = None) -> Optional[Timespan]: 

338 # Docstring inherited. 

339 if name is None: 

340 name = cls.NAME 

341 return mapping[name] 

342 

343 @classmethod 

344 def fromLiteral(cls, timespan: Timespan) -> _RangeTimespanRepresentation: 

345 # Docstring inherited. 

346 return cls( 

347 column=sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType), 

348 name=cls.NAME, 

349 ) 

350 

351 @classmethod 

352 def fromSelectable(cls, selectable: sqlalchemy.sql.FromClause, name: Optional[str] = None 

353 ) -> _RangeTimespanRepresentation: 

354 # Docstring inherited. 

355 if name is None: 

356 name = cls.NAME 

357 return cls(selectable.columns[name], name) 

358 

359 @property 

360 def name(self) -> str: 

361 # Docstring inherited. 

362 return self._name 

363 

364 def isNull(self) -> sqlalchemy.sql.ColumnElement: 

365 # Docstring inherited. 

366 return self.column.is_(None) 

367 

368 def isEmpty(self) -> sqlalchemy.sql.ColumnElement: 

369 # Docstring inherited 

370 return sqlalchemy.sql.func.isempty(self.column) 

371 

372 def __lt__( 

373 self, 

374 other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

375 ) -> sqlalchemy.sql.ColumnElement: 

376 # Docstring inherited. 

377 if isinstance(other, sqlalchemy.sql.ColumnElement): 

378 return sqlalchemy.sql.and_( 

379 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)), 

380 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

381 sqlalchemy.sql.func.upper(self.column) <= other, 

382 ) 

383 else: 

384 return self.column << other.column 

385 

386 def __gt__( 

387 self, 

388 other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

389 ) -> sqlalchemy.sql.ColumnElement: 

390 # Docstring inherited. 

391 if isinstance(other, sqlalchemy.sql.ColumnElement): 

392 return sqlalchemy.sql.and_( 

393 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)), 

394 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

395 sqlalchemy.sql.func.lower(self.column) > other, 

396 ) 

397 else: 

398 return self.column >> other.column 

399 

400 def overlaps(self, other: _RangeTimespanRepresentation) -> sqlalchemy.sql.ColumnElement: 

401 # Docstring inherited. 

402 return self.column.overlaps(other.column) 

403 

404 def contains(self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

405 ) -> sqlalchemy.sql.ColumnElement: 

406 # Docstring inherited 

407 if isinstance(other, _RangeTimespanRepresentation): 

408 return self.column.contains(other.column) 

409 else: 

410 return self.column.contains(other) 

411 

412 def flatten(self, name: Optional[str] = None) -> Iterator[sqlalchemy.sql.ColumnElement]: 

413 # Docstring inherited. 

414 if name is None: 

415 yield self.column 

416 else: 

417 yield self.column.label(name)