Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from __future__ import annotations 

2 

3__all__ = ("ByDimensionsDatasetRecordStorage",) 

4 

5from typing import ( 

6 Any, 

7 Dict, 

8 Iterable, 

9 Iterator, 

10 Optional, 

11 Set, 

12 TYPE_CHECKING, 

13) 

14 

15import sqlalchemy 

16 

17from lsst.daf.butler import ( 

18 CollectionType, 

19 DataCoordinate, 

20 DataCoordinateSet, 

21 DatasetRef, 

22 DatasetType, 

23 SimpleQuery, 

24 Timespan, 

25) 

26from lsst.daf.butler.registry import ConflictingDefinitionError 

27from lsst.daf.butler.registry.interfaces import DatasetRecordStorage 

28 

29from ...summaries import GovernorDimensionRestriction 

30 

31if TYPE_CHECKING: 31 ↛ 32line 31 didn't jump to line 32, because the condition on line 31 was never true

32 from ...interfaces import CollectionManager, CollectionRecord, Database, RunRecord 

33 from .tables import StaticDatasetTablesTuple 

34 from .summaries import CollectionSummaryManager 

35 

36 

37class ByDimensionsDatasetRecordStorage(DatasetRecordStorage): 

38 """Dataset record storage implementation paired with 

39 `ByDimensionsDatasetRecordStorageManager`; see that class for more 

40 information. 

41 

42 Instances of this class should never be constructed directly; use 

43 `DatasetRecordStorageManager.register` instead. 

44 """ 

45 def __init__(self, *, datasetType: DatasetType, 

46 db: Database, 

47 dataset_type_id: int, 

48 collections: CollectionManager, 

49 static: StaticDatasetTablesTuple, 

50 summaries: CollectionSummaryManager, 

51 tags: sqlalchemy.schema.Table, 

52 calibs: Optional[sqlalchemy.schema.Table]): 

53 super().__init__(datasetType=datasetType) 

54 self._dataset_type_id = dataset_type_id 

55 self._db = db 

56 self._collections = collections 

57 self._static = static 

58 self._summaries = summaries 

59 self._tags = tags 

60 self._calibs = calibs 

61 self._runKeyColumn = collections.getRunForeignKeyName() 

62 

63 def insert(self, run: RunRecord, dataIds: Iterable[DataCoordinate]) -> Iterator[DatasetRef]: 

64 # Docstring inherited from DatasetRecordStorage. 

65 staticRow = { 

66 "dataset_type_id": self._dataset_type_id, 

67 self._runKeyColumn: run.key, 

68 } 

69 # Iterate over data IDs, transforming a possibly-single-pass iterable 

70 # into a list, and remembering any governor dimension values we see. 

71 governorValues = GovernorDimensionRestriction.makeEmpty(self.datasetType.dimensions.universe) 

72 dataIdList = [] 

73 for dataId in dataIds: 

74 dataIdList.append(dataId) 

75 governorValues.update_extract(dataId) 

76 with self._db.transaction(): 

77 # Insert into the static dataset table, generating autoincrement 

78 # dataset_id values. 

79 datasetIds = self._db.insert(self._static.dataset, *([staticRow]*len(dataIdList)), 

80 returnIds=True) 

81 assert datasetIds is not None 

82 # Update the summary tables for this collection in case this is the 

83 # first time this dataset type or these governor values will be 

84 # inserted there. 

85 self._summaries.update(run, self.datasetType, self._dataset_type_id, governorValues) 

86 # Combine the generated dataset_id values and data ID fields to 

87 # form rows to be inserted into the tags table. 

88 protoTagsRow = { 

89 "dataset_type_id": self._dataset_type_id, 

90 self._collections.getCollectionForeignKeyName(): run.key, 

91 } 

92 tagsRows = [ 

93 dict(protoTagsRow, dataset_id=dataset_id, **dataId.byName()) 

94 for dataId, dataset_id in zip(dataIdList, datasetIds) 

95 ] 

96 # Insert those rows into the tags table. This is where we'll 

97 # get any unique constraint violations. 

98 self._db.insert(self._tags, *tagsRows) 

99 for dataId, datasetId in zip(dataIdList, datasetIds): 

100 yield DatasetRef( 

101 datasetType=self.datasetType, 

102 dataId=dataId, 

103 id=datasetId, 

104 run=run.name, 

105 ) 

106 

