Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ["PostgresqlDatabase"] 

24 

25from contextlib import contextmanager, closing 

26from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, Type, Union 

27 

28import psycopg2 

29import sqlalchemy.dialects.postgresql 

30 

31from ..interfaces import Database 

32from ..nameShrinker import NameShrinker 

33from ...core import DatabaseTimespanRepresentation, ddl, Timespan, time_utils 

34 

35 

36class PostgresqlDatabase(Database): 

37 """An implementation of the `Database` interface for PostgreSQL. 

38 

39 Parameters 

40 ---------- 

41 connection : `sqlalchemy.engine.Connection` 

42 An existing connection created by a previous call to `connect`. 

43 origin : `int` 

44 An integer ID that should be used as the default for any datasets, 

45 quanta, or other entities that use a (autoincrement, origin) compound 

46 primary key. 

47 namespace : `str`, optional 

48 The namespace (schema) this database is associated with. If `None`, 

49 the default schema for the connection is used (which may be `None`). 

50 writeable : `bool`, optional 

51 If `True`, allow write operations on the database, including 

52 ``CREATE TABLE``. 

53 

54 Notes 

55 ----- 

56 This currently requires the psycopg2 driver to be used as the backend for 

57 SQLAlchemy. Running the tests for this class requires the 

58 ``testing.postgresql`` be installed, which we assume indicates that a 

59 PostgreSQL server is installed and can be run locally in userspace. 

60 

61 Some functionality provided by this class (and used by `Registry`) requires 

62 the ``btree_gist`` PostgreSQL server extension to be installed an enabled 

63 on the database being connected to; this is checked at connection time. 

64 """ 

65 

66 def __init__(self, *, connection: sqlalchemy.engine.Connection, origin: int, 

67 namespace: Optional[str] = None, writeable: bool = True): 

68 super().__init__(origin=origin, connection=connection, namespace=namespace) 

69 dbapi = connection.connection 

70 try: 

71 dsn = dbapi.get_dsn_parameters() 

72 except (AttributeError, KeyError) as err: 

73 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err 

74 if namespace is None: 

75 namespace = connection.execute("SELECT current_schema();").scalar() 

76 if not connection.execute("SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';").scalar(): 

77 raise RuntimeError( 

78 "The Butler PostgreSQL backend requires the btree_gist extension. " 

79 "As extensions are enabled per-database, this may require an administrator to run " 

80 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is initialized." 

81 ) 

82 self.namespace = namespace 

83 self.dbname = dsn.get("dbname") 

84 self._writeable = writeable 

85 self._shrinker = NameShrinker(connection.engine.dialect.max_identifier_length) 

86 

87 @classmethod 

88 def connect(cls, uri: str, *, writeable: bool = True) -> sqlalchemy.engine.Connection: 

89 return sqlalchemy.engine.create_engine(uri, poolclass=sqlalchemy.pool.NullPool).connect() 

90 

91 @classmethod 

92 def fromConnection(cls, connection: sqlalchemy.engine.Connection, *, origin: int, 

93 namespace: Optional[str] = None, writeable: bool = True) -> Database: 

94 return cls(connection=connection, origin=origin, namespace=namespace, writeable=writeable) 

95 

96 @contextmanager 

97 def transaction(self, *, interrupting: bool = False, savepoint: bool = False, 

98 lock: Iterable[sqlalchemy.schema.Table] = ()) -> Iterator[None]: 

99 with super().transaction(interrupting=interrupting, savepoint=savepoint, lock=lock): 

100 if not self.isWriteable(): 

101 with closing(self._connection.connection.cursor()) as cursor: 

102 cursor.execute("SET TRANSACTION READ ONLY") 

103 yield 

104 

105 def _lockTables(self, tables: Iterable[sqlalchemy.schema.Table] = ()) -> None: 

106 # Docstring inherited. 

107 for table in tables: 

108 self._connection.execute(f"LOCK TABLE {table.key} IN EXCLUSIVE MODE") 

109 

110 def isWriteable(self) -> bool: 

111 return self._writeable 

112 

113 def __str__(self) -> str: 

114 return f"PostgreSQL@{self.dbname}:{self.namespace}" 

