Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ["PostgresqlDatabase"] 

24 

25from contextlib import contextmanager, closing 

26from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, Type, Union 

27 

28import psycopg2 

29import sqlalchemy.dialects.postgresql 

30 

31from ..interfaces import Database 

32from ..nameShrinker import NameShrinker 

33from ...core import ddl, time_utils, Timespan, TimespanDatabaseRepresentation 

34 

35 

36class PostgresqlDatabase(Database): 

37 """An implementation of the `Database` interface for PostgreSQL. 

38 

39 Parameters 

40 ---------- 

41 connection : `sqlalchemy.engine.Connection` 

42 An existing connection created by a previous call to `connect`. 

43 origin : `int` 

44 An integer ID that should be used as the default for any datasets, 

45 quanta, or other entities that use a (autoincrement, origin) compound 

46 primary key. 

47 namespace : `str`, optional 

48 The namespace (schema) this database is associated with. If `None`, 

49 the default schema for the connection is used (which may be `None`). 

50 writeable : `bool`, optional 

51 If `True`, allow write operations on the database, including 

52 ``CREATE TABLE``. 

53 

54 Notes 

55 ----- 

56 This currently requires the psycopg2 driver to be used as the backend for 

57 SQLAlchemy. Running the tests for this class requires the 

58 ``testing.postgresql`` be installed, which we assume indicates that a 

59 PostgreSQL server is installed and can be run locally in userspace. 

60 

61 Some functionality provided by this class (and used by `Registry`) requires 

62 the ``btree_gist`` PostgreSQL server extension to be installed an enabled 

63 on the database being connected to; this is checked at connection time. 

64 """ 

65 

66 def __init__(self, *, engine: sqlalchemy.engine.Engine, origin: int, 

67 namespace: Optional[str] = None, writeable: bool = True): 

68 super().__init__(origin=origin, engine=engine, namespace=namespace) 

69 with engine.connect() as connection: 

70 dbapi = connection.connection 

71 try: 

72 dsn = dbapi.get_dsn_parameters() 

73 except (AttributeError, KeyError) as err: 

74 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err 

75 if namespace is None: 

76 namespace = connection.execute("SELECT current_schema();").scalar() 

77 query = "SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';" 

78 if not connection.execute(query).scalar(): 

79 raise RuntimeError( 

80 "The Butler PostgreSQL backend requires the btree_gist extension. " 

81 "As extensions are enabled per-database, this may require an administrator to run " 

82 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is " 

83 " initialized." 

84 ) 

85 self.namespace = namespace 

86 self.dbname = dsn.get("dbname") 

87 self._writeable = writeable 

88 self._shrinker = NameShrinker(connection.engine.dialect.max_identifier_length) 

89 

90 @classmethod 

91 def makeEngine(cls, uri: str, *, writeable: bool = True) -> sqlalchemy.engine.Engine: 

92 return sqlalchemy.engine.create_engine(uri) 

93 

94 @classmethod 

95 def fromEngine(cls, engine: sqlalchemy.engine.Engine, *, origin: int, 

96 namespace: Optional[str] = None, writeable: bool = True) -> Database: 

97 return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable) 

98 

99 @contextmanager 

100 def transaction(self, *, interrupting: bool = False, savepoint: bool = False, 

101 lock: Iterable[sqlalchemy.schema.Table] = ()) -> Iterator[None]: 

102 with super().transaction(interrupting=interrupting, savepoint=savepoint, lock=lock): 

103 if not self.isWriteable(): 

104 with closing(self._connection.connection.cursor()) as cursor: 

105 cursor.execute("SET TRANSACTION READ ONLY") 

106 else: 

107 with closing(self._connection.connection.cursor()) as cursor: 

108 # Make timestamps UTC, because we didn't use TIMESTAMPZ for 

109 # the column type. When we can tolerate a schema change, 

110 # we should change that type and remove this line. 

111 cursor.execute("SET TIME ZONE 0") 

112 yield 

113 

114 def _lockTables(self, tables: Iterable[sqlalchemy.schema.Table] = ()) -> None: 

115 # Docstring inherited. 

116 for table in tables: 

