Coverage for python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py: 94%
237 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-28 04:40 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-28 04:40 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
23from __future__ import annotations
25__all__ = ("ByDimensionsDatasetRecordStorage",)
27import uuid
28from collections.abc import Iterable, Iterator, Sequence, Set
29from typing import TYPE_CHECKING
31import sqlalchemy
32from lsst.daf.relation import Relation, sql
34from ....core import (
35 DataCoordinate,
36 DatasetColumnTag,
37 DatasetId,
38 DatasetRef,
39 DatasetType,
40 DimensionKeyColumnTag,
41 LogicalColumn,
42 Timespan,
43 ddl,
44)
45from ..._collection_summary import CollectionSummary
46from ..._collectionType import CollectionType
47from ..._exceptions import CollectionTypeError, ConflictingDefinitionError
48from ...interfaces import DatasetIdFactory, DatasetIdGenEnum, DatasetRecordStorage
49from ...queries import SqlQueryContext
50from .tables import makeTagTableSpec
52if TYPE_CHECKING: 52 ↛ 53line 52 didn't jump to line 53, because the condition on line 52 was never true
53 from ...interfaces import CollectionManager, CollectionRecord, Database, RunRecord
54 from .summaries import CollectionSummaryManager
55 from .tables import StaticDatasetTablesTuple
58class ByDimensionsDatasetRecordStorage(DatasetRecordStorage):
59 """Dataset record storage implementation paired with
60 `ByDimensionsDatasetRecordStorageManagerUUID`; see that class for more
61 information.
63 Instances of this class should never be constructed directly; use
64 `DatasetRecordStorageManager.register` instead.
65 """
67 def __init__(
68 self,
69 *,
70 datasetType: DatasetType,
71 db: Database,
72 dataset_type_id: int,
73 collections: CollectionManager,
74 static: StaticDatasetTablesTuple,
75 summaries: CollectionSummaryManager,
76 tags: sqlalchemy.schema.Table,
77 calibs: sqlalchemy.schema.Table | None,
78 ):
79 super().__init__(datasetType=datasetType)
80 self._dataset_type_id = dataset_type_id
81 self._db = db
82 self._collections = collections
83 self._static = static
84 self._summaries = summaries
85 self._tags = tags
86 self._calibs = calibs
87 self._runKeyColumn = collections.getRunForeignKeyName()
89 def delete(self, datasets: Iterable[DatasetRef]) -> None:
90 # Docstring inherited from DatasetRecordStorage.
91 # Only delete from common dataset table; ON DELETE foreign key clauses
92 # will handle the rest.
93 self._db.delete(
94 self._static.dataset,
95 ["id"],
96 *[{"id": dataset.getCheckedId()} for dataset in datasets],
97 )
99 def associate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None:
100 # Docstring inherited from DatasetRecordStorage.
101 if collection.type is not CollectionType.TAGGED: 101 ↛ 102line 101 didn't jump to line 102, because the condition on line 101 was never true
102 raise TypeError(
103 f"Cannot associate into collection '{collection.name}' "
104 f"of type {collection.type.name}; must be TAGGED."
105 )
106 protoRow = {
107 self._collections.getCollectionForeignKeyName(): collection.key,
108 "dataset_type_id": self._dataset_type_id,
109 }
110 rows = []
111 summary = CollectionSummary()
112 for dataset in summary.add_datasets_generator(datasets):
113 row = dict(protoRow, dataset_id=dataset.getCheckedId())
114 for dimension, value in dataset.dataId.items():
115 row[dimension.name] = value
116 rows.append(row)
117 # Update the summary tables for this collection in case this is the
118 # first time this dataset type or these governor values will be
119 # inserted there.
120 self._summaries.update(collection, [self._dataset_type_id], summary)
121 # Update the tag table itself.
122 self._db.replace(self._tags, *rows)
124 def disassociate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None:
125 # Docstring inherited from DatasetRecordStorage.
126 if collection.type is not CollectionType.TAGGED: 126 ↛ 127line 126 didn't jump to line 127, because the condition on line 126 was never true
127 raise TypeError(
128 f"Cannot disassociate from collection '{collection.name}' "
129 f"of type {collection.type.name}; must be TAGGED."
130 )
131 rows = [
132 {
133 "dataset_id": dataset.getCheckedId(),
134 self._collections.getCollectionForeignKeyName(): collection.key,
135 }
136 for dataset in datasets
137 ]
138 self._db.delete(self._tags, ["dataset_id", self._collections.getCollectionForeignKeyName()], *rows)
140 def _buildCalibOverlapQuery(
141 self,
142 collection: CollectionRecord,
143 data_ids: set[DataCoordinate] | None,
144 timespan: Timespan,
145 context: SqlQueryContext,
146 ) -> Relation:
147 relation = self.make_relation(
148 collection, columns={"timespan", "dataset_id", "calib_pkey"}, context=context
149 ).with_rows_satisfying(
150 context.make_timespan_overlap_predicate(
151 DatasetColumnTag(self.datasetType.name, "timespan"), timespan
152 ),
153 )
154 if data_ids is not None:
155 relation = relation.join(
156 context.make_data_id_relation(
157 data_ids, self.datasetType.dimensions.required.names
158 ).transferred_to(context.sql_engine),
159 )
160 return relation
162 def certify(
163 self,
164 collection: CollectionRecord,
165 datasets: Iterable[DatasetRef],
166 timespan: Timespan,
167 context: SqlQueryContext,
168 ) -> None:
169 # Docstring inherited from DatasetRecordStorage.
170 if self._calibs is None: 170 ↛ 171line 170 didn't jump to line 171, because the condition on line 170 was never true
171 raise CollectionTypeError(
172 f"Cannot certify datasets of type {self.datasetType.name}, for which "
173 "DatasetType.isCalibration() is False."
174 )
175 if collection.type is not CollectionType.CALIBRATION: 175 ↛ 176line 175 didn't jump to line 176, because the condition on line 175 was never true
176 raise CollectionTypeError(
177 f"Cannot certify into collection '{collection.name}' "
178 f"of type {collection.type.name}; must be CALIBRATION."
179 )
180 TimespanReprClass = self._db.getTimespanRepresentation()
181 protoRow = {
182 self._collections.getCollectionForeignKeyName(): collection.key,
183 "dataset_type_id": self._dataset_type_id,
184 }
185 rows = []
186 dataIds: set[DataCoordinate] | None = (
187 set() if not TimespanReprClass.hasExclusionConstraint() else None
188 )
189 summary = CollectionSummary()
190 for dataset in summary.add_datasets_generator(datasets):
191 row = dict(protoRow, dataset_id=dataset.getCheckedId())
192 for dimension, value in dataset.dataId.items():
193 row[dimension.name] = value
194 TimespanReprClass.update(timespan, result=row)
195 rows.append(row)
196 if dataIds is not None: 196 ↛ 190line 196 didn't jump to line 190, because the condition on line 196 was never false
197 dataIds.add(dataset.dataId)
198 # Update the summary tables for this collection in case this is the
199 # first time this dataset type or these governor values will be
200 # inserted there.
201 self._summaries.update(collection, [self._dataset_type_id], summary)
202 # Update the association table itself.
203 if TimespanReprClass.hasExclusionConstraint(): 203 ↛ 206line 203 didn't jump to line 206, because the condition on line 203 was never true
204 # Rely on database constraint to enforce invariants; we just
205 # reraise the exception for consistency across DB engines.
206 try:
207 self._db.insert(self._calibs, *rows)
208 except sqlalchemy.exc.IntegrityError as err:
209 raise ConflictingDefinitionError(
210 f"Validity range conflict certifying datasets of type {self.datasetType.name} "
211 f"into {collection.name} for range [{timespan.begin}, {timespan.end})."
212 ) from err
213 else:
214 # Have to implement exclusion constraint ourselves.
215 # Start by building a SELECT query for any rows that would overlap
216 # this one.
217 relation = self._buildCalibOverlapQuery(collection, dataIds, timespan, context)
218 # Acquire a table lock to ensure there are no concurrent writes
219 # could invalidate our checking before we finish the inserts. We
220 # use a SAVEPOINT in case there is an outer transaction that a
221 # failure here should not roll back.
222 with self._db.transaction(lock=[self._calibs], savepoint=True):
223 # Enter SqlQueryContext in case we need to use a temporary
224 # table to include the give data IDs in the query. Note that
225 # by doing this inside the transaction, we make sure it doesn't
226 # attempt to close the session when its done, since it just
227 # sees an already-open session that it knows it shouldn't
228 # manage.
229 with context:
230 # Run the check SELECT query.
231 conflicting = context.count(context.process(relation))
232 if conflicting > 0:
233 raise ConflictingDefinitionError(
234 f"{conflicting} validity range conflicts certifying datasets of type "
235 f"{self.datasetType.name} into {collection.name} for range "
236 f"[{timespan.begin}, {timespan.end})."
237 )
238 # Proceed with the insert.
239 self._db.insert(self._calibs, *rows)
241 def decertify(
242 self,
243 collection: CollectionRecord,
244 timespan: Timespan,
245 *,
246 dataIds: Iterable[DataCoordinate] | None = None,
247 context: SqlQueryContext,
248 ) -> None:
249 # Docstring inherited from DatasetRecordStorage.
250 if self._calibs is None: 250 ↛ 251line 250 didn't jump to line 251, because the condition on line 250 was never true
251 raise CollectionTypeError(
252 f"Cannot decertify datasets of type {self.datasetType.name}, for which "
253 "DatasetType.isCalibration() is False."
254 )
255 if collection.type is not CollectionType.CALIBRATION: 255 ↛ 256line 255 didn't jump to line 256, because the condition on line 255 was never true
256 raise CollectionTypeError(
257 f"Cannot decertify from collection '{collection.name}' "
258 f"of type {collection.type.name}; must be CALIBRATION."
259 )
260 TimespanReprClass = self._db.getTimespanRepresentation()
261 # Construct a SELECT query to find all rows that overlap our inputs.
262 dataIdSet: set[DataCoordinate] | None
263 if dataIds is not None:
264 dataIdSet = set(dataIds)
265 else:
266 dataIdSet = None
267 relation = self._buildCalibOverlapQuery(collection, dataIdSet, timespan, context)
268 calib_pkey_tag = DatasetColumnTag(self.datasetType.name, "calib_pkey")
269 dataset_id_tag = DatasetColumnTag(self.datasetType.name, "dataset_id")
270 timespan_tag = DatasetColumnTag(self.datasetType.name, "timespan")
271 data_id_tags = [
272 (name, DimensionKeyColumnTag(name)) for name in self.datasetType.dimensions.required.names
273 ]
274 # Set up collections to populate with the rows we'll want to modify.
275 # The insert rows will have the same values for collection and
276 # dataset type.
277 protoInsertRow = {
278 self._collections.getCollectionForeignKeyName(): collection.key,
279 "dataset_type_id": self._dataset_type_id,
280 }
281 rowsToDelete = []
282 rowsToInsert = []
283 # Acquire a table lock to ensure there are no concurrent writes
284 # between the SELECT and the DELETE and INSERT queries based on it.
285 with self._db.transaction(lock=[self._calibs], savepoint=True):
286 # Enter SqlQueryContext in case we need to use a temporary table to
287 # include the give data IDs in the query (see similar block in
288 # certify for details).
289 with context:
290 for row in context.fetch_iterable(relation):
291 rowsToDelete.append({"id": row[calib_pkey_tag]})
292 # Construct the insert row(s) by copying the prototype row,
293 # then adding the dimension column values, then adding
294 # what's left of the timespan from that row after we
295 # subtract the given timespan.
296 newInsertRow = protoInsertRow.copy()
297 newInsertRow["dataset_id"] = row[dataset_id_tag]
298 for name, tag in data_id_tags:
299 newInsertRow[name] = row[tag]
300 rowTimespan = row[timespan_tag]
301 assert rowTimespan is not None, "Field should have a NOT NULL constraint."
302 for diffTimespan in rowTimespan.difference(timespan):
303 rowsToInsert.append(
304 TimespanReprClass.update(diffTimespan, result=newInsertRow.copy())
305 )
306 # Run the DELETE and INSERT queries.
307 self._db.delete(self._calibs, ["id"], *rowsToDelete)
308 self._db.insert(self._calibs, *rowsToInsert)
310 def make_relation(
311 self,
312 *collections: CollectionRecord,
313 columns: Set[str],
314 context: SqlQueryContext,
315 ) -> Relation:
316 # Docstring inherited from DatasetRecordStorage.
317 collection_types = {collection.type for collection in collections}
318 assert CollectionType.CHAINED not in collection_types, "CHAINED collections must be flattened."
319 TimespanReprClass = self._db.getTimespanRepresentation()
320 #
321 # There are two kinds of table in play here:
322 #
323 # - the static dataset table (with the dataset ID, dataset type ID,
324 # run ID/name, and ingest date);
325 #
326 # - the dynamic tags/calibs table (with the dataset ID, dataset type
327 # type ID, collection ID/name, data ID, and possibly validity
328 # range).
329 #
330 # That means that we might want to return a query against either table
331 # or a JOIN of both, depending on which quantities the caller wants.
332 # But the data ID is always included, which means we'll always include
333 # the tags/calibs table and join in the static dataset table only if we
334 # need things from it that we can't get from the tags/calibs table.
335 #
336 # Note that it's important that we include a WHERE constraint on both
337 # tables for any column (e.g. dataset_type_id) that is in both when
338 # it's given explicitly; not doing can prevent the query planner from
339 # using very important indexes. At present, we don't include those
340 # redundant columns in the JOIN ON expression, however, because the
341 # FOREIGN KEY (and its index) are defined only on dataset_id.
342 tag_relation: Relation | None = None
343 calib_relation: Relation | None = None
344 if collection_types != {CollectionType.CALIBRATION}:
345 # We'll need a subquery for the tags table if any of the given
346 # collections are not a CALIBRATION collection. This intentionally
347 # also fires when the list of collections is empty as a way to
348 # create a dummy subquery that we know will fail.
349 # We give the table an alias because it might appear multiple times
350 # in the same query, for different dataset types.
351 tags_parts = sql.Payload[LogicalColumn](self._tags.alias(f"{self.datasetType.name}_tags"))
352 if "timespan" in columns:
353 tags_parts.columns_available[
354 DatasetColumnTag(self.datasetType.name, "timespan")
355 ] = TimespanReprClass.fromLiteral(Timespan(None, None))
356 tag_relation = self._finish_single_relation(
357 tags_parts,
358 columns,
359 [
360 (record, rank)
361 for rank, record in enumerate(collections)
362 if record.type is not CollectionType.CALIBRATION
363 ],
364 context,
365 )
366 assert "calib_pkey" not in columns, "For internal use only, and only for pure-calib queries."
367 if CollectionType.CALIBRATION in collection_types:
368 # If at least one collection is a CALIBRATION collection, we'll
369 # need a subquery for the calibs table, and could include the
370 # timespan as a result or constraint.
371 assert (
372 self._calibs is not None
373 ), "DatasetTypes with isCalibration() == False can never be found in a CALIBRATION collection."
374 calibs_parts = sql.Payload[LogicalColumn](self._calibs.alias(f"{self.datasetType.name}_calibs"))
375 if "timespan" in columns:
376 calibs_parts.columns_available[
377 DatasetColumnTag(self.datasetType.name, "timespan")
378 ] = TimespanReprClass.from_columns(calibs_parts.from_clause.columns)
379 if "calib_pkey" in columns:
380 # This is a private extension not included in the base class
381 # interface, for internal use only in _buildCalibOverlapQuery,
382 # which needs access to the autoincrement primary key for the
383 # calib association table.
384 calibs_parts.columns_available[
385 DatasetColumnTag(self.datasetType.name, "calib_pkey")
386 ] = calibs_parts.from_clause.columns.id
387 calib_relation = self._finish_single_relation(
388 calibs_parts,
389 columns,
390 [
391 (record, rank)
392 for rank, record in enumerate(collections)
393 if record.type is CollectionType.CALIBRATION
394 ],
395 context,
396 )
397 if tag_relation is not None:
398 if calib_relation is not None:
399 # daf_relation's chain operation does not automatically
400 # deduplicate; it's more like SQL's UNION ALL. To get UNION
401 # in SQL here, we add an explicit deduplication.
402 return tag_relation.chain(calib_relation).without_duplicates()
403 else:
404 return tag_relation
405 elif calib_relation is not None: 405 ↛ 408line 405 didn't jump to line 408, because the condition on line 405 was never false
406 return calib_relation
407 else:
408 raise AssertionError("Branch should be unreachable.")
410 def _finish_single_relation(
411 self,
412 payload: sql.Payload[LogicalColumn],
413 requested_columns: Set[str],
414 collections: Sequence[tuple[CollectionRecord, int]],
415 context: SqlQueryContext,
416 ) -> Relation:
417 """Helper method for `make_relation`.
419 This handles adding columns and WHERE terms that are not specific to
420 either the tags or calibs tables.
422 Parameters
423 ----------
424 payload : `lsst.daf.relation.sql.Payload`
425 SQL query parts under construction, to be modified in-place and
426 used to construct the new relation.
427 requested_columns : `~collections.abc.Set` [ `str` ]
428 Columns the relation should include.
429 collections : `Sequence` [ `tuple` [ `CollectionRecord`, `int` ] ]
430 Collections to search for the dataset and their ranks.
431 context : `SqlQueryContext`
432 Context that manages engines and state for the query.
434 Returns
435 -------
436 relation : `lsst.daf.relation.Relation`
437 New dataset query relation.
438 """
439 payload.where.append(payload.from_clause.columns.dataset_type_id == self._dataset_type_id)
440 dataset_id_col = payload.from_clause.columns.dataset_id
441 collection_col = payload.from_clause.columns[self._collections.getCollectionForeignKeyName()]
442 # We always constrain and optionally retrieve the collection(s) via the
443 # tags/calibs table.
444 if len(collections) == 1:
445 payload.where.append(collection_col == collections[0][0].key)
446 if "collection" in requested_columns:
447 payload.columns_available[
448 DatasetColumnTag(self.datasetType.name, "collection")
449 ] = sqlalchemy.sql.literal(collections[0][0].key)
450 else:
451 assert collections, "The no-collections case should be in calling code for better diagnostics."
452 payload.where.append(collection_col.in_([collection.key for collection, _ in collections]))
453 if "collection" in requested_columns:
454 payload.columns_available[
455 DatasetColumnTag(self.datasetType.name, "collection")
456 ] = collection_col
457 # Add rank if requested as a CASE-based calculation the collection
458 # column.
459 if "rank" in requested_columns:
460 payload.columns_available[DatasetColumnTag(self.datasetType.name, "rank")] = sqlalchemy.sql.case(
461 {record.key: rank for record, rank in collections},
462 value=collection_col,
463 )
464 # Add more column definitions, starting with the data ID.
465 for dimension_name in self.datasetType.dimensions.required.names:
466 payload.columns_available[DimensionKeyColumnTag(dimension_name)] = payload.from_clause.columns[
467 dimension_name
468 ]
469 # We can always get the dataset_id from the tags/calibs table.
470 if "dataset_id" in requested_columns:
471 payload.columns_available[DatasetColumnTag(self.datasetType.name, "dataset_id")] = dataset_id_col
472 # It's possible we now have everything we need, from just the
473 # tags/calibs table. The things we might need to get from the static
474 # dataset table are the run key and the ingest date.
475 need_static_table = False
476 if "run" in requested_columns:
477 if len(collections) == 1 and collections[0][0].type is CollectionType.RUN:
478 # If we are searching exactly one RUN collection, we
479 # know that if we find the dataset in that collection,
480 # then that's the datasets's run; we don't need to
481 # query for it.
482 payload.columns_available[
483 DatasetColumnTag(self.datasetType.name, "run")
484 ] = sqlalchemy.sql.literal(collections[0][0].key)
485 else:
486 payload.columns_available[
487 DatasetColumnTag(self.datasetType.name, "run")
488 ] = self._static.dataset.columns[self._runKeyColumn]
489 need_static_table = True
490 # Ingest date can only come from the static table.
491 if "ingest_date" in requested_columns:
492 need_static_table = True
493 payload.columns_available[
494 DatasetColumnTag(self.datasetType.name, "ingest_date")
495 ] = self._static.dataset.columns.ingest_date
496 # If we need the static table, join it in via dataset_id and
497 # dataset_type_id
498 if need_static_table:
499 payload.from_clause = payload.from_clause.join(
500 self._static.dataset, onclause=(dataset_id_col == self._static.dataset.columns.id)
501 )
502 # Also constrain dataset_type_id in static table in case that helps
503 # generate a better plan.
504 # We could also include this in the JOIN ON clause, but my guess is
505 # that that's a good idea IFF it's in the foreign key, and right
506 # now it isn't.
507 payload.where.append(self._static.dataset.columns.dataset_type_id == self._dataset_type_id)
508 leaf = context.sql_engine.make_leaf(
509 payload.columns_available.keys(),
510 payload=payload,
511 name=self.datasetType.name,
512 parameters={record.name: rank for record, rank in collections},
513 )
514 return leaf
516 def getDataId(self, id: DatasetId) -> DataCoordinate:
517 """Return DataId for a dataset.
519 Parameters
520 ----------
521 id : `DatasetId`
522 Unique dataset identifier.
524 Returns
525 -------
526 dataId : `DataCoordinate`
527 DataId for the dataset.
528 """
529 # This query could return multiple rows (one for each tagged collection
530 # the dataset is in, plus one for its run collection), and we don't
531 # care which of those we get.
532 sql = (
533 self._tags.select()
534 .where(
535 sqlalchemy.sql.and_(
536 self._tags.columns.dataset_id == id,
537 self._tags.columns.dataset_type_id == self._dataset_type_id,
538 )
539 )
540 .limit(1)
541 )
542 with self._db.query(sql) as sql_result:
543 row = sql_result.mappings().fetchone()
544 assert row is not None, "Should be guaranteed by caller and foreign key constraints."
545 return DataCoordinate.standardize(
546 {dimension.name: row[dimension.name] for dimension in self.datasetType.dimensions.required},
547 graph=self.datasetType.dimensions,
548 )
551class ByDimensionsDatasetRecordStorageUUID(ByDimensionsDatasetRecordStorage):
552 """Implementation of ByDimensionsDatasetRecordStorage which uses UUID for
553 dataset IDs.
554 """
556 idMaker = DatasetIdFactory()
557 """Factory for dataset IDs. In the future this factory may be shared with
558 other classes (e.g. Registry)."""
560 def insert(
561 self,
562 run: RunRecord,
563 dataIds: Iterable[DataCoordinate],
564 idMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE,
565 ) -> Iterator[DatasetRef]:
566 # Docstring inherited from DatasetRecordStorage.
568 # Iterate over data IDs, transforming a possibly-single-pass iterable
569 # into a list.
570 dataIdList = []
571 rows = []
572 summary = CollectionSummary()
573 for dataId in summary.add_data_ids_generator(self.datasetType, dataIds):
574 dataIdList.append(dataId)
575 rows.append(
576 {
577 "id": self.idMaker.makeDatasetId(run.name, self.datasetType, dataId, idMode),
578 "dataset_type_id": self._dataset_type_id,
579 self._runKeyColumn: run.key,
580 }
581 )
583 with self._db.transaction():
584 # Insert into the static dataset table.
585 self._db.insert(self._static.dataset, *rows)
586 # Update the summary tables for this collection in case this is the
587 # first time this dataset type or these governor values will be
588 # inserted there.
589 self._summaries.update(run, [self._dataset_type_id], summary)
590 # Combine the generated dataset_id values and data ID fields to
591 # form rows to be inserted into the tags table.
592 protoTagsRow = {
593 "dataset_type_id": self._dataset_type_id,
594 self._collections.getCollectionForeignKeyName(): run.key,
595 }
596 tagsRows = [
597 dict(protoTagsRow, dataset_id=row["id"], **dataId.byName())
598 for dataId, row in zip(dataIdList, rows)
599 ]
600 # Insert those rows into the tags table.
601 self._db.insert(self._tags, *tagsRows)
603 for dataId, row in zip(dataIdList, rows):
604 yield DatasetRef(
605 datasetType=self.datasetType,
606 dataId=dataId,
607 id=row["id"],
608 run=run.name,
609 )
611 def import_(
612 self,
613 run: RunRecord,
614 datasets: Iterable[DatasetRef],
615 idGenerationMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE,
616 reuseIds: bool = False,
617 ) -> Iterator[DatasetRef]:
618 # Docstring inherited from DatasetRecordStorage.
620 # Iterate over data IDs, transforming a possibly-single-pass iterable
621 # into a list.
622 dataIds = {}
623 summary = CollectionSummary()
624 for dataset in summary.add_datasets_generator(datasets):
625 # Ignore unknown ID types, normally all IDs have the same type but
626 # this code supports mixed types or missing IDs.
627 datasetId = dataset.id if isinstance(dataset.id, uuid.UUID) else None
628 if datasetId is None:
629 datasetId = self.idMaker.makeDatasetId(
630 run.name, self.datasetType, dataset.dataId, idGenerationMode
631 )
632 dataIds[datasetId] = dataset.dataId
634 # We'll insert all new rows into a temporary table
635 tableSpec = makeTagTableSpec(self.datasetType, type(self._collections), ddl.GUID, constraints=False)
636 collFkName = self._collections.getCollectionForeignKeyName()
637 protoTagsRow = {
638 "dataset_type_id": self._dataset_type_id,
639 collFkName: run.key,
640 }
641 tmpRows = [
642 dict(protoTagsRow, dataset_id=dataset_id, **dataId.byName())
643 for dataset_id, dataId in dataIds.items()
644 ]
645 with self._db.transaction(for_temp_tables=True):
646 with self._db.temporary_table(tableSpec) as tmp_tags:
647 # store all incoming data in a temporary table
648 self._db.insert(tmp_tags, *tmpRows)
650 # There are some checks that we want to make for consistency
651 # of the new datasets with existing ones.
652 self._validateImport(tmp_tags, run)
654 # Before we merge temporary table into dataset/tags we need to
655 # drop datasets which are already there (and do not conflict).
656 self._db.deleteWhere(
657 tmp_tags,
658 tmp_tags.columns.dataset_id.in_(sqlalchemy.sql.select(self._static.dataset.columns.id)),
659 )
661 # Copy it into dataset table, need to re-label some columns.
662 self._db.insert(
663 self._static.dataset,
664 select=sqlalchemy.sql.select(
665 tmp_tags.columns.dataset_id.label("id"),
666 tmp_tags.columns.dataset_type_id,
667 tmp_tags.columns[collFkName].label(self._runKeyColumn),
668 ),
669 )
671 # Update the summary tables for this collection in case this
672 # is the first time this dataset type or these governor values
673 # will be inserted there.
674 self._summaries.update(run, [self._dataset_type_id], summary)
676 # Copy it into tags table.
677 self._db.insert(self._tags, select=tmp_tags.select())
679 # Return refs in the same order as in the input list.
680 for dataset_id, dataId in dataIds.items():
681 yield DatasetRef(
682 datasetType=self.datasetType,
683 id=dataset_id,
684 dataId=dataId,
685 run=run.name,
686 )
688 def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) -> None:
689 """Validate imported refs against existing datasets.
691 Parameters
692 ----------
693 tmp_tags : `sqlalchemy.schema.Table`
694 Temporary table with new datasets and the same schema as tags
695 table.
696 run : `RunRecord`
697 The record object describing the `~CollectionType.RUN` collection.
699 Raises
700 ------
701 ConflictingDefinitionError
702 Raise if new datasets conflict with existing ones.
703 """
704 dataset = self._static.dataset
705 tags = self._tags
706 collFkName = self._collections.getCollectionForeignKeyName()
708 # Check that existing datasets have the same dataset type and
709 # run.
710 query = (
711 sqlalchemy.sql.select(
712 dataset.columns.id.label("dataset_id"),
713 dataset.columns.dataset_type_id.label("dataset_type_id"),
714 tmp_tags.columns.dataset_type_id.label("new dataset_type_id"),
715 dataset.columns[self._runKeyColumn].label("run"),
716 tmp_tags.columns[collFkName].label("new run"),
717 )
718 .select_from(dataset.join(tmp_tags, dataset.columns.id == tmp_tags.columns.dataset_id))
719 .where(
720 sqlalchemy.sql.or_(
721 dataset.columns.dataset_type_id != tmp_tags.columns.dataset_type_id,
722 dataset.columns[self._runKeyColumn] != tmp_tags.columns[collFkName],
723 )
724 )
725 .limit(1)
726 )
727 with self._db.query(query) as result:
728 if (row := result.first()) is not None:
729 # Only include the first one in the exception message
730 raise ConflictingDefinitionError(
731 f"Existing dataset type or run do not match new dataset: {row._asdict()}"
732 )
734 # Check that matching dataset in tags table has the same DataId.
735 query = (
736 sqlalchemy.sql.select(
737 tags.columns.dataset_id,
738 tags.columns.dataset_type_id.label("type_id"),
739 tmp_tags.columns.dataset_type_id.label("new type_id"),
740 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names],
741 *[
742 tmp_tags.columns[dim].label(f"new {dim}")
743 for dim in self.datasetType.dimensions.required.names
744 ],
745 )
746 .select_from(tags.join(tmp_tags, tags.columns.dataset_id == tmp_tags.columns.dataset_id))
747 .where(
748 sqlalchemy.sql.or_(
749 tags.columns.dataset_type_id != tmp_tags.columns.dataset_type_id,
750 *[
751 tags.columns[dim] != tmp_tags.columns[dim]
752 for dim in self.datasetType.dimensions.required.names
753 ],
754 )
755 )
756 .limit(1)
757 )
759 with self._db.query(query) as result:
760 if (row := result.first()) is not None:
761 # Only include the first one in the exception message
762 raise ConflictingDefinitionError(
763 f"Existing dataset type or dataId do not match new dataset: {row._asdict()}"
764 )
766 # Check that matching run+dataId have the same dataset ID.
767 query = (
768 sqlalchemy.sql.select(
769 tags.columns.dataset_type_id.label("dataset_type_id"),
770 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names],
771 tags.columns.dataset_id,
772 tmp_tags.columns.dataset_id.label("new dataset_id"),
773 tags.columns[collFkName],
774 tmp_tags.columns[collFkName].label(f"new {collFkName}"),
775 )
776 .select_from(
777 tags.join(
778 tmp_tags,
779 sqlalchemy.sql.and_(
780 tags.columns.dataset_type_id == tmp_tags.columns.dataset_type_id,
781 tags.columns[collFkName] == tmp_tags.columns[collFkName],
782 *[
783 tags.columns[dim] == tmp_tags.columns[dim]
784 for dim in self.datasetType.dimensions.required.names
785 ],
786 ),
787 )
788 )
789 .where(tags.columns.dataset_id != tmp_tags.columns.dataset_id)
790 .limit(1)
791 )
792 with self._db.query(query) as result:
793 if (row := result.first()) is not None:
794 # only include the first one in the exception message
795 raise ConflictingDefinitionError(
796 f"Existing dataset type and dataId does not match new dataset: {row._asdict()}"
797 )