Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ["PostgresqlDatabase"] 

24 

25from contextlib import contextmanager, closing 

26from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, Type, Union 

27 

28import psycopg2 

29import sqlalchemy.dialects.postgresql 

30 

31from ..interfaces import Database 

32from ..nameShrinker import NameShrinker 

33from ...core import ddl, time_utils, Timespan, TimespanDatabaseRepresentation 

34 

35 

36class PostgresqlDatabase(Database): 

37 """An implementation of the `Database` interface for PostgreSQL. 

38 

39 Parameters 

40 ---------- 

41 connection : `sqlalchemy.engine.Connection` 

42 An existing connection created by a previous call to `connect`. 

43 origin : `int` 

44 An integer ID that should be used as the default for any datasets, 

45 quanta, or other entities that use a (autoincrement, origin) compound 

46 primary key. 

47 namespace : `str`, optional 

48 The namespace (schema) this database is associated with. If `None`, 

49 the default schema for the connection is used (which may be `None`). 

50 writeable : `bool`, optional 

51 If `True`, allow write operations on the database, including 

52 ``CREATE TABLE``. 

53 

54 Notes 

55 ----- 

56 This currently requires the psycopg2 driver to be used as the backend for 

57 SQLAlchemy. Running the tests for this class requires the 

58 ``testing.postgresql`` be installed, which we assume indicates that a 

59 PostgreSQL server is installed and can be run locally in userspace. 

60 

61 Some functionality provided by this class (and used by `Registry`) requires 

62 the ``btree_gist`` PostgreSQL server extension to be installed an enabled 

63 on the database being connected to; this is checked at connection time. 

64 """ 

65 

66 def __init__(self, *, engine: sqlalchemy.engine.Engine, origin: int, 

67 namespace: Optional[str] = None, writeable: bool = True): 

68 super().__init__(origin=origin, engine=engine, namespace=namespace) 

69 with engine.connect() as connection: 

70 dbapi = connection.connection 

71 try: 

72 dsn = dbapi.get_dsn_parameters() 

73 except (AttributeError, KeyError) as err: 

74 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err 

75 if namespace is None: 

76 namespace = connection.execute("SELECT current_schema();").scalar() 

77 query = "SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';" 

78 if not connection.execute(query).scalar(): 

79 raise RuntimeError( 

80 "The Butler PostgreSQL backend requires the btree_gist extension. " 

81 "As extensions are enabled per-database, this may require an administrator to run " 

82 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is " 

83 " initialized." 

84 ) 

85 self.namespace = namespace 

86 self.dbname = dsn.get("dbname") 

87 self._writeable = writeable 

88 self._shrinker = NameShrinker(connection.engine.dialect.max_identifier_length) 

89 

90 @classmethod 

91 def makeEngine(cls, uri: str, *, writeable: bool = True) -> sqlalchemy.engine.Engine: 

92 return sqlalchemy.engine.create_engine(uri) 

93 

94 @classmethod 

95 def fromEngine(cls, engine: sqlalchemy.engine.Engine, *, origin: int, 

96 namespace: Optional[str] = None, writeable: bool = True) -> Database: 

97 return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable) 

98 

99 @contextmanager 

100 def transaction(self, *, interrupting: bool = False, savepoint: bool = False, 

101 lock: Iterable[sqlalchemy.schema.Table] = ()) -> Iterator[None]: 

102 with super().transaction(interrupting=interrupting, savepoint=savepoint, lock=lock): 

103 if not self.isWriteable(): 

104 with closing(self._connection.connection.cursor()) as cursor: 

105 cursor.execute("SET TRANSACTION READ ONLY") 

106 else: 

107 with closing(self._connection.connection.cursor()) as cursor: 

108 # Make timestamps UTC, because we didn't use TIMESTAMPZ for 

109 # the column type. When we can tolerate a schema change, 

110 # we should change that type and remove this line. 

111 cursor.execute("SET TIME ZONE 0") 

112 yield 

113 

114 def _lockTables(self, tables: Iterable[sqlalchemy.schema.Table] = ()) -> None: 

115 # Docstring inherited. 

116 for table in tables: 