117 self._connection.execute(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE") 

118 

119 def isWriteable(self) -> bool: 

120 return self._writeable 

121 

122 def __str__(self) -> str: 

123 return f"PostgreSQL@{self.dbname}:{self.namespace}" 

124 

125 def shrinkDatabaseEntityName(self, original: str) -> str: 

126 return self._shrinker.shrink(original) 

127 

128 def expandDatabaseEntityName(self, shrunk: str) -> str: 

129 return self._shrinker.expand(shrunk) 

130 

131 def _convertExclusionConstraintSpec(self, table: str, 

132 spec: Tuple[Union[str, Type[TimespanDatabaseRepresentation]], ...], 

133 metadata: sqlalchemy.MetaData) -> sqlalchemy.schema.Constraint: 

134 # Docstring inherited. 

135 args = [] 

136 names = ["excl"] 

137 for item in spec: 

138 if isinstance(item, str): 

139 args.append((sqlalchemy.schema.Column(item), "=")) 

140 names.append(item) 

141 elif issubclass(item, TimespanDatabaseRepresentation): 

142 assert item is self.getTimespanRepresentation() 

143 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&")) 

144 names.append(TimespanDatabaseRepresentation.NAME) 

145 return sqlalchemy.dialects.postgresql.ExcludeConstraint( 

146 *args, 

147 name=self.shrinkDatabaseEntityName("_".join(names)), 

148 ) 

149 

150 @classmethod 

151 def getTimespanRepresentation(cls) -> Type[TimespanDatabaseRepresentation]: 

152 # Docstring inherited. 

153 return _RangeTimespanRepresentation 

154 

155 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None: 

156 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.") 

157 if not rows: 

158 return 

159 # This uses special support for UPSERT in PostgreSQL backend: 

160 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert 

161 query = sqlalchemy.dialects.postgresql.dml.insert(table) 

162 # In the SET clause assign all columns using special `excluded` 

163 # pseudo-table. If some column in the table does not appear in the 

164 # INSERT list this will set it to NULL. 

165 excluded = query.excluded 

166 data = {column.name: getattr(excluded, column.name) 

167 for column in table.columns 

168 if column.name not in table.primary_key} 

169 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data) 

170 self._connection.execute(query, *rows) 

171 

172 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict) -> int: 

173 # Docstring inherited. 

174 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.") 

175 if not rows: 

176 return 0 

177 # Like `replace`, this uses UPSERT, but it's a bit simpler because 

178 # we don't care which constraint is violated or specify which columns 

179 # to update. 

180 query = sqlalchemy.dialects.postgresql.dml.insert(table).on_conflict_do_nothing() 

181 return self._connection.execute(query, *rows).rowcount 

182 

183 

184class _RangeTimespanType(sqlalchemy.TypeDecorator): 

185 """A single-column `Timespan` representation usable only with 

186 PostgreSQL. 

187 

188 This type should be able to take advantage of PostgreSQL's built-in 

189 range operators, and the indexing and EXCLUSION table constraints built 

190 off of them. 

191 """ 

192 

193 impl = sqlalchemy.dialects.postgresql.INT8RANGE 

194 

195 def process_bind_param(self, value: Optional[Timespan], 

196 dialect: sqlalchemy.engine.Dialect 

197 ) -> Optional[psycopg2.extras.NumericRange]: 

198 if value is None: 

199 return None 

200 if not isinstance(value, Timespan): 

201 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.") 

202 if value.isEmpty(): 

203 return psycopg2.extras.NumericRange(empty=True) 

204 else: 

205 converter = time_utils.TimeConverter() 

206 assert value._nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__." 

207 assert value._nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__." 

208 lower = None if value._nsec[0] == converter.min_nsec else value._nsec[0] 

209 upper = None if value._nsec[1] == converter.max_nsec else value._nsec[1] 

210 return psycopg2.extras.NumericRange(lower=lower, upper=upper) 

211 

212 def process_result_value(self, value: Optional[psycopg2.extras.NumericRange], 

213 dialect: sqlalchemy.engine.Dialect 

214 ) -> Optional[Timespan]: 

215 if value is None: 

216 return None 

217 if value.isempty: 

218 return Timespan.makeEmpty() 

219 converter = time_utils.TimeConverter() 

220 begin_nsec = converter.min_nsec if value.lower is None else value.lower 

221 end_nsec = converter.max_nsec if value.upper is None else value.upper 

222 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec)) 

223 

224 class comparator_factory(sqlalchemy.types.Concatenable.Comparator): # noqa: N801 

225 """Comparison operators for TimespanColumnRanges. 

226 

227 Notes 

228 ----- 

229 The existence of this nested class is a workaround for a bug 

230 submitted upstream as 

231 https://github.com/sqlalchemy/sqlalchemy/issues/5476 (now fixed on 

232 master, but not in the releases we currently use). The code is 

233 a limited copy of the operators in 

234 ``sqlalchemy.dialects.postgresql.ranges.RangeOperators``, but with 

235 ``is_comparison=True`` added to all calls. 

236 """ 

237 

238 def __ne__(self, other: Any) -> Any: 

239 "Boolean expression. Returns true if two ranges are not equal" 

240 if other is None: 

241 return super().__ne__(other) 

242 else: 

243 return self.expr.op("<>", is_comparison=True)(other) 

244 

245 def contains(self, other: Any, **kw: Any) -> Any: 

246 """Boolean expression. Returns true if the right hand operand, 

247 which can be an element or a range, is contained within the 

248 column. 

249 """ 

250 return self.expr.op("@>", is_comparison=True)(other) 

251 

252 def contained_by(self, other: Any) -> Any: 

253 """Boolean expression. Returns true if the column is contained 

254 within the right hand operand. 

255 """ 

256 return self.expr.op("<@", is_comparison=True)(other) 

257 

258 def overlaps(self, other: Any) -> Any: 

259 """Boolean expression. Returns true if the column overlaps 

260 (has points in common with) the right hand operand. 

261 """ 

262 return self.expr.op("&&", is_comparison=True)(other) 

263 

