Coverage for python/lsst/daf/butler/registry/databases/postgresql.py: 24%

187 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-11-12 02:19 -0800

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ["PostgresqlDatabase"] 

24 

25from contextlib import closing, contextmanager 

26from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, Type, Union 

27 

28import psycopg2 

29import sqlalchemy 

30import sqlalchemy.dialects.postgresql 

31 

32from ...core import Timespan, TimespanDatabaseRepresentation, ddl, time_utils 

33from ...core.named import NamedValueAbstractSet 

34from ..interfaces import Database 

35from ..nameShrinker import NameShrinker 

36 

37 

38class PostgresqlDatabase(Database): 

39 """An implementation of the `Database` interface for PostgreSQL. 

40 

41 Parameters 

42 ---------- 

43 connection : `sqlalchemy.engine.Connection` 

44 An existing connection created by a previous call to `connect`. 

45 origin : `int` 

46 An integer ID that should be used as the default for any datasets, 

47 quanta, or other entities that use a (autoincrement, origin) compound 

48 primary key. 

49 namespace : `str`, optional 

50 The namespace (schema) this database is associated with. If `None`, 

51 the default schema for the connection is used (which may be `None`). 

52 writeable : `bool`, optional 

53 If `True`, allow write operations on the database, including 

54 ``CREATE TABLE``. 

55 

56 Notes 

57 ----- 

58 This currently requires the psycopg2 driver to be used as the backend for 

59 SQLAlchemy. Running the tests for this class requires the 

60 ``testing.postgresql`` be installed, which we assume indicates that a 

61 PostgreSQL server is installed and can be run locally in userspace. 

62 

63 Some functionality provided by this class (and used by `Registry`) requires 

64 the ``btree_gist`` PostgreSQL server extension to be installed an enabled 

65 on the database being connected to; this is checked at connection time. 

66 """ 

67 

68 def __init__( 

69 self, 

70 *, 

71 engine: sqlalchemy.engine.Engine, 

72 origin: int, 

73 namespace: Optional[str] = None, 

74 writeable: bool = True, 

75 ): 

76 super().__init__(origin=origin, engine=engine, namespace=namespace) 

77 with engine.connect() as connection: 

78 dbapi = connection.connection 

79 try: 

80 dsn = dbapi.get_dsn_parameters() 

81 except (AttributeError, KeyError) as err: 

82 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err 

83 if namespace is None: 

84 namespace = connection.execute("SELECT current_schema();").scalar() 

85 query = "SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';" 

86 if not connection.execute(sqlalchemy.text(query)).scalar(): 

87 raise RuntimeError( 

88 "The Butler PostgreSQL backend requires the btree_gist extension. " 

89 "As extensions are enabled per-database, this may require an administrator to run " 

90 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is " 

91 " initialized." 

92 ) 

93 self.namespace = namespace 

94 self.dbname = dsn.get("dbname") 

95 self._writeable = writeable 

96 self._shrinker = NameShrinker(connection.engine.dialect.max_identifier_length) 

97 

98 @classmethod 

99 def makeEngine(cls, uri: str, *, writeable: bool = True) -> sqlalchemy.engine.Engine: 

100 return sqlalchemy.engine.create_engine(uri) 

101 

102 @classmethod 

103 def fromEngine( 

104 cls, 

105 engine: sqlalchemy.engine.Engine, 

106 *, 

107 origin: int, 

108 namespace: Optional[str] = None, 

109 writeable: bool = True, 

110 ) -> Database: 

111 return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable) 

112 

113 @contextmanager 

114 def transaction( 

115 self, 

116 *, 

117 interrupting: bool = False, 

118 savepoint: bool = False, 

119 lock: Iterable[sqlalchemy.schema.Table] = (), 

120 ) -> Iterator[None]: 

121 with super().transaction(interrupting=interrupting, savepoint=savepoint, lock=lock): 

122 assert self._session_connection is not None, "Guaranteed to have a connection in transaction" 

123 if not self.isWriteable(): 

124 with closing(self._session_connection.connection.cursor()) as cursor: 

125 cursor.execute("SET TRANSACTION READ ONLY") 

126 else: 

127 with closing(self._session_connection.connection.cursor()) as cursor: 

128 # Make timestamps UTC, because we didn't use TIMESTAMPZ for 

129 # the column type. When we can tolerate a schema change, 

130 # we should change that type and remove this line. 

131 cursor.execute("SET TIME ZONE 0") 

132 yield 

133 

134 def _lockTables( 

135 self, connection: sqlalchemy.engine.Connection, tables: Iterable[sqlalchemy.schema.Table] = () 

136 ) -> None: 

137 # Docstring inherited. 

138 for table in tables: 

