Coverage for python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py: 95%
232 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-04-01 02:05 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2023-04-01 02:05 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
23from __future__ import annotations
25__all__ = ("ByDimensionsDatasetRecordStorage",)
27import uuid
28from collections.abc import Iterable, Iterator, Sequence, Set
29from typing import TYPE_CHECKING
31import sqlalchemy
32from lsst.daf.relation import Relation, sql
34from ....core import (
35 DataCoordinate,
36 DatasetColumnTag,
37 DatasetId,
38 DatasetRef,
39 DatasetType,
40 DimensionKeyColumnTag,
41 LogicalColumn,
42 Timespan,
43 ddl,
44)
45from ..._collection_summary import CollectionSummary
46from ..._collectionType import CollectionType
47from ..._exceptions import CollectionTypeError, ConflictingDefinitionError
48from ...interfaces import DatasetIdFactory, DatasetIdGenEnum, DatasetRecordStorage
49from ...queries import SqlQueryContext
50from .tables import makeTagTableSpec
52if TYPE_CHECKING:
53 from ...interfaces import CollectionManager, CollectionRecord, Database, RunRecord
54 from .summaries import CollectionSummaryManager
55 from .tables import StaticDatasetTablesTuple
58class ByDimensionsDatasetRecordStorage(DatasetRecordStorage):
59 """Dataset record storage implementation paired with
60 `ByDimensionsDatasetRecordStorageManagerUUID`; see that class for more
61 information.
63 Instances of this class should never be constructed directly; use
64 `DatasetRecordStorageManager.register` instead.
65 """
67 def __init__(
68 self,
69 *,
70 datasetType: DatasetType,
71 db: Database,
72 dataset_type_id: int,
73 collections: CollectionManager,
74 static: StaticDatasetTablesTuple,
75 summaries: CollectionSummaryManager,
76 tags: sqlalchemy.schema.Table,
77 calibs: sqlalchemy.schema.Table | None,
78 ):
79 super().__init__(datasetType=datasetType)
80 self._dataset_type_id = dataset_type_id
81 self._db = db
82 self._collections = collections
83 self._static = static
84 self._summaries = summaries
85 self._tags = tags
86 self._calibs = calibs
87 self._runKeyColumn = collections.getRunForeignKeyName()
89 def delete(self, datasets: Iterable[DatasetRef]) -> None:
90 # Docstring inherited from DatasetRecordStorage.
91 # Only delete from common dataset table; ON DELETE foreign key clauses
92 # will handle the rest.
93 self._db.delete(
94 self._static.dataset,
95 ["id"],
96 *[{"id": dataset.getCheckedId()} for dataset in datasets],
97 )
99 def associate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None:
100 # Docstring inherited from DatasetRecordStorage.
101 if collection.type is not CollectionType.TAGGED: 101 ↛ 102line 101 didn't jump to line 102, because the condition on line 101 was never true
102 raise TypeError(
103 f"Cannot associate into collection '{collection.name}' "
104 f"of type {collection.type.name}; must be TAGGED."
105 )
106 protoRow = {
107 self._collections.getCollectionForeignKeyName(): collection.key,
108 "dataset_type_id": self._dataset_type_id,
109 }
110 rows = []
111 summary = CollectionSummary()
112 for dataset in summary.add_datasets_generator(datasets):
113 row = dict(protoRow, dataset_id=dataset.getCheckedId())
114 for dimension, value in dataset.dataId.items():
115 row[dimension.name] = value
116 rows.append(row)
117 # Update the summary tables for this collection in case this is the
118 # first time this dataset type or these governor values will be
119 # inserted there.
120 self._summaries.update(collection, [self._dataset_type_id], summary)
121 # Update the tag table itself.
122 self._db.replace(self._tags, *rows)
124 def disassociate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None:
125 # Docstring inherited from DatasetRecordStorage.
126 if collection.type is not CollectionType.TAGGED: 126 ↛ 127line 126 didn't jump to line 127, because the condition on line 126 was never true
127 raise TypeError(
128 f"Cannot disassociate from collection '{collection.name}' "
129 f"of type {collection.type.name}; must be TAGGED."
130 )
131 rows = [
132 {
133 "dataset_id": dataset.getCheckedId(),
134 self._collections.getCollectionForeignKeyName(): collection.key,
135 }
136 for dataset in datasets
137 ]
138 self._db.delete(self._tags, ["dataset_id", self._collections.getCollectionForeignKeyName()], *rows)
140 def _buildCalibOverlapQuery(
141 self,
142 collection: CollectionRecord,
143 data_ids: set[DataCoordinate] | None,
144 timespan: Timespan,
145 context: SqlQueryContext,
146 ) -> Relation:
147 relation = self.make_relation(
148 collection, columns={"timespan", "dataset_id", "calib_pkey"}, context=context
149 ).with_rows_satisfying(
150 context.make_timespan_overlap_predicate(
151 DatasetColumnTag(self.datasetType.name, "timespan"), timespan
152 ),
153 )
154 if data_ids is not None:
155 relation = relation.join(
156 context.make_data_id_relation(
157 data_ids, self.datasetType.dimensions.required.names
158 ).transferred_to(context.sql_engine),
159 )
160 return relation
162 def certify(
163 self,
164 collection: CollectionRecord,
165 datasets: Iterable[DatasetRef],
166 timespan: Timespan,
167 context: SqlQueryContext,
168 ) -> None:
169 # Docstring inherited from DatasetRecordStorage.
170 if self._calibs is None: 170 ↛ 171line 170 didn't jump to line 171, because the condition on line 170 was never true
171 raise CollectionTypeError(
172 f"Cannot certify datasets of type {self.datasetType.name}, for which "
173 "DatasetType.isCalibration() is False."
174 )
175 if collection.type is not CollectionType.CALIBRATION: 175 ↛ 176line 175 didn't jump to line 176, because the condition on line 175 was never true
176 raise CollectionTypeError(
177 f"Cannot certify into collection '{collection.name}' "
178 f"of type {collection.type.name}; must be CALIBRATION."
179 )
180 TimespanReprClass = self._db.getTimespanRepresentation()
181 protoRow = {
182 self._collections.getCollectionForeignKeyName(): collection.key,
183 "dataset_type_id": self._dataset_type_id,
184 }
185 rows = []
186 dataIds: set[DataCoordinate] | None = (
187 set() if not TimespanReprClass.hasExclusionConstraint() else None
188 )
189 summary = CollectionSummary()
190 for dataset in summary.add_datasets_generator(datasets):
191 row = dict(protoRow, dataset_id=dataset.getCheckedId())
192 for dimension, value in dataset.dataId.items():
193 row[dimension.name] = value
194 TimespanReprClass.update(timespan, result=row)
195 rows.append(row)
196 if dataIds is not None: 196 ↛ 190line 196 didn't jump to line 190, because the condition on line 196 was never false
197 dataIds.add(dataset.dataId)
198 # Update the summary tables for this collection in case this is the
199 # first time this dataset type or these governor values will be
200 # inserted there.
201 self._summaries.update(collection, [self._dataset_type_id], summary)
202 # Update the association table itself.
203 if TimespanReprClass.hasExclusionConstraint(): 203 ↛ 206line 203 didn't jump to line 206, because the condition on line 203 was never true
204 # Rely on database constraint to enforce invariants; we just
205 # reraise the exception for consistency across DB engines.
206 try:
207 self._db.insert(self._calibs, *rows)
208 except sqlalchemy.exc.IntegrityError as err:
209 raise ConflictingDefinitionError(
210 f"Validity range conflict certifying datasets of type {self.datasetType.name} "
211 f"into {collection.name} for range [{timespan.begin}, {timespan.end})."
212 ) from err
213 else:
214 # Have to implement exclusion constraint ourselves.
215 # Start by building a SELECT query for any rows that would overlap
216 # this one.
217 relation = self._buildCalibOverlapQuery(collection, dataIds, timespan, context)
218 # Acquire a table lock to ensure there are no concurrent writes
219 # could invalidate our checking before we finish the inserts. We
220 # use a SAVEPOINT in case there is an outer transaction that a
221 # failure here should not roll back.
222 with self._db.transaction(lock=[self._calibs], savepoint=True):
223 # Enter SqlQueryContext in case we need to use a temporary
224 # table to include the give data IDs in the query. Note that
225 # by doing this inside the transaction, we make sure it doesn't
226 # attempt to close the session when its done, since it just
227 # sees an already-open session that it knows it shouldn't
228 # manage.
229 with context:
230 # Run the check SELECT query.
231 conflicting = context.count(context.process(relation))
232 if conflicting > 0:
233 raise ConflictingDefinitionError(
234 f"{conflicting} validity range conflicts certifying datasets of type "
235 f"{self.datasetType.name} into {collection.name} for range "
236 f"[{timespan.begin}, {timespan.end})."
237 )
238 # Proceed with the insert.
239 self._db.insert(self._calibs, *rows)
241 def decertify(
242 self,
243 collection: CollectionRecord,
244 timespan: Timespan,
245 *,
246 dataIds: Iterable[DataCoordinate] | None = None,
247 context: SqlQueryContext,
248 ) -> None:
249 # Docstring inherited from DatasetRecordStorage.
250 if self._calibs is None: 250 ↛ 251line 250 didn't jump to line 251, because the condition on line 250 was never true
251 raise CollectionTypeError(
252 f"Cannot decertify datasets of type {self.datasetType.name}, for which "
253 "DatasetType.isCalibration() is False."
254 )
255 if collection.type is not CollectionType.CALIBRATION: 255 ↛ 256line 255 didn't jump to line 256, because the condition on line 255 was never true
256 raise CollectionTypeError(
257 f"Cannot decertify from collection '{collection.name}' "
258 f"of type {collection.type.name}; must be CALIBRATION."
259 )
260 TimespanReprClass = self._db.getTimespanRepresentation()
261 # Construct a SELECT query to find all rows that overlap our inputs.
262 dataIdSet: set[DataCoordinate] | None
263 if dataIds is not None:
264 dataIdSet = set(dataIds)
265 else:
266 dataIdSet = None
267 relation = self._buildCalibOverlapQuery(collection, dataIdSet, timespan, context)
268 calib_pkey_tag = DatasetColumnTag(self.datasetType.name, "calib_pkey")
269 dataset_id_tag = DatasetColumnTag(self.datasetType.name, "dataset_id")
270 timespan_tag = DatasetColumnTag(self.datasetType.name, "timespan")
271 data_id_tags = [
272 (name, DimensionKeyColumnTag(name)) for name in self.datasetType.dimensions.required.names
273 ]
274 # Set up collections to populate with the rows we'll want to modify.
275 # The insert rows will have the same values for collection and
276 # dataset type.
277 protoInsertRow = {
278 self._collections.getCollectionForeignKeyName(): collection.key,
279 "dataset_type_id": self._dataset_type_id,
280 }
281 rowsToDelete = []
282 rowsToInsert = []
283 # Acquire a table lock to ensure there are no concurrent writes
284 # between the SELECT and the DELETE and INSERT queries based on it.
285 with self._db.transaction(lock=[self._calibs], savepoint=True):
286 # Enter SqlQueryContext in case we need to use a temporary table to
287 # include the give data IDs in the query (see similar block in
288 # certify for details).
289 with context:
290 for row in context.fetch_iterable(relation):
291 rowsToDelete.append({"id": row[calib_pkey_tag]})
292 # Construct the insert row(s) by copying the prototype row,
293 # then adding the dimension column values, then adding
294 # what's left of the timespan from that row after we
295 # subtract the given timespan.
296 newInsertRow = protoInsertRow.copy()
297 newInsertRow["dataset_id"] = row[dataset_id_tag]
298 for name, tag in data_id_tags:
299 newInsertRow[name] = row[tag]
300 rowTimespan = row[timespan_tag]
301 assert rowTimespan is not None, "Field should have a NOT NULL constraint."
302 for diffTimespan in rowTimespan.difference(timespan):
303 rowsToInsert.append(
304 TimespanReprClass.update(diffTimespan, result=newInsertRow.copy())
305 )
306 # Run the DELETE and INSERT queries.
307 self._db.delete(self._calibs, ["id"], *rowsToDelete)
308 self._db.insert(self._calibs, *rowsToInsert)
310 def make_relation(
311 self,
312 *collections: CollectionRecord,
313 columns: Set[str],
314 context: SqlQueryContext,
315 ) -> Relation:
316 # Docstring inherited from DatasetRecordStorage.
317 collection_types = {collection.type for collection in collections}
318 assert CollectionType.CHAINED not in collection_types, "CHAINED collections must be flattened."
319 TimespanReprClass = self._db.getTimespanRepresentation()
320 #
321 # There are two kinds of table in play here:
322 #
323 # - the static dataset table (with the dataset ID, dataset type ID,
324 # run ID/name, and ingest date);
325 #
326 # - the dynamic tags/calibs table (with the dataset ID, dataset type
327 # type ID, collection ID/name, data ID, and possibly validity
328 # range).
329 #
330 # That means that we might want to return a query against either table
331 # or a JOIN of both, depending on which quantities the caller wants.
332 # But the data ID is always included, which means we'll always include
333 # the tags/calibs table and join in the static dataset table only if we
334 # need things from it that we can't get from the tags/calibs table.
335 #
336 # Note that it's important that we include a WHERE constraint on both
337 # tables for any column (e.g. dataset_type_id) that is in both when
338 # it's given explicitly; not doing can prevent the query planner from
339 # using very important indexes. At present, we don't include those
340 # redundant columns in the JOIN ON expression, however, because the
341 # FOREIGN KEY (and its index) are defined only on dataset_id.
342 tag_relation: Relation | None = None
343 calib_relation: Relation | None = None
344 if collection_types != {CollectionType.CALIBRATION}:
345 # We'll need a subquery for the tags table if any of the given
346 # collections are not a CALIBRATION collection. This intentionally
347 # also fires when the list of collections is empty as a way to
348 # create a dummy subquery that we know will fail.
349 # We give the table an alias because it might appear multiple times
350 # in the same query, for different dataset types.
351 tags_parts = sql.Payload[LogicalColumn](self._tags.alias(f"{self.datasetType.name}_tags"))
352 if "timespan" in columns:
353 tags_parts.columns_available[
354 DatasetColumnTag(self.datasetType.name, "timespan")
355 ] = TimespanReprClass.fromLiteral(Timespan(None, None))
356 tag_relation = self._finish_single_relation(
357 tags_parts,
358 columns,
359 [
360 (record, rank)
361 for rank, record in enumerate(collections)
362 if record.type is not CollectionType.CALIBRATION
363 ],
364 context,
365 )
366 assert "calib_pkey" not in columns, "For internal use only, and only for pure-calib queries."
367 if CollectionType.CALIBRATION in collection_types:
368 # If at least one collection is a CALIBRATION collection, we'll
369 # need a subquery for the calibs table, and could include the
370 # timespan as a result or constraint.
371 assert (
372 self._calibs is not None
373 ), "DatasetTypes with isCalibration() == False can never be found in a CALIBRATION collection."
374 calibs_parts = sql.Payload[LogicalColumn](self._calibs.alias(f"{self.datasetType.name}_calibs"))
375 if "timespan" in columns:
376 calibs_parts.columns_available[
377 DatasetColumnTag(self.datasetType.name, "timespan")
378 ] = TimespanReprClass.from_columns(calibs_parts.from_clause.columns)
379 if "calib_pkey" in columns:
380 # This is a private extension not included in the base class
381 # interface, for internal use only in _buildCalibOverlapQuery,
382 # which needs access to the autoincrement primary key for the
383 # calib association table.
384 calibs_parts.columns_available[
385 DatasetColumnTag(self.datasetType.name, "calib_pkey")
386 ] = calibs_parts.from_clause.columns.id
387 calib_relation = self._finish_single_relation(
388 calibs_parts,
389 columns,
390 [
391 (record, rank)
392 for rank, record in enumerate(collections)
393 if record.type is CollectionType.CALIBRATION
394 ],
395 context,
396 )
397 if tag_relation is not None:
398 if calib_relation is not None:
399 # daf_relation's chain operation does not automatically
400 # deduplicate; it's more like SQL's UNION ALL. To get UNION
401 # in SQL here, we add an explicit deduplication.
402 return tag_relation.chain(calib_relation).without_duplicates()
403 else:
404 return tag_relation
405 elif calib_relation is not None:
406 return calib_relation
407 else:
408 raise AssertionError("Branch should be unreachable.")
410 def _finish_single_relation(
411 self,
412 payload: sql.Payload[LogicalColumn],
413 requested_columns: Set[str],
414 collections: Sequence[tuple[CollectionRecord, int]],
415 context: SqlQueryContext,
416 ) -> Relation:
417 """Helper method for `make_relation`.
419 This handles adding columns and WHERE terms that are not specific to
420 either the tags or calibs tables.
422 Parameters
423 ----------
424 payload : `lsst.daf.relation.sql.Payload`
425 SQL query parts under construction, to be modified in-place and
426 used to construct the new relation.
427 requested_columns : `~collections.abc.Set` [ `str` ]
428 Columns the relation should include.
429 collections : `Sequence` [ `tuple` [ `CollectionRecord`, `int` ] ]
430 Collections to search for the dataset and their ranks.
431 context : `SqlQueryContext`
432 Context that manages engines and state for the query.
434 Returns
435 -------
436 relation : `lsst.daf.relation.Relation`
437 New dataset query relation.
438 """
439 payload.where.append(payload.from_clause.columns.dataset_type_id == self._dataset_type_id)
440 dataset_id_col = payload.from_clause.columns.dataset_id
441 collection_col = payload.from_clause.columns[self._collections.getCollectionForeignKeyName()]
442 # We always constrain and optionally retrieve the collection(s) via the
443 # tags/calibs table.
444 if len(collections) == 1:
445 payload.where.append(collection_col == collections[0][0].key)
446 if "collection" in requested_columns:
447 payload.columns_available[
448 DatasetColumnTag(self.datasetType.name, "collection")
449 ] = sqlalchemy.sql.literal(collections[0][0].key)
450 else:
451 assert collections, "The no-collections case should be in calling code for better diagnostics."
452 payload.where.append(collection_col.in_([collection.key for collection, _ in collections]))
453 if "collection" in requested_columns:
454 payload.columns_available[
455 DatasetColumnTag(self.datasetType.name, "collection")
456 ] = collection_col
457 # Add rank if requested as a CASE-based calculation the collection
458 # column.
459 if "rank" in requested_columns:
460 payload.columns_available[DatasetColumnTag(self.datasetType.name, "rank")] = sqlalchemy.sql.case(
461 {record.key: rank for record, rank in collections},
462 value=collection_col,
463 )
464 # Add more column definitions, starting with the data ID.
465 for dimension_name in self.datasetType.dimensions.required.names:
466 payload.columns_available[DimensionKeyColumnTag(dimension_name)] = payload.from_clause.columns[
467 dimension_name
468 ]
469 # We can always get the dataset_id from the tags/calibs table.
470 if "dataset_id" in requested_columns:
471 payload.columns_available[DatasetColumnTag(self.datasetType.name, "dataset_id")] = dataset_id_col
472 # It's possible we now have everything we need, from just the
473 # tags/calibs table. The things we might need to get from the static
474 # dataset table are the run key and the ingest date.
475 need_static_table = False
476 if "run" in requested_columns:
477 if len(collections) == 1 and collections[0][0].type is CollectionType.RUN:
478 # If we are searching exactly one RUN collection, we
479 # know that if we find the dataset in that collection,
480 # then that's the datasets's run; we don't need to
481 # query for it.
482 payload.columns_available[
483 DatasetColumnTag(self.datasetType.name, "run")
484 ] = sqlalchemy.sql.literal(collections[0][0].key)
485 else:
486 payload.columns_available[
487 DatasetColumnTag(self.datasetType.name, "run")
488 ] = self._static.dataset.columns[self._runKeyColumn]
489 need_static_table = True
490 # Ingest date can only come from the static table.
491 if "ingest_date" in requested_columns:
492 need_static_table = True
493 payload.columns_available[
494 DatasetColumnTag(self.datasetType.name, "ingest_date")
495 ] = self._static.dataset.columns.ingest_date
496 # If we need the static table, join it in via dataset_id and
497 # dataset_type_id
498 if need_static_table:
499 payload.from_clause = payload.from_clause.join(
500 self._static.dataset, onclause=(dataset_id_col == self._static.dataset.columns.id)
501 )
502 # Also constrain dataset_type_id in static table in case that helps
503 # generate a better plan.
504 # We could also include this in the JOIN ON clause, but my guess is
505 # that that's a good idea IFF it's in the foreign key, and right
506 # now it isn't.
507 payload.where.append(self._static.dataset.columns.dataset_type_id == self._dataset_type_id)
508 leaf = context.sql_engine.make_leaf(
509 payload.columns_available.keys(),
510 payload=payload,
511 name=self.datasetType.name,
512 parameters={record.name: rank for record, rank in collections},
513 )
514 return leaf
516 def getDataId(self, id: DatasetId) -> DataCoordinate:
517 """Return DataId for a dataset.
519 Parameters
520 ----------
521 id : `DatasetId`
522 Unique dataset identifier.
524 Returns
525 -------
526 dataId : `DataCoordinate`
527 DataId for the dataset.
528 """
529 # This query could return multiple rows (one for each tagged collection
530 # the dataset is in, plus one for its run collection), and we don't
531 # care which of those we get.
532 sql = (
533 self._tags.select()
534 .where(
535 sqlalchemy.sql.and_(
536 self._tags.columns.dataset_id == id,
537 self._tags.columns.dataset_type_id == self._dataset_type_id,
538 )
539 )
540 .limit(1)
541 )
542 with self._db.query(sql) as sql_result:
543 row = sql_result.mappings().fetchone()
544 assert row is not None, "Should be guaranteed by caller and foreign key constraints."
545 return DataCoordinate.standardize(
546 {dimension.name: row[dimension.name] for dimension in self.datasetType.dimensions.required},
547 graph=self.datasetType.dimensions,
548 )
551class ByDimensionsDatasetRecordStorageUUID(ByDimensionsDatasetRecordStorage):
552 """Implementation of ByDimensionsDatasetRecordStorage which uses UUID for
553 dataset IDs.
554 """
556 idMaker = DatasetIdFactory()
557 """Factory for dataset IDs. In the future this factory may be shared with
558 other classes (e.g. Registry)."""
560 def insert(
561 self,
562 run: RunRecord,
563 dataIds: Iterable[DataCoordinate],
564 idMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE,
565 ) -> Iterator[DatasetRef]:
566 # Docstring inherited from DatasetRecordStorage.
568 # Iterate over data IDs, transforming a possibly-single-pass iterable
569 # into a list.
570 dataIdList = []
571 rows = []
572 summary = CollectionSummary()
573 for dataId in summary.add_data_ids_generator(self.datasetType, dataIds):
574 dataIdList.append(dataId)
575 rows.append(
576 {
577 "id": self.idMaker.makeDatasetId(run.name, self.datasetType, dataId, idMode),
578 "dataset_type_id": self._dataset_type_id,
579 self._runKeyColumn: run.key,
580 }
581 )
583 with self._db.transaction():
584 # Insert into the static dataset table.
585 self._db.insert(self._static.dataset, *rows)
586 # Update the summary tables for this collection in case this is the
587 # first time this dataset type or these governor values will be
588 # inserted there.
589 self._summaries.update(run, [self._dataset_type_id], summary)
590 # Combine the generated dataset_id values and data ID fields to
591 # form rows to be inserted into the tags table.
592 protoTagsRow = {
593 "dataset_type_id": self._dataset_type_id,
594 self._collections.getCollectionForeignKeyName(): run.key,
595 }
596 tagsRows = [
597 dict(protoTagsRow, dataset_id=row["id"], **dataId.byName())
598 for dataId, row in zip(dataIdList, rows)
599 ]
600 # Insert those rows into the tags table.
601 self._db.insert(self._tags, *tagsRows)
603 for dataId, row in zip(dataIdList, rows):
604 yield DatasetRef(
605 datasetType=self.datasetType,
606 dataId=dataId,
607 id=row["id"],
608 run=run.name,
609 )
611 def import_(
612 self,
613 run: RunRecord,
614 datasets: Iterable[DatasetRef],
615 idGenerationMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE,
616 reuseIds: bool = False,
617 ) -> Iterator[DatasetRef]:
618 # Docstring inherited from DatasetRecordStorage.
620 # Iterate over data IDs, transforming a possibly-single-pass iterable
621 # into a list.
622 dataIds = {}
623 summary = CollectionSummary()
624 for dataset in summary.add_datasets_generator(datasets):
625 # Ignore unknown ID types, normally all IDs have the same type but
626 # this code supports mixed types or missing IDs.
627 datasetId = dataset.id if isinstance(dataset.id, uuid.UUID) else None
628 if datasetId is None:
629 datasetId = self.idMaker.makeDatasetId(
630 run.name, self.datasetType, dataset.dataId, idGenerationMode
631 )
632 dataIds[datasetId] = dataset.dataId
634 # We'll insert all new rows into a temporary table
635 tableSpec = makeTagTableSpec(self.datasetType, type(self._collections), ddl.GUID, constraints=False)
636 collFkName = self._collections.getCollectionForeignKeyName()
637 protoTagsRow = {
638 "dataset_type_id": self._dataset_type_id,
639 collFkName: run.key,
640 }
641 tmpRows = [
642 dict(protoTagsRow, dataset_id=dataset_id, **dataId.byName())
643 for dataset_id, dataId in dataIds.items()
644 ]
645 with self._db.transaction(for_temp_tables=True):
646 with self._db.temporary_table(tableSpec) as tmp_tags:
647 # store all incoming data in a temporary table
648 self._db.insert(tmp_tags, *tmpRows)
650 # There are some checks that we want to make for consistency
651 # of the new datasets with existing ones.
652 self._validateImport(tmp_tags, run)
654 # Before we merge temporary table into dataset/tags we need to
655 # drop datasets which are already there (and do not conflict).
656 self._db.deleteWhere(
657 tmp_tags,
658 tmp_tags.columns.dataset_id.in_(sqlalchemy.sql.select(self._static.dataset.columns.id)),
659 )
661 # Copy it into dataset table, need to re-label some columns.
662 self._db.insert(
663 self._static.dataset,
664 select=sqlalchemy.sql.select(
665 tmp_tags.columns.dataset_id.label("id"),
666 tmp_tags.columns.dataset_type_id,
667 tmp_tags.columns[collFkName].label(self._runKeyColumn),
668 ),
669 )
671 # Update the summary tables for this collection in case this
672 # is the first time this dataset type or these governor values
673 # will be inserted there.
674 self._summaries.update(run, [self._dataset_type_id], summary)
676 # Copy it into tags table.
677 self._db.insert(self._tags, select=tmp_tags.select())
679 # Return refs in the same order as in the input list.
680 for dataset_id, dataId in dataIds.items():
681 yield DatasetRef(
682 datasetType=self.datasetType,
683 id=dataset_id,
684 dataId=dataId,
685 run=run.name,
686 )
688 def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) -> None:
689 """Validate imported refs against existing datasets.
691 Parameters
692 ----------
693 tmp_tags : `sqlalchemy.schema.Table`
694 Temporary table with new datasets and the same schema as tags
695 table.
696 run : `RunRecord`
697 The record object describing the `~CollectionType.RUN` collection.
699 Raises
700 ------
701 ConflictingDefinitionError
702 Raise if new datasets conflict with existing ones.
703 """
704 dataset = self._static.dataset
705 tags = self._tags
706 collFkName = self._collections.getCollectionForeignKeyName()
708 # Check that existing datasets have the same dataset type and
709 # run.
710 query = (
711 sqlalchemy.sql.select(
712 dataset.columns.id.label("dataset_id"),
713 dataset.columns.dataset_type_id.label("dataset_type_id"),
714 tmp_tags.columns.dataset_type_id.label("new dataset_type_id"),
715 dataset.columns[self._runKeyColumn].label("run"),
716 tmp_tags.columns[collFkName].label("new run"),
717 )
718 .select_from(dataset.join(tmp_tags, dataset.columns.id == tmp_tags.columns.dataset_id))
719 .where(
720 sqlalchemy.sql.or_(
721 dataset.columns.dataset_type_id != tmp_tags.columns.dataset_type_id,
722 dataset.columns[self._runKeyColumn] != tmp_tags.columns[collFkName],
723 )
724 )
725 .limit(1)
726 )
727 with self._db.query(query) as result:
728 if (row := result.first()) is not None:
729 # Only include the first one in the exception message
730 raise ConflictingDefinitionError(
731 f"Existing dataset type or run do not match new dataset: {row._asdict()}"
732 )
734 # Check that matching dataset in tags table has the same DataId.
735 query = (
736 sqlalchemy.sql.select(
737 tags.columns.dataset_id,
738 tags.columns.dataset_type_id.label("type_id"),
739 tmp_tags.columns.dataset_type_id.label("new type_id"),
740 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names],
741 *[
742 tmp_tags.columns[dim].label(f"new {dim}")
743 for dim in self.datasetType.dimensions.required.names
744 ],
745 )
746 .select_from(tags.join(tmp_tags, tags.columns.dataset_id == tmp_tags.columns.dataset_id))
747 .where(
748 sqlalchemy.sql.or_(
749 tags.columns.dataset_type_id != tmp_tags.columns.dataset_type_id,
750 *[
751 tags.columns[dim] != tmp_tags.columns[dim]
752 for dim in self.datasetType.dimensions.required.names
753 ],
754 )
755 )
756 .limit(1)
757 )
759 with self._db.query(query) as result:
760 if (row := result.first()) is not None:
761 # Only include the first one in the exception message
762 raise ConflictingDefinitionError(
763 f"Existing dataset type or dataId do not match new dataset: {row._asdict()}"
764 )
766 # Check that matching run+dataId have the same dataset ID.
767 query = (
768 sqlalchemy.sql.select(
769 tags.columns.dataset_type_id.label("dataset_type_id"),
770 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names],
771 tags.columns.dataset_id,
772 tmp_tags.columns.dataset_id.label("new dataset_id"),
773 tags.columns[collFkName],
774 tmp_tags.columns[collFkName].label(f"new {collFkName}"),
775 )
776 .select_from(
777 tags.join(
778 tmp_tags,
779 sqlalchemy.sql.and_(
780 tags.columns.dataset_type_id == tmp_tags.columns.dataset_type_id,
781 tags.columns[collFkName] == tmp_tags.columns[collFkName],
782 *[
783 tags.columns[dim] == tmp_tags.columns[dim]
784 for dim in self.datasetType.dimensions.required.names
785 ],
786 ),
787 )
788 )
789 .where(tags.columns.dataset_id != tmp_tags.columns.dataset_id)
790 .limit(1)
791 )
792 with self._db.query(query) as result:
793 if (row := result.first()) is not None:
794 # only include the first one in the exception message
795 raise ConflictingDefinitionError(
796 f"Existing dataset type and dataId does not match new dataset: {row._asdict()}"
797 )