Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21"""Classes for representing SQL data-definition language (DDL; "CREATE TABLE", 

22etc.) in Python. 

23 

24This provides an extra layer on top of SQLAlchemy's classes for these concepts, 

25because we need a level of indirection between logical tables and the actual 

26SQL, and SQLAlchemy's DDL classes always map 1-1 to SQL. 

27 

28We've opted for the rather more obscure "ddl" as the name of this module 

29instead of "schema" because the latter is too overloaded; in most SQL 

30databases, a "schema" is also another term for a namespace. 

31""" 

32from __future__ import annotations 

33 

34__all__ = ("TableSpec", "FieldSpec", "ForeignKeySpec", "Base64Bytes", "Base64Region", 

35 "AstropyTimeNsecTai") 

36 

37from base64 import b64encode, b64decode 

38import logging 

39from math import ceil 

40from dataclasses import dataclass 

41from typing import Optional, Tuple, Sequence, Set 

42 

43import sqlalchemy 

44import astropy.time 

45 

46from lsst.sphgeom import ConvexPolygon 

47from .config import Config 

48from .exceptions import ValidationError 

49from .utils import iterable, stripIfNotNone, NamedValueSet 

50 

51 

52_LOG = logging.getLogger(__name__) 

53 

54# These constants can be used by client code 

55EPOCH = astropy.time.Time("1970-01-01 00:00:00", format="iso", scale="tai") 

56"""Epoch for calculating time delta, this is the minimum time that can be 

57stored in the database. 

58""" 

59 

60MAX_TIME = astropy.time.Time("2100-01-01 00:00:00", format="iso", scale="tai") 

61"""Maximum time value that we can store. Assuming 64-bit integer field we 

62can actually store higher values but we intentionally limit it to arbitrary 

63but reasonably high value. Note that this value will be stored in registry 

64database for eternity, so it should not be changed without proper 

65consideration. 

66""" 

67 

68 

69class SchemaValidationError(ValidationError): 

70 """Exceptions used to indicate problems in Registry schema configuration. 

71 """ 

72 

73 @classmethod 

74 def translate(cls, caught, message): 

75 """A decorator that re-raises exceptions as `SchemaValidationError`. 

76 

77 Decorated functions must be class or instance methods, with a 

78 ``config`` parameter as their first argument. This will be passed 

79 to ``message.format()`` as a keyword argument, along with ``err``, 

80 the original exception. 

81 

82 Parameters 

83 ---------- 

84 caught : `type` (`Exception` subclass) 

85 The type of exception to catch. 

86 message : `str` 

87 A `str.format` string that may contain named placeholders for 

88 ``config``, ``err``, or any keyword-only argument accepted by 

89 the decorated function. 

90 """ 

91 def decorate(func): 

92 def decorated(self, config, *args, **kwds): 

93 try: 

94 return func(self, config, *args, **kwds) 

95 except caught as err: 

96 raise cls(message.format(config=str(config), err=err)) 

97 return decorated 

98 return decorate 

99 

100 

101class Base64Bytes(sqlalchemy.TypeDecorator): 

102 """A SQLAlchemy custom type that maps Python `bytes` to a base64-encoded 

103 `sqlalchemy.String`. 

104 """ 

105 

106 impl = sqlalchemy.String 

107 

108 def __init__(self, nbytes, *args, **kwds): 

109 length = 4*ceil(nbytes/3) 

110 super().__init__(*args, length=length, **kwds) 

111 self.nbytes = nbytes 

112 

113 def process_bind_param(self, value, dialect): 

114 # 'value' is native `bytes`. We want to encode that to base64 `bytes` 

115 # and then ASCII `str`, because `str` is what SQLAlchemy expects for 

116 # String fields. 

117 if value is None: 

118 return None 

119 if not isinstance(value, bytes): 

120 raise TypeError( 

121 f"Base64Bytes fields require 'bytes' values; got '{value}' with type {type(value)}." 

122 ) 

123 return b64encode(value).decode("ascii") 

124 

125 def process_result_value(self, value, dialect): 

126 # 'value' is a `str` that must be ASCII because it's base64-encoded. 

127 # We want to transform that to base64-encoded `bytes` and then 

128 # native `bytes`. 

129 return b64decode(value.encode("ascii")) if value is not None else None 

130 

131 

132class Base64Region(Base64Bytes): 

133 """A SQLAlchemy custom type that maps Python `sphgeom.ConvexPolygon` to a 

134 base64-encoded `sqlalchemy.String`. 

135 """ 

136 

137 def process_bind_param(self, value, dialect): 

138 if value is None: 

139 return None 

140 return super().process_bind_param(value.encode(), dialect) 

141 

142 def process_result_value(self, value, dialect): 

143 if value is None: 