139 connection.execute(sqlalchemy.text(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE")) 

140 

141 def isWriteable(self) -> bool: 

142 return self._writeable 

143 

144 def __str__(self) -> str: 

145 return f"PostgreSQL@{self.dbname}:{self.namespace}" 

146 

147 def shrinkDatabaseEntityName(self, original: str) -> str: 

148 return self._shrinker.shrink(original) 

149 

150 def expandDatabaseEntityName(self, shrunk: str) -> str: 

151 return self._shrinker.expand(shrunk) 

152 

153 def _convertExclusionConstraintSpec( 

154 self, 

155 table: str, 

156 spec: Tuple[Union[str, Type[TimespanDatabaseRepresentation]], ...], 

157 metadata: sqlalchemy.MetaData, 

158 ) -> sqlalchemy.schema.Constraint: 

159 # Docstring inherited. 

160 args = [] 

161 names = ["excl"] 

162 for item in spec: 

163 if isinstance(item, str): 

164 args.append((sqlalchemy.schema.Column(item), "=")) 

165 names.append(item) 

166 elif issubclass(item, TimespanDatabaseRepresentation): 

167 assert item is self.getTimespanRepresentation() 

168 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&")) 

169 names.append(TimespanDatabaseRepresentation.NAME) 

170 return sqlalchemy.dialects.postgresql.ExcludeConstraint( 

171 *args, 

172 name=self.shrinkDatabaseEntityName("_".join(names)), 

173 ) 

174 

175 @classmethod 

176 def getTimespanRepresentation(cls) -> Type[TimespanDatabaseRepresentation]: 

177 # Docstring inherited. 

178 return _RangeTimespanRepresentation 

179 

180 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None: 

181 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.") 

182 if not rows: 

183 return 

184 # This uses special support for UPSERT in PostgreSQL backend: 

185 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert 

186 query = sqlalchemy.dialects.postgresql.dml.insert(table) 

187 # In the SET clause assign all columns using special `excluded` 

188 # pseudo-table. If some column in the table does not appear in the 

189 # INSERT list this will set it to NULL. 

190 excluded = query.excluded 

191 data = { 

192 column.name: getattr(excluded, column.name) 

193 for column in table.columns 

194 if column.name not in table.primary_key 

195 } 

196 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data) 

197 with self._connection() as connection: 

198 connection.execute(query, rows) 

199 

200 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict, primary_key_only: bool = False) -> int: 

201 # Docstring inherited. 

202 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.") 

203 if not rows: 

204 return 0 

205 # Like `replace`, this uses UPSERT. 

206 base_insert = sqlalchemy.dialects.postgresql.dml.insert(table) 

207 if primary_key_only: 

208 query = base_insert.on_conflict_do_nothing(constraint=table.primary_key) 

209 else: 

210 query = base_insert.on_conflict_do_nothing() 

211 with self._connection() as connection: 

212 return connection.execute(query, rows).rowcount 

213 

214 def constant_rows( 

215 self, 

216 fields: NamedValueAbstractSet[ddl.FieldSpec], 

217 *rows: dict, 

218 name: Optional[str] = None, 

219 ) -> sqlalchemy.sql.FromClause: 

220 # Docstring inherited. 

221 return super().constant_rows(fields, *rows, name=name) 

222 

223 

224class _RangeTimespanType(sqlalchemy.TypeDecorator): 

225 """A single-column `Timespan` representation usable only with 

226 PostgreSQL. 

227 

228 This type should be able to take advantage of PostgreSQL's built-in 

229 range operators, and the indexing and EXCLUSION table constraints built 

230 off of them. 

231 """ 

232 

233 impl = sqlalchemy.dialects.postgresql.INT8RANGE 

234 

235 cache_ok = True 

236 

237 def process_bind_param( 

238 self, value: Optional[Timespan], dialect: sqlalchemy.engine.Dialect 

239 ) -> Optional[psycopg2.extras.NumericRange]: 

240 if value is None: 

241 return None 

242 if not isinstance(value, Timespan): 

243 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.") 

244 if value.isEmpty(): 

245 return psycopg2.extras.NumericRange(empty=True) 

246 else: 

247 converter = time_utils.TimeConverter() 

248 assert value._nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__." 

249 assert value._nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__." 

250 lower = None if value._nsec[0] == converter.min_nsec else value._nsec[0] 

251 upper = None if value._nsec[1] == converter.max_nsec else value._nsec[1] 

252 return psycopg2.extras.NumericRange(lower=lower, upper=upper) 

253 

254 def process_result_value( 

255 self, value: Optional[psycopg2.extras.NumericRange], dialect: sqlalchemy.engine.Dialect 

256 ) -> Optional[Timespan]: 

257 if value is None: 

258 return None 

259 if value.isempty: 

260 return Timespan.makeEmpty() 

261 converter = time_utils.TimeConverter() 

262 begin_nsec = converter.min_nsec if value.lower is None else value.lower 

263 end_nsec = converter.max_nsec if value.upper is None else value.upper 

264 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec)) 

265 

266 

267class _RangeTimespanRepresentation(TimespanDatabaseRepresentation): 

268 """An implementation of `TimespanDatabaseRepresentation` that uses 

269 `_RangeTimespanType` to store a timespan in a single 

270 PostgreSQL-specific field. 

271 

272 Parameters 

273 ---------- 

274 column : `sqlalchemy.sql.ColumnElement` 

275 SQLAlchemy object representing the column. 

276 """ 

277 

278 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str): 