107 def find(self, collection: CollectionRecord, dataId: DataCoordinate, 

108 timespan: Optional[Timespan] = None) -> Optional[DatasetRef]: 

109 # Docstring inherited from DatasetRecordStorage. 

110 assert dataId.graph == self.datasetType.dimensions 

111 if collection.type is CollectionType.CALIBRATION and timespan is None: 111 ↛ 112line 111 didn't jump to line 112, because the condition on line 111 was never true

112 raise TypeError(f"Cannot search for dataset in CALIBRATION collection {collection.name} " 

113 f"without an input timespan.") 

114 sql = self.select(collection=collection, dataId=dataId, id=SimpleQuery.Select, 

115 run=SimpleQuery.Select, timespan=timespan).combine() 

116 results = self._db.query(sql) 

117 row = results.fetchone() 

118 if row is None: 

119 return None 

120 if collection.type is CollectionType.CALIBRATION: 

121 # For temporal calibration lookups (only!) our invariants do not 

122 # guarantee that the number of result rows is <= 1. 

123 # They would if `select` constrained the given timespan to be 

124 # _contained_ by the validity range in the self._calibs table, 

125 # instead of simply _overlapping_ it, because we do guarantee that 

126 # the validity ranges are disjoint for a particular dataset type, 

127 # collection, and data ID. But using an overlap test and a check 

128 # for multiple result rows here allows us to provide a more useful 

129 # diagnostic, as well as allowing `select` to support more general 

130 # queries where multiple results are not an error. 

131 if results.fetchone() is not None: 

132 raise RuntimeError( 

133 f"Multiple matches found for calibration lookup in {collection.name} for " 

134 f"{self.datasetType.name} with {dataId} overlapping {timespan}. " 

135 ) 

136 return DatasetRef( 

137 datasetType=self.datasetType, 

138 dataId=dataId, 

139 id=row["id"], 

140 run=self._collections[row[self._runKeyColumn]].name 

141 ) 

142 

143 def delete(self, datasets: Iterable[DatasetRef]) -> None: 

144 # Docstring inherited from DatasetRecordStorage. 

145 # Only delete from common dataset table; ON DELETE foreign key clauses 

146 # will handle the rest. 

147 self._db.delete( 

148 self._static.dataset, 

149 ["id"], 

150 *[{"id": dataset.getCheckedId()} for dataset in datasets], 

151 ) 

152 

153 def associate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None: 

154 # Docstring inherited from DatasetRecordStorage. 

155 if collection.type is not CollectionType.TAGGED: 155 ↛ 156line 155 didn't jump to line 156, because the condition on line 155 was never true

156 raise TypeError(f"Cannot associate into collection '{collection.name}' " 

157 f"of type {collection.type.name}; must be TAGGED.") 

158 protoRow = { 

159 self._collections.getCollectionForeignKeyName(): collection.key, 

160 "dataset_type_id": self._dataset_type_id, 

161 } 

162 rows = [] 

163 governorValues = GovernorDimensionRestriction.makeEmpty(self.datasetType.dimensions.universe) 

164 for dataset in datasets: 

165 row = dict(protoRow, dataset_id=dataset.getCheckedId()) 

166 for dimension, value in dataset.dataId.items(): 

167 row[dimension.name] = value 

168 governorValues.update_extract(dataset.dataId) 

169 rows.append(row) 

170 # Update the summary tables for this collection in case this is the 

171 # first time this dataset type or these governor values will be 

172 # inserted there. 

173 self._summaries.update(collection, self.datasetType, self._dataset_type_id, governorValues) 

174 # Update the tag table itself. 

175 self._db.replace(self._tags, *rows) 

176 

177 def disassociate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None: 

178 # Docstring inherited from DatasetRecordStorage. 

179 if collection.type is not CollectionType.TAGGED: 179 ↛ 180line 179 didn't jump to line 180, because the condition on line 179 was never true

180 raise TypeError(f"Cannot disassociate from collection '{collection.name}' " 

181 f"of type {collection.type.name}; must be TAGGED.") 

182 rows = [ 

183 { 

184 "dataset_id": dataset.getCheckedId(), 

185 self._collections.getCollectionForeignKeyName(): collection.key 

186 } 

187 for dataset in datasets 

188 ] 

189 self._db.delete(self._tags, ["dataset_id", self._collections.getCollectionForeignKeyName()], 

190 *rows) 

191 

