Coverage for python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py: 81%
277 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-05 10:06 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-05 10:06 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
23from __future__ import annotations
25__all__ = ("ByDimensionsDatasetRecordStorage",)
27import uuid
28from collections.abc import Iterable, Iterator, Sequence, Set
29from typing import TYPE_CHECKING
31import sqlalchemy
32from deprecated.sphinx import deprecated
33from lsst.daf.relation import Relation, sql
35from ....core import (
36 DataCoordinate,
37 DatasetColumnTag,
38 DatasetId,
39 DatasetRef,
40 DatasetType,
41 DimensionKeyColumnTag,
42 LogicalColumn,
43 Timespan,
44 ddl,
45)
46from ..._collection_summary import CollectionSummary
47from ..._collectionType import CollectionType
48from ..._exceptions import CollectionTypeError, ConflictingDefinitionError, UnsupportedIdGeneratorError
49from ...interfaces import DatasetIdFactory, DatasetIdGenEnum, DatasetRecordStorage
50from ...queries import SqlQueryContext
51from .tables import makeTagTableSpec
53if TYPE_CHECKING: 53 ↛ 54line 53 didn't jump to line 54, because the condition on line 53 was never true
54 from ...interfaces import CollectionManager, CollectionRecord, Database, RunRecord
55 from .summaries import CollectionSummaryManager
56 from .tables import StaticDatasetTablesTuple
59class ByDimensionsDatasetRecordStorage(DatasetRecordStorage):
60 """Dataset record storage implementation paired with
61 `ByDimensionsDatasetRecordStorageManager`; see that class for more
62 information.
64 Instances of this class should never be constructed directly; use
65 `DatasetRecordStorageManager.register` instead.
66 """
68 def __init__(
69 self,
70 *,
71 datasetType: DatasetType,
72 db: Database,
73 dataset_type_id: int,
74 collections: CollectionManager,
75 static: StaticDatasetTablesTuple,
76 summaries: CollectionSummaryManager,
77 tags: sqlalchemy.schema.Table,
78 calibs: sqlalchemy.schema.Table | None,
79 ):
80 super().__init__(datasetType=datasetType)
81 self._dataset_type_id = dataset_type_id
82 self._db = db
83 self._collections = collections
84 self._static = static
85 self._summaries = summaries
86 self._tags = tags
87 self._calibs = calibs
88 self._runKeyColumn = collections.getRunForeignKeyName()
90 def delete(self, datasets: Iterable[DatasetRef]) -> None:
91 # Docstring inherited from DatasetRecordStorage.
92 # Only delete from common dataset table; ON DELETE foreign key clauses
93 # will handle the rest.
94 self._db.delete(
95 self._static.dataset,
96 ["id"],
97 *[{"id": dataset.getCheckedId()} for dataset in datasets],
98 )
100 def associate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None:
101 # Docstring inherited from DatasetRecordStorage.
102 if collection.type is not CollectionType.TAGGED: 102 ↛ 103line 102 didn't jump to line 103, because the condition on line 102 was never true
103 raise TypeError(
104 f"Cannot associate into collection '{collection.name}' "
105 f"of type {collection.type.name}; must be TAGGED."
106 )
107 protoRow = {
108 self._collections.getCollectionForeignKeyName(): collection.key,
109 "dataset_type_id": self._dataset_type_id,
110 }
111 rows = []
112 summary = CollectionSummary()
113 for dataset in summary.add_datasets_generator(datasets):
114 row = dict(protoRow, dataset_id=dataset.getCheckedId())
115 for dimension, value in dataset.dataId.items():
116 row[dimension.name] = value
117 rows.append(row)
118 # Update the summary tables for this collection in case this is the
119 # first time this dataset type or these governor values will be
120 # inserted there.
121 self._summaries.update(collection, [self._dataset_type_id], summary)
122 # Update the tag table itself.
123 self._db.replace(self._tags, *rows)
125 def disassociate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None:
126 # Docstring inherited from DatasetRecordStorage.
127 if collection.type is not CollectionType.TAGGED: 127 ↛ 128line 127 didn't jump to line 128, because the condition on line 127 was never true
128 raise TypeError(
129 f"Cannot disassociate from collection '{collection.name}' "
130 f"of type {collection.type.name}; must be TAGGED."
131 )
132 rows = [
133 {
134 "dataset_id": dataset.getCheckedId(),
135 self._collections.getCollectionForeignKeyName(): collection.key,
136 }
137 for dataset in datasets
138 ]
139 self._db.delete(self._tags, ["dataset_id", self._collections.getCollectionForeignKeyName()], *rows)
141 def _buildCalibOverlapQuery(
142 self,
143 collection: CollectionRecord,
144 data_ids: set[DataCoordinate] | None,
145 timespan: Timespan,
146 context: SqlQueryContext,
147 ) -> Relation:
148 relation = self.make_relation(
149 collection, columns={"timespan", "dataset_id", "calib_pkey"}, context=context
150 ).with_rows_satisfying(
151 context.make_timespan_overlap_predicate(
152 DatasetColumnTag(self.datasetType.name, "timespan"), timespan
153 ),
154 )
155 if data_ids is not None:
156 relation = relation.join(
157 context.make_data_id_relation(
158 data_ids, self.datasetType.dimensions.required.names
159 ).transferred_to(context.sql_engine),
160 )
161 return relation
163 def certify(
164 self,
165 collection: CollectionRecord,
166 datasets: Iterable[DatasetRef],
167 timespan: Timespan,
168 context: SqlQueryContext,
169 ) -> None:
170 # Docstring inherited from DatasetRecordStorage.
171 if self._calibs is None: 171 ↛ 172line 171 didn't jump to line 172, because the condition on line 171 was never true
172 raise CollectionTypeError(
173 f"Cannot certify datasets of type {self.datasetType.name}, for which "
174 "DatasetType.isCalibration() is False."
175 )
176 if collection.type is not CollectionType.CALIBRATION: 176 ↛ 177line 176 didn't jump to line 177, because the condition on line 176 was never true
177 raise CollectionTypeError(
178 f"Cannot certify into collection '{collection.name}' "
179 f"of type {collection.type.name}; must be CALIBRATION."
180 )
181 TimespanReprClass = self._db.getTimespanRepresentation()
182 protoRow = {
183 self._collections.getCollectionForeignKeyName(): collection.key,
184 "dataset_type_id": self._dataset_type_id,
185 }
186 rows = []
187 dataIds: set[DataCoordinate] | None = (
188 set() if not TimespanReprClass.hasExclusionConstraint() else None
189 )
190 summary = CollectionSummary()
191 for dataset in summary.add_datasets_generator(datasets):
192 row = dict(protoRow, dataset_id=dataset.getCheckedId())
193 for dimension, value in dataset.dataId.items():
194 row[dimension.name] = value
195 TimespanReprClass.update(timespan, result=row)
196 rows.append(row)
197 if dataIds is not None: 197 ↛ 191line 197 didn't jump to line 191, because the condition on line 197 was never false
198 dataIds.add(dataset.dataId)
199 # Update the summary tables for this collection in case this is the
200 # first time this dataset type or these governor values will be
201 # inserted there.
202 self._summaries.update(collection, [self._dataset_type_id], summary)
203 # Update the association table itself.
204 if TimespanReprClass.hasExclusionConstraint(): 204 ↛ 207line 204 didn't jump to line 207, because the condition on line 204 was never true
205 # Rely on database constraint to enforce invariants; we just
206 # reraise the exception for consistency across DB engines.
207 try:
208 self._db.insert(self._calibs, *rows)
209 except sqlalchemy.exc.IntegrityError as err:
210 raise ConflictingDefinitionError(
211 f"Validity range conflict certifying datasets of type {self.datasetType.name} "
212 f"into {collection.name} for range [{timespan.begin}, {timespan.end})."
213 ) from err
214 else:
215 # Have to implement exclusion constraint ourselves.
216 # Start by building a SELECT query for any rows that would overlap
217 # this one.
218 relation = self._buildCalibOverlapQuery(collection, dataIds, timespan, context)
219 # Acquire a table lock to ensure there are no concurrent writes
220 # could invalidate our checking before we finish the inserts. We
221 # use a SAVEPOINT in case there is an outer transaction that a
222 # failure here should not roll back.
223 with self._db.transaction(lock=[self._calibs], savepoint=True):
224 # Enter SqlQueryContext in case we need to use a temporary
225 # table to include the give data IDs in the query. Note that
226 # by doing this inside the transaction, we make sure it doesn't
227 # attempt to close the session when its done, since it just
228 # sees an already-open session that it knows it shouldn't
229 # manage.
230 with context:
231 # Run the check SELECT query.
232 conflicting = context.count(context.process(relation))
233 if conflicting > 0:
234 raise ConflictingDefinitionError(
235 f"{conflicting} validity range conflicts certifying datasets of type "
236 f"{self.datasetType.name} into {collection.name} for range "
237 f"[{timespan.begin}, {timespan.end})."
238 )
239 # Proceed with the insert.
240 self._db.insert(self._calibs, *rows)
242 def decertify(
243 self,
244 collection: CollectionRecord,
245 timespan: Timespan,
246 *,
247 dataIds: Iterable[DataCoordinate] | None = None,
248 context: SqlQueryContext,
249 ) -> None:
250 # Docstring inherited from DatasetRecordStorage.
251 if self._calibs is None: 251 ↛ 252line 251 didn't jump to line 252, because the condition on line 251 was never true
252 raise CollectionTypeError(
253 f"Cannot decertify datasets of type {self.datasetType.name}, for which "
254 "DatasetType.isCalibration() is False."
255 )
256 if collection.type is not CollectionType.CALIBRATION: 256 ↛ 257line 256 didn't jump to line 257, because the condition on line 256 was never true
257 raise CollectionTypeError(
258 f"Cannot decertify from collection '{collection.name}' "
259 f"of type {collection.type.name}; must be CALIBRATION."
260 )
261 TimespanReprClass = self._db.getTimespanRepresentation()
262 # Construct a SELECT query to find all rows that overlap our inputs.
263 dataIdSet: set[DataCoordinate] | None
264 if dataIds is not None:
265 dataIdSet = set(dataIds)
266 else:
267 dataIdSet = None
268 relation = self._buildCalibOverlapQuery(collection, dataIdSet, timespan, context)
269 calib_pkey_tag = DatasetColumnTag(self.datasetType.name, "calib_pkey")
270 dataset_id_tag = DatasetColumnTag(self.datasetType.name, "dataset_id")
271 timespan_tag = DatasetColumnTag(self.datasetType.name, "timespan")
272 data_id_tags = [
273 (name, DimensionKeyColumnTag(name)) for name in self.datasetType.dimensions.required.names
274 ]
275 # Set up collections to populate with the rows we'll want to modify.
276 # The insert rows will have the same values for collection and
277 # dataset type.
278 protoInsertRow = {
279 self._collections.getCollectionForeignKeyName(): collection.key,
280 "dataset_type_id": self._dataset_type_id,
281 }
282 rowsToDelete = []
283 rowsToInsert = []
284 # Acquire a table lock to ensure there are no concurrent writes
285 # between the SELECT and the DELETE and INSERT queries based on it.
286 with self._db.transaction(lock=[self._calibs], savepoint=True):
287 # Enter SqlQueryContext in case we need to use a temporary table to
288 # include the give data IDs in the query (see similar block in
289 # certify for details).
290 with context:
291 for row in context.fetch_iterable(relation):
292 rowsToDelete.append({"id": row[calib_pkey_tag]})
293 # Construct the insert row(s) by copying the prototype row,
294 # then adding the dimension column values, then adding
295 # what's left of the timespan from that row after we
296 # subtract the given timespan.
297 newInsertRow = protoInsertRow.copy()
298 newInsertRow["dataset_id"] = row[dataset_id_tag]
299 for name, tag in data_id_tags:
300 newInsertRow[name] = row[tag]
301 rowTimespan = row[timespan_tag]
302 assert rowTimespan is not None, "Field should have a NOT NULL constraint."
303 for diffTimespan in rowTimespan.difference(timespan):
304 rowsToInsert.append(
305 TimespanReprClass.update(diffTimespan, result=newInsertRow.copy())
306 )
307 # Run the DELETE and INSERT queries.
308 self._db.delete(self._calibs, ["id"], *rowsToDelete)
309 self._db.insert(self._calibs, *rowsToInsert)
311 def make_relation(
312 self,
313 *collections: CollectionRecord,
314 columns: Set[str],
315 context: SqlQueryContext,
316 ) -> Relation:
317 # Docstring inherited from DatasetRecordStorage.
318 collection_types = {collection.type for collection in collections}
319 assert CollectionType.CHAINED not in collection_types, "CHAINED collections must be flattened."
320 TimespanReprClass = self._db.getTimespanRepresentation()
321 #
322 # There are two kinds of table in play here:
323 #
324 # - the static dataset table (with the dataset ID, dataset type ID,
325 # run ID/name, and ingest date);
326 #
327 # - the dynamic tags/calibs table (with the dataset ID, dataset type
328 # type ID, collection ID/name, data ID, and possibly validity
329 # range).
330 #
331 # That means that we might want to return a query against either table
332 # or a JOIN of both, depending on which quantities the caller wants.
333 # But the data ID is always included, which means we'll always include
334 # the tags/calibs table and join in the static dataset table only if we
335 # need things from it that we can't get from the tags/calibs table.
336 #
337 # Note that it's important that we include a WHERE constraint on both
338 # tables for any column (e.g. dataset_type_id) that is in both when
339 # it's given explicitly; not doing can prevent the query planner from
340 # using very important indexes. At present, we don't include those
341 # redundant columns in the JOIN ON expression, however, because the
342 # FOREIGN KEY (and its index) are defined only on dataset_id.
343 tag_relation: Relation | None = None
344 calib_relation: Relation | None = None
345 if collection_types != {CollectionType.CALIBRATION}:
346 # We'll need a subquery for the tags table if any of the given
347 # collections are not a CALIBRATION collection. This intentionally
348 # also fires when the list of collections is empty as a way to
349 # create a dummy subquery that we know will fail.
350 # We give the table an alias because it might appear multiple times
351 # in the same query, for different dataset types.
352 tags_parts = sql.Payload[LogicalColumn](self._tags.alias(f"{self.datasetType.name}_tags"))
353 if "timespan" in columns:
354 tags_parts.columns_available[
355 DatasetColumnTag(self.datasetType.name, "timespan")
356 ] = TimespanReprClass.fromLiteral(Timespan(None, None))
357 tag_relation = self._finish_single_relation(
358 tags_parts,
359 columns,
360 [
361 (record, rank)
362 for rank, record in enumerate(collections)
363 if record.type is not CollectionType.CALIBRATION
364 ],
365 context,
366 )
367 assert "calib_pkey" not in columns, "For internal use only, and only for pure-calib queries."
368 if CollectionType.CALIBRATION in collection_types:
369 # If at least one collection is a CALIBRATION collection, we'll
370 # need a subquery for the calibs table, and could include the
371 # timespan as a result or constraint.
372 assert (
373 self._calibs is not None
374 ), "DatasetTypes with isCalibration() == False can never be found in a CALIBRATION collection."
375 calibs_parts = sql.Payload[LogicalColumn](self._calibs.alias(f"{self.datasetType.name}_calibs"))
376 if "timespan" in columns:
377 calibs_parts.columns_available[
378 DatasetColumnTag(self.datasetType.name, "timespan")
379 ] = TimespanReprClass.from_columns(calibs_parts.from_clause.columns)
380 if "calib_pkey" in columns:
381 # This is a private extension not included in the base class
382 # interface, for internal use only in _buildCalibOverlapQuery,
383 # which needs access to the autoincrement primary key for the
384 # calib association table.
385 calibs_parts.columns_available[
386 DatasetColumnTag(self.datasetType.name, "calib_pkey")
387 ] = calibs_parts.from_clause.columns.id
388 calib_relation = self._finish_single_relation(
389 calibs_parts,
390 columns,
391 [
392 (record, rank)
393 for rank, record in enumerate(collections)
394 if record.type is CollectionType.CALIBRATION
395 ],
396 context,
397 )
398 if tag_relation is not None:
399 if calib_relation is not None:
400 # daf_relation's chain operation does not automatically
401 # deduplicate; it's more like SQL's UNION ALL. To get UNION
402 # in SQL here, we add an explicit deduplication.
403 return tag_relation.chain(calib_relation).without_duplicates()
404 else:
405 return tag_relation
406 elif calib_relation is not None: 406 ↛ 409line 406 didn't jump to line 409, because the condition on line 406 was never false
407 return calib_relation
408 else:
409 raise AssertionError("Branch should be unreachable.")
411 def _finish_single_relation(
412 self,
413 payload: sql.Payload[LogicalColumn],
414 requested_columns: Set[str],
415 collections: Sequence[tuple[CollectionRecord, int]],
416 context: SqlQueryContext,
417 ) -> Relation:
418 """Helper method for `make_relation`.
420 This handles adding columns and WHERE terms that are not specific to
421 either the tags or calibs tables.
423 Parameters
424 ----------
425 payload : `lsst.daf.relation.sql.Payload`
426 SQL query parts under construction, to be modified in-place and
427 used to construct the new relation.
428 requested_columns : `~collections.abc.Set` [ `str` ]
429 Columns the relation should include.
430 collections : `Sequence` [ `tuple` [ `CollectionRecord`, `int` ] ]
431 Collections to search for the dataset and their ranks.
432 context : `SqlQueryContext`
433 Context that manages engines and state for the query.
435 Returns
436 -------
437 relation : `lsst.daf.relation.Relation`
438 New dataset query relation.
439 """
440 payload.where.append(payload.from_clause.columns.dataset_type_id == self._dataset_type_id)
441 dataset_id_col = payload.from_clause.columns.dataset_id
442 collection_col = payload.from_clause.columns[self._collections.getCollectionForeignKeyName()]
443 # We always constrain and optionally retrieve the collection(s) via the
444 # tags/calibs table.
445 if len(collections) == 1:
446 payload.where.append(collection_col == collections[0][0].key)
447 if "collection" in requested_columns:
448 payload.columns_available[
449 DatasetColumnTag(self.datasetType.name, "collection")
450 ] = sqlalchemy.sql.literal(collections[0][0].key)
451 else:
452 assert collections, "The no-collections case should be in calling code for better diagnostics."
453 payload.where.append(collection_col.in_([collection.key for collection, _ in collections]))
454 if "collection" in requested_columns:
455 payload.columns_available[
456 DatasetColumnTag(self.datasetType.name, "collection")
457 ] = collection_col
458 # Add rank if requested as a CASE-based calculation the collection
459 # column.
460 if "rank" in requested_columns:
461 payload.columns_available[DatasetColumnTag(self.datasetType.name, "rank")] = sqlalchemy.sql.case(
462 {record.key: rank for record, rank in collections},
463 value=collection_col,
464 )
465 # Add more column definitions, starting with the data ID.
466 for dimension_name in self.datasetType.dimensions.required.names:
467 payload.columns_available[DimensionKeyColumnTag(dimension_name)] = payload.from_clause.columns[
468 dimension_name
469 ]
470 # We can always get the dataset_id from the tags/calibs table.
471 if "dataset_id" in requested_columns:
472 payload.columns_available[DatasetColumnTag(self.datasetType.name, "dataset_id")] = dataset_id_col
473 # It's possible we now have everything we need, from just the
474 # tags/calibs table. The things we might need to get from the static
475 # dataset table are the run key and the ingest date.
476 need_static_table = False
477 if "run" in requested_columns:
478 if len(collections) == 1 and collections[0][0].type is CollectionType.RUN:
479 # If we are searching exactly one RUN collection, we
480 # know that if we find the dataset in that collection,
481 # then that's the datasets's run; we don't need to
482 # query for it.
483 payload.columns_available[
484 DatasetColumnTag(self.datasetType.name, "run")
485 ] = sqlalchemy.sql.literal(collections[0][0].key)
486 else:
487 payload.columns_available[
488 DatasetColumnTag(self.datasetType.name, "run")
489 ] = self._static.dataset.columns[self._runKeyColumn]
490 need_static_table = True
491 # Ingest date can only come from the static table.
492 if "ingest_date" in requested_columns:
493 need_static_table = True
494 payload.columns_available[
495 DatasetColumnTag(self.datasetType.name, "ingest_date")
496 ] = self._static.dataset.columns.ingest_date
497 # If we need the static table, join it in via dataset_id and
498 # dataset_type_id
499 if need_static_table:
500 payload.from_clause = payload.from_clause.join(
501 self._static.dataset, onclause=(dataset_id_col == self._static.dataset.columns.id)
502 )
503 # Also constrain dataset_type_id in static table in case that helps
504 # generate a better plan.
505 # We could also include this in the JOIN ON clause, but my guess is
506 # that that's a good idea IFF it's in the foreign key, and right
507 # now it isn't.
508 payload.where.append(self._static.dataset.columns.dataset_type_id == self._dataset_type_id)
509 leaf = context.sql_engine.make_leaf(
510 payload.columns_available.keys(),
511 payload=payload,
512 name=self.datasetType.name,
513 parameters={record.name: rank for record, rank in collections},
514 )
515 return leaf
517 def getDataId(self, id: DatasetId) -> DataCoordinate:
518 """Return DataId for a dataset.
520 Parameters
521 ----------
522 id : `DatasetId`
523 Unique dataset identifier.
525 Returns
526 -------
527 dataId : `DataCoordinate`
528 DataId for the dataset.
529 """
530 # This query could return multiple rows (one for each tagged collection
531 # the dataset is in, plus one for its run collection), and we don't
532 # care which of those we get.
533 sql = (
534 self._tags.select()
535 .where(
536 sqlalchemy.sql.and_(
537 self._tags.columns.dataset_id == id,
538 self._tags.columns.dataset_type_id == self._dataset_type_id,
539 )
540 )
541 .limit(1)
542 )
543 with self._db.query(sql) as sql_result:
544 row = sql_result.mappings().fetchone()
545 assert row is not None, "Should be guaranteed by caller and foreign key constraints."
546 return DataCoordinate.standardize(
547 {dimension.name: row[dimension.name] for dimension in self.datasetType.dimensions.required},
548 graph=self.datasetType.dimensions,
549 )
552@deprecated(
553 "Integer dataset IDs are deprecated in favor of UUIDs; support will be removed after v26. "
554 "Please migrate or re-create this data repository.",
555 version="v25.0",
556 category=FutureWarning,
557)
558class ByDimensionsDatasetRecordStorageInt(ByDimensionsDatasetRecordStorage):
559 """Implementation of ByDimensionsDatasetRecordStorage which uses integer
560 auto-incremented column for dataset IDs.
561 """
563 def insert(
564 self,
565 run: RunRecord,
566 dataIds: Iterable[DataCoordinate],
567 idMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE,
568 ) -> Iterator[DatasetRef]:
569 # Docstring inherited from DatasetRecordStorage.
571 # We only support UNIQUE mode for integer dataset IDs
572 if idMode != DatasetIdGenEnum.UNIQUE:
573 raise UnsupportedIdGeneratorError("Only UNIQUE mode can be used with integer dataset IDs.")
575 # Transform a possibly-single-pass iterable into a list.
576 dataIdList = list(dataIds)
577 yield from self._insert(run, dataIdList)
579 def import_(
580 self,
581 run: RunRecord,
582 datasets: Iterable[DatasetRef],
583 idGenerationMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE,
584 reuseIds: bool = False,
585 ) -> Iterator[DatasetRef]:
586 # Docstring inherited from DatasetRecordStorage.
588 # We only support UNIQUE mode for integer dataset IDs
589 if idGenerationMode != DatasetIdGenEnum.UNIQUE:
590 raise UnsupportedIdGeneratorError("Only UNIQUE mode can be used with integer dataset IDs.")
592 # Make a list of dataIds and optionally dataset IDs.
593 dataIdList: list[DataCoordinate] = []
594 datasetIdList: list[int] = []
595 for dataset in datasets:
596 dataIdList.append(dataset.dataId)
598 # We only accept integer dataset IDs, but also allow None.
599 datasetId = dataset.id
600 if datasetId is None:
601 # if reuseIds is set then all IDs must be known
602 if reuseIds:
603 raise TypeError("All dataset IDs must be known if `reuseIds` is set")
604 elif isinstance(datasetId, int):
605 if reuseIds:
606 datasetIdList.append(datasetId)
607 else:
608 raise TypeError(f"Unsupported type of dataset ID: {type(datasetId)}")
610 yield from self._insert(run, dataIdList, datasetIdList)
612 def _insert(
613 self, run: RunRecord, dataIdList: list[DataCoordinate], datasetIdList: list[int] | None = None
614 ) -> Iterator[DatasetRef]:
615 """Common part of implementation of `insert` and `import_` methods."""
617 # Remember any governor dimension values we see.
618 summary = CollectionSummary()
619 summary.add_data_ids(self.datasetType, dataIdList)
621 staticRow = {
622 "dataset_type_id": self._dataset_type_id,
623 self._runKeyColumn: run.key,
624 }
625 with self._db.transaction():
626 # Insert into the static dataset table, generating autoincrement
627 # dataset_id values.
628 if datasetIdList:
629 # reuse existing IDs
630 rows = [dict(staticRow, id=datasetId) for datasetId in datasetIdList]
631 self._db.insert(self._static.dataset, *rows)
632 else:
633 # use auto-incremented IDs
634 datasetIdList = self._db.insert(
635 self._static.dataset, *([staticRow] * len(dataIdList)), returnIds=True
636 )
637 assert datasetIdList is not None
638 # Update the summary tables for this collection in case this is the
639 # first time this dataset type or these governor values will be
640 # inserted there.
641 self._summaries.update(run, [self._dataset_type_id], summary)
642 # Combine the generated dataset_id values and data ID fields to
643 # form rows to be inserted into the tags table.
644 protoTagsRow = {
645 "dataset_type_id": self._dataset_type_id,
646 self._collections.getCollectionForeignKeyName(): run.key,
647 }
648 tagsRows = [
649 dict(protoTagsRow, dataset_id=dataset_id, **dataId.byName())
650 for dataId, dataset_id in zip(dataIdList, datasetIdList)
651 ]
652 # Insert those rows into the tags table. This is where we'll
653 # get any unique constraint violations.
654 self._db.insert(self._tags, *tagsRows)
656 for dataId, datasetId in zip(dataIdList, datasetIdList):
657 yield DatasetRef(
658 datasetType=self.datasetType,
659 dataId=dataId,
660 id=datasetId,
661 run=run.name,
662 )
665class ByDimensionsDatasetRecordStorageUUID(ByDimensionsDatasetRecordStorage):
666 """Implementation of ByDimensionsDatasetRecordStorage which uses UUID for
667 dataset IDs.
668 """
670 idMaker = DatasetIdFactory()
671 """Factory for dataset IDs. In the future this factory may be shared with
672 other classes (e.g. Registry)."""
674 def insert(
675 self,
676 run: RunRecord,
677 dataIds: Iterable[DataCoordinate],
678 idMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE,
679 ) -> Iterator[DatasetRef]:
680 # Docstring inherited from DatasetRecordStorage.
682 # Iterate over data IDs, transforming a possibly-single-pass iterable
683 # into a list.
684 dataIdList = []
685 rows = []
686 summary = CollectionSummary()
687 for dataId in summary.add_data_ids_generator(self.datasetType, dataIds):
688 dataIdList.append(dataId)
689 rows.append(
690 {
691 "id": self.idMaker.makeDatasetId(run.name, self.datasetType, dataId, idMode),
692 "dataset_type_id": self._dataset_type_id,
693 self._runKeyColumn: run.key,
694 }
695 )
697 with self._db.transaction():
698 # Insert into the static dataset table.
699 self._db.insert(self._static.dataset, *rows)
700 # Update the summary tables for this collection in case this is the
701 # first time this dataset type or these governor values will be
702 # inserted there.
703 self._summaries.update(run, [self._dataset_type_id], summary)
704 # Combine the generated dataset_id values and data ID fields to
705 # form rows to be inserted into the tags table.
706 protoTagsRow = {
707 "dataset_type_id": self._dataset_type_id,
708 self._collections.getCollectionForeignKeyName(): run.key,
709 }
710 tagsRows = [
711 dict(protoTagsRow, dataset_id=row["id"], **dataId.byName())
712 for dataId, row in zip(dataIdList, rows)
713 ]
714 # Insert those rows into the tags table.
715 self._db.insert(self._tags, *tagsRows)
717 for dataId, row in zip(dataIdList, rows):
718 yield DatasetRef(
719 datasetType=self.datasetType,
720 dataId=dataId,
721 id=row["id"],
722 run=run.name,
723 )
725 def import_(
726 self,
727 run: RunRecord,
728 datasets: Iterable[DatasetRef],
729 idGenerationMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE,
730 reuseIds: bool = False,
731 ) -> Iterator[DatasetRef]:
732 # Docstring inherited from DatasetRecordStorage.
734 # Iterate over data IDs, transforming a possibly-single-pass iterable
735 # into a list.
736 dataIds = {}
737 summary = CollectionSummary()
738 for dataset in summary.add_datasets_generator(datasets):
739 # Ignore unknown ID types, normally all IDs have the same type but
740 # this code supports mixed types or missing IDs.
741 datasetId = dataset.id if isinstance(dataset.id, uuid.UUID) else None
742 if datasetId is None:
743 datasetId = self.idMaker.makeDatasetId(
744 run.name, self.datasetType, dataset.dataId, idGenerationMode
745 )
746 dataIds[datasetId] = dataset.dataId
748 # We'll insert all new rows into a temporary table
749 tableSpec = makeTagTableSpec(self.datasetType, type(self._collections), ddl.GUID, constraints=False)
750 collFkName = self._collections.getCollectionForeignKeyName()
751 protoTagsRow = {
752 "dataset_type_id": self._dataset_type_id,
753 collFkName: run.key,
754 }
755 tmpRows = [
756 dict(protoTagsRow, dataset_id=dataset_id, **dataId.byName())
757 for dataset_id, dataId in dataIds.items()
758 ]
759 with self._db.transaction(for_temp_tables=True):
760 with self._db.temporary_table(tableSpec) as tmp_tags:
761 # store all incoming data in a temporary table
762 self._db.insert(tmp_tags, *tmpRows)
764 # There are some checks that we want to make for consistency
765 # of the new datasets with existing ones.
766 self._validateImport(tmp_tags, run)
768 # Before we merge temporary table into dataset/tags we need to
769 # drop datasets which are already there (and do not conflict).
770 self._db.deleteWhere(
771 tmp_tags,
772 tmp_tags.columns.dataset_id.in_(sqlalchemy.sql.select(self._static.dataset.columns.id)),
773 )
775 # Copy it into dataset table, need to re-label some columns.
776 self._db.insert(
777 self._static.dataset,
778 select=sqlalchemy.sql.select(
779 tmp_tags.columns.dataset_id.label("id"),
780 tmp_tags.columns.dataset_type_id,
781 tmp_tags.columns[collFkName].label(self._runKeyColumn),
782 ),
783 )
785 # Update the summary tables for this collection in case this
786 # is the first time this dataset type or these governor values
787 # will be inserted there.
788 self._summaries.update(run, [self._dataset_type_id], summary)
790 # Copy it into tags table.
791 self._db.insert(self._tags, select=tmp_tags.select())
793 # Return refs in the same order as in the input list.
794 for dataset_id, dataId in dataIds.items():
795 yield DatasetRef(
796 datasetType=self.datasetType,
797 id=dataset_id,
798 dataId=dataId,
799 run=run.name,
800 )
802 def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) -> None:
803 """Validate imported refs against existing datasets.
805 Parameters
806 ----------
807 tmp_tags : `sqlalchemy.schema.Table`
808 Temporary table with new datasets and the same schema as tags
809 table.
810 run : `RunRecord`
811 The record object describing the `~CollectionType.RUN` collection.
813 Raises
814 ------
815 ConflictingDefinitionError
816 Raise if new datasets conflict with existing ones.
817 """
818 dataset = self._static.dataset
819 tags = self._tags
820 collFkName = self._collections.getCollectionForeignKeyName()
822 # Check that existing datasets have the same dataset type and
823 # run.
824 query = (
825 sqlalchemy.sql.select(
826 dataset.columns.id.label("dataset_id"),
827 dataset.columns.dataset_type_id.label("dataset_type_id"),
828 tmp_tags.columns.dataset_type_id.label("new dataset_type_id"),
829 dataset.columns[self._runKeyColumn].label("run"),
830 tmp_tags.columns[collFkName].label("new run"),
831 )
832 .select_from(dataset.join(tmp_tags, dataset.columns.id == tmp_tags.columns.dataset_id))
833 .where(
834 sqlalchemy.sql.or_(
835 dataset.columns.dataset_type_id != tmp_tags.columns.dataset_type_id,
836 dataset.columns[self._runKeyColumn] != tmp_tags.columns[collFkName],
837 )
838 )
839 .limit(1)
840 )
841 with self._db.query(query) as result:
842 if (row := result.first()) is not None:
843 # Only include the first one in the exception message
844 raise ConflictingDefinitionError(
845 f"Existing dataset type or run do not match new dataset: {row._asdict()}"
846 )
848 # Check that matching dataset in tags table has the same DataId.
849 query = (
850 sqlalchemy.sql.select(
851 tags.columns.dataset_id,
852 tags.columns.dataset_type_id.label("type_id"),
853 tmp_tags.columns.dataset_type_id.label("new type_id"),
854 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names],
855 *[
856 tmp_tags.columns[dim].label(f"new {dim}")
857 for dim in self.datasetType.dimensions.required.names
858 ],
859 )
860 .select_from(tags.join(tmp_tags, tags.columns.dataset_id == tmp_tags.columns.dataset_id))
861 .where(
862 sqlalchemy.sql.or_(
863 tags.columns.dataset_type_id != tmp_tags.columns.dataset_type_id,
864 *[
865 tags.columns[dim] != tmp_tags.columns[dim]
866 for dim in self.datasetType.dimensions.required.names
867 ],
868 )
869 )
870 .limit(1)
871 )
873 with self._db.query(query) as result:
874 if (row := result.first()) is not None:
875 # Only include the first one in the exception message
876 raise ConflictingDefinitionError(
877 f"Existing dataset type or dataId do not match new dataset: {row._asdict()}"
878 )
880 # Check that matching run+dataId have the same dataset ID.
881 query = (
882 sqlalchemy.sql.select(
883 tags.columns.dataset_type_id.label("dataset_type_id"),
884 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names],
885 tags.columns.dataset_id,
886 tmp_tags.columns.dataset_id.label("new dataset_id"),
887 tags.columns[collFkName],
888 tmp_tags.columns[collFkName].label(f"new {collFkName}"),
889 )
890 .select_from(
891 tags.join(
892 tmp_tags,
893 sqlalchemy.sql.and_(
894 tags.columns.dataset_type_id == tmp_tags.columns.dataset_type_id,
895 tags.columns[collFkName] == tmp_tags.columns[collFkName],
896 *[
897 tags.columns[dim] == tmp_tags.columns[dim]
898 for dim in self.datasetType.dimensions.required.names
899 ],
900 ),
901 )
902 )
903 .where(tags.columns.dataset_id != tmp_tags.columns.dataset_id)
904 .limit(1)
905 )
906 with self._db.query(query) as result:
907 if (row := result.first()) is not None:
908 # only include the first one in the exception message
909 raise ConflictingDefinitionError(
910 f"Existing dataset type and dataId does not match new dataset: {row._asdict()}"
911 )