Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from __future__ import annotations 

2 

3__all__ = ("ByDimensionsDatasetRecordStorage",) 

4 

5from typing import ( 

6 Any, 

7 Dict, 

8 Iterable, 

9 Iterator, 

10 Optional, 

11 Set, 

12 TYPE_CHECKING, 

13) 

14 

15import sqlalchemy 

16 

17from lsst.daf.butler import ( 

18 CollectionType, 

19 DataCoordinate, 

20 DataCoordinateSet, 

21 DatasetRef, 

22 DatasetType, 

23 SimpleQuery, 

24 Timespan, 

25) 

26from lsst.daf.butler.registry import ConflictingDefinitionError 

27from lsst.daf.butler.registry.interfaces import DatasetRecordStorage 

28 

29if TYPE_CHECKING: 29 ↛ 30line 29 didn't jump to line 30, because the condition on line 29 was never true

30 from ...interfaces import CollectionManager, CollectionRecord, Database, RunRecord 

31 from .tables import StaticDatasetTablesTuple 

32 

33 

34class ByDimensionsDatasetRecordStorage(DatasetRecordStorage): 

35 """Dataset record storage implementation paired with 

36 `ByDimensionsDatasetRecordStorageManager`; see that class for more 

37 information. 

38 

39 Instances of this class should never be constructed directly; use 

40 `DatasetRecordStorageManager.register` instead. 

41 """ 

42 def __init__(self, *, datasetType: DatasetType, 

43 db: Database, 

44 dataset_type_id: int, 

45 collections: CollectionManager, 

46 static: StaticDatasetTablesTuple, 

47 tags: sqlalchemy.schema.Table, 

48 calibs: Optional[sqlalchemy.schema.Table]): 

49 super().__init__(datasetType=datasetType) 

50 self._dataset_type_id = dataset_type_id 

51 self._db = db 

52 self._collections = collections 

53 self._static = static 

54 self._tags = tags 

55 self._calibs = calibs 

56 self._runKeyColumn = collections.getRunForeignKeyName() 

57 

58 def insert(self, run: RunRecord, dataIds: Iterable[DataCoordinate]) -> Iterator[DatasetRef]: 

59 # Docstring inherited from DatasetRecordStorage. 

60 staticRow = { 

61 "dataset_type_id": self._dataset_type_id, 

62 self._runKeyColumn: run.key, 

63 } 

64 dataIds = list(dataIds) 

65 # Insert into the static dataset table, generating autoincrement 

66 # dataset_id values. 

67 with self._db.transaction(): 

68 datasetIds = self._db.insert(self._static.dataset, *([staticRow]*len(dataIds)), 

69 returnIds=True) 

70 assert datasetIds is not None 

71 # Combine the generated dataset_id values and data ID fields to 

72 # form rows to be inserted into the tags table. 

73 protoTagsRow = { 

74 "dataset_type_id": self._dataset_type_id, 

75 self._collections.getCollectionForeignKeyName(): run.key, 

76 } 

77 tagsRows = [ 

78 dict(protoTagsRow, dataset_id=dataset_id, **dataId.byName()) 

79 for dataId, dataset_id in zip(dataIds, datasetIds) 

80 ] 

81 # Insert those rows into the tags table. This is where we'll 

82 # get any unique constraint violations. 

83 self._db.insert(self._tags, *tagsRows) 

84 for dataId, datasetId in zip(dataIds, datasetIds): 

85 yield DatasetRef( 

86 datasetType=self.datasetType, 

87 dataId=dataId, 

88 id=datasetId, 

89 run=run.name, 

90 ) 

91 

92 def find(self, collection: CollectionRecord, dataId: DataCoordinate, 

93 timespan: Optional[Timespan] = None) -> Optional[DatasetRef]: 

94 # Docstring inherited from DatasetRecordStorage. 

95 assert dataId.graph == self.datasetType.dimensions 

96 if collection.type is CollectionType.CALIBRATION and timespan is None: 96 ↛ 97line 96 didn't jump to line 97, because the condition on line 96 was never true

97 raise TypeError(f"Cannot search for dataset in CALIBRATION collection {collection.name} " 

98 f"without an input timespan.") 

99 sql = self.select(collection=collection, dataId=dataId, id=SimpleQuery.Select, 

100 run=SimpleQuery.Select, timespan=timespan).combine() 

101 results = self._db.query(sql) 

102 row = results.fetchone() 

103 if row is None: 

104 return None 

105 if collection.type is CollectionType.CALIBRATION: 

106 # For temporal calibration lookups (only!) our invariants do not 

107 # guarantee that the number of result rows is <= 1. 

108 # They would if `select` constrained the given timespan to be 

109 # _contained_ by the validity range in the self._calibs table, 

