Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ["PostgresqlDatabase"] 

24 

25from contextlib import contextmanager, closing 

26from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, Type, Union 

27 

28import psycopg2 

29import sqlalchemy.dialects.postgresql 

30 

31from ..interfaces import Database 

32from ..nameShrinker import NameShrinker 

33from ...core import ddl, time_utils, Timespan, TimespanDatabaseRepresentation 

34 

35 

36class PostgresqlDatabase(Database): 

37 """An implementation of the `Database` interface for PostgreSQL. 

38 

39 Parameters 

40 ---------- 

41 connection : `sqlalchemy.engine.Connection` 

42 An existing connection created by a previous call to `connect`. 

43 origin : `int` 

44 An integer ID that should be used as the default for any datasets, 

45 quanta, or other entities that use a (autoincrement, origin) compound 

46 primary key. 

47 namespace : `str`, optional 

48 The namespace (schema) this database is associated with. If `None`, 

49 the default schema for the connection is used (which may be `None`). 

50 writeable : `bool`, optional 

51 If `True`, allow write operations on the database, including 

52 ``CREATE TABLE``. 

53 

54 Notes 

55 ----- 

56 This currently requires the psycopg2 driver to be used as the backend for 

57 SQLAlchemy. Running the tests for this class requires the 

58 ``testing.postgresql`` be installed, which we assume indicates that a 

59 PostgreSQL server is installed and can be run locally in userspace. 

60 

61 Some functionality provided by this class (and used by `Registry`) requires 

62 the ``btree_gist`` PostgreSQL server extension to be installed an enabled 

63 on the database being connected to; this is checked at connection time. 

64 """ 

65 

66 def __init__(self, *, connection: sqlalchemy.engine.Connection, origin: int, 

67 namespace: Optional[str] = None, writeable: bool = True): 

68 super().__init__(origin=origin, connection=connection, namespace=namespace) 

69 dbapi = connection.connection 

70 try: 

71 dsn = dbapi.get_dsn_parameters() 

72 except (AttributeError, KeyError) as err: 

73 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err 

74 if namespace is None: 

75 namespace = connection.execute("SELECT current_schema();").scalar() 

76 if not connection.execute("SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';").scalar(): 

77 raise RuntimeError( 

78 "The Butler PostgreSQL backend requires the btree_gist extension. " 

79 "As extensions are enabled per-database, this may require an administrator to run " 

80 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is initialized." 

81 ) 

82 self.namespace = namespace 

83 self.dbname = dsn.get("dbname") 

84 self._writeable = writeable 

85 self._shrinker = NameShrinker(connection.engine.dialect.max_identifier_length) 

86 

87 @classmethod 

88 def connect(cls, uri: str, *, writeable: bool = True) -> sqlalchemy.engine.Connection: 

89 return sqlalchemy.engine.create_engine(uri, poolclass=sqlalchemy.pool.NullPool).connect() 

90 

91 @classmethod 

92 def fromConnection(cls, connection: sqlalchemy.engine.Connection, *, origin: int, 

93 namespace: Optional[str] = None, writeable: bool = True) -> Database: 

94 return cls(connection=connection, origin=origin, namespace=namespace, writeable=writeable) 

95 

96 @contextmanager 

97 def transaction(self, *, interrupting: bool = False, savepoint: bool = False, 

98 lock: Iterable[sqlalchemy.schema.Table] = ()) -> Iterator[None]: 

99 with super().transaction(interrupting=interrupting, savepoint=savepoint, lock=lock): 

100 if not self.isWriteable(): 

101 with closing(self._connection.connection.cursor()) as cursor: 

102 cursor.execute("SET TRANSACTION READ ONLY") 

103 else: 

104 with closing(self._connection.connection.cursor()) as cursor: 

105 # Make timestamps UTC, because we didn't use TIMESTAMPZ for 

106 # the column type. When we can tolerate a schema change, 

107 # we should change that type and remove this line. 

108 cursor.execute("SET TIME ZONE 0") 

109 yield 

110 