192 def _buildCalibOverlapQuery(self, collection: CollectionRecord, 

193 dataIds: Optional[DataCoordinateSet], 

194 timespan: Timespan) -> SimpleQuery: 

195 assert self._calibs is not None 

196 # Start by building a SELECT query for any rows that would overlap 

197 # this one. 

198 query = SimpleQuery() 

199 query.join(self._calibs) 

200 # Add a WHERE clause matching the dataset type and collection. 

201 query.where.append(self._calibs.columns.dataset_type_id == self._dataset_type_id) 

202 query.where.append( 

203 self._calibs.columns[self._collections.getCollectionForeignKeyName()] == collection.key 

204 ) 

205 # Add a WHERE clause matching any of the given data IDs. 

206 if dataIds is not None: 

207 dataIds.constrain( 

208 query, 

209 lambda name: self._calibs.columns[name], # type: ignore 

210 ) 

211 # Add WHERE clause for timespan overlaps. 

212 TimespanReprClass = self._db.getTimespanRepresentation() 

213 query.where.append( 

214 TimespanReprClass.fromSelectable(self._calibs).overlaps(TimespanReprClass.fromLiteral(timespan)) 

215 ) 

216 return query 

217 

218 def certify(self, collection: CollectionRecord, datasets: Iterable[DatasetRef], 

219 timespan: Timespan) -> None: 

220 # Docstring inherited from DatasetRecordStorage. 

221 if self._calibs is None: 221 ↛ 222line 221 didn't jump to line 222, because the condition on line 221 was never true

222 raise TypeError(f"Cannot certify datasets of type {self.datasetType.name}, for which " 

223 f"DatasetType.isCalibration() is False.") 

224 if collection.type is not CollectionType.CALIBRATION: 224 ↛ 225line 224 didn't jump to line 225, because the condition on line 224 was never true

225 raise TypeError(f"Cannot certify into collection '{collection.name}' " 

226 f"of type {collection.type.name}; must be CALIBRATION.") 

227 TimespanReprClass = self._db.getTimespanRepresentation() 

228 protoRow = { 

229 self._collections.getCollectionForeignKeyName(): collection.key, 

230 "dataset_type_id": self._dataset_type_id, 

231 } 

232 rows = [] 

233 governorValues = GovernorDimensionRestriction.makeEmpty(self.datasetType.dimensions.universe) 

234 dataIds: Optional[Set[DataCoordinate]] = ( 

235 set() if not TimespanReprClass.hasExclusionConstraint() else None 

236 ) 

237 for dataset in datasets: 

238 row = dict(protoRow, dataset_id=dataset.getCheckedId()) 

239 for dimension, value in dataset.dataId.items(): 

240 row[dimension.name] = value 

241 TimespanReprClass.update(timespan, result=row) 

242 governorValues.update_extract(dataset.dataId) 

243 rows.append(row) 

244 if dataIds is not None: 244 ↛ 237line 244 didn't jump to line 237, because the condition on line 244 was never false

245 dataIds.add(dataset.dataId) 

246 # Update the summary tables for this collection in case this is the 

247 # first time this dataset type or these governor values will be 

248 # inserted there. 

249 self._summaries.update(collection, self.datasetType, self._dataset_type_id, governorValues) 

250 # Update the association table itself. 

251 if TimespanReprClass.hasExclusionConstraint(): 251 ↛ 254line 251 didn't jump to line 254, because the condition on line 251 was never true

252 # Rely on database constraint to enforce invariants; we just 

253 # reraise the exception for consistency across DB engines. 

254 try: 

255 self._db.insert(self._calibs, *rows) 

256 except sqlalchemy.exc.IntegrityError as err: 

257 raise ConflictingDefinitionError( 

258 f"Validity range conflict certifying datasets of type {self.datasetType.name} " 

259 f"into {collection.name} for range [{timespan.begin}, {timespan.end})." 

260 ) from err 

261 else: 

262 # Have to implement exclusion constraint ourselves. 

263 # Start by building a SELECT query for any rows that would overlap 

264 # this one. 

265 query = self._buildCalibOverlapQuery( 

266 collection, 

267 DataCoordinateSet(dataIds, graph=self.datasetType.dimensions), # type: ignore 

268 timespan 

269 ) 

270 query.columns.append(sqlalchemy.sql.func.count()) 

271 sql = query.combine() 

272 # Acquire a table lock to ensure there are no concurrent writes 

273 # could invalidate our checking before we finish the inserts. We 