110 # instead of simply _overlapping_ it, because we do guarantee that 

111 # the validity ranges are disjoint for a particular dataset type, 

112 # collection, and data ID. But using an overlap test and a check 

113 # for multiple result rows here allows us to provide a more useful 

114 # diagnostic, as well as allowing `select` to support more general 

115 # queries where multiple results are not an error. 

116 if results.fetchone() is not None: 

117 raise RuntimeError( 

118 f"Multiple matches found for calibration lookup in {collection.name} for " 

119 f"{self.datasetType.name} with {dataId} overlapping {timespan}. " 

120 ) 

121 return DatasetRef( 

122 datasetType=self.datasetType, 

123 dataId=dataId, 

124 id=row["id"], 

125 run=self._collections[row[self._runKeyColumn]].name 

126 ) 

127 

128 def delete(self, datasets: Iterable[DatasetRef]) -> None: 

129 # Docstring inherited from DatasetRecordStorage. 

130 # Only delete from common dataset table; ON DELETE foreign key clauses 

131 # will handle the rest. 

132 self._db.delete( 

133 self._static.dataset, 

134 ["id"], 

135 *[{"id": dataset.getCheckedId()} for dataset in datasets], 

136 ) 

137 

138 def associate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None: 

139 # Docstring inherited from DatasetRecordStorage. 

140 if collection.type is not CollectionType.TAGGED: 140 ↛ 141line 140 didn't jump to line 141, because the condition on line 140 was never true

141 raise TypeError(f"Cannot associate into collection '{collection}' " 

142 f"of type {collection.type.name}; must be TAGGED.") 

143 protoRow = { 

144 self._collections.getCollectionForeignKeyName(): collection.key, 

145 "dataset_type_id": self._dataset_type_id, 

146 } 

147 rows = [] 

148 for dataset in datasets: 

149 row = dict(protoRow, dataset_id=dataset.getCheckedId()) 

150 for dimension, value in dataset.dataId.items(): 

151 row[dimension.name] = value 

152 rows.append(row) 

153 self._db.replace(self._tags, *rows) 

154 

155 def disassociate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None: 

156 # Docstring inherited from DatasetRecordStorage. 

157 if collection.type is not CollectionType.TAGGED: 157 ↛ 158line 157 didn't jump to line 158, because the condition on line 157 was never true

158 raise TypeError(f"Cannot disassociate from collection '{collection}' " 

159 f"of type {collection.type.name}; must be TAGGED.") 

160 rows = [ 

161 { 

162 "dataset_id": dataset.getCheckedId(), 

163 self._collections.getCollectionForeignKeyName(): collection.key 

164 } 

165 for dataset in datasets 

166 ] 

167 self._db.delete(self._tags, ["dataset_id", self._collections.getCollectionForeignKeyName()], 

168 *rows) 

169 

170 def _buildCalibOverlapQuery(self, collection: CollectionRecord, 

171 dataIds: Optional[DataCoordinateSet], 

172 timespan: Timespan) -> SimpleQuery: 

173 assert self._calibs is not None 

174 # Start by building a SELECT query for any rows that would overlap 

175 # this one. 

176 query = SimpleQuery() 

177 query.join(self._calibs) 

178 # Add a WHERE clause matching the dataset type and collection. 

179 query.where.append(self._calibs.columns.dataset_type_id == self._dataset_type_id) 

180 query.where.append( 

181 self._calibs.columns[self._collections.getCollectionForeignKeyName()] == collection.key 

182 ) 

183 # Add a WHERE clause matching any of the given data IDs. 

184 if dataIds is not None: 

185 dataIds.constrain( 

186 query, 

187 lambda name: self._calibs.columns[name], # type: ignore 

188 ) 

189 # Add WHERE clause for timespan overlaps. 

190 tsRepr = self._db.getTimespanRepresentation() 

191 query.where.append(tsRepr.fromSelectable(self._calibs).overlaps(timespan)) 

192 return query 

193 

194 def certify(self, collection: CollectionRecord, datasets: Iterable[DatasetRef], 

195 timespan: Timespan) -> None: 

196 # Docstring inherited from DatasetRecordStorage. 

197 if self._calibs is None: 197 ↛ 198line 197 didn't jump to line 198, because the condition on line 197 was never true

198 raise TypeError(f"Cannot certify datasets of type {self.datasetType.name}, for which " 

199 f"DatasetType.isCalibration() is False.") 

200 if collection.type is not CollectionType.CALIBRATION: 200 ↛ 201line 200 didn't jump to line 201, because the condition on line 200 was never true

201 raise TypeError(f"Cannot certify into collection '{collection}' " 

202 f"of type {collection.type.name}; must be CALIBRATION.") 

203 tsRepr = self._db.getTimespanRepresentation() 