115 

116 def shrinkDatabaseEntityName(self, original: str) -> str: 

117 return self._shrinker.shrink(original) 

118 

119 def expandDatabaseEntityName(self, shrunk: str) -> str: 

120 return self._shrinker.expand(shrunk) 

121 

122 def _convertExclusionConstraintSpec(self, table: str, 

123 spec: Tuple[Union[str, Type[DatabaseTimespanRepresentation]], ...], 

124 metadata: sqlalchemy.MetaData) -> sqlalchemy.schema.Constraint: 

125 # Docstring inherited. 

126 args = [] 

127 names = ["excl"] 

128 for item in spec: 

129 if isinstance(item, str): 

130 args.append((sqlalchemy.schema.Column(item), "=")) 

131 names.append(item) 

132 elif issubclass(item, DatabaseTimespanRepresentation): 

133 assert item is self.getTimespanRepresentation() 

134 args.append((sqlalchemy.schema.Column(DatabaseTimespanRepresentation.NAME), "&&")) 

135 names.append(DatabaseTimespanRepresentation.NAME) 

136 return sqlalchemy.dialects.postgresql.ExcludeConstraint( 

137 *args, 

138 name=self.shrinkDatabaseEntityName("_".join(names)), 

139 ) 

140 

141 @classmethod 

142 def getTimespanRepresentation(cls) -> Type[DatabaseTimespanRepresentation]: 

143 # Docstring inherited. 

144 return _RangeTimespanRepresentation 

145 

146 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None: 

147 self.assertTableWriteable(table, f"Cannot replace into read-only table {table}.") 

148 if not rows: 

149 return 

150 # This uses special support for UPSERT in PostgreSQL backend: 

151 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert 

152 query = sqlalchemy.dialects.postgresql.dml.insert(table) 

153 # In the SET clause assign all columns using special `excluded` 

154 # pseudo-table. If some column in the table does not appear in the 

155 # INSERT list this will set it to NULL. 

156 excluded = query.excluded 

157 data = {column.name: getattr(excluded, column.name) 

158 for column in table.columns 

159 if column.name not in table.primary_key} 

160 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data) 

161 self._connection.execute(query, *rows) 

162 

163 def ensure(self, table: sqlalchemy.schema.Table, *rows: dict) -> int: 

164 # Docstring inherited. 

165 self.assertTableWriteable(table, f"Cannot ensure into read-only table {table}.") 

166 if not rows: 

167 return 0 

168 # Like `replace`, this uses UPSERT, but it's a bit simpler because 

169 # we don't care which constraint is violated or specify which columns 

170 # to update. 

171 query = sqlalchemy.dialects.postgresql.dml.insert(table).on_conflict_do_nothing() 

172 return self._connection.execute(query, *rows).rowcount 

173 

174 

175class _RangeTimespanType(sqlalchemy.TypeDecorator): 

176 """A single-column `Timespan` representation usable only with 

177 PostgreSQL. 

178 

179 This type should be able to take advantage of PostgreSQL's built-in 

180 range operators, and the indexing and EXCLUSION table constraints built 

181 off of them. 

182 """ 

183 

184 impl = sqlalchemy.dialects.postgresql.INT8RANGE 

185 

186 def process_bind_param(self, value: Optional[Timespan], 

187 dialect: sqlalchemy.engine.Dialect 

188 ) -> Optional[psycopg2.extras.NumericRange]: 

189 if value is None: 

190 return None 

191 if not isinstance(value, Timespan): 

192 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.") 

193 lower = None if value.begin is None else time_utils.astropy_to_nsec(value.begin) 

194 upper = None if value.end is None else time_utils.astropy_to_nsec(value.end) 

195 return psycopg2.extras.NumericRange(lower=lower, upper=upper) 

196 

197 def process_result_value(self, value: Optional[psycopg2.extras.NumericRange], 

198 dialect: sqlalchemy.engine.Dialect 

199 ) -> Optional[Timespan]: 

200 if value is None or value.isempty: 

201 return None 

202 begin = None if value.lower is None else time_utils.nsec_to_astropy(value.lower) 

203 end = None if value.upper is None else time_utils.nsec_to_astropy(value.upper) 