117 self._connection.execute(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE") 

118 

119 def isWriteable(self) -> bool: 

120 return self._writeable 

121 

122 def __str__(self) -> str: 

123 return f"PostgreSQL@{self.dbname}:{self.namespace}" 

124 

125 def shrinkDatabaseEntityName(self, original: str) -> str: 

126 return self._shrinker.shrink(original) 

127 

128 def expandDatabaseEntityName(self, shrunk: str) -> str: 

129 return self._shrinker.expand(shrunk) 

130 

131 def _convertExclusionConstraintSpec(self, table: str, 

132 spec: Tuple[Union[str, Type[TimespanDatabaseRepresentation]], ...], 

133 metadata: sqlalchemy.MetaData) -> sqlalchemy.schema.Constraint: 

134 # Docstring inherited. 

135 args = [] 

136 names = ["excl"] 

137 for item in spec: 

138 if isinstance(item, str): 

139 args.append((sqlalchemy.schema.Column(item), "=")) 

140 names.append(item) 

141 elif issubclass(item, TimespanDatabaseRepresentation): 

142 assert item is self.getTimespanRepresentation() 

143 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&")) 

144 names.append(TimespanDatabaseRepresentation.NAME) 

145 return sqlalchemy.dialects.postgresql.ExcludeConstraint( 

146 *args, 

147 name=self.shrinkDatabaseEntityName("_".join(names)), 

148 ) 

149 

150 @classmethod 

151 def getTimespanRepresentation(cls) -> Type[TimespanDatabaseRepresentation]: 

152 # Docstring inherited. 

153 return _RangeTimespanRepresentation 

154 

155 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None: 

156 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.") 

157 if not rows: 

158 return 

159 # This uses special support for UPSERT in PostgreSQL backend: 

160 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert 

161 query = sqlalchemy.dialects.postgresql.dml.insert(table) 

162 # In the SET clause assign all columns using special `excluded` 

163 # pseudo-table. If some column in the table does not appear in the 

164 # INSERT list this will set it to NULL. 

165 excluded = query.excluded 

166 data = {column.name: getattr(excluded, column.name) 

167 for column in table.columns 

168 if column.name not in table.primary_key} 

169 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data) 

170 self._connection.execute(query, *rows) 

171 

172 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict) -> int: 

173 # Docstring inherited. 

174 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.") 

175 if not rows: 

176 return 0 

177 # Like `replace`, this uses UPSERT, but it's a bit simpler because 

178 # we don't care which constraint is violated or specify which columns 

179 # to update. 

180 query = sqlalchemy.dialects.postgresql.dml.insert(table).on_conflict_do_nothing() 

181 return self._connection.execute(query, *rows).rowcount 

182 

183 

184class _RangeTimespanType(sqlalchemy.TypeDecorator): 

185 """A single-column `Timespan` representation usable only with 

186 PostgreSQL. 

187 

188 This type should be able to take advantage of PostgreSQL's built-in 

189 range operators, and the indexing and EXCLUSION table constraints built 

190 off of them. 

191 """ 

192 

193 impl = sqlalchemy.dialects.postgresql.INT8RANGE 

194 

195 cache_ok = True 

196 

197 def process_bind_param(self, value: Optional[Timespan], 

198 dialect: sqlalchemy.engine.Dialect 

199 ) -> Optional[psycopg2.extras.NumericRange]: 

200 if value is None: 

201 return None 

202 if not isinstance(value, Timespan): 

203 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.") 

204 if value.isEmpty(): 

205 return psycopg2.extras.NumericRange(empty=True) 

206 else: 

207 converter = time_utils.TimeConverter() 

208 assert value._nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__." 

209 assert value._nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__." 

210 lower = None if value._nsec[0] == converter.min_nsec else value._nsec[0] 

211 upper = None if value._nsec[1] == converter.max_nsec else value._nsec[1] 

212 return psycopg2.extras.NumericRange(lower=lower, upper=upper) 

213 

214 def process_result_value(self, value: Optional[psycopg2.extras.NumericRange], 

215 dialect: sqlalchemy.engine.Dialect 

216 ) -> Optional[Timespan]: 

217 if value is None: 

218 return None 

219 if value.isempty: 

220 return Timespan.makeEmpty() 

221 converter = time_utils.TimeConverter() 

222 begin_nsec = converter.min_nsec if value.lower is None else value.lower 

223 end_nsec = converter.max_nsec if value.upper is None else value.upper 

224 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec)) 

225 

226 class comparator_factory(sqlalchemy.types.Concatenable.Comparator): # noqa: N801 

227 """Comparison operators for TimespanColumnRanges. 

228 

229 Notes 

230 ----- 

231 The existence of this nested class is a workaround for a bug 

232 submitted upstream as 

233 https://github.com/sqlalchemy/sqlalchemy/issues/5476 (now fixed on 

234 master, but not in the releases we currently use). The code is 

235 a limited copy of the operators in 

236 ``sqlalchemy.dialects.postgresql.ranges.RangeOperators``, but with 

237 ``is_comparison=True`` added to all calls. 

238 """ 

239 

240 def __ne__(self, other: Any) -> Any: 

241 "Boolean expression. Returns true if two ranges are not equal" 

242 if other is None: 

243 return super().__ne__(other) 

244 else: 

245 return self.expr.op("<>", is_comparison=True)(other) 

246 

247 def contains(self, other: Any, **kw: Any) -> Any: 

248 """Boolean expression. Returns true if the right hand operand, 

249 which can be an element or a range, is contained within the 

250 column. 

251 """ 

252 return self.expr.op("@>", is_comparison=True)(other) 

253 

254 def contained_by(self, other: Any) -> Any: 

255 """Boolean expression. Returns true if the column is contained 

256 within the right hand operand. 

257 """ 

258 return self.expr.op("<@", is_comparison=True)(other) 

259 

260 def overlaps(self, other: Any) -> Any: 

261 """Boolean expression. Returns true if the column overlaps 

262 (has points in common with) the right hand operand. 

263 """ 

264 return self.expr.op("&&", is_comparison=True)(other) 