204 protoRow = { 

205 self._collections.getCollectionForeignKeyName(): collection.key, 

206 "dataset_type_id": self._dataset_type_id, 

207 } 

208 rows = [] 

209 dataIds: Optional[Set[DataCoordinate]] = set() if not tsRepr.hasExclusionConstraint() else None 

210 for dataset in datasets: 

211 row = dict(protoRow, dataset_id=dataset.getCheckedId()) 

212 for dimension, value in dataset.dataId.items(): 

213 row[dimension.name] = value 

214 tsRepr.update(timespan, result=row) 

215 rows.append(row) 

216 if dataIds is not None: 216 ↛ 210line 216 didn't jump to line 210, because the condition on line 216 was never false

217 dataIds.add(dataset.dataId) 

218 if tsRepr.hasExclusionConstraint(): 218 ↛ 221line 218 didn't jump to line 221, because the condition on line 218 was never true

219 # Rely on database constraint to enforce invariants; we just 

220 # reraise the exception for consistency across DB engines. 

221 try: 

222 self._db.insert(self._calibs, *rows) 

223 except sqlalchemy.exc.IntegrityError as err: 

224 raise ConflictingDefinitionError( 

225 f"Validity range conflict certifying datasets of type {self.datasetType.name} " 

226 f"into {collection.name} for range [{timespan.begin}, {timespan.end})." 

227 ) from err 

228 else: 

229 # Have to implement exclusion constraint ourselves. 

230 # Start by building a SELECT query for any rows that would overlap 

231 # this one. 

232 query = self._buildCalibOverlapQuery( 

233 collection, 

234 DataCoordinateSet(dataIds, graph=self.datasetType.dimensions), # type: ignore 

235 timespan 

236 ) 

237 query.columns.append(sqlalchemy.sql.func.count()) 

238 sql = query.combine() 

239 # Acquire a table lock to ensure there are no concurrent writes 

240 # could invalidate our checking before we finish the inserts. We 

241 # use a SAVEPOINT in case there is an outer transaction that a 

242 # failure here should not roll back. 

243 with self._db.transaction(lock=[self._calibs], savepoint=True): 

244 # Run the check SELECT query. 

245 conflicting = self._db.query(sql).scalar() 

246 if conflicting > 0: 

247 raise ConflictingDefinitionError( 

248 f"{conflicting} validity range conflicts certifying datasets of type " 

249 f"{self.datasetType.name} into {collection.name} for range " 

250 f"[{timespan.begin}, {timespan.end})." 

251 ) 

252 # Proceed with the insert. 

253 self._db.insert(self._calibs, *rows) 

254 

255 def decertify(self, collection: CollectionRecord, timespan: Timespan, *, 

256 dataIds: Optional[Iterable[DataCoordinate]] = None) -> None: 

257 # Docstring inherited from DatasetRecordStorage. 

258 if self._calibs is None: 258 ↛ 259line 258 didn't jump to line 259, because the condition on line 258 was never true

259 raise TypeError(f"Cannot decertify datasets of type {self.datasetType.name}, for which " 

260 f"DatasetType.isCalibration() is False.") 

261 if collection.type is not CollectionType.CALIBRATION: 261 ↛ 262line 261 didn't jump to line 262, because the condition on line 261 was never true

262 raise TypeError(f"Cannot decertify from collection '{collection}' " 

263 f"of type {collection.type.name}; must be CALIBRATION.") 

264 tsRepr = self._db.getTimespanRepresentation() 

265 # Construct a SELECT query to find all rows that overlap our inputs. 

266 dataIdSet: Optional[DataCoordinateSet] 

267 if dataIds is not None: 

268 dataIdSet = DataCoordinateSet(set(dataIds), graph=self.datasetType.dimensions) 

269 else: 

270 dataIdSet = None 

271 query = self._buildCalibOverlapQuery(collection, dataIdSet, timespan) 

272 query.columns.extend(self._calibs.columns) 

273 sql = query.combine() 

274 # Set up collections to populate with the rows we'll want to modify. 

275 # The insert rows will have the same values for collection and 

276 # dataset type. 

277 protoInsertRow = { 

278 self._collections.getCollectionForeignKeyName(): collection.key, 

279 "dataset_type_id": self._dataset_type_id, 

280 } 

281 rowsToDelete = [] 

282 rowsToInsert = [] 

283 # Acquire a table lock to ensure there are no concurrent writes 

284 # between the SELECT and the DELETE and INSERT queries based on it. 

285 with self._db.transaction(lock=[self._calibs], savepoint=True): 

286 for row in self._db.query(sql): 

287 rowsToDelete.append({"id": row["id"]}) 

288 # Construct the insert row(s) by copying the prototype row, 

289 # then adding the dimension column values, then adding what's 