144 return None 

145 return ConvexPolygon.decode(super().process_result_value(value, dialect)) 

146 

147 

148class AstropyTimeNsecTai(sqlalchemy.TypeDecorator): 

149 """A SQLAlchemy custom type that maps Python `astropy.time.Time` to a 

150 number of nanoseconds sunce Unix epoch in TAI scale. 

151 """ 

152 

153 impl = sqlalchemy.BigInteger 

154 

155 def process_bind_param(self, value, dialect): 

156 # value is astropy.time.Time or None 

157 if value is None: 

158 return None 

159 if not isinstance(value, astropy.time.Time): 

160 raise TypeError(f"Unsupported type: {type(value)}, expected astropy.time.Time") 

161 # sometimes comparison produces warnings if input value is in UTC 

162 # scale, transform it to TAI before doing anyhting 

163 value = value.tai 

164 # anything before epoch or after MAX_TIME is truncated 

165 if value < EPOCH: 

166 _LOG.warning("%s is earlier than epoch time %s, epoch time will be used instead", 

167 value, EPOCH) 

168 value = EPOCH 

169 elif value > MAX_TIME: 

170 _LOG.warning("%s is later than max. time %s, max. time time will be used instead", 

171 value, MAX_TIME) 

172 value = MAX_TIME 

173 value = round((value - EPOCH).to_value("sec") * 1e9) 

174 return int(value) 

175 

176 def process_result_value(self, value, dialect): 

177 # value is nanoseconds since epoch, or None 

178 if value is None: 

179 return None 

180 delta = astropy.time.TimeDelta(value * 1e-9, format="sec") 

181 return EPOCH + delta 

182 

183 

184VALID_CONFIG_COLUMN_TYPES = { 

185 "string": sqlalchemy.String, 

186 "int": sqlalchemy.Integer, 

187 "float": sqlalchemy.Float, 

188 "region": Base64Region, 

189 "bool": sqlalchemy.Boolean, 

190 "blob": sqlalchemy.LargeBinary, 

191 "datetime": AstropyTimeNsecTai, 

192 "hash": Base64Bytes 

193} 

194 

195 

196@dataclass 

197class FieldSpec: 

198 """A struct-like class used to define a column in a logical `Registry` 

199 table. 

200 """ 

201 

202 name: str 

203 """Name of the column.""" 

204 

205 dtype: type 

206 """Type of the column; usually a `type` subclass provided by SQLAlchemy 

207 that defines both a Python type and a corresponding precise SQL type. 

208 """ 

209 

210 length: Optional[int] = None 

211 """Length of the type in the database, for variable-length types.""" 

212 

213 nbytes: Optional[int] = None 

214 """Natural length used for hash and encoded-region columns, to be converted 

215 into the post-encoding length. 

216 """ 

217 

218 primaryKey: bool = False 

219 """Whether this field is (part of) its table's primary key.""" 

220 

221 autoincrement: bool = False 

222 """Whether the database should insert automatically incremented values when 

223 no value is provided in an INSERT. 

224 """ 

225 

226 nullable: bool = True 

227 """Whether this field is allowed to be NULL.""" 

228 

229 doc: Optional[str] = None 

230 """Documentation for this field.""" 

231 

232 def __eq__(self, other): 

233 return self.name == other.name 

234 

235 def __hash__(self): 

236 return hash(self.name) 

237 

238 @classmethod 

239 @SchemaValidationError.translate(KeyError, "Missing key {err} in column config '{config}'.") 

240 def fromConfig(cls, config: Config, **kwds) -> FieldSpec: 

241 """Create a `FieldSpec` from a subset of a `SchemaConfig`. 

242 

243 Parameters 

244 ---------- 

245 config: `Config` 

246 Configuration describing the column. Nested configuration keys 

247 correspond to `FieldSpec` attributes. 

248 kwds 

249 Additional keyword arguments that provide defaults for values 

250 not present in config. 

251 

252 Returns 

253 ------- 

254 spec: `FieldSpec` 

255 Specification structure for the column. 

256 

257 Raises 

258 ------ 

259 SchemaValidationError 

260 Raised if configuration keys are missing or have invalid values. 

261 """ 

262 dtype = VALID_CONFIG_COLUMN_TYPES.get(config["type"]) 

263 if dtype is None: 

264 raise SchemaValidationError(f"Invalid field type string: '{config['type']}'.") 

265 if not config["name"].islower(): 

266 raise SchemaValidationError(f"Column name '{config['name']}' is not all lowercase.") 

267 self = cls(name=config["name"], dtype=dtype, **kwds) 

268 self.length = config.get("length", self.length) 

269 self.nbytes = config.get("nbytes", self.nbytes) 

270 if self.length is not None and self.nbytes is not None: 

