Coverage for python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py : 89%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3__all__ = ("ByDimensionsDatasetRecordStorage",)
5from typing import (
6 Any,
7 Dict,
8 Iterable,
9 Iterator,
10 Optional,
11 Set,
12 TYPE_CHECKING,
13)
15import sqlalchemy
17from lsst.daf.butler import (
18 CollectionType,
19 DataCoordinate,
20 DataCoordinateSet,
21 DatasetRef,
22 DatasetType,
23 SimpleQuery,
24 Timespan,
25)
26from lsst.daf.butler.registry import ConflictingDefinitionError
27from lsst.daf.butler.registry.interfaces import DatasetRecordStorage
29if TYPE_CHECKING: 29 ↛ 30line 29 didn't jump to line 30, because the condition on line 29 was never true
30 from ...interfaces import CollectionManager, CollectionRecord, Database, RunRecord
31 from .tables import StaticDatasetTablesTuple
34class ByDimensionsDatasetRecordStorage(DatasetRecordStorage):
35 """Dataset record storage implementation paired with
36 `ByDimensionsDatasetRecordStorageManager`; see that class for more
37 information.
39 Instances of this class should never be constructed directly; use
40 `DatasetRecordStorageManager.register` instead.
41 """
42 def __init__(self, *, datasetType: DatasetType,
43 db: Database,
44 dataset_type_id: int,
45 collections: CollectionManager,
46 static: StaticDatasetTablesTuple,
47 tags: sqlalchemy.schema.Table,
48 calibs: Optional[sqlalchemy.schema.Table]):
49 super().__init__(datasetType=datasetType)
50 self._dataset_type_id = dataset_type_id
51 self._db = db
52 self._collections = collections
53 self._static = static
54 self._tags = tags
55 self._calibs = calibs
56 self._runKeyColumn = collections.getRunForeignKeyName()
58 def insert(self, run: RunRecord, dataIds: Iterable[DataCoordinate]) -> Iterator[DatasetRef]:
59 # Docstring inherited from DatasetRecordStorage.
60 staticRow = {
61 "dataset_type_id": self._dataset_type_id,
62 self._runKeyColumn: run.key,
63 }
64 dataIds = list(dataIds)
65 # Insert into the static dataset table, generating autoincrement
66 # dataset_id values.
67 with self._db.transaction():
68 datasetIds = self._db.insert(self._static.dataset, *([staticRow]*len(dataIds)),
69 returnIds=True)
70 assert datasetIds is not None
71 # Combine the generated dataset_id values and data ID fields to
72 # form rows to be inserted into the tags table.
73 protoTagsRow = {
74 "dataset_type_id": self._dataset_type_id,
75 self._collections.getCollectionForeignKeyName(): run.key,
76 }
77 tagsRows = [
78 dict(protoTagsRow, dataset_id=dataset_id, **dataId.byName())
79 for dataId, dataset_id in zip(dataIds, datasetIds)
80 ]
81 # Insert those rows into the tags table. This is where we'll
82 # get any unique constraint violations.
83 self._db.insert(self._tags, *tagsRows)
84 for dataId, datasetId in zip(dataIds, datasetIds):
85 yield DatasetRef(
86 datasetType=self.datasetType,
87 dataId=dataId,
88 id=datasetId,
89 run=run.name,
90 )
92 def find(self, collection: CollectionRecord, dataId: DataCoordinate,
93 timespan: Optional[Timespan] = None) -> Optional[DatasetRef]:
94 # Docstring inherited from DatasetRecordStorage.
95 assert dataId.graph == self.datasetType.dimensions
96 if collection.type is CollectionType.CALIBRATION and timespan is None: 96 ↛ 97line 96 didn't jump to line 97, because the condition on line 96 was never true
97 raise TypeError(f"Cannot search for dataset in CALIBRATION collection {collection.name} "
98 f"without an input timespan.")
99 sql = self.select(collection=collection, dataId=dataId, id=SimpleQuery.Select,
100 run=SimpleQuery.Select, timespan=timespan).combine()
101 results = self._db.query(sql)
102 row = results.fetchone()
103 if row is None:
104 return None
105 if collection.type is CollectionType.CALIBRATION:
106 # For temporal calibration lookups (only!) our invariants do not
107 # guarantee that the number of result rows is <= 1.
108 # They would if `select` constrained the given timespan to be
109 # _contained_ by the validity range in the self._calibs table,
110 # instead of simply _overlapping_ it, because we do guarantee that
111 # the validity ranges are disjoint for a particular dataset type,
112 # collection, and data ID. But using an overlap test and a check
113 # for multiple result rows here allows us to provide a more useful
114 # diagnostic, as well as allowing `select` to support more general
115 # queries where multiple results are not an error.
116 if results.fetchone() is not None:
117 raise RuntimeError(
118 f"Multiple matches found for calibration lookup in {collection.name} for "
119 f"{self.datasetType.name} with {dataId} overlapping {timespan}. "
120 )
121 return DatasetRef(
122 datasetType=self.datasetType,
123 dataId=dataId,
124 id=row["id"],
125 run=self._collections[row[self._runKeyColumn]].name
126 )
128 def delete(self, datasets: Iterable[DatasetRef]) -> None:
129 # Docstring inherited from DatasetRecordStorage.
130 # Only delete from common dataset table; ON DELETE foreign key clauses
131 # will handle the rest.
132 self._db.delete(
133 self._static.dataset,
134 ["id"],
135 *[{"id": dataset.getCheckedId()} for dataset in datasets],
136 )
138 def associate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None:
139 # Docstring inherited from DatasetRecordStorage.
140 if collection.type is not CollectionType.TAGGED: 140 ↛ 141line 140 didn't jump to line 141, because the condition on line 140 was never true
141 raise TypeError(f"Cannot associate into collection '{collection}' "
142 f"of type {collection.type.name}; must be TAGGED.")
143 protoRow = {
144 self._collections.getCollectionForeignKeyName(): collection.key,
145 "dataset_type_id": self._dataset_type_id,
146 }
147 rows = []
148 for dataset in datasets:
149 row = dict(protoRow, dataset_id=dataset.getCheckedId())
150 for dimension, value in dataset.dataId.items():
151 row[dimension.name] = value
152 rows.append(row)
153 self._db.replace(self._tags, *rows)
155 def disassociate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None:
156 # Docstring inherited from DatasetRecordStorage.
157 if collection.type is not CollectionType.TAGGED: 157 ↛ 158line 157 didn't jump to line 158, because the condition on line 157 was never true
158 raise TypeError(f"Cannot disassociate from collection '{collection}' "
159 f"of type {collection.type.name}; must be TAGGED.")
160 rows = [
161 {
162 "dataset_id": dataset.getCheckedId(),
163 self._collections.getCollectionForeignKeyName(): collection.key
164 }
165 for dataset in datasets
166 ]
167 self._db.delete(self._tags, ["dataset_id", self._collections.getCollectionForeignKeyName()],
168 *rows)
170 def _buildCalibOverlapQuery(self, collection: CollectionRecord,
171 dataIds: Optional[DataCoordinateSet],
172 timespan: Timespan) -> SimpleQuery:
173 assert self._calibs is not None
174 # Start by building a SELECT query for any rows that would overlap
175 # this one.
176 query = SimpleQuery()
177 query.join(self._calibs)
178 # Add a WHERE clause matching the dataset type and collection.
179 query.where.append(self._calibs.columns.dataset_type_id == self._dataset_type_id)
180 query.where.append(
181 self._calibs.columns[self._collections.getCollectionForeignKeyName()] == collection.key
182 )
183 # Add a WHERE clause matching any of the given data IDs.
184 if dataIds is not None:
185 dataIds.constrain(
186 query,
187 lambda name: self._calibs.columns[name], # type: ignore
188 )
189 # Add WHERE clause for timespan overlaps.
190 tsRepr = self._db.getTimespanRepresentation()
191 query.where.append(tsRepr.fromSelectable(self._calibs).overlaps(timespan))
192 return query
194 def certify(self, collection: CollectionRecord, datasets: Iterable[DatasetRef],
195 timespan: Timespan) -> None:
196 # Docstring inherited from DatasetRecordStorage.
197 if self._calibs is None: 197 ↛ 198line 197 didn't jump to line 198, because the condition on line 197 was never true
198 raise TypeError(f"Cannot certify datasets of type {self.datasetType.name}, for which "
199 f"DatasetType.isCalibration() is False.")
200 if collection.type is not CollectionType.CALIBRATION: 200 ↛ 201line 200 didn't jump to line 201, because the condition on line 200 was never true
201 raise TypeError(f"Cannot certify into collection '{collection}' "
202 f"of type {collection.type.name}; must be CALIBRATION.")
203 tsRepr = self._db.getTimespanRepresentation()
204 protoRow = {
205 self._collections.getCollectionForeignKeyName(): collection.key,
206 "dataset_type_id": self._dataset_type_id,
207 }
208 rows = []
209 dataIds: Optional[Set[DataCoordinate]] = set() if not tsRepr.hasExclusionConstraint() else None
210 for dataset in datasets:
211 row = dict(protoRow, dataset_id=dataset.getCheckedId())
212 for dimension, value in dataset.dataId.items():
213 row[dimension.name] = value
214 tsRepr.update(timespan, result=row)
215 rows.append(row)
216 if dataIds is not None: 216 ↛ 210line 216 didn't jump to line 210, because the condition on line 216 was never false
217 dataIds.add(dataset.dataId)
218 if tsRepr.hasExclusionConstraint(): 218 ↛ 221line 218 didn't jump to line 221, because the condition on line 218 was never true
219 # Rely on database constraint to enforce invariants; we just
220 # reraise the exception for consistency across DB engines.
221 try:
222 self._db.insert(self._calibs, *rows)
223 except sqlalchemy.exc.IntegrityError as err:
224 raise ConflictingDefinitionError(
225 f"Validity range conflict certifying datasets of type {self.datasetType.name} "
226 f"into {collection.name} for range [{timespan.begin}, {timespan.end})."
227 ) from err
228 else:
229 # Have to implement exclusion constraint ourselves.
230 # Start by building a SELECT query for any rows that would overlap
231 # this one.
232 query = self._buildCalibOverlapQuery(
233 collection,
234 DataCoordinateSet(dataIds, graph=self.datasetType.dimensions), # type: ignore
235 timespan
236 )
237 query.columns.append(sqlalchemy.sql.func.count())
238 sql = query.combine()
239 # Acquire a table lock to ensure there are no concurrent writes
240 # could invalidate our checking before we finish the inserts. We
241 # use a SAVEPOINT in case there is an outer transaction that a
242 # failure here should not roll back.
243 with self._db.transaction(lock=[self._calibs], savepoint=True):
244 # Run the check SELECT query.
245 conflicting = self._db.query(sql).scalar()
246 if conflicting > 0:
247 raise ConflictingDefinitionError(
248 f"{conflicting} validity range conflicts certifying datasets of type "
249 f"{self.datasetType.name} into {collection.name} for range "
250 f"[{timespan.begin}, {timespan.end})."
251 )
252 # Proceed with the insert.
253 self._db.insert(self._calibs, *rows)
255 def decertify(self, collection: CollectionRecord, timespan: Timespan, *,
256 dataIds: Optional[Iterable[DataCoordinate]] = None) -> None:
257 # Docstring inherited from DatasetRecordStorage.
258 if self._calibs is None: 258 ↛ 259line 258 didn't jump to line 259, because the condition on line 258 was never true
259 raise TypeError(f"Cannot decertify datasets of type {self.datasetType.name}, for which "
260 f"DatasetType.isCalibration() is False.")
261 if collection.type is not CollectionType.CALIBRATION: 261 ↛ 262line 261 didn't jump to line 262, because the condition on line 261 was never true
262 raise TypeError(f"Cannot decertify from collection '{collection}' "
263 f"of type {collection.type.name}; must be CALIBRATION.")
264 tsRepr = self._db.getTimespanRepresentation()
265 # Construct a SELECT query to find all rows that overlap our inputs.
266 dataIdSet: Optional[DataCoordinateSet]
267 if dataIds is not None:
268 dataIdSet = DataCoordinateSet(set(dataIds), graph=self.datasetType.dimensions)
269 else:
270 dataIdSet = None
271 query = self._buildCalibOverlapQuery(collection, dataIdSet, timespan)
272 query.columns.extend(self._calibs.columns)
273 sql = query.combine()
274 # Set up collections to populate with the rows we'll want to modify.
275 # The insert rows will have the same values for collection and
276 # dataset type.
277 protoInsertRow = {
278 self._collections.getCollectionForeignKeyName(): collection.key,
279 "dataset_type_id": self._dataset_type_id,
280 }
281 rowsToDelete = []
282 rowsToInsert = []
283 # Acquire a table lock to ensure there are no concurrent writes
284 # between the SELECT and the DELETE and INSERT queries based on it.
285 with self._db.transaction(lock=[self._calibs], savepoint=True):
286 for row in self._db.query(sql):
287 rowsToDelete.append({"id": row["id"]})
288 # Construct the insert row(s) by copying the prototype row,
289 # then adding the dimension column values, then adding what's
290 # left of the timespan from that row after we subtract the
291 # given timespan.
292 newInsertRow = protoInsertRow.copy()
293 newInsertRow["dataset_id"] = row["dataset_id"]
294 for name in self.datasetType.dimensions.required.names:
295 newInsertRow[name] = row[name]
296 rowTimespan = tsRepr.extract(row)
297 assert rowTimespan is not None, "Field should have a NOT NULL constraint."
298 for diffTimespan in rowTimespan.difference(timespan):
299 rowsToInsert.append(tsRepr.update(diffTimespan, result=newInsertRow.copy()))
300 # Run the DELETE and INSERT queries.
301 self._db.delete(self._calibs, ["id"], *rowsToDelete)
302 self._db.insert(self._calibs, *rowsToInsert)
304 def select(self, collection: CollectionRecord,
305 dataId: SimpleQuery.Select.Or[DataCoordinate] = SimpleQuery.Select,
306 id: SimpleQuery.Select.Or[Optional[int]] = SimpleQuery.Select,
307 run: SimpleQuery.Select.Or[None] = SimpleQuery.Select,
308 timespan: SimpleQuery.Select.Or[Optional[Timespan]] = SimpleQuery.Select,
309 ) -> SimpleQuery:
310 # Docstring inherited from DatasetRecordStorage.
311 assert collection.type is not CollectionType.CHAINED
312 query = SimpleQuery()
313 # We always include the _static.dataset table, and we can always get
314 # the id and run fields from that; passing them as kwargs here tells
315 # SimpleQuery to handle them whether they're constraints or results.
316 # We always constraint the dataset_type_id here as well.
317 query.join(
318 self._static.dataset,
319 id=id,
320 dataset_type_id=self._dataset_type_id,
321 **{self._runKeyColumn: run}
322 )
323 # If and only if the collection is a RUN, we constrain it in the static
324 # table (and also the tags or calibs table below)
325 if collection.type is CollectionType.RUN:
326 query.where.append(self._static.dataset.columns[self._runKeyColumn]
327 == collection.key)
328 # We get or constrain the data ID from the tags/calibs table, but
329 # that's multiple columns, not one, so we need to transform the one
330 # Select.Or argument into a dictionary of them.
331 kwargs: Dict[str, Any]
332 if dataId is SimpleQuery.Select:
333 kwargs = {dim.name: SimpleQuery.Select for dim in self.datasetType.dimensions.required}
334 else:
335 kwargs = dict(dataId.byName())
336 # We always constrain (never retrieve) the collection from the tags
337 # table.
338 kwargs[self._collections.getCollectionForeignKeyName()] = collection.key
339 # And now we finally join in the tags or calibs table.
340 if collection.type is CollectionType.CALIBRATION:
341 assert self._calibs is not None, \
342 "DatasetTypes with isCalibration() == False can never be found in a CALIBRATION collection."
343 tsRepr = self._db.getTimespanRepresentation()
344 # Add the timespan column(s) to the result columns, or constrain
345 # the timespan via an overlap condition.
346 if timespan is SimpleQuery.Select:
347 kwargs.update({k: SimpleQuery.Select for k in tsRepr.getFieldNames()})
348 elif timespan is not None: 348 ↛ 350line 348 didn't jump to line 350, because the condition on line 348 was never false
349 query.where.append(tsRepr.fromSelectable(self._calibs).overlaps(timespan))
350 query.join(
351 self._calibs,
352 onclause=(self._static.dataset.columns.id == self._calibs.columns.dataset_id),
353 **kwargs
354 )
355 else:
356 query.join(
357 self._tags,
358 onclause=(self._static.dataset.columns.id == self._tags.columns.dataset_id),
359 **kwargs
360 )
361 return query
363 def getDataId(self, id: int) -> DataCoordinate:
364 # Docstring inherited from DatasetRecordStorage.
365 # This query could return multiple rows (one for each tagged collection
366 # the dataset is in, plus one for its run collection), and we don't
367 # care which of those we get.
368 sql = self._tags.select().where(
369 sqlalchemy.sql.and_(
370 self._tags.columns.dataset_id == id,
371 self._tags.columns.dataset_type_id == self._dataset_type_id
372 )
373 ).limit(1)
374 row = self._db.query(sql).fetchone()
375 assert row is not None, "Should be guaranteed by caller and foreign key constraints."
376 return DataCoordinate.standardize(
377 {dimension.name: row[dimension.name] for dimension in self.datasetType.dimensions.required},
378 graph=self.datasetType.dimensions
379 )