Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ["PostgresqlDatabase"] 

24 

25from contextlib import contextmanager, closing 

26from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, Type, Union 

27 

28import psycopg2 

29import sqlalchemy.dialects.postgresql 

30 

31from ..interfaces import Database, ReadOnlyDatabaseError 

32from ..nameShrinker import NameShrinker 

33from ...core import DatabaseTimespanRepresentation, ddl, Timespan, time_utils 

34 

35 

36class PostgresqlDatabase(Database): 

37 """An implementation of the `Database` interface for PostgreSQL. 

38 

39 Parameters 

40 ---------- 

41 connection : `sqlalchemy.engine.Connection` 

42 An existing connection created by a previous call to `connect`. 

43 origin : `int` 

44 An integer ID that should be used as the default for any datasets, 

45 quanta, or other entities that use a (autoincrement, origin) compound 

46 primary key. 

47 namespace : `str`, optional 

48 The namespace (schema) this database is associated with. If `None`, 

49 the default schema for the connection is used (which may be `None`). 

50 writeable : `bool`, optional 

51 If `True`, allow write operations on the database, including 

52 ``CREATE TABLE``. 

53 

54 Notes 

55 ----- 

56 This currently requires the psycopg2 driver to be used as the backend for 

57 SQLAlchemy. Running the tests for this class requires the 

58 ``testing.postgresql`` be installed, which we assume indicates that a 

59 PostgreSQL server is installed and can be run locally in userspace. 

60 

61 Some functionality provided by this class (and used by `Registry`) requires 

62 the ``btree_gist`` PostgreSQL server extension to be installed an enabled 

63 on the database being connected to; this is checked at connection time. 

64 """ 

65 

66 def __init__(self, *, connection: sqlalchemy.engine.Connection, origin: int, 

67 namespace: Optional[str] = None, writeable: bool = True): 

68 super().__init__(origin=origin, connection=connection, namespace=namespace) 

69 dbapi = connection.connection 

70 try: 

71 dsn = dbapi.get_dsn_parameters() 

72 except (AttributeError, KeyError) as err: 

73 raise RuntimeError("Only the psycopg2 driver for PostgreSQL is supported.") from err 

74 if namespace is None: 

75 namespace = connection.execute("SELECT current_schema();").scalar() 

76 if not connection.execute("SELECT COUNT(*) FROM pg_extension WHERE extname='btree_gist';").scalar(): 

77 raise RuntimeError( 

78 "The Butler PostgreSQL backend requires the btree_gist extension. " 

79 "As extensions are enabled per-database, this may require an administrator to run " 

80 "`CREATE EXTENSION btree_gist;` in a database before a butler client for it is initialized." 

81 ) 

82 self.namespace = namespace 

83 self.dbname = dsn.get("dbname") 

84 self._writeable = writeable 

85 self._shrinker = NameShrinker(connection.engine.dialect.max_identifier_length) 

86 

87 @classmethod 

88 def connect(cls, uri: str, *, writeable: bool = True) -> sqlalchemy.engine.Connection: 

89 return sqlalchemy.engine.create_engine(uri, poolclass=sqlalchemy.pool.NullPool).connect() 

90 

91 @classmethod 

92 def fromConnection(cls, connection: sqlalchemy.engine.Connection, *, origin: int, 

93 namespace: Optional[str] = None, writeable: bool = True) -> Database: 

94 return cls(connection=connection, origin=origin, namespace=namespace, writeable=writeable) 

95 

96 @contextmanager 

97 def transaction(self, *, interrupting: bool = False, savepoint: bool = False, 

98 lock: Iterable[sqlalchemy.schema.Table] = ()) -> Iterator[None]: 

99 with super().transaction(interrupting=interrupting, savepoint=savepoint, lock=lock): 

100 if not self.isWriteable(): 

101 with closing(self._connection.connection.cursor()) as cursor: 

102 cursor.execute("SET TRANSACTION READ ONLY") 

103 yield 

104 

105 def _lockTables(self, tables: Iterable[sqlalchemy.schema.Table] = ()) -> None: 

106 # Docstring inherited. 

107 for table in tables: 

108 self._connection.execute(f"LOCK TABLE {table.key} IN SHARE MODE") 