279 self.column = column 

280 self._name = name 

281 

282 __slots__ = ("column", "_name") 

283 

284 @classmethod 

285 def makeFieldSpecs( 

286 cls, nullable: bool, name: Optional[str] = None, **kwargs: Any 

287 ) -> Tuple[ddl.FieldSpec, ...]: 

288 # Docstring inherited. 

289 if name is None: 

290 name = cls.NAME 

291 return ( 

292 ddl.FieldSpec( 

293 name, 

294 dtype=_RangeTimespanType, 

295 nullable=nullable, 

296 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")), 

297 **kwargs, 

298 ), 

299 ) 

300 

301 @classmethod 

302 def getFieldNames(cls, name: Optional[str] = None) -> Tuple[str, ...]: 

303 # Docstring inherited. 

304 if name is None: 

305 name = cls.NAME 

306 return (name,) 

307 

308 @classmethod 

309 def update( 

310 cls, extent: Optional[Timespan], name: Optional[str] = None, result: Optional[Dict[str, Any]] = None 

311 ) -> Dict[str, Any]: 

312 # Docstring inherited. 

313 if name is None: 

314 name = cls.NAME 

315 if result is None: 

316 result = {} 

317 result[name] = extent 

318 return result 

319 

320 @classmethod 

321 def extract(cls, mapping: Mapping[str, Any], name: Optional[str] = None) -> Optional[Timespan]: 

322 # Docstring inherited. 

323 if name is None: 

324 name = cls.NAME 

325 return mapping[name] 

326 

327 @classmethod 

328 def fromLiteral(cls, timespan: Optional[Timespan]) -> _RangeTimespanRepresentation: 

329 # Docstring inherited. 

330 if timespan is None: 

331 return cls(column=sqlalchemy.sql.null(), name=cls.NAME) 

332 return cls( 

333 column=sqlalchemy.sql.cast( 

334 sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType), type_=_RangeTimespanType 

335 ), 

336 name=cls.NAME, 

337 ) 

338 

339 @classmethod 

340 def from_columns( 

341 cls, columns: sqlalchemy.sql.ColumnCollection, name: Optional[str] = None 

342 ) -> _RangeTimespanRepresentation: 

343 # Docstring inherited. 

344 if name is None: 

345 name = cls.NAME 

346 return cls(columns[name], name) 

347 

348 @property 

349 def name(self) -> str: 

350 # Docstring inherited. 

351 return self._name 

352 

353 def isNull(self) -> sqlalchemy.sql.ColumnElement: 

354 # Docstring inherited. 

355 return self.column.is_(None) 

356 

357 def isEmpty(self) -> sqlalchemy.sql.ColumnElement: 

358 # Docstring inherited 

359 return sqlalchemy.sql.func.isempty(self.column) 

360 

361 def __lt__( 

362 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

363 ) -> sqlalchemy.sql.ColumnElement: 

364 # Docstring inherited. 

365 if isinstance(other, sqlalchemy.sql.ColumnElement): 

366 return sqlalchemy.sql.and_( 

367 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)), 

368 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

369 sqlalchemy.sql.func.upper(self.column) <= other, 

370 ) 

371 else: 

372 return self.column << other.column 

373 

374 def __gt__( 

375 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

376 ) -> sqlalchemy.sql.ColumnElement: 

377 # Docstring inherited. 

378 if isinstance(other, sqlalchemy.sql.ColumnElement): 

379 return sqlalchemy.sql.and_( 

380 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)), 

381 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

382 sqlalchemy.sql.func.lower(self.column) > other, 

383 ) 

384 else: 

385 return self.column >> other.column 

386 

387 def overlaps( 

388 self, other: _RangeTimespanRepresentation | sqlalchemy.sql.ColumnElement 

389 ) -> sqlalchemy.sql.ColumnElement: 

390 # Docstring inherited. 

391 if not isinstance(other, _RangeTimespanRepresentation): 

392 return self.contains(other) 

393 return self.column.overlaps(other.column) 

394 

395 def contains( 

396 self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

397 ) -> sqlalchemy.sql.ColumnElement: 

398 # Docstring inherited 

399 if isinstance(other, _RangeTimespanRepresentation): 

400 return self.column.contains(other.column) 

401 else: 

402 return self.column.contains(other) 

403 

404 def lower(self) -> sqlalchemy.sql.ColumnElement: 

405 # Docstring inherited. 

406 return sqlalchemy.sql.functions.coalesce( 

407 sqlalchemy.sql.func.lower(self.column), sqlalchemy.sql.literal(0) 

408 ) 

409 

410 def upper(self) -> sqlalchemy.sql.ColumnElement: 

411 # Docstring inherited. 

412 return sqlalchemy.sql.functions.coalesce( 

413 sqlalchemy.sql.func.upper(self.column), sqlalchemy.sql.literal(0) 

414 ) 

415 

416 def flatten(self, name: Optional[str] = None) -> tuple[sqlalchemy.sql.ColumnElement]: 

417 # Docstring inherited. 

418 if name is None: 

419 return (self.column,) 

420 else: 

421 return (self.column.label(name),)