111 def _lockTables(self, tables: Iterable[sqlalchemy.schema.Table] = ()) -> None: 

112 # Docstring inherited. 

113 for table in tables: 

114 self._connection.execute(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE") 

115 

116 def isWriteable(self) -> bool: 

117 return self._writeable 

118 

119 def __str__(self) -> str: 

120 return f"PostgreSQL@{self.dbname}:{self.namespace}" 

121 

122 def shrinkDatabaseEntityName(self, original: str) -> str: 

123 return self._shrinker.shrink(original) 

124 

125 def expandDatabaseEntityName(self, shrunk: str) -> str: 

126 return self._shrinker.expand(shrunk) 

127 

128 def _convertExclusionConstraintSpec(self, table: str, 

129 spec: Tuple[Union[str, Type[TimespanDatabaseRepresentation]], ...], 

130 metadata: sqlalchemy.MetaData) -> sqlalchemy.schema.Constraint: 

131 # Docstring inherited. 

132 args = [] 

133 names = ["excl"] 

134 for item in spec: 

135 if isinstance(item, str): 

136 args.append((sqlalchemy.schema.Column(item), "=")) 

137 names.append(item) 

138 elif issubclass(item, TimespanDatabaseRepresentation): 

139 assert item is self.getTimespanRepresentation() 

140 args.append((sqlalchemy.schema.Column(TimespanDatabaseRepresentation.NAME), "&&")) 

141 names.append(TimespanDatabaseRepresentation.NAME) 

142 return sqlalchemy.dialects.postgresql.ExcludeConstraint( 

143 *args, 

144 name=self.shrinkDatabaseEntityName("_".join(names)), 

145 ) 

146 

147 @classmethod 

148 def getTimespanRepresentation(cls) -> Type[TimespanDatabaseRepresentation]: 

149 # Docstring inherited. 

150 return _RangeTimespanRepresentation 

151 

152 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None: 

153 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.") 

154 if not rows: 

155 return 

156 # This uses special support for UPSERT in PostgreSQL backend: 

157 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert 

158 query = sqlalchemy.dialects.postgresql.dml.insert(table) 

159 # In the SET clause assign all columns using special `excluded` 

160 # pseudo-table. If some column in the table does not appear in the 

161 # INSERT list this will set it to NULL. 

162 excluded = query.excluded 

163 data = {column.name: getattr(excluded, column.name) 

164 for column in table.columns 

165 if column.name not in table.primary_key} 

166 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data) 

167 self._connection.execute(query, *rows) 

168 

169 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict) -> int: 

170 # Docstring inherited. 

171 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.") 

172 if not rows: 

173 return 0 

174 # Like `replace`, this uses UPSERT, but it's a bit simpler because 

175 # we don't care which constraint is violated or specify which columns 

176 # to update. 

177 query = sqlalchemy.dialects.postgresql.dml.insert(table).on_conflict_do_nothing() 

178 return self._connection.execute(query, *rows).rowcount 

179 

180 

181class _RangeTimespanType(sqlalchemy.TypeDecorator): 

182 """A single-column `Timespan` representation usable only with 

183 PostgreSQL. 

184 

185 This type should be able to take advantage of PostgreSQL's built-in 

186 range operators, and the indexing and EXCLUSION table constraints built 

187 off of them. 

188 """ 

189 

190 impl = sqlalchemy.dialects.postgresql.INT8RANGE 

191 

192 def process_bind_param(self, value: Optional[Timespan], 

193 dialect: sqlalchemy.engine.Dialect 

194 ) -> Optional[psycopg2.extras.NumericRange]: 

195 if value is None: 

196 return None 

197 if not isinstance(value, Timespan): 

198 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.") 

199 if value.isEmpty(): 

200 return psycopg2.extras.NumericRange(empty=True) 

201 else: 

202 converter = time_utils.TimeConverter() 

203 assert value._nsec[0] >= converter.min_nsec, "Guaranteed by Timespan.__init__." 

204 assert value._nsec[1] <= converter.max_nsec, "Guaranteed by Timespan.__init__." 