290 # left of the timespan from that row after we subtract the 

291 # given timespan. 

292 newInsertRow = protoInsertRow.copy() 

293 newInsertRow["dataset_id"] = row["dataset_id"] 

294 for name in self.datasetType.dimensions.required.names: 

295 newInsertRow[name] = row[name] 

296 rowTimespan = tsRepr.extract(row) 

297 assert rowTimespan is not None, "Field should have a NOT NULL constraint." 

298 for diffTimespan in rowTimespan.difference(timespan): 

299 rowsToInsert.append(tsRepr.update(diffTimespan, result=newInsertRow.copy())) 

300 # Run the DELETE and INSERT queries. 

301 self._db.delete(self._calibs, ["id"], *rowsToDelete) 

302 self._db.insert(self._calibs, *rowsToInsert) 

303 

304 def select(self, collection: CollectionRecord, 

305 dataId: SimpleQuery.Select.Or[DataCoordinate] = SimpleQuery.Select, 

306 id: SimpleQuery.Select.Or[Optional[int]] = SimpleQuery.Select, 

307 run: SimpleQuery.Select.Or[None] = SimpleQuery.Select, 

308 timespan: SimpleQuery.Select.Or[Optional[Timespan]] = SimpleQuery.Select, 

309 ) -> SimpleQuery: 

310 # Docstring inherited from DatasetRecordStorage. 

311 assert collection.type is not CollectionType.CHAINED 

312 query = SimpleQuery() 

313 # We always include the _static.dataset table, and we can always get 

314 # the id and run fields from that; passing them as kwargs here tells 

315 # SimpleQuery to handle them whether they're constraints or results. 

316 # We always constraint the dataset_type_id here as well. 

317 query.join( 

318 self._static.dataset, 

319 id=id, 

320 dataset_type_id=self._dataset_type_id, 

321 **{self._runKeyColumn: run} 

322 ) 

323 # If and only if the collection is a RUN, we constrain it in the static 

324 # table (and also the tags or calibs table below) 

325 if collection.type is CollectionType.RUN: 

326 query.where.append(self._static.dataset.columns[self._runKeyColumn] 

327 == collection.key) 

328 # We get or constrain the data ID from the tags/calibs table, but 

329 # that's multiple columns, not one, so we need to transform the one 

330 # Select.Or argument into a dictionary of them. 

331 kwargs: Dict[str, Any] 

332 if dataId is SimpleQuery.Select: 

333 kwargs = {dim.name: SimpleQuery.Select for dim in self.datasetType.dimensions.required} 

334 else: 

335 kwargs = dict(dataId.byName()) 

336 # We always constrain (never retrieve) the collection from the tags 

337 # table. 

338 kwargs[self._collections.getCollectionForeignKeyName()] = collection.key 

339 # And now we finally join in the tags or calibs table. 

340 if collection.type is CollectionType.CALIBRATION: 

341 assert self._calibs is not None, \ 

342 "DatasetTypes with isCalibration() == False can never be found in a CALIBRATION collection." 

343 tsRepr = self._db.getTimespanRepresentation() 

344 # Add the timespan column(s) to the result columns, or constrain 

345 # the timespan via an overlap condition. 

346 if timespan is SimpleQuery.Select: 

347 kwargs.update({k: SimpleQuery.Select for k in tsRepr.getFieldNames()}) 

348 elif timespan is not None: 348 ↛ 350line 348 didn't jump to line 350, because the condition on line 348 was never false

349 query.where.append(tsRepr.fromSelectable(self._calibs).overlaps(timespan)) 

350 query.join( 

351 self._calibs, 

352 onclause=(self._static.dataset.columns.id == self._calibs.columns.dataset_id), 

353 **kwargs 

354 ) 

355 else: 

356 query.join( 

357 self._tags, 

358 onclause=(self._static.dataset.columns.id == self._tags.columns.dataset_id), 

359 **kwargs 

360 ) 

361 return query 

362 

363 def getDataId(self, id: int) -> DataCoordinate: 

364 # Docstring inherited from DatasetRecordStorage. 

365 # This query could return multiple rows (one for each tagged collection 

366 # the dataset is in, plus one for its run collection), and we don't 

367 # care which of those we get. 

368 sql = self._tags.select().where( 

369 sqlalchemy.sql.and_( 

370 self._tags.columns.dataset_id == id, 

371 self._tags.columns.dataset_type_id == self._dataset_type_id 

372 ) 

373 ).limit(1) 

374 row = self._db.query(sql).fetchone() 

375 assert row is not None, "Should be guaranteed by caller and foreign key constraints." 

376 return DataCoordinate.standardize( 

377 {dimension.name: row[dimension.name] for dimension in self.datasetType.dimensions.required}, 

378 graph=self.datasetType.dimensions 

379 )