Coverage for python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py: 88%
301 statements
« prev ^ index » next coverage.py v6.4.4, created at 2022-08-30 02:26 -0700
« prev ^ index » next coverage.py v6.4.4, created at 2022-08-30 02:26 -0700
1from __future__ import annotations
3__all__ = ("ByDimensionsDatasetRecordStorage",)
5import uuid
6from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set
8import sqlalchemy
9from lsst.daf.butler import (
10 CollectionType,
11 DataCoordinate,
12 DataCoordinateSet,
13 DatasetId,
14 DatasetRef,
15 DatasetType,
16 SimpleQuery,
17 Timespan,
18 ddl,
19)
20from lsst.daf.butler.registry import (
21 CollectionTypeError,
22 ConflictingDefinitionError,
23 UnsupportedIdGeneratorError,
24)
25from lsst.daf.butler.registry.interfaces import DatasetIdFactory, DatasetIdGenEnum, DatasetRecordStorage
27from ...summaries import GovernorDimensionRestriction
28from .tables import makeTagTableSpec
30if TYPE_CHECKING: 30 ↛ 31line 30 didn't jump to line 31, because the condition on line 30 was never true
31 from ...interfaces import CollectionManager, CollectionRecord, Database, RunRecord
32 from .summaries import CollectionSummaryManager
33 from .tables import StaticDatasetTablesTuple
36class ByDimensionsDatasetRecordStorage(DatasetRecordStorage):
37 """Dataset record storage implementation paired with
38 `ByDimensionsDatasetRecordStorageManager`; see that class for more
39 information.
41 Instances of this class should never be constructed directly; use
42 `DatasetRecordStorageManager.register` instead.
43 """
45 def __init__(
46 self,
47 *,
48 datasetType: DatasetType,
49 db: Database,
50 dataset_type_id: int,
51 collections: CollectionManager,
52 static: StaticDatasetTablesTuple,
53 summaries: CollectionSummaryManager,
54 tags: sqlalchemy.schema.Table,
55 calibs: Optional[sqlalchemy.schema.Table],
56 ):
57 super().__init__(datasetType=datasetType)
58 self._dataset_type_id = dataset_type_id
59 self._db = db
60 self._collections = collections
61 self._static = static
62 self._summaries = summaries
63 self._tags = tags
64 self._calibs = calibs
65 self._runKeyColumn = collections.getRunForeignKeyName()
67 def find(
68 self, collection: CollectionRecord, dataId: DataCoordinate, timespan: Optional[Timespan] = None
69 ) -> Optional[DatasetRef]:
70 # Docstring inherited from DatasetRecordStorage.
71 assert dataId.graph == self.datasetType.dimensions
72 if collection.type is CollectionType.CALIBRATION and timespan is None: 72 ↛ 73line 72 didn't jump to line 73, because the condition on line 72 was never true
73 raise TypeError(
74 f"Cannot search for dataset in CALIBRATION collection {collection.name} "
75 f"without an input timespan."
76 )
77 sql = self.select(
78 collection, dataId=dataId, id=SimpleQuery.Select, run=SimpleQuery.Select, timespan=timespan
79 )
80 results = self._db.query(sql)
81 row = results.fetchone()
82 if row is None:
83 return None
84 if collection.type is CollectionType.CALIBRATION:
85 # For temporal calibration lookups (only!) our invariants do not
86 # guarantee that the number of result rows is <= 1.
87 # They would if `select` constrained the given timespan to be
88 # _contained_ by the validity range in the self._calibs table,
89 # instead of simply _overlapping_ it, because we do guarantee that
90 # the validity ranges are disjoint for a particular dataset type,
91 # collection, and data ID. But using an overlap test and a check
92 # for multiple result rows here allows us to provide a more useful
93 # diagnostic, as well as allowing `select` to support more general
94 # queries where multiple results are not an error.
95 if results.fetchone() is not None:
96 raise RuntimeError(
97 f"Multiple matches found for calibration lookup in {collection.name} for "
98 f"{self.datasetType.name} with {dataId} overlapping {timespan}. "
99 )
100 return DatasetRef(
101 datasetType=self.datasetType,
102 dataId=dataId,
103 id=row.id,
104 run=self._collections[row._mapping[self._runKeyColumn]].name,
105 )
107 def delete(self, datasets: Iterable[DatasetRef]) -> None:
108 # Docstring inherited from DatasetRecordStorage.
109 # Only delete from common dataset table; ON DELETE foreign key clauses
110 # will handle the rest.
111 self._db.delete(
112 self._static.dataset,
113 ["id"],
114 *[{"id": dataset.getCheckedId()} for dataset in datasets],
115 )
117 def associate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None:
118 # Docstring inherited from DatasetRecordStorage.
119 if collection.type is not CollectionType.TAGGED: 119 ↛ 120line 119 didn't jump to line 120, because the condition on line 119 was never true
120 raise TypeError(
121 f"Cannot associate into collection '{collection.name}' "
122 f"of type {collection.type.name}; must be TAGGED."
123 )
124 protoRow = {
125 self._collections.getCollectionForeignKeyName(): collection.key,
126 "dataset_type_id": self._dataset_type_id,
127 }
128 rows = []
129 governorValues = GovernorDimensionRestriction.makeEmpty(self.datasetType.dimensions.universe)
130 for dataset in datasets:
131 row = dict(protoRow, dataset_id=dataset.getCheckedId())
132 for dimension, value in dataset.dataId.items():
133 row[dimension.name] = value
134 governorValues.update_extract(dataset.dataId)
135 rows.append(row)
136 # Update the summary tables for this collection in case this is the
137 # first time this dataset type or these governor values will be
138 # inserted there.
139 self._summaries.update(collection, self.datasetType, self._dataset_type_id, governorValues)
140 # Update the tag table itself.
141 self._db.replace(self._tags, *rows)
143 def disassociate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None:
144 # Docstring inherited from DatasetRecordStorage.
145 if collection.type is not CollectionType.TAGGED: 145 ↛ 146line 145 didn't jump to line 146, because the condition on line 145 was never true
146 raise TypeError(
147 f"Cannot disassociate from collection '{collection.name}' "
148 f"of type {collection.type.name}; must be TAGGED."
149 )
150 rows = [
151 {
152 "dataset_id": dataset.getCheckedId(),
153 self._collections.getCollectionForeignKeyName(): collection.key,
154 }
155 for dataset in datasets
156 ]
157 self._db.delete(self._tags, ["dataset_id", self._collections.getCollectionForeignKeyName()], *rows)
159 def _buildCalibOverlapQuery(
160 self, collection: CollectionRecord, dataIds: Optional[DataCoordinateSet], timespan: Timespan
161 ) -> SimpleQuery:
162 assert self._calibs is not None
163 # Start by building a SELECT query for any rows that would overlap
164 # this one.
165 query = SimpleQuery()
166 query.join(self._calibs)
167 # Add a WHERE clause matching the dataset type and collection.
168 query.where.append(self._calibs.columns.dataset_type_id == self._dataset_type_id)
169 query.where.append(
170 self._calibs.columns[self._collections.getCollectionForeignKeyName()] == collection.key
171 )
172 # Add a WHERE clause matching any of the given data IDs.
173 if dataIds is not None:
174 dataIds.constrain(
175 query,
176 lambda name: self._calibs.columns[name], # type: ignore
177 )
178 # Add WHERE clause for timespan overlaps.
179 TimespanReprClass = self._db.getTimespanRepresentation()
180 query.where.append(
181 TimespanReprClass.fromSelectable(self._calibs).overlaps(TimespanReprClass.fromLiteral(timespan))
182 )
183 return query
185 def certify(
186 self, collection: CollectionRecord, datasets: Iterable[DatasetRef], timespan: Timespan
187 ) -> None:
188 # Docstring inherited from DatasetRecordStorage.
189 if self._calibs is None: 189 ↛ 190line 189 didn't jump to line 190, because the condition on line 189 was never true
190 raise CollectionTypeError(
191 f"Cannot certify datasets of type {self.datasetType.name}, for which "
192 f"DatasetType.isCalibration() is False."
193 )
194 if collection.type is not CollectionType.CALIBRATION: 194 ↛ 195line 194 didn't jump to line 195, because the condition on line 194 was never true
195 raise CollectionTypeError(
196 f"Cannot certify into collection '{collection.name}' "
197 f"of type {collection.type.name}; must be CALIBRATION."
198 )
199 TimespanReprClass = self._db.getTimespanRepresentation()
200 protoRow = {
201 self._collections.getCollectionForeignKeyName(): collection.key,
202 "dataset_type_id": self._dataset_type_id,
203 }
204 rows = []
205 governorValues = GovernorDimensionRestriction.makeEmpty(self.datasetType.dimensions.universe)
206 dataIds: Optional[Set[DataCoordinate]] = (
207 set() if not TimespanReprClass.hasExclusionConstraint() else None
208 )
209 for dataset in datasets:
210 row = dict(protoRow, dataset_id=dataset.getCheckedId())
211 for dimension, value in dataset.dataId.items():
212 row[dimension.name] = value
213 TimespanReprClass.update(timespan, result=row)
214 governorValues.update_extract(dataset.dataId)
215 rows.append(row)
216 if dataIds is not None: 216 ↛ 209line 216 didn't jump to line 209, because the condition on line 216 was never false
217 dataIds.add(dataset.dataId)
218 # Update the summary tables for this collection in case this is the
219 # first time this dataset type or these governor values will be
220 # inserted there.
221 self._summaries.update(collection, self.datasetType, self._dataset_type_id, governorValues)
222 # Update the association table itself.
223 if TimespanReprClass.hasExclusionConstraint(): 223 ↛ 226line 223 didn't jump to line 226, because the condition on line 223 was never true
224 # Rely on database constraint to enforce invariants; we just
225 # reraise the exception for consistency across DB engines.
226 try:
227 self._db.insert(self._calibs, *rows)
228 except sqlalchemy.exc.IntegrityError as err:
229 raise ConflictingDefinitionError(
230 f"Validity range conflict certifying datasets of type {self.datasetType.name} "
231 f"into {collection.name} for range [{timespan.begin}, {timespan.end})."
232 ) from err
233 else:
234 # Have to implement exclusion constraint ourselves.
235 # Start by building a SELECT query for any rows that would overlap
236 # this one.
237 query = self._buildCalibOverlapQuery(
238 collection,
239 DataCoordinateSet(dataIds, graph=self.datasetType.dimensions), # type: ignore
240 timespan,
241 )
242 query.columns.append(sqlalchemy.sql.func.count())
243 sql = query.combine()
244 # Acquire a table lock to ensure there are no concurrent writes
245 # could invalidate our checking before we finish the inserts. We
246 # use a SAVEPOINT in case there is an outer transaction that a
247 # failure here should not roll back.
248 with self._db.transaction(lock=[self._calibs], savepoint=True):
249 # Run the check SELECT query.
250 conflicting = self._db.query(sql).scalar()
251 if conflicting > 0:
252 raise ConflictingDefinitionError(
253 f"{conflicting} validity range conflicts certifying datasets of type "
254 f"{self.datasetType.name} into {collection.name} for range "
255 f"[{timespan.begin}, {timespan.end})."
256 )
257 # Proceed with the insert.
258 self._db.insert(self._calibs, *rows)
260 def decertify(
261 self,
262 collection: CollectionRecord,
263 timespan: Timespan,
264 *,
265 dataIds: Optional[Iterable[DataCoordinate]] = None,
266 ) -> None:
267 # Docstring inherited from DatasetRecordStorage.
268 if self._calibs is None: 268 ↛ 269line 268 didn't jump to line 269, because the condition on line 268 was never true
269 raise CollectionTypeError(
270 f"Cannot decertify datasets of type {self.datasetType.name}, for which "
271 f"DatasetType.isCalibration() is False."
272 )
273 if collection.type is not CollectionType.CALIBRATION: 273 ↛ 274line 273 didn't jump to line 274, because the condition on line 273 was never true
274 raise CollectionTypeError(
275 f"Cannot decertify from collection '{collection.name}' "
276 f"of type {collection.type.name}; must be CALIBRATION."
277 )
278 TimespanReprClass = self._db.getTimespanRepresentation()
279 # Construct a SELECT query to find all rows that overlap our inputs.
280 dataIdSet: Optional[DataCoordinateSet]
281 if dataIds is not None:
282 dataIdSet = DataCoordinateSet(set(dataIds), graph=self.datasetType.dimensions)
283 else:
284 dataIdSet = None
285 query = self._buildCalibOverlapQuery(collection, dataIdSet, timespan)
286 query.columns.extend(self._calibs.columns)
287 sql = query.combine()
288 # Set up collections to populate with the rows we'll want to modify.
289 # The insert rows will have the same values for collection and
290 # dataset type.
291 protoInsertRow = {
292 self._collections.getCollectionForeignKeyName(): collection.key,
293 "dataset_type_id": self._dataset_type_id,
294 }
295 rowsToDelete = []
296 rowsToInsert = []
297 # Acquire a table lock to ensure there are no concurrent writes
298 # between the SELECT and the DELETE and INSERT queries based on it.
299 with self._db.transaction(lock=[self._calibs], savepoint=True):
300 for row in self._db.query(sql).mappings():
301 rowsToDelete.append({"id": row["id"]})
302 # Construct the insert row(s) by copying the prototype row,
303 # then adding the dimension column values, then adding what's
304 # left of the timespan from that row after we subtract the
305 # given timespan.
306 newInsertRow = protoInsertRow.copy()
307 newInsertRow["dataset_id"] = row["dataset_id"]
308 for name in self.datasetType.dimensions.required.names:
309 newInsertRow[name] = row[name]
310 rowTimespan = TimespanReprClass.extract(row)
311 assert rowTimespan is not None, "Field should have a NOT NULL constraint."
312 for diffTimespan in rowTimespan.difference(timespan):
313 rowsToInsert.append(TimespanReprClass.update(diffTimespan, result=newInsertRow.copy()))
314 # Run the DELETE and INSERT queries.
315 self._db.delete(self._calibs, ["id"], *rowsToDelete)
316 self._db.insert(self._calibs, *rowsToInsert)
318 def select(
319 self,
320 *collections: CollectionRecord,
321 dataId: SimpleQuery.Select.Or[DataCoordinate] = SimpleQuery.Select,
322 id: SimpleQuery.Select.Or[Optional[int]] = SimpleQuery.Select,
323 run: SimpleQuery.Select.Or[None] = SimpleQuery.Select,
324 timespan: SimpleQuery.Select.Or[Optional[Timespan]] = SimpleQuery.Select,
325 ingestDate: SimpleQuery.Select.Or[Optional[Timespan]] = None,
326 ) -> sqlalchemy.sql.Selectable:
327 # Docstring inherited from DatasetRecordStorage.
328 collection_types = {collection.type for collection in collections}
329 assert CollectionType.CHAINED not in collection_types, "CHAINED collections must be flattened."
330 #
331 # There are two kinds of table in play here:
332 #
333 # - the static dataset table (with the dataset ID, dataset type ID,
334 # run ID/name, and ingest date);
335 #
336 # - the dynamic tags/calibs table (with the dataset ID, dataset type
337 # type ID, collection ID/name, data ID, and possibly validity
338 # range).
339 #
340 # That means that we might want to return a query against either table
341 # or a JOIN of both, depending on which quantities the caller wants.
342 # But this method is documented/typed such that ``dataId`` is never
343 # `None` - i.e. we always constrain or retreive the data ID. That
344 # means we'll always include the tags/calibs table and join in the
345 # static dataset table only if we need things from it that we can't get
346 # from the tags/calibs table.
347 #
348 # Note that it's important that we include a WHERE constraint on both
349 # tables for any column (e.g. dataset_type_id) that is in both when
350 # it's given explicitly; not doing can prevent the query planner from
351 # using very important indexes. At present, we don't include those
352 # redundant columns in the JOIN ON expression, however, because the
353 # FOREIGN KEY (and its index) are defined only on dataset_id.
354 #
355 # We'll start by accumulating kwargs to pass to SimpleQuery.join when
356 # we bring in the tags/calibs table. We get the data ID or constrain
357 # it in the tags/calibs table(s), but that's multiple columns, not one,
358 # so we need to transform the one Select.Or argument into a dictionary
359 # of them.
360 kwargs: Dict[str, Any]
361 if dataId is SimpleQuery.Select:
362 kwargs = {dim.name: SimpleQuery.Select for dim in self.datasetType.dimensions.required}
363 else:
364 kwargs = dict(dataId.byName())
365 # We always constrain (never retrieve) the dataset type in at least the
366 # tags/calibs table.
367 kwargs["dataset_type_id"] = self._dataset_type_id
368 # Join in the tags and/or calibs tables, turning those 'kwargs' entries
369 # into WHERE constraints or SELECT columns as appropriate.
370 if collection_types != {CollectionType.CALIBRATION}:
371 # We'll need a subquery for the tags table if any of the given
372 # collections are not a CALIBRATION collection. This intentionally
373 # also fires when the list of collections is empty as a way to
374 # create a dummy subquery that we know will fail.
375 tags_query = SimpleQuery()
376 tags_query.join(self._tags, **kwargs)
377 self._finish_single_select(
378 tags_query, self._tags, collections, id=id, run=run, ingestDate=ingestDate
379 )
380 else:
381 tags_query = None
382 if CollectionType.CALIBRATION in collection_types:
383 # If at least one collection is a CALIBRATION collection, we'll
384 # need a subquery for the calibs table, and could include the
385 # timespan as a result or constraint.
386 calibs_query = SimpleQuery()
387 assert (
388 self._calibs is not None
389 ), "DatasetTypes with isCalibration() == False can never be found in a CALIBRATION collection."
390 TimespanReprClass = self._db.getTimespanRepresentation()
391 # Add the timespan column(s) to the result columns, or constrain
392 # the timespan via an overlap condition.
393 if timespan is SimpleQuery.Select:
394 kwargs.update({k: SimpleQuery.Select for k in TimespanReprClass.getFieldNames()})
395 elif timespan is not None:
396 calibs_query.where.append(
397 TimespanReprClass.fromSelectable(self._calibs).overlaps(
398 TimespanReprClass.fromLiteral(timespan)
399 )
400 )
401 calibs_query.join(self._calibs, **kwargs)
402 self._finish_single_select(
403 calibs_query, self._calibs, collections, id=id, run=run, ingestDate=ingestDate
404 )
405 else:
406 calibs_query = None
407 if calibs_query is not None:
408 if tags_query is not None:
409 if timespan is not None: 409 ↛ 410line 409 didn't jump to line 410, because the condition on line 409 was never true
410 raise TypeError(
411 "Cannot query for timespan when the collections include both calibration and "
412 "non-calibration collections."
413 )
414 return tags_query.combine().union(calibs_query.combine())
415 else:
416 return calibs_query.combine()
417 else:
418 assert tags_query is not None, "Earlier logic should guaranteed at least one is not None."
419 return tags_query.combine()
421 def _finish_single_select(
422 self,
423 query: SimpleQuery,
424 table: sqlalchemy.schema.Table,
425 collections: Sequence[CollectionRecord],
426 id: SimpleQuery.Select.Or[Optional[int]],
427 run: SimpleQuery.Select.Or[None],
428 ingestDate: SimpleQuery.Select.Or[Optional[Timespan]],
429 ) -> None:
430 dataset_id_col = table.columns.dataset_id
431 collection_col = table.columns[self._collections.getCollectionForeignKeyName()]
432 # We always constrain (never retrieve) the collection(s) in the
433 # tags/calibs table.
434 if len(collections) == 1:
435 query.where.append(collection_col == collections[0].key)
436 elif len(collections) == 0:
437 # We support the case where there are no collections as a way to
438 # generate a valid SQL query that can't yield results. This should
439 # never get executed, but lots of downstream code will still try
440 # to access the SQLAlchemy objects representing the columns in the
441 # subquery. That's not ideal, but it'd take a lot of refactoring
442 # to fix it (DM-31725).
443 query.where.append(sqlalchemy.sql.literal(False))
444 else:
445 query.where.append(collection_col.in_([collection.key for collection in collections]))
446 # We can always get the dataset_id from the tags/calibs table or
447 # constrain it there. Can't use kwargs for that because we need to
448 # alias it to 'id'.
449 if id is SimpleQuery.Select:
450 query.columns.append(dataset_id_col.label("id"))
451 elif id is not None: 451 ↛ 452line 451 didn't jump to line 452, because the condition on line 451 was never true
452 query.where.append(dataset_id_col == id)
453 # It's possible we now have everything we need, from just the
454 # tags/calibs table. The things we might need to get from the static
455 # dataset table are the run key and the ingest date.
456 need_static_table = False
457 static_kwargs: Dict[str, Any] = {}
458 if run is not None:
459 assert run is SimpleQuery.Select, "To constrain the run name, pass a RunRecord as a collection."
460 if len(collections) == 1 and collections[0].type is CollectionType.RUN:
461 # If we are searching exactly one RUN collection, we
462 # know that if we find the dataset in that collection,
463 # then that's the datasets's run; we don't need to
464 # query for it.
465 query.columns.append(sqlalchemy.sql.literal(collections[0].key).label(self._runKeyColumn))
466 else:
467 static_kwargs[self._runKeyColumn] = SimpleQuery.Select
468 need_static_table = True
469 # Ingest date can only come from the static table.
470 if ingestDate is not None:
471 need_static_table = True
472 if ingestDate is SimpleQuery.Select: 472 ↛ 475line 472 didn't jump to line 475, because the condition on line 472 was never false
473 static_kwargs["ingest_date"] = SimpleQuery.Select
474 else:
475 assert isinstance(ingestDate, Timespan)
476 # Timespan is astropy Time (usually in TAI) and ingest_date is
477 # TIMESTAMP, convert values to Python datetime for sqlalchemy.
478 if ingestDate.isEmpty():
479 raise RuntimeError("Empty timespan constraint provided for ingest_date.")
480 if ingestDate.begin is not None:
481 begin = ingestDate.begin.utc.datetime # type: ignore
482 query.where.append(self._static.dataset.columns.ingest_date >= begin)
483 if ingestDate.end is not None:
484 end = ingestDate.end.utc.datetime # type: ignore
485 query.where.append(self._static.dataset.columns.ingest_date < end)
486 # If we need the static table, join it in via dataset_id and
487 # dataset_type_id
488 if need_static_table:
489 query.join(
490 self._static.dataset,
491 onclause=(dataset_id_col == self._static.dataset.columns.id),
492 **static_kwargs,
493 )
494 # Also constrain dataset_type_id in static table in case that helps
495 # generate a better plan.
496 # We could also include this in the JOIN ON clause, but my guess is
497 # that that's a good idea IFF it's in the foreign key, and right
498 # now it isn't.
499 query.where.append(self._static.dataset.columns.dataset_type_id == self._dataset_type_id)
501 def getDataId(self, id: DatasetId) -> DataCoordinate:
502 """Return DataId for a dataset.
504 Parameters
505 ----------
506 id : `DatasetId`
507 Unique dataset identifier.
509 Returns
510 -------
511 dataId : `DataCoordinate`
512 DataId for the dataset.
513 """
514 # This query could return multiple rows (one for each tagged collection
515 # the dataset is in, plus one for its run collection), and we don't
516 # care which of those we get.
517 sql = (
518 self._tags.select()
519 .where(
520 sqlalchemy.sql.and_(
521 self._tags.columns.dataset_id == id,
522 self._tags.columns.dataset_type_id == self._dataset_type_id,
523 )
524 )
525 .limit(1)
526 )
527 row = self._db.query(sql).mappings().fetchone()
528 assert row is not None, "Should be guaranteed by caller and foreign key constraints."
529 return DataCoordinate.standardize(
530 {dimension.name: row[dimension.name] for dimension in self.datasetType.dimensions.required},
531 graph=self.datasetType.dimensions,
532 )
535class ByDimensionsDatasetRecordStorageInt(ByDimensionsDatasetRecordStorage):
536 """Implementation of ByDimensionsDatasetRecordStorage which uses integer
537 auto-incremented column for dataset IDs.
538 """
540 def insert(
541 self,
542 run: RunRecord,
543 dataIds: Iterable[DataCoordinate],
544 idMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE,
545 ) -> Iterator[DatasetRef]:
546 # Docstring inherited from DatasetRecordStorage.
548 # We only support UNIQUE mode for integer dataset IDs
549 if idMode != DatasetIdGenEnum.UNIQUE: 549 ↛ 550line 549 didn't jump to line 550, because the condition on line 549 was never true
550 raise UnsupportedIdGeneratorError("Only UNIQUE mode can be used with integer dataset IDs.")
552 # Transform a possibly-single-pass iterable into a list.
553 dataIdList = list(dataIds)
554 yield from self._insert(run, dataIdList)
556 def import_(
557 self,
558 run: RunRecord,
559 datasets: Iterable[DatasetRef],
560 idGenerationMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE,
561 reuseIds: bool = False,
562 ) -> Iterator[DatasetRef]:
563 # Docstring inherited from DatasetRecordStorage.
565 # We only support UNIQUE mode for integer dataset IDs
566 if idGenerationMode != DatasetIdGenEnum.UNIQUE: 566 ↛ 567line 566 didn't jump to line 567, because the condition on line 566 was never true
567 raise UnsupportedIdGeneratorError("Only UNIQUE mode can be used with integer dataset IDs.")
569 # Make a list of dataIds and optionally dataset IDs.
570 dataIdList: List[DataCoordinate] = []
571 datasetIdList: List[int] = []
572 for dataset in datasets:
573 dataIdList.append(dataset.dataId)
575 # We only accept integer dataset IDs, but also allow None.
576 datasetId = dataset.id
577 if datasetId is None: 577 ↛ 579line 577 didn't jump to line 579, because the condition on line 577 was never true
578 # if reuseIds is set then all IDs must be known
579 if reuseIds:
580 raise TypeError("All dataset IDs must be known if `reuseIds` is set")
581 elif isinstance(datasetId, int): 581 ↛ 585line 581 didn't jump to line 585, because the condition on line 581 was never false
582 if reuseIds:
583 datasetIdList.append(datasetId)
584 else:
585 raise TypeError(f"Unsupported type of dataset ID: {type(datasetId)}")
587 yield from self._insert(run, dataIdList, datasetIdList)
589 def _insert(
590 self, run: RunRecord, dataIdList: List[DataCoordinate], datasetIdList: Optional[List[int]] = None
591 ) -> Iterator[DatasetRef]:
592 """Common part of implementation of `insert` and `import_` methods."""
594 # Remember any governor dimension values we see.
595 governorValues = GovernorDimensionRestriction.makeEmpty(self.datasetType.dimensions.universe)
596 for dataId in dataIdList:
597 governorValues.update_extract(dataId)
599 staticRow = {
600 "dataset_type_id": self._dataset_type_id,
601 self._runKeyColumn: run.key,
602 }
603 with self._db.transaction():
604 # Insert into the static dataset table, generating autoincrement
605 # dataset_id values.
606 if datasetIdList:
607 # reuse existing IDs
608 rows = [dict(staticRow, id=datasetId) for datasetId in datasetIdList]
609 self._db.insert(self._static.dataset, *rows)
610 else:
611 # use auto-incremented IDs
612 datasetIdList = self._db.insert(
613 self._static.dataset, *([staticRow] * len(dataIdList)), returnIds=True
614 )
615 assert datasetIdList is not None
616 # Update the summary tables for this collection in case this is the
617 # first time this dataset type or these governor values will be
618 # inserted there.
619 self._summaries.update(run, self.datasetType, self._dataset_type_id, governorValues)
620 # Combine the generated dataset_id values and data ID fields to
621 # form rows to be inserted into the tags table.
622 protoTagsRow = {
623 "dataset_type_id": self._dataset_type_id,
624 self._collections.getCollectionForeignKeyName(): run.key,
625 }
626 tagsRows = [
627 dict(protoTagsRow, dataset_id=dataset_id, **dataId.byName())
628 for dataId, dataset_id in zip(dataIdList, datasetIdList)
629 ]
630 # Insert those rows into the tags table. This is where we'll
631 # get any unique constraint violations.
632 self._db.insert(self._tags, *tagsRows)
634 for dataId, datasetId in zip(dataIdList, datasetIdList):
635 yield DatasetRef(
636 datasetType=self.datasetType,
637 dataId=dataId,
638 id=datasetId,
639 run=run.name,
640 )
643class ByDimensionsDatasetRecordStorageUUID(ByDimensionsDatasetRecordStorage):
644 """Implementation of ByDimensionsDatasetRecordStorage which uses UUID for
645 dataset IDs.
646 """
648 idMaker = DatasetIdFactory()
649 """Factory for dataset IDs. In the future this factory may be shared with
650 other classes (e.g. Registry)."""
652 def insert(
653 self,
654 run: RunRecord,
655 dataIds: Iterable[DataCoordinate],
656 idMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE,
657 ) -> Iterator[DatasetRef]:
658 # Docstring inherited from DatasetRecordStorage.
660 # Remember any governor dimension values we see.
661 governorValues = GovernorDimensionRestriction.makeEmpty(self.datasetType.dimensions.universe)
663 # Iterate over data IDs, transforming a possibly-single-pass iterable
664 # into a list.
665 dataIdList = []
666 rows = []
667 for dataId in dataIds:
668 dataIdList.append(dataId)
669 rows.append(
670 {
671 "id": self.idMaker.makeDatasetId(run.name, self.datasetType, dataId, idMode),
672 "dataset_type_id": self._dataset_type_id,
673 self._runKeyColumn: run.key,
674 }
675 )
676 governorValues.update_extract(dataId)
678 with self._db.transaction():
679 # Insert into the static dataset table.
680 self._db.insert(self._static.dataset, *rows)
681 # Update the summary tables for this collection in case this is the
682 # first time this dataset type or these governor values will be
683 # inserted there.
684 self._summaries.update(run, self.datasetType, self._dataset_type_id, governorValues)
685 # Combine the generated dataset_id values and data ID fields to
686 # form rows to be inserted into the tags table.
687 protoTagsRow = {
688 "dataset_type_id": self._dataset_type_id,
689 self._collections.getCollectionForeignKeyName(): run.key,
690 }
691 tagsRows = [
692 dict(protoTagsRow, dataset_id=row["id"], **dataId.byName())
693 for dataId, row in zip(dataIdList, rows)
694 ]
695 # Insert those rows into the tags table.
696 self._db.insert(self._tags, *tagsRows)
698 for dataId, row in zip(dataIdList, rows):
699 yield DatasetRef(
700 datasetType=self.datasetType,
701 dataId=dataId,
702 id=row["id"],
703 run=run.name,
704 )
706 def import_(
707 self,
708 run: RunRecord,
709 datasets: Iterable[DatasetRef],
710 idGenerationMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE,
711 reuseIds: bool = False,
712 ) -> Iterator[DatasetRef]:
713 # Docstring inherited from DatasetRecordStorage.
715 # Remember any governor dimension values we see.
716 governorValues = GovernorDimensionRestriction.makeEmpty(self.datasetType.dimensions.universe)
718 # Iterate over data IDs, transforming a possibly-single-pass iterable
719 # into a list.
720 dataIds = {}
721 for dataset in datasets:
722 # Ignore unknown ID types, normally all IDs have the same type but
723 # this code supports mixed types or missing IDs.
724 datasetId = dataset.id if isinstance(dataset.id, uuid.UUID) else None
725 if datasetId is None:
726 datasetId = self.idMaker.makeDatasetId(
727 run.name, self.datasetType, dataset.dataId, idGenerationMode
728 )
729 dataIds[datasetId] = dataset.dataId
730 governorValues.update_extract(dataset.dataId)
732 with self._db.session() as session:
734 # insert all new rows into a temporary table
735 tableSpec = makeTagTableSpec(
736 self.datasetType, type(self._collections), ddl.GUID, constraints=False
737 )
738 tmp_tags = session.makeTemporaryTable(tableSpec)
740 collFkName = self._collections.getCollectionForeignKeyName()
741 protoTagsRow = {
742 "dataset_type_id": self._dataset_type_id,
743 collFkName: run.key,
744 }
745 tmpRows = [
746 dict(protoTagsRow, dataset_id=dataset_id, **dataId.byName())
747 for dataset_id, dataId in dataIds.items()
748 ]
750 with self._db.transaction():
752 # store all incoming data in a temporary table
753 self._db.insert(tmp_tags, *tmpRows)
755 # There are some checks that we want to make for consistency
756 # of the new datasets with existing ones.
757 self._validateImport(tmp_tags, run)
759 # Before we merge temporary table into dataset/tags we need to
760 # drop datasets which are already there (and do not conflict).
761 self._db.deleteWhere(
762 tmp_tags,
763 tmp_tags.columns.dataset_id.in_(sqlalchemy.sql.select(self._static.dataset.columns.id)),
764 )
766 # Copy it into dataset table, need to re-label some columns.
767 self._db.insert(
768 self._static.dataset,
769 select=sqlalchemy.sql.select(
770 tmp_tags.columns.dataset_id.label("id"),
771 tmp_tags.columns.dataset_type_id,
772 tmp_tags.columns[collFkName].label(self._runKeyColumn),
773 ),
774 )
776 # Update the summary tables for this collection in case this
777 # is the first time this dataset type or these governor values
778 # will be inserted there.
779 self._summaries.update(run, self.datasetType, self._dataset_type_id, governorValues)
781 # Copy it into tags table.
782 self._db.insert(self._tags, select=tmp_tags.select())
784 # Return refs in the same order as in the input list.
785 for dataset_id, dataId in dataIds.items():
786 yield DatasetRef(
787 datasetType=self.datasetType,
788 id=dataset_id,
789 dataId=dataId,
790 run=run.name,
791 )
793 def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) -> None:
794 """Validate imported refs against existing datasets.
796 Parameters
797 ----------
798 tmp_tags : `sqlalchemy.schema.Table`
799 Temporary table with new datasets and the same schema as tags
800 table.
801 run : `RunRecord`
802 The record object describing the `~CollectionType.RUN` collection.
804 Raises
805 ------
806 ConflictingDefinitionError
807 Raise if new datasets conflict with existing ones.
808 """
809 dataset = self._static.dataset
810 tags = self._tags
811 collFkName = self._collections.getCollectionForeignKeyName()
813 # Check that existing datasets have the same dataset type and
814 # run.
815 query = (
816 sqlalchemy.sql.select(
817 dataset.columns.id.label("dataset_id"),
818 dataset.columns.dataset_type_id.label("dataset_type_id"),
819 tmp_tags.columns.dataset_type_id.label("new dataset_type_id"),
820 dataset.columns[self._runKeyColumn].label("run"),
821 tmp_tags.columns[collFkName].label("new run"),
822 )
823 .select_from(dataset.join(tmp_tags, dataset.columns.id == tmp_tags.columns.dataset_id))
824 .where(
825 sqlalchemy.sql.or_(
826 dataset.columns.dataset_type_id != tmp_tags.columns.dataset_type_id,
827 dataset.columns[self._runKeyColumn] != tmp_tags.columns[collFkName],
828 )
829 )
830 )
831 result = self._db.query(query)
832 if (row := result.first()) is not None:
833 # Only include the first one in the exception message
834 raise ConflictingDefinitionError(
835 f"Existing dataset type or run do not match new dataset: {row._asdict()}"
836 )
838 # Check that matching dataset in tags table has the same DataId.
839 query = (
840 sqlalchemy.sql.select(
841 tags.columns.dataset_id,
842 tags.columns.dataset_type_id.label("type_id"),
843 tmp_tags.columns.dataset_type_id.label("new type_id"),
844 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names],
845 *[
846 tmp_tags.columns[dim].label(f"new {dim}")
847 for dim in self.datasetType.dimensions.required.names
848 ],
849 )
850 .select_from(tags.join(tmp_tags, tags.columns.dataset_id == tmp_tags.columns.dataset_id))
851 .where(
852 sqlalchemy.sql.or_(
853 tags.columns.dataset_type_id != tmp_tags.columns.dataset_type_id,
854 *[
855 tags.columns[dim] != tmp_tags.columns[dim]
856 for dim in self.datasetType.dimensions.required.names
857 ],
858 )
859 )
860 )
861 result = self._db.query(query)
862 if (row := result.first()) is not None:
863 # Only include the first one in the exception message
864 raise ConflictingDefinitionError(
865 f"Existing dataset type or dataId do not match new dataset: {row._asdict()}"
866 )
868 # Check that matching run+dataId have the same dataset ID.
869 query = (
870 sqlalchemy.sql.select(
871 tags.columns.dataset_type_id.label("dataset_type_id"),
872 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names],
873 tags.columns.dataset_id,
874 tmp_tags.columns.dataset_id.label("new dataset_id"),
875 tags.columns[collFkName],
876 tmp_tags.columns[collFkName].label(f"new {collFkName}"),
877 )
878 .select_from(
879 tags.join(
880 tmp_tags,
881 sqlalchemy.sql.and_(
882 tags.columns.dataset_type_id == tmp_tags.columns.dataset_type_id,
883 tags.columns[collFkName] == tmp_tags.columns[collFkName],
884 *[
885 tags.columns[dim] == tmp_tags.columns[dim]
886 for dim in self.datasetType.dimensions.required.names
887 ],
888 ),
889 )
890 )
891 .where(tags.columns.dataset_id != tmp_tags.columns.dataset_id)
892 )
893 result = self._db.query(query)
894 if (row := result.first()) is not None:
895 # only include the first one in the exception message
896 raise ConflictingDefinitionError(
897 f"Existing dataset type and dataId does not match new dataset: {row._asdict()}"
898 )