264 def strictly_left_of(self, other: Any) -> Any: 

265 """Boolean expression. Returns true if the column is strictly 

266 left of the right hand operand. 

267 """ 

268 return self.expr.op("<<", is_comparison=True)(other) 

269 

270 __lshift__ = strictly_left_of 

271 

272 def strictly_right_of(self, other: Any) -> Any: 

273 """Boolean expression. Returns true if the column is strictly 

274 right of the right hand operand. 

275 """ 

276 return self.expr.op(">>", is_comparison=True)(other) 

277 

278 __rshift__ = strictly_right_of 

279 

280 

281class _RangeTimespanRepresentation(TimespanDatabaseRepresentation): 

282 """An implementation of `TimespanDatabaseRepresentation` that uses 

283 `_RangeTimespanType` to store a timespan in a single 

284 PostgreSQL-specific field. 

285 

286 Parameters 

287 ---------- 

288 column : `sqlalchemy.sql.ColumnElement` 

289 SQLAlchemy object representing the column. 

290 """ 

291 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str): 

292 self.column = column 

293 self._name = name 

294 

295 __slots__ = ("column", "_name") 

296 

297 @classmethod 

298 def makeFieldSpecs(cls, nullable: bool, name: Optional[str] = None, **kwargs: Any 

299 ) -> Tuple[ddl.FieldSpec, ...]: 

300 # Docstring inherited. 

301 if name is None: 

302 name = cls.NAME 

303 return ( 

304 ddl.FieldSpec( 

305 name, dtype=_RangeTimespanType, nullable=nullable, 

306 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")), 

307 **kwargs 

308 ), 

309 ) 

310 

311 @classmethod 

312 def getFieldNames(cls, name: Optional[str] = None) -> Tuple[str, ...]: 

313 # Docstring inherited. 

314 if name is None: 

315 name = cls.NAME 

316 return (name,) 

317 

318 @classmethod 

319 def update(cls, extent: Optional[Timespan], name: Optional[str] = None, 

320 result: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: 

321 # Docstring inherited. 

322 if name is None: 

323 name = cls.NAME 

324 if result is None: 

325 result = {} 

326 result[name] = extent 

327 return result 

328 

329 @classmethod 

330 def extract(cls, mapping: Mapping[str, Any], name: Optional[str] = None) -> Optional[Timespan]: 

331 # Docstring inherited. 

332 if name is None: 

333 name = cls.NAME 

334 return mapping[name] 

335 

336 @classmethod 

337 def fromLiteral(cls, timespan: Timespan) -> _RangeTimespanRepresentation: 

338 # Docstring inherited. 

339 return cls( 

340 column=sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType), 

341 name=cls.NAME, 

342 ) 

343 

344 @classmethod 

345 def fromSelectable(cls, selectable: sqlalchemy.sql.FromClause, name: Optional[str] = None 

346 ) -> _RangeTimespanRepresentation: 

347 # Docstring inherited. 

348 if name is None: 

349 name = cls.NAME 

350 return cls(selectable.columns[name], name) 

351 

352 @property 

353 def name(self) -> str: 

354 # Docstring inherited. 

355 return self._name 

356 

357 def isNull(self) -> sqlalchemy.sql.ColumnElement: 

358 # Docstring inherited. 

359 return self.column.is_(None) 

360 

361 def isEmpty(self) -> sqlalchemy.sql.ColumnElement: 

362 # Docstring inherited 

363 return sqlalchemy.sql.func.isempty(self.column) 

364 

365 def __lt__( 

366 self, 

367 other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

368 ) -> sqlalchemy.sql.ColumnElement: 

369 # Docstring inherited. 

370 if isinstance(other, sqlalchemy.sql.ColumnElement): 

371 return sqlalchemy.sql.and_( 

372 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)), 

373 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

374 sqlalchemy.sql.func.upper(self.column) <= other, 

375 ) 

376 else: 

377 return self.column << other.column 

378 

379 def __gt__( 

380 self, 

381 other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

382 ) -> sqlalchemy.sql.ColumnElement: 

383 # Docstring inherited. 

384 if isinstance(other, sqlalchemy.sql.ColumnElement): 

385 return sqlalchemy.sql.and_( 

386 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)), 

387 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

388 sqlalchemy.sql.func.lower(self.column) > other, 

389 ) 

390 else: 

391 return self.column >> other.column 

392 

393 def overlaps(self, other: _RangeTimespanRepresentation) -> sqlalchemy.sql.ColumnElement: 

394 # Docstring inherited. 

395 return self.column.overlaps(other.column) 

396 

397 def contains(self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

398 ) -> sqlalchemy.sql.ColumnElement: 

399 # Docstring inherited 

400 if isinstance(other, _RangeTimespanRepresentation): 

401 return self.column.contains(other.column) 

402 else: 

403 return self.column.contains(other) 

404 

405 def flatten(self, name: Optional[str] = None) -> Iterator[sqlalchemy.sql.ColumnElement]: 

406 # Docstring inherited. 

407 if name is None: 

408 yield self.column 

409 else: 

410 yield self.column.label(name)