274 # use a SAVEPOINT in case there is an outer transaction that a 

275 # failure here should not roll back. 

276 with self._db.transaction(lock=[self._calibs], savepoint=True): 

277 # Run the check SELECT query. 

278 conflicting = self._db.query(sql).scalar() 

279 if conflicting > 0: 

280 raise ConflictingDefinitionError( 

281 f"{conflicting} validity range conflicts certifying datasets of type " 

282 f"{self.datasetType.name} into {collection.name} for range " 

283 f"[{timespan.begin}, {timespan.end})." 

284 ) 

285 # Proceed with the insert. 

286 self._db.insert(self._calibs, *rows) 

287 

288 def decertify(self, collection: CollectionRecord, timespan: Timespan, *, 

289 dataIds: Optional[Iterable[DataCoordinate]] = None) -> None: 

290 # Docstring inherited from DatasetRecordStorage. 

291 if self._calibs is None: 291 ↛ 292line 291 didn't jump to line 292, because the condition on line 291 was never true

292 raise TypeError(f"Cannot decertify datasets of type {self.datasetType.name}, for which " 

293 f"DatasetType.isCalibration() is False.") 

294 if collection.type is not CollectionType.CALIBRATION: 294 ↛ 295line 294 didn't jump to line 295, because the condition on line 294 was never true

295 raise TypeError(f"Cannot decertify from collection '{collection.name}' " 

296 f"of type {collection.type.name}; must be CALIBRATION.") 

297 TimespanReprClass = self._db.getTimespanRepresentation() 

298 # Construct a SELECT query to find all rows that overlap our inputs. 

299 dataIdSet: Optional[DataCoordinateSet] 

300 if dataIds is not None: 

301 dataIdSet = DataCoordinateSet(set(dataIds), graph=self.datasetType.dimensions) 

302 else: 

303 dataIdSet = None 

304 query = self._buildCalibOverlapQuery(collection, dataIdSet, timespan) 

305 query.columns.extend(self._calibs.columns) 

306 sql = query.combine() 

307 # Set up collections to populate with the rows we'll want to modify. 

308 # The insert rows will have the same values for collection and 

309 # dataset type. 

310 protoInsertRow = { 

311 self._collections.getCollectionForeignKeyName(): collection.key, 

312 "dataset_type_id": self._dataset_type_id, 

313 } 

314 rowsToDelete = [] 

315 rowsToInsert = [] 

316 # Acquire a table lock to ensure there are no concurrent writes 

317 # between the SELECT and the DELETE and INSERT queries based on it. 

318 with self._db.transaction(lock=[self._calibs], savepoint=True): 

319 for row in self._db.query(sql): 

320 rowsToDelete.append({"id": row["id"]}) 

321 # Construct the insert row(s) by copying the prototype row, 

322 # then adding the dimension column values, then adding what's 

323 # left of the timespan from that row after we subtract the 

324 # given timespan. 

325 newInsertRow = protoInsertRow.copy() 

326 newInsertRow["dataset_id"] = row["dataset_id"] 

327 for name in self.datasetType.dimensions.required.names: 

328 newInsertRow[name] = row[name] 

329 rowTimespan = TimespanReprClass.extract(row) 

330 assert rowTimespan is not None, "Field should have a NOT NULL constraint." 

331 for diffTimespan in rowTimespan.difference(timespan): 

332 rowsToInsert.append(TimespanReprClass.update(diffTimespan, result=newInsertRow.copy())) 

333 # Run the DELETE and INSERT queries. 

334 self._db.delete(self._calibs, ["id"], *rowsToDelete) 

335 self._db.insert(self._calibs, *rowsToInsert) 

336 

337 def select(self, collection: CollectionRecord, 

338 dataId: SimpleQuery.Select.Or[DataCoordinate] = SimpleQuery.Select, 

339 id: SimpleQuery.Select.Or[Optional[int]] = SimpleQuery.Select, 

340 run: SimpleQuery.Select.Or[None] = SimpleQuery.Select, 

341 timespan: SimpleQuery.Select.Or[Optional[Timespan]] = SimpleQuery.Select, 

342 ingestDate: SimpleQuery.Select.Or[Optional[Timespan]] = None, 

343 ) -> SimpleQuery: 

344 # Docstring inherited from DatasetRecordStorage. 

345 assert collection.type is not CollectionType.CHAINED 

346 query = SimpleQuery() 