265 

266 def strictly_left_of(self, other: Any) -> Any: 

267 """Boolean expression. Returns true if the column is strictly 

268 left of the right hand operand. 

269 """ 

270 return self.expr.op("<<", is_comparison=True)(other) 

271 

272 __lshift__ = strictly_left_of 

273 

274 def strictly_right_of(self, other: Any) -> Any: 

275 """Boolean expression. Returns true if the column is strictly 

276 right of the right hand operand. 

277 """ 

278 return self.expr.op(">>", is_comparison=True)(other) 

279 

280 __rshift__ = strictly_right_of 

281 

282 

283class _RangeTimespanRepresentation(TimespanDatabaseRepresentation): 

284 """An implementation of `TimespanDatabaseRepresentation` that uses 

285 `_RangeTimespanType` to store a timespan in a single 

286 PostgreSQL-specific field. 

287 

288 Parameters 

289 ---------- 

290 column : `sqlalchemy.sql.ColumnElement` 

291 SQLAlchemy object representing the column. 

292 """ 

293 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str): 

294 self.column = column 

295 self._name = name 

296 

297 __slots__ = ("column", "_name") 

298 

299 @classmethod 

300 def makeFieldSpecs(cls, nullable: bool, name: Optional[str] = None, **kwargs: Any 

301 ) -> Tuple[ddl.FieldSpec, ...]: 

302 # Docstring inherited. 

303 if name is None: 

304 name = cls.NAME 

305 return ( 

306 ddl.FieldSpec( 

307 name, dtype=_RangeTimespanType, nullable=nullable, 

308 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")), 

309 **kwargs 

310 ), 

311 ) 

312 

313 @classmethod 

314 def getFieldNames(cls, name: Optional[str] = None) -> Tuple[str, ...]: 

315 # Docstring inherited. 

316 if name is None: 

317 name = cls.NAME 

318 return (name,) 

319 

320 @classmethod 

321 def update(cls, extent: Optional[Timespan], name: Optional[str] = None, 

322 result: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: 

323 # Docstring inherited. 

324 if name is None: 

325 name = cls.NAME 

326 if result is None: 

327 result = {} 

328 result[name] = extent 

329 return result 

330 

331 @classmethod 

332 def extract(cls, mapping: Mapping[str, Any], name: Optional[str] = None) -> Optional[Timespan]: 

333 # Docstring inherited. 

334 if name is None: 

335 name = cls.NAME 

336 return mapping[name] 

337 

338 @classmethod 

339 def fromLiteral(cls, timespan: Timespan) -> _RangeTimespanRepresentation: 

340 # Docstring inherited. 

341 return cls( 

342 column=sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType), 

343 name=cls.NAME, 

344 ) 

345 

346 @classmethod 

347 def fromSelectable(cls, selectable: sqlalchemy.sql.FromClause, name: Optional[str] = None 

348 ) -> _RangeTimespanRepresentation: 

349 # Docstring inherited. 

350 if name is None: 

351 name = cls.NAME 

352 return cls(selectable.columns[name], name) 

353 

354 @property 

355 def name(self) -> str: 

356 # Docstring inherited. 

357 return self._name 

358 

359 def isNull(self) -> sqlalchemy.sql.ColumnElement: 

360 # Docstring inherited. 

361 return self.column.is_(None) 

362 

363 def isEmpty(self) -> sqlalchemy.sql.ColumnElement: 

364 # Docstring inherited 

365 return sqlalchemy.sql.func.isempty(self.column) 

366 

367 def __lt__( 

368 self, 

369 other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

370 ) -> sqlalchemy.sql.ColumnElement: 

371 # Docstring inherited. 

372 if isinstance(other, sqlalchemy.sql.ColumnElement): 

373 return sqlalchemy.sql.and_( 

374 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)), 

375 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

376 sqlalchemy.sql.func.upper(self.column) <= other, 

377 ) 

378 else: 

379 return self.column << other.column 

380 

381 def __gt__( 

382 self, 

383 other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

384 ) -> sqlalchemy.sql.ColumnElement: 

385 # Docstring inherited. 

386 if isinstance(other, sqlalchemy.sql.ColumnElement): 

387 return sqlalchemy.sql.and_( 

388 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)), 

389 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

390 sqlalchemy.sql.func.lower(self.column) > other, 

391 ) 

392 else: 

393 return self.column >> other.column 

394 

395 def overlaps(self, other: _RangeTimespanRepresentation) -> sqlalchemy.sql.ColumnElement: 

396 # Docstring inherited. 

397 return self.column.overlaps(other.column) 

398 

399 def contains(self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

400 ) -> sqlalchemy.sql.ColumnElement: 

401 # Docstring inherited 

402 if isinstance(other, _RangeTimespanRepresentation): 

403 return self.column.contains(other.column) 

404 else: 

405 return self.column.contains(other) 

406 

407 def flatten(self, name: Optional[str] = None) -> Iterator[sqlalchemy.sql.ColumnElement]: 

408 # Docstring inherited. 

409 if name is None: 

410 yield self.column 

411 else: 

412 yield self.column.label(name)