109 

110 def isWriteable(self) -> bool: 

111 return self._writeable 

112 

113 def __str__(self) -> str: 

114 return f"PostgreSQL@{self.dbname}:{self.namespace}" 

115 

116 def shrinkDatabaseEntityName(self, original: str) -> str: 

117 return self._shrinker.shrink(original) 

118 

119 def expandDatabaseEntityName(self, shrunk: str) -> str: 

120 return self._shrinker.expand(shrunk) 

121 

122 def _convertExclusionConstraintSpec(self, table: str, 

123 spec: Tuple[Union[str, Type[DatabaseTimespanRepresentation]], ...], 

124 metadata: sqlalchemy.MetaData) -> sqlalchemy.schema.Constraint: 

125 # Docstring inherited. 

126 args = [] 

127 names = ["excl"] 

128 for item in spec: 

129 if isinstance(item, str): 

130 args.append((sqlalchemy.schema.Column(item), "=")) 

131 names.append(item) 

132 elif issubclass(item, DatabaseTimespanRepresentation): 

133 assert item is self.getTimespanRepresentation() 

134 args.append((sqlalchemy.schema.Column(DatabaseTimespanRepresentation.NAME), "&&")) 

135 names.append(DatabaseTimespanRepresentation.NAME) 

136 return sqlalchemy.dialects.postgresql.ExcludeConstraint( 

137 *args, 

138 name=self.shrinkDatabaseEntityName("_".join(names)), 

139 ) 

140 

141 @classmethod 

142 def getTimespanRepresentation(cls) -> Type[DatabaseTimespanRepresentation]: 

143 # Docstring inherited. 

144 return _RangeTimespanRepresentation 

145 

146 def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None: 

147 if not (self.isWriteable() or table.key in self._tempTables): 

148 raise ReadOnlyDatabaseError(f"Attempt to replace into read-only database '{self}'.") 

149 if not rows: 

150 return 

151 # This uses special support for UPSERT in PostgreSQL backend: 

152 # https://docs.sqlalchemy.org/en/13/dialects/postgresql.html#insert-on-conflict-upsert 

153 query = sqlalchemy.dialects.postgresql.dml.insert(table) 

154 # In the SET clause assign all columns using special `excluded` 

155 # pseudo-table. If some column in the table does not appear in the 

156 # INSERT list this will set it to NULL. 

157 excluded = query.excluded 

158 data = {column.name: getattr(excluded, column.name) 

159 for column in table.columns 

160 if column.name not in table.primary_key} 

161 query = query.on_conflict_do_update(constraint=table.primary_key, set_=data) 

162 self._connection.execute(query, *rows) 

163 

164 

165class _RangeTimespanType(sqlalchemy.TypeDecorator): 

166 """A single-column `Timespan` representation usable only with 

167 PostgreSQL. 

168 

169 This type should be able to take advantage of PostgreSQL's built-in 

170 range operators, and the indexing and EXCLUSION table constraints built 

171 off of them. 

172 """ 

173 

174 impl = sqlalchemy.dialects.postgresql.INT8RANGE 

175 

176 def process_bind_param(self, value: Optional[Timespan], 

177 dialect: sqlalchemy.engine.Dialect 

178 ) -> Optional[psycopg2.extras.NumericRange]: 

179 if value is None: 

180 return None 

181 if not isinstance(value, Timespan): 

182 raise TypeError(f"Unsupported type: {type(value)}, expected Timespan.") 

183 lower = None if value.begin is None else time_utils.astropy_to_nsec(value.begin) 

184 upper = None if value.end is None else time_utils.astropy_to_nsec(value.end) 

185 return psycopg2.extras.NumericRange(lower=lower, upper=upper) 

186 

187 def process_result_value(self, value: Optional[psycopg2.extras.NumericRange], 

188 dialect: sqlalchemy.engine.Dialect 

189 ) -> Optional[Timespan]: 

190 if value is None or value.isempty: 

191 return None 

192 begin = None if value.lower is None else time_utils.nsec_to_astropy(value.lower) 

193 end = None if value.upper is None else time_utils.nsec_to_astropy(value.upper) 

194 return Timespan(begin=begin, end=end) 