347 # We always include the _static.dataset table, and we can always get 

348 # the id and run fields from that; passing them as kwargs here tells 

349 # SimpleQuery to handle them whether they're constraints or results. 

350 # We always constraint the dataset_type_id here as well. 

351 static_kwargs = {self._runKeyColumn: run} 

352 if ingestDate is not None: 

353 static_kwargs["ingest_date"] = SimpleQuery.Select 

354 query.join( 

355 self._static.dataset, 

356 id=id, 

357 dataset_type_id=self._dataset_type_id, 

358 **static_kwargs 

359 ) 

360 # If and only if the collection is a RUN, we constrain it in the static 

361 # table (and also the tags or calibs table below) 

362 if collection.type is CollectionType.RUN: 

363 query.where.append(self._static.dataset.columns[self._runKeyColumn] 

364 == collection.key) 

365 # We get or constrain the data ID from the tags/calibs table, but 

366 # that's multiple columns, not one, so we need to transform the one 

367 # Select.Or argument into a dictionary of them. 

368 kwargs: Dict[str, Any] 

369 if dataId is SimpleQuery.Select: 

370 kwargs = {dim.name: SimpleQuery.Select for dim in self.datasetType.dimensions.required} 

371 else: 

372 kwargs = dict(dataId.byName()) 

373 # We always constrain (never retrieve) the collection from the tags 

374 # table. 

375 kwargs[self._collections.getCollectionForeignKeyName()] = collection.key 

376 # constrain ingest time 

377 if isinstance(ingestDate, Timespan): 377 ↛ 380line 377 didn't jump to line 380, because the condition on line 377 was never true

378 # Tmespan is astropy Time (usually in TAI) and ingest_date is 

379 # TIMESTAMP, convert values to Python datetime for sqlalchemy. 

380 if ingestDate.isEmpty(): 

381 raise RuntimeError("Empty timespan constraint provided for ingest_date.") 

382 if ingestDate.begin is not None: 

383 begin = ingestDate.begin.utc.datetime # type: ignore 

384 query.where.append(self._static.dataset.ingest_date >= begin) 

385 if ingestDate.end is not None: 

386 end = ingestDate.end.utc.datetime # type: ignore 

387 query.where.append(self._static.dataset.ingest_date < end) 

388 # And now we finally join in the tags or calibs table. 

389 if collection.type is CollectionType.CALIBRATION: 

390 assert self._calibs is not None, \ 

391 "DatasetTypes with isCalibration() == False can never be found in a CALIBRATION collection." 

392 TimespanReprClass = self._db.getTimespanRepresentation() 

393 # Add the timespan column(s) to the result columns, or constrain 

394 # the timespan via an overlap condition. 

395 if timespan is SimpleQuery.Select: 

396 kwargs.update({k: SimpleQuery.Select for k in TimespanReprClass.getFieldNames()}) 

397 elif timespan is not None: 397 ↛ 403line 397 didn't jump to line 403, because the condition on line 397 was never false

398 query.where.append( 

399 TimespanReprClass.fromSelectable(self._calibs).overlaps( 

400 TimespanReprClass.fromLiteral(timespan) 

401 ) 

402 ) 

403 query.join( 

404 self._calibs, 

405 onclause=(self._static.dataset.columns.id == self._calibs.columns.dataset_id), 

406 **kwargs 

407 ) 

408 else: 

409 query.join( 

410 self._tags, 

411 onclause=(self._static.dataset.columns.id == self._tags.columns.dataset_id), 

412 **kwargs 

413 ) 

414 return query 

415 

416 def getDataId(self, id: int) -> DataCoordinate: 

417 # Docstring inherited from DatasetRecordStorage. 

418 # This query could return multiple rows (one for each tagged collection 

419 # the dataset is in, plus one for its run collection), and we don't 

420 # care which of those we get. 

421 sql = self._tags.select().where( 

422 sqlalchemy.sql.and_( 

423 self._tags.columns.dataset_id == id, 

424 self._tags.columns.dataset_type_id == self._dataset_type_id 

425 ) 

426 ).limit(1) 

427 row = self._db.query(sql).fetchone() 

428 assert row is not None, "Should be guaranteed by caller and foreign key constraints." 

429 return DataCoordinate.standardize( 

430 {dimension.name: row[dimension.name] for dimension in self.datasetType.dimensions.required}, 

431 graph=self.datasetType.dimensions 

432 )