271 raise SchemaValidationError(f"Both length and nbytes provided for field '{self.name}'.") 

272 self.primaryKey = config.get("primaryKey", self.primaryKey) 

273 self.autoincrement = config.get("autoincrement", self.autoincrement) 

274 self.nullable = config.get("nullable", False if self.primaryKey else self.nullable) 

275 self.doc = stripIfNotNone(config.get("doc", None)) 

276 return self 

277 

278 def getSizedColumnType(self) -> sqlalchemy.types.TypeEngine: 

279 """Return a sized version of the column type, utilizing either (or 

280 neither) of ``self.length`` and ``self.nbytes``. 

281 

282 Returns 

283 ------- 

284 dtype : `sqlalchemy.types.TypeEngine` 

285 A SQLAlchemy column type object. 

286 """ 

287 if self.length is not None: 

288 return self.dtype(length=self.length) 

289 if self.nbytes is not None: 

290 return self.dtype(nbytes=self.nbytes) 

291 return self.dtype 

292 

293 

294@dataclass 

295class ForeignKeySpec: 

296 """A struct-like class used to define a foreign key constraint in a logical 

297 `Registry` table. 

298 """ 

299 

300 table: str 

301 """Name of the target table.""" 

302 

303 source: Tuple[str, ...] 

304 """Tuple of source table column names.""" 

305 

306 target: Tuple[str, ...] 

307 """Tuple of target table column names.""" 

308 

309 onDelete: Optional[str] = None 

310 """SQL clause indicating how to handle deletes to the target table. 

311 

312 If not `None` (which indicates that a constraint violation exception should 

313 be raised), should be either "SET NULL" or "CASCADE". 

314 """ 

315 

316 @classmethod 

317 @SchemaValidationError.translate(KeyError, "Missing key {err} in foreignKey config '{config}'.") 

318 def fromConfig(cls, config: Config) -> ForeignKeySpec: 

319 """Create a `ForeignKeySpec` from a subset of a `SchemaConfig`. 

320 

321 Parameters 

322 ---------- 

323 config: `Config` 

324 Configuration describing the constraint. Nested configuration keys 

325 correspond to `ForeignKeySpec` attributes. 

326 

327 Returns 

328 ------- 

329 spec: `ForeignKeySpec` 

330 Specification structure for the constraint. 

331 

332 Raises 

333 ------ 

334 SchemaValidationError 

335 Raised if configuration keys are missing or have invalid values. 

336 """ 

337 return cls(table=config["table"], 

338 source=tuple(iterable(config["source"])), 

339 target=tuple(iterable(config["target"])), 

340 onDelete=config.get("onDelete", None)) 

341 

342 

343@dataclass 

344class TableSpec: 

345 """A struct-like class used to define a table or table-like 

346 query interface. 

347 """ 

348 

349 fields: NamedValueSet[FieldSpec] 

350 """Specifications for the columns in this table.""" 

351 

352 unique: Set[Tuple[str, ...]] = frozenset() 

353 """Non-primary-key unique constraints for the table.""" 

354 

355 indexes: Set[Tuple[str, ...]] = frozenset() 

356 """Indexes for the table.""" 

357 

358 foreignKeys: Sequence[ForeignKeySpec] = tuple() 

359 """Foreign key constraints for the table.""" 

360 

361 doc: Optional[str] = None 

362 """Documentation for the table.""" 

363 

364 def __post_init__(self): 

365 self.fields = NamedValueSet(self.fields) 

366 self.unique = set(self.unique) 

367 self.indexes = set(self.indexes) 

368 self.foreignKeys = list(self.foreignKeys) 

369 

370 @classmethod 

371 @SchemaValidationError.translate(KeyError, "Missing key {err} in table config '{config}'.") 

372 def fromConfig(cls, config: Config) -> TableSpec: 

373 """Create a `ForeignKeySpec` from a subset of a `SchemaConfig`. 

374 

375 Parameters 

376 ---------- 

377 config: `Config` 

378 Configuration describing the constraint. Nested configuration keys 

379 correspond to `TableSpec` attributes. 

380 

381 Returns 

382 ------- 

383 spec: `TableSpec` 

384 Specification structure for the table. 

385 

386 Raises 

387 ------ 

388 SchemaValidationError 

389 Raised if configuration keys are missing or have invalid values. 

390 """ 

391 return cls( 

392 fields=NamedValueSet(FieldSpec.fromConfig(c) for c in config["columns"]), 

393 unique={tuple(u) for u in config.get("unique", ())}, 

394 foreignKeys=[ForeignKeySpec.fromConfig(c) for c in config.get("foreignKeys", ())], 

395 sql=config.get("sql"), 

396 doc=stripIfNotNone(config.get("doc")), 

397 )