195 

196 class comparator_factory(sqlalchemy.types.Concatenable.Comparator): # noqa: N801 

197 """Comparison operators for TimespanColumnRanges. 

198 

199 Notes 

200 ----- 

201 The existence of this nested class is a workaround for a bug 

202 submitted upstream as 

203 https://github.com/sqlalchemy/sqlalchemy/issues/5476. The code is 

204 a limited copy of the operators in 

205 ``sqlalchemy.dialects.postgresql.ranges.RangeOperators``, but with 

206 ``is_comparison=True`` added to all calls. 

207 """ 

208 

209 def __ne__(self, other: Any) -> Any: 

210 "Boolean expression. Returns true if two ranges are not equal" 

211 if other is None: 

212 return super().__ne__(other) 

213 else: 

214 return self.expr.op("<>", is_comparison=True)(other) 

215 

216 def contains(self, other: Any, **kw: Any) -> Any: 

217 """Boolean expression. Returns true if the right hand operand, 

218 which can be an element or a range, is contained within the 

219 column. 

220 """ 

221 return self.expr.op("@>", is_comparison=True)(other) 

222 

223 def contained_by(self, other: Any) -> Any: 

224 """Boolean expression. Returns true if the column is contained 

225 within the right hand operand. 

226 """ 

227 return self.expr.op("<@", is_comparison=True)(other) 

228 

229 def overlaps(self, other: Any) -> Any: 

230 """Boolean expression. Returns true if the column overlaps 

231 (has points in common with) the right hand operand. 

232 """ 

233 return self.expr.op("&&", is_comparison=True)(other) 

234 

235 

236class _RangeTimespanRepresentation(DatabaseTimespanRepresentation): 

237 """An implementation of `DatabaseTimespanRepresentation` that uses 

238 `_RangeTimespanType` to store a timespan in a single 

239 PostgreSQL-specific field. 

240 

241 Parameters 

242 ---------- 

243 column : `sqlalchemy.sql.ColumnElement` 

244 SQLAlchemy object representing the column. 

245 """ 

246 def __init__(self, column: sqlalchemy.sql.ColumnElement): 

247 self.column = column 

248 

249 __slots__ = ("column",) 

250 

251 @classmethod 

252 def makeFieldSpecs(cls, nullable: bool, **kwargs: Any) -> Tuple[ddl.FieldSpec, ...]: 

253 # Docstring inherited. 

254 return ( 

255 ddl.FieldSpec( 

256 cls.NAME, dtype=_RangeTimespanType, nullable=nullable, 

257 default=(None if nullable else sqlalchemy.sql.text("'(,)'::int8range")), 

258 **kwargs 

259 ), 

260 ) 

261 

262 @classmethod 

263 def getFieldNames(cls) -> Tuple[str, ...]: 

264 # Docstring inherited. 

265 return (cls.NAME,) 

266 

267 @classmethod 

268 def update(cls, timespan: Optional[Timespan], *, 

269 result: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: 

270 # Docstring inherited. 

271 if result is None: 

272 result = {} 

273 result[cls.NAME] = timespan 

274 return result 

275 

276 @classmethod 

277 def extract(cls, mapping: Mapping[str, Any]) -> Optional[Timespan]: 

278 # Docstring inherited. 

279 return mapping[cls.NAME] 

280 

281 @classmethod 

282 def hasExclusionConstraint(cls) -> bool: 

283 # Docstring inherited. 

284 return True 

285 

286 @classmethod 

287 def fromSelectable(cls, selectable: sqlalchemy.sql.FromClause) -> _RangeTimespanRepresentation: 

288 # Docstring inherited. 

289 return cls(selectable.columns[cls.NAME]) 

290 

291 def isNull(self) -> sqlalchemy.sql.ColumnElement: 

292 # Docstring inherited. 

293 return self.column.is_(None) 

294 

295 def overlaps(self, other: Union[Timespan, _RangeTimespanRepresentation]) -> sqlalchemy.sql.ColumnElement: 

296 # Docstring inherited. 

297 if isinstance(other, Timespan): 

298 return self.column.overlaps(other) 

299 else: 

300 return self.column.overlaps(other.column)