204 return Timespan(begin=begin, end=end) 

205 

206 class comparator_factory(sqlalchemy.types.Concatenable.Comparator): # noqa: N801 

207 """Comparison operators for TimespanColumnRanges. 

208 

209 Notes 

210 ----- 

211 The existence of this nested class is a workaround for a bug 

212 submitted upstream as 

213 https://github.com/sqlalchemy/sqlalchemy/issues/5476. The code is 

214 a limited copy of the operators in 

215 ``sqlalchemy.dialects.postgresql.ranges.RangeOperators``, but with 

216 ``is_comparison=True`` added to all calls. 

217 """ 

218 

219 def __ne__(self, other: Any) -> Any: 

220 "Boolean expression. Returns true if two ranges are not equal" 

221 if other is None: 

222 return super().__ne__(other) 

223 else: 

224 return self.expr.op("<>", is_comparison=True)(other) 

225 

226 def contains(self, other: Any, **kw: Any) -> Any: 

227 """Boolean expression. Returns true if the right hand operand, 

228 which can be an element or a range, is contained within the 

229 column. 

230 """ 

231 return self.expr.op("@>", is_comparison=True)(other) 

232 

233 def contained_by(self, other: Any) -> Any: 

234 """Boolean expression. Returns true if the column is contained 

235 within the right hand operand. 

236 """ 

237 return self.expr.op("<@", is_comparison=True)(other) 

238 

239 def overlaps(self, other: Any) -> Any: 

240 """Boolean expression. Returns true if the column overlaps 

241 (has points in common with) the right hand operand. 

242 """ 

243 return self.expr.op("&&", is_comparison=True)(other) 

244 

245 

246class _RangeTimespanRepresentation(DatabaseTimespanRepresentation): 

247 """An implementation of `DatabaseTimespanRepresentation` that uses 

248 `_RangeTimespanType` to store a timespan in a single 

249 PostgreSQL-specific field. 

250 

251 Parameters 

252 ---------- 

253 column : `sqlalchemy.sql.ColumnElement` 

254 SQLAlchemy object representing the column. 

255 """ 

256 def __init__(self, column: sqlalchemy.sql.ColumnElement): 

257 self.column = column 

258 

259 __slots__ = ("column",) 

260 

261 @classmethod 

262 def makeFieldSpecs(cls, nullable: bool, **kwargs: Any) -> Tuple[ddl.FieldSpec, ...]: 

263 # Docstring inherited. 

264 return ( 

265 ddl.FieldSpec( 

266 cls.NAME, dtype=_RangeTimespanType, nullable=nullable, 

267 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")), 

268 **kwargs 

269 ), 

270 ) 

271 

272 @classmethod 

273 def getFieldNames(cls) -> Tuple[str, ...]: 

274 # Docstring inherited. 

275 return (cls.NAME,) 

276 

277 @classmethod 

278 def update(cls, timespan: Optional[Timespan], *, 

279 result: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: 

280 # Docstring inherited. 

281 if result is None: 

282 result = {} 

283 result[cls.NAME] = timespan 

284 return result 

285 

286 @classmethod 

287 def extract(cls, mapping: Mapping[str, Any]) -> Optional[Timespan]: 

288 # Docstring inherited. 

289 return mapping[cls.NAME] 

290 

291 @classmethod 

292 def hasExclusionConstraint(cls) -> bool: 

293 # Docstring inherited. 

294 return True 

295 

296 @classmethod 

297 def fromSelectable(cls, selectable: sqlalchemy.sql.FromClause) -> _RangeTimespanRepresentation: 

298 # Docstring inherited. 

299 return cls(selectable.columns[cls.NAME]) 

300 

301 def isNull(self) -> sqlalchemy.sql.ColumnElement: 

302 # Docstring inherited. 

303 return self.column.is_(None) 

304 

305 def overlaps(self, other: Union[Timespan, _RangeTimespanRepresentation]) -> sqlalchemy.sql.ColumnElement: 

306 # Docstring inherited. 

307 if isinstance(other, Timespan): 

308 return self.column.overlaps(other) 

309 else: 

310 return self.column.overlaps(other.column)