205 lower = None if value._nsec[0] == converter.min_nsec else value._nsec[0] 

206 upper = None if value._nsec[1] == converter.max_nsec else value._nsec[1] 

207 return psycopg2.extras.NumericRange(lower=lower, upper=upper) 

208 

209 def process_result_value(self, value: Optional[psycopg2.extras.NumericRange], 

210 dialect: sqlalchemy.engine.Dialect 

211 ) -> Optional[Timespan]: 

212 if value is None: 

213 return None 

214 if value.isempty: 

215 return Timespan.makeEmpty() 

216 converter = time_utils.TimeConverter() 

217 begin_nsec = converter.min_nsec if value.lower is None else value.lower 

218 end_nsec = converter.max_nsec if value.upper is None else value.upper 

219 return Timespan(begin=None, end=None, _nsec=(begin_nsec, end_nsec)) 

220 

221 class comparator_factory(sqlalchemy.types.Concatenable.Comparator): # noqa: N801 

222 """Comparison operators for TimespanColumnRanges. 

223 

224 Notes 

225 ----- 

226 The existence of this nested class is a workaround for a bug 

227 submitted upstream as 

228 https://github.com/sqlalchemy/sqlalchemy/issues/5476 (now fixed on 

229 master, but not in the releases we currently use). The code is 

230 a limited copy of the operators in 

231 ``sqlalchemy.dialects.postgresql.ranges.RangeOperators``, but with 

232 ``is_comparison=True`` added to all calls. 

233 """ 

234 

235 def __ne__(self, other: Any) -> Any: 

236 "Boolean expression. Returns true if two ranges are not equal" 

237 if other is None: 

238 return super().__ne__(other) 

239 else: 

240 return self.expr.op("<>", is_comparison=True)(other) 

241 

242 def contains(self, other: Any, **kw: Any) -> Any: 

243 """Boolean expression. Returns true if the right hand operand, 

244 which can be an element or a range, is contained within the 

245 column. 

246 """ 

247 return self.expr.op("@>", is_comparison=True)(other) 

248 

249 def contained_by(self, other: Any) -> Any: 

250 """Boolean expression. Returns true if the column is contained 

251 within the right hand operand. 

252 """ 

253 return self.expr.op("<@", is_comparison=True)(other) 

254 

255 def overlaps(self, other: Any) -> Any: 

256 """Boolean expression. Returns true if the column overlaps 

257 (has points in common with) the right hand operand. 

258 """ 

259 return self.expr.op("&&", is_comparison=True)(other) 

260 

261 def strictly_left_of(self, other: Any) -> Any: 

262 """Boolean expression. Returns true if the column is strictly 

263 left of the right hand operand. 

264 """ 

265 return self.expr.op("<<", is_comparison=True)(other) 

266 

267 __lshift__ = strictly_left_of 

268 

269 def strictly_right_of(self, other: Any) -> Any: 

270 """Boolean expression. Returns true if the column is strictly 

271 right of the right hand operand. 

272 """ 

273 return self.expr.op(">>", is_comparison=True)(other) 

274 

275 __rshift__ = strictly_right_of 

276 

277 

278class _RangeTimespanRepresentation(TimespanDatabaseRepresentation): 

279 """An implementation of `TimespanDatabaseRepresentation` that uses 

280 `_RangeTimespanType` to store a timespan in a single 

281 PostgreSQL-specific field. 

282 

283 Parameters 

284 ---------- 

285 column : `sqlalchemy.sql.ColumnElement` 

286 SQLAlchemy object representing the column. 

287 """ 

288 def __init__(self, column: sqlalchemy.sql.ColumnElement, name: str): 

289 self.column = column 

290 self._name = name 

291 

292 __slots__ = ("column", "_name") 

293 

294 @classmethod 

295 def makeFieldSpecs(cls, nullable: bool, name: Optional[str] = None, **kwargs: Any 

296 ) -> Tuple[ddl.FieldSpec, ...]: 

297 # Docstring inherited. 

298 if name is None: 

299 name = cls.NAME 

300 return ( 

301 ddl.FieldSpec( 

302 name, dtype=_RangeTimespanType, nullable=nullable, 

303 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")), 

304 **kwargs 

305 ), 

306 ) 

307 

308 @classmethod 

309 def getFieldNames(cls, name: Optional[str] = None) -> Tuple[str, ...]: 

310 # Docstring inherited. 

311 if name is None: 

312 name = cls.NAME 

313 return (name,) 

314 

315 @classmethod 

316 def update(cls, extent: Optional[Timespan], name: Optional[str] = None, 

317 result: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: 

318 # Docstring inherited. 

319 if name is None: 

320 name = cls.NAME 

321 if result is None: 

322 result = {} 

323 result[name] = extent 

324 return result 

325 

326 @classmethod 

327 def extract(cls, mapping: Mapping[str, Any], name: Optional[str] = None) -> Optional[Timespan]: 

328 # Docstring inherited. 

329 if name is None: 

330 name = cls.NAME 

331 return mapping[name] 

332 

333 @classmethod 

334 def fromLiteral(cls, timespan: Timespan) -> _RangeTimespanRepresentation: 

335 # Docstring inherited. 

336 return cls( 

337 column=sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType), 

338 name=cls.NAME, 

339 ) 

340 

341 @classmethod 

342 def fromSelectable(cls, selectable: sqlalchemy.sql.FromClause, name: Optional[str] = None 

343 ) -> _RangeTimespanRepresentation: 

344 # Docstring inherited. 

345 if name is None: 

346 name = cls.NAME 

347 return cls(selectable.columns[name], name) 

348 

349 @property 

350 def name(self) -> str: 

351 # Docstring inherited. 

352 return self._name 

353 

354 def isNull(self) -> sqlalchemy.sql.ColumnElement: 

355 # Docstring inherited. 

356 return self.column.is_(None) 

357 

358 def isEmpty(self) -> sqlalchemy.sql.ColumnElement: 

359 # Docstring inherited 

360 return sqlalchemy.sql.func.isempty(self.column) 

361 

362 def __lt__( 

363 self, 

364 other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

365 ) -> sqlalchemy.sql.ColumnElement: 

366 # Docstring inherited. 

367 if isinstance(other, sqlalchemy.sql.ColumnElement): 

368 return sqlalchemy.sql.and_( 

369 sqlalchemy.sql.not_(sqlalchemy.sql.func.upper_inf(self.column)), 

370 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

371 sqlalchemy.sql.func.upper(self.column) <= other, 

372 ) 

373 else: 

374 return self.column << other.column 

375 

376 def __gt__( 

377 self, 

378 other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

379 ) -> sqlalchemy.sql.ColumnElement: 

380 # Docstring inherited. 

381 if isinstance(other, sqlalchemy.sql.ColumnElement): 

382 return sqlalchemy.sql.and_( 

383 sqlalchemy.sql.not_(sqlalchemy.sql.func.lower_inf(self.column)), 

384 sqlalchemy.sql.not_(sqlalchemy.sql.func.isempty(self.column)), 

385 sqlalchemy.sql.func.lower(self.column) > other, 

386 ) 

387 else: 

388 return self.column >> other.column 

389 

390 def overlaps(self, other: _RangeTimespanRepresentation) -> sqlalchemy.sql.ColumnElement: 

391 # Docstring inherited. 

392 return self.column.overlaps(other.column) 

393 

394 def contains(self, other: Union[_RangeTimespanRepresentation, sqlalchemy.sql.ColumnElement] 

395 ) -> sqlalchemy.sql.ColumnElement: 

396 # Docstring inherited 

397 if isinstance(other, _RangeTimespanRepresentation): 

398 return self.column.contains(other.column) 

399 else: 

400 return self.column.contains(other) 

401 

402 def flatten(self, name: Optional[str] = None) -> Iterator[sqlalchemy.sql.ColumnElement]: 

403 # Docstring inherited. 

404 if name is None: 

405 yield self.column 

406 else: 

407 yield self.column.label(name)