Coverage for python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py: 95%
237 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-23 09:29 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-23 09:29 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
23from __future__ import annotations
25__all__ = ("ByDimensionsDatasetRecordStorage",)
27from collections.abc import Iterable, Iterator, Sequence, Set
28from datetime import datetime
29from typing import TYPE_CHECKING
31import astropy.time
32import sqlalchemy
33from lsst.daf.relation import Relation, sql
35from ....core import (
36 DataCoordinate,
37 DatasetColumnTag,
38 DatasetId,
39 DatasetIdFactory,
40 DatasetIdGenEnum,
41 DatasetRef,
42 DatasetType,
43 DimensionKeyColumnTag,
44 LogicalColumn,
45 Timespan,
46 ddl,
47)
48from ..._collection_summary import CollectionSummary
49from ..._collectionType import CollectionType
50from ..._exceptions import CollectionTypeError, ConflictingDefinitionError
51from ...interfaces import DatasetRecordStorage
52from ...queries import SqlQueryContext
53from .tables import makeTagTableSpec
55if TYPE_CHECKING:
56 from ...interfaces import CollectionManager, CollectionRecord, Database, RunRecord
57 from .summaries import CollectionSummaryManager
58 from .tables import StaticDatasetTablesTuple
61class ByDimensionsDatasetRecordStorage(DatasetRecordStorage):
62 """Dataset record storage implementation paired with
63 `ByDimensionsDatasetRecordStorageManagerUUID`; see that class for more
64 information.
66 Instances of this class should never be constructed directly; use
67 `DatasetRecordStorageManager.register` instead.
68 """
70 def __init__(
71 self,
72 *,
73 datasetType: DatasetType,
74 db: Database,
75 dataset_type_id: int,
76 collections: CollectionManager,
77 static: StaticDatasetTablesTuple,
78 summaries: CollectionSummaryManager,
79 tags: sqlalchemy.schema.Table,
80 use_astropy_ingest_date: bool,
81 calibs: sqlalchemy.schema.Table | None,
82 ):
83 super().__init__(datasetType=datasetType)
84 self._dataset_type_id = dataset_type_id
85 self._db = db
86 self._collections = collections
87 self._static = static
88 self._summaries = summaries
89 self._tags = tags
90 self._calibs = calibs
91 self._runKeyColumn = collections.getRunForeignKeyName()
92 self._use_astropy = use_astropy_ingest_date
94 def delete(self, datasets: Iterable[DatasetRef]) -> None:
95 # Docstring inherited from DatasetRecordStorage.
96 # Only delete from common dataset table; ON DELETE foreign key clauses
97 # will handle the rest.
98 self._db.delete(
99 self._static.dataset,
100 ["id"],
101 *[{"id": dataset.id} for dataset in datasets],
102 )
104 def associate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None:
105 # Docstring inherited from DatasetRecordStorage.
106 if collection.type is not CollectionType.TAGGED: 106 ↛ 107line 106 didn't jump to line 107, because the condition on line 106 was never true
107 raise TypeError(
108 f"Cannot associate into collection '{collection.name}' "
109 f"of type {collection.type.name}; must be TAGGED."
110 )
111 protoRow = {
112 self._collections.getCollectionForeignKeyName(): collection.key,
113 "dataset_type_id": self._dataset_type_id,
114 }
115 rows = []
116 summary = CollectionSummary()
117 for dataset in summary.add_datasets_generator(datasets):
118 row = dict(protoRow, dataset_id=dataset.id)
119 for dimension, value in dataset.dataId.items():
120 row[dimension.name] = value
121 rows.append(row)
122 # Update the summary tables for this collection in case this is the
123 # first time this dataset type or these governor values will be
124 # inserted there.
125 self._summaries.update(collection, [self._dataset_type_id], summary)
126 # Update the tag table itself.
127 self._db.replace(self._tags, *rows)
129 def disassociate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None:
130 # Docstring inherited from DatasetRecordStorage.
131 if collection.type is not CollectionType.TAGGED: 131 ↛ 132line 131 didn't jump to line 132, because the condition on line 131 was never true
132 raise TypeError(
133 f"Cannot disassociate from collection '{collection.name}' "
134 f"of type {collection.type.name}; must be TAGGED."
135 )
136 rows = [
137 {
138 "dataset_id": dataset.id,
139 self._collections.getCollectionForeignKeyName(): collection.key,
140 }
141 for dataset in datasets
142 ]
143 self._db.delete(self._tags, ["dataset_id", self._collections.getCollectionForeignKeyName()], *rows)
145 def _buildCalibOverlapQuery(
146 self,
147 collection: CollectionRecord,
148 data_ids: set[DataCoordinate] | None,
149 timespan: Timespan,
150 context: SqlQueryContext,
151 ) -> Relation:
152 relation = self.make_relation(
153 collection, columns={"timespan", "dataset_id", "calib_pkey"}, context=context
154 ).with_rows_satisfying(
155 context.make_timespan_overlap_predicate(
156 DatasetColumnTag(self.datasetType.name, "timespan"), timespan
157 ),
158 )
159 if data_ids is not None:
160 relation = relation.join(
161 context.make_data_id_relation(
162 data_ids, self.datasetType.dimensions.required.names
163 ).transferred_to(context.sql_engine),
164 )
165 return relation
167 def certify(
168 self,
169 collection: CollectionRecord,
170 datasets: Iterable[DatasetRef],
171 timespan: Timespan,
172 context: SqlQueryContext,
173 ) -> None:
174 # Docstring inherited from DatasetRecordStorage.
175 if self._calibs is None: 175 ↛ 176line 175 didn't jump to line 176, because the condition on line 175 was never true
176 raise CollectionTypeError(
177 f"Cannot certify datasets of type {self.datasetType.name}, for which "
178 "DatasetType.isCalibration() is False."
179 )
180 if collection.type is not CollectionType.CALIBRATION: 180 ↛ 181line 180 didn't jump to line 181, because the condition on line 180 was never true
181 raise CollectionTypeError(
182 f"Cannot certify into collection '{collection.name}' "
183 f"of type {collection.type.name}; must be CALIBRATION."
184 )
185 TimespanReprClass = self._db.getTimespanRepresentation()
186 protoRow = {
187 self._collections.getCollectionForeignKeyName(): collection.key,
188 "dataset_type_id": self._dataset_type_id,
189 }
190 rows = []
191 dataIds: set[DataCoordinate] | None = (
192 set() if not TimespanReprClass.hasExclusionConstraint() else None
193 )
194 summary = CollectionSummary()
195 for dataset in summary.add_datasets_generator(datasets):
196 row = dict(protoRow, dataset_id=dataset.id)
197 for dimension, value in dataset.dataId.items():
198 row[dimension.name] = value
199 TimespanReprClass.update(timespan, result=row)
200 rows.append(row)
201 if dataIds is not None: 201 ↛ 195line 201 didn't jump to line 195, because the condition on line 201 was never false
202 dataIds.add(dataset.dataId)
203 # Update the summary tables for this collection in case this is the
204 # first time this dataset type or these governor values will be
205 # inserted there.
206 self._summaries.update(collection, [self._dataset_type_id], summary)
207 # Update the association table itself.
208 if TimespanReprClass.hasExclusionConstraint(): 208 ↛ 211line 208 didn't jump to line 211, because the condition on line 208 was never true
209 # Rely on database constraint to enforce invariants; we just
210 # reraise the exception for consistency across DB engines.
211 try:
212 self._db.insert(self._calibs, *rows)
213 except sqlalchemy.exc.IntegrityError as err:
214 raise ConflictingDefinitionError(
215 f"Validity range conflict certifying datasets of type {self.datasetType.name} "
216 f"into {collection.name} for range [{timespan.begin}, {timespan.end})."
217 ) from err
218 else:
219 # Have to implement exclusion constraint ourselves.
220 # Start by building a SELECT query for any rows that would overlap
221 # this one.
222 relation = self._buildCalibOverlapQuery(collection, dataIds, timespan, context)
223 # Acquire a table lock to ensure there are no concurrent writes
224 # could invalidate our checking before we finish the inserts. We
225 # use a SAVEPOINT in case there is an outer transaction that a
226 # failure here should not roll back.
227 with self._db.transaction(lock=[self._calibs], savepoint=True):
228 # Enter SqlQueryContext in case we need to use a temporary
229 # table to include the give data IDs in the query. Note that
230 # by doing this inside the transaction, we make sure it doesn't
231 # attempt to close the session when its done, since it just
232 # sees an already-open session that it knows it shouldn't
233 # manage.
234 with context:
235 # Run the check SELECT query.
236 conflicting = context.count(context.process(relation))
237 if conflicting > 0:
238 raise ConflictingDefinitionError(
239 f"{conflicting} validity range conflicts certifying datasets of type "
240 f"{self.datasetType.name} into {collection.name} for range "
241 f"[{timespan.begin}, {timespan.end})."
242 )
243 # Proceed with the insert.
244 self._db.insert(self._calibs, *rows)
246 def decertify(
247 self,
248 collection: CollectionRecord,
249 timespan: Timespan,
250 *,
251 dataIds: Iterable[DataCoordinate] | None = None,
252 context: SqlQueryContext,
253 ) -> None:
254 # Docstring inherited from DatasetRecordStorage.
255 if self._calibs is None: 255 ↛ 256line 255 didn't jump to line 256, because the condition on line 255 was never true
256 raise CollectionTypeError(
257 f"Cannot decertify datasets of type {self.datasetType.name}, for which "
258 "DatasetType.isCalibration() is False."
259 )
260 if collection.type is not CollectionType.CALIBRATION: 260 ↛ 261line 260 didn't jump to line 261, because the condition on line 260 was never true
261 raise CollectionTypeError(
262 f"Cannot decertify from collection '{collection.name}' "
263 f"of type {collection.type.name}; must be CALIBRATION."
264 )
265 TimespanReprClass = self._db.getTimespanRepresentation()
266 # Construct a SELECT query to find all rows that overlap our inputs.
267 dataIdSet: set[DataCoordinate] | None
268 if dataIds is not None:
269 dataIdSet = set(dataIds)
270 else:
271 dataIdSet = None
272 relation = self._buildCalibOverlapQuery(collection, dataIdSet, timespan, context)
273 calib_pkey_tag = DatasetColumnTag(self.datasetType.name, "calib_pkey")
274 dataset_id_tag = DatasetColumnTag(self.datasetType.name, "dataset_id")
275 timespan_tag = DatasetColumnTag(self.datasetType.name, "timespan")
276 data_id_tags = [
277 (name, DimensionKeyColumnTag(name)) for name in self.datasetType.dimensions.required.names
278 ]
279 # Set up collections to populate with the rows we'll want to modify.
280 # The insert rows will have the same values for collection and
281 # dataset type.
282 protoInsertRow = {
283 self._collections.getCollectionForeignKeyName(): collection.key,
284 "dataset_type_id": self._dataset_type_id,
285 }
286 rowsToDelete = []
287 rowsToInsert = []
288 # Acquire a table lock to ensure there are no concurrent writes
289 # between the SELECT and the DELETE and INSERT queries based on it.
290 with self._db.transaction(lock=[self._calibs], savepoint=True):
291 # Enter SqlQueryContext in case we need to use a temporary table to
292 # include the give data IDs in the query (see similar block in
293 # certify for details).
294 with context:
295 for row in context.fetch_iterable(relation):
296 rowsToDelete.append({"id": row[calib_pkey_tag]})
297 # Construct the insert row(s) by copying the prototype row,
298 # then adding the dimension column values, then adding
299 # what's left of the timespan from that row after we
300 # subtract the given timespan.
301 newInsertRow = protoInsertRow.copy()
302 newInsertRow["dataset_id"] = row[dataset_id_tag]
303 for name, tag in data_id_tags:
304 newInsertRow[name] = row[tag]
305 rowTimespan = row[timespan_tag]
306 assert rowTimespan is not None, "Field should have a NOT NULL constraint."
307 for diffTimespan in rowTimespan.difference(timespan):
308 rowsToInsert.append(
309 TimespanReprClass.update(diffTimespan, result=newInsertRow.copy())
310 )
311 # Run the DELETE and INSERT queries.
312 self._db.delete(self._calibs, ["id"], *rowsToDelete)
313 self._db.insert(self._calibs, *rowsToInsert)
315 def make_relation(
316 self,
317 *collections: CollectionRecord,
318 columns: Set[str],
319 context: SqlQueryContext,
320 ) -> Relation:
321 # Docstring inherited from DatasetRecordStorage.
322 collection_types = {collection.type for collection in collections}
323 assert CollectionType.CHAINED not in collection_types, "CHAINED collections must be flattened."
324 TimespanReprClass = self._db.getTimespanRepresentation()
325 #
326 # There are two kinds of table in play here:
327 #
328 # - the static dataset table (with the dataset ID, dataset type ID,
329 # run ID/name, and ingest date);
330 #
331 # - the dynamic tags/calibs table (with the dataset ID, dataset type
332 # type ID, collection ID/name, data ID, and possibly validity
333 # range).
334 #
335 # That means that we might want to return a query against either table
336 # or a JOIN of both, depending on which quantities the caller wants.
337 # But the data ID is always included, which means we'll always include
338 # the tags/calibs table and join in the static dataset table only if we
339 # need things from it that we can't get from the tags/calibs table.
340 #
341 # Note that it's important that we include a WHERE constraint on both
342 # tables for any column (e.g. dataset_type_id) that is in both when
343 # it's given explicitly; not doing can prevent the query planner from
344 # using very important indexes. At present, we don't include those
345 # redundant columns in the JOIN ON expression, however, because the
346 # FOREIGN KEY (and its index) are defined only on dataset_id.
347 tag_relation: Relation | None = None
348 calib_relation: Relation | None = None
349 if collection_types != {CollectionType.CALIBRATION}:
350 # We'll need a subquery for the tags table if any of the given
351 # collections are not a CALIBRATION collection. This intentionally
352 # also fires when the list of collections is empty as a way to
353 # create a dummy subquery that we know will fail.
354 # We give the table an alias because it might appear multiple times
355 # in the same query, for different dataset types.
356 tags_parts = sql.Payload[LogicalColumn](self._tags.alias(f"{self.datasetType.name}_tags"))
357 if "timespan" in columns:
358 tags_parts.columns_available[
359 DatasetColumnTag(self.datasetType.name, "timespan")
360 ] = TimespanReprClass.fromLiteral(Timespan(None, None))
361 tag_relation = self._finish_single_relation(
362 tags_parts,
363 columns,
364 [
365 (record, rank)
366 for rank, record in enumerate(collections)
367 if record.type is not CollectionType.CALIBRATION
368 ],
369 context,
370 )
371 assert "calib_pkey" not in columns, "For internal use only, and only for pure-calib queries."
372 if CollectionType.CALIBRATION in collection_types:
373 # If at least one collection is a CALIBRATION collection, we'll
374 # need a subquery for the calibs table, and could include the
375 # timespan as a result or constraint.
376 assert (
377 self._calibs is not None
378 ), "DatasetTypes with isCalibration() == False can never be found in a CALIBRATION collection."
379 calibs_parts = sql.Payload[LogicalColumn](self._calibs.alias(f"{self.datasetType.name}_calibs"))
380 if "timespan" in columns:
381 calibs_parts.columns_available[
382 DatasetColumnTag(self.datasetType.name, "timespan")
383 ] = TimespanReprClass.from_columns(calibs_parts.from_clause.columns)
384 if "calib_pkey" in columns:
385 # This is a private extension not included in the base class
386 # interface, for internal use only in _buildCalibOverlapQuery,
387 # which needs access to the autoincrement primary key for the
388 # calib association table.
389 calibs_parts.columns_available[
390 DatasetColumnTag(self.datasetType.name, "calib_pkey")
391 ] = calibs_parts.from_clause.columns.id
392 calib_relation = self._finish_single_relation(
393 calibs_parts,
394 columns,
395 [
396 (record, rank)
397 for rank, record in enumerate(collections)
398 if record.type is CollectionType.CALIBRATION
399 ],
400 context,
401 )
402 if tag_relation is not None:
403 if calib_relation is not None:
404 # daf_relation's chain operation does not automatically
405 # deduplicate; it's more like SQL's UNION ALL. To get UNION
406 # in SQL here, we add an explicit deduplication.
407 return tag_relation.chain(calib_relation).without_duplicates()
408 else:
409 return tag_relation
410 elif calib_relation is not None:
411 return calib_relation
412 else:
413 raise AssertionError("Branch should be unreachable.")
415 def _finish_single_relation(
416 self,
417 payload: sql.Payload[LogicalColumn],
418 requested_columns: Set[str],
419 collections: Sequence[tuple[CollectionRecord, int]],
420 context: SqlQueryContext,
421 ) -> Relation:
422 """Helper method for `make_relation`.
424 This handles adding columns and WHERE terms that are not specific to
425 either the tags or calibs tables.
427 Parameters
428 ----------
429 payload : `lsst.daf.relation.sql.Payload`
430 SQL query parts under construction, to be modified in-place and
431 used to construct the new relation.
432 requested_columns : `~collections.abc.Set` [ `str` ]
433 Columns the relation should include.
434 collections : `~collections.abc.Sequence` [ `tuple` \
435 [ `CollectionRecord`, `int` ] ]
436 Collections to search for the dataset and their ranks.
437 context : `SqlQueryContext`
438 Context that manages engines and state for the query.
440 Returns
441 -------
442 relation : `lsst.daf.relation.Relation`
443 New dataset query relation.
444 """
445 payload.where.append(payload.from_clause.columns.dataset_type_id == self._dataset_type_id)
446 dataset_id_col = payload.from_clause.columns.dataset_id
447 collection_col = payload.from_clause.columns[self._collections.getCollectionForeignKeyName()]
448 # We always constrain and optionally retrieve the collection(s) via the
449 # tags/calibs table.
450 if len(collections) == 1:
451 payload.where.append(collection_col == collections[0][0].key)
452 if "collection" in requested_columns:
453 payload.columns_available[
454 DatasetColumnTag(self.datasetType.name, "collection")
455 ] = sqlalchemy.sql.literal(collections[0][0].key)
456 else:
457 assert collections, "The no-collections case should be in calling code for better diagnostics."
458 payload.where.append(collection_col.in_([collection.key for collection, _ in collections]))
459 if "collection" in requested_columns:
460 payload.columns_available[
461 DatasetColumnTag(self.datasetType.name, "collection")
462 ] = collection_col
463 # Add rank if requested as a CASE-based calculation the collection
464 # column.
465 if "rank" in requested_columns:
466 payload.columns_available[DatasetColumnTag(self.datasetType.name, "rank")] = sqlalchemy.sql.case(
467 {record.key: rank for record, rank in collections},
468 value=collection_col,
469 )
470 # Add more column definitions, starting with the data ID.
471 for dimension_name in self.datasetType.dimensions.required.names:
472 payload.columns_available[DimensionKeyColumnTag(dimension_name)] = payload.from_clause.columns[
473 dimension_name
474 ]
475 # We can always get the dataset_id from the tags/calibs table.
476 if "dataset_id" in requested_columns:
477 payload.columns_available[DatasetColumnTag(self.datasetType.name, "dataset_id")] = dataset_id_col
478 # It's possible we now have everything we need, from just the
479 # tags/calibs table. The things we might need to get from the static
480 # dataset table are the run key and the ingest date.
481 need_static_table = False
482 if "run" in requested_columns:
483 if len(collections) == 1 and collections[0][0].type is CollectionType.RUN:
484 # If we are searching exactly one RUN collection, we
485 # know that if we find the dataset in that collection,
486 # then that's the datasets's run; we don't need to
487 # query for it.
488 payload.columns_available[
489 DatasetColumnTag(self.datasetType.name, "run")
490 ] = sqlalchemy.sql.literal(collections[0][0].key)
491 else:
492 payload.columns_available[
493 DatasetColumnTag(self.datasetType.name, "run")
494 ] = self._static.dataset.columns[self._runKeyColumn]
495 need_static_table = True
496 # Ingest date can only come from the static table.
497 if "ingest_date" in requested_columns:
498 need_static_table = True
499 payload.columns_available[
500 DatasetColumnTag(self.datasetType.name, "ingest_date")
501 ] = self._static.dataset.columns.ingest_date
502 # If we need the static table, join it in via dataset_id and
503 # dataset_type_id
504 if need_static_table:
505 payload.from_clause = payload.from_clause.join(
506 self._static.dataset, onclause=(dataset_id_col == self._static.dataset.columns.id)
507 )
508 # Also constrain dataset_type_id in static table in case that helps
509 # generate a better plan.
510 # We could also include this in the JOIN ON clause, but my guess is
511 # that that's a good idea IFF it's in the foreign key, and right
512 # now it isn't.
513 payload.where.append(self._static.dataset.columns.dataset_type_id == self._dataset_type_id)
514 leaf = context.sql_engine.make_leaf(
515 payload.columns_available.keys(),
516 payload=payload,
517 name=self.datasetType.name,
518 parameters={record.name: rank for record, rank in collections},
519 )
520 return leaf
522 def getDataId(self, id: DatasetId) -> DataCoordinate:
523 """Return DataId for a dataset.
525 Parameters
526 ----------
527 id : `DatasetId`
528 Unique dataset identifier.
530 Returns
531 -------
532 dataId : `DataCoordinate`
533 DataId for the dataset.
534 """
535 # This query could return multiple rows (one for each tagged collection
536 # the dataset is in, plus one for its run collection), and we don't
537 # care which of those we get.
538 sql = (
539 self._tags.select()
540 .where(
541 sqlalchemy.sql.and_(
542 self._tags.columns.dataset_id == id,
543 self._tags.columns.dataset_type_id == self._dataset_type_id,
544 )
545 )
546 .limit(1)
547 )
548 with self._db.query(sql) as sql_result:
549 row = sql_result.mappings().fetchone()
550 assert row is not None, "Should be guaranteed by caller and foreign key constraints."
551 return DataCoordinate.standardize(
552 {dimension.name: row[dimension.name] for dimension in self.datasetType.dimensions.required},
553 graph=self.datasetType.dimensions,
554 )
557class ByDimensionsDatasetRecordStorageUUID(ByDimensionsDatasetRecordStorage):
558 """Implementation of ByDimensionsDatasetRecordStorage which uses UUID for
559 dataset IDs.
560 """
562 idMaker = DatasetIdFactory()
563 """Factory for dataset IDs. In the future this factory may be shared with
564 other classes (e.g. Registry)."""
566 def insert(
567 self,
568 run: RunRecord,
569 dataIds: Iterable[DataCoordinate],
570 idMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE,
571 ) -> Iterator[DatasetRef]:
572 # Docstring inherited from DatasetRecordStorage.
574 # Current timestamp, type depends on schema version. Use microsecond
575 # precision for astropy time to keep things consistent with
576 # TIMESTAMP(6) SQL type.
577 timestamp: datetime | astropy.time.Time
578 if self._use_astropy:
579 # Astropy `now()` precision should be the same as `utcnow()` which
580 # should mean microsecond.
581 timestamp = astropy.time.Time.now()
582 else:
583 timestamp = datetime.utcnow()
585 # Iterate over data IDs, transforming a possibly-single-pass iterable
586 # into a list.
587 dataIdList = []
588 rows = []
589 summary = CollectionSummary()
590 for dataId in summary.add_data_ids_generator(self.datasetType, dataIds):
591 dataIdList.append(dataId)
592 rows.append(
593 {
594 "id": self.idMaker.makeDatasetId(run.name, self.datasetType, dataId, idMode),
595 "dataset_type_id": self._dataset_type_id,
596 self._runKeyColumn: run.key,
597 "ingest_date": timestamp,
598 }
599 )
601 with self._db.transaction():
602 # Insert into the static dataset table.
603 self._db.insert(self._static.dataset, *rows)
604 # Update the summary tables for this collection in case this is the
605 # first time this dataset type or these governor values will be
606 # inserted there.
607 self._summaries.update(run, [self._dataset_type_id], summary)
608 # Combine the generated dataset_id values and data ID fields to
609 # form rows to be inserted into the tags table.
610 protoTagsRow = {
611 "dataset_type_id": self._dataset_type_id,
612 self._collections.getCollectionForeignKeyName(): run.key,
613 }
614 tagsRows = [
615 dict(protoTagsRow, dataset_id=row["id"], **dataId.byName())
616 for dataId, row in zip(dataIdList, rows)
617 ]
618 # Insert those rows into the tags table.
619 self._db.insert(self._tags, *tagsRows)
621 for dataId, row in zip(dataIdList, rows):
622 yield DatasetRef(
623 datasetType=self.datasetType,
624 dataId=dataId,
625 id=row["id"],
626 run=run.name,
627 )
629 def import_(
630 self,
631 run: RunRecord,
632 datasets: Iterable[DatasetRef],
633 ) -> Iterator[DatasetRef]:
634 # Docstring inherited from DatasetRecordStorage.
636 # Current timestamp, type depends on schema version.
637 if self._use_astropy:
638 # Astropy `now()` precision should be the same as `utcnow()` which
639 # should mean microsecond.
640 timestamp = sqlalchemy.sql.literal(astropy.time.Time.now(), type_=ddl.AstropyTimeNsecTai)
641 else:
642 timestamp = sqlalchemy.sql.literal(datetime.utcnow())
644 # Iterate over data IDs, transforming a possibly-single-pass iterable
645 # into a list.
646 dataIds = {}
647 summary = CollectionSummary()
648 for dataset in summary.add_datasets_generator(datasets):
649 dataIds[dataset.id] = dataset.dataId
651 # We'll insert all new rows into a temporary table
652 tableSpec = makeTagTableSpec(self.datasetType, type(self._collections), ddl.GUID, constraints=False)
653 collFkName = self._collections.getCollectionForeignKeyName()
654 protoTagsRow = {
655 "dataset_type_id": self._dataset_type_id,
656 collFkName: run.key,
657 }
658 tmpRows = [
659 dict(protoTagsRow, dataset_id=dataset_id, **dataId.byName())
660 for dataset_id, dataId in dataIds.items()
661 ]
662 with self._db.transaction(for_temp_tables=True):
663 with self._db.temporary_table(tableSpec) as tmp_tags:
664 # store all incoming data in a temporary table
665 self._db.insert(tmp_tags, *tmpRows)
667 # There are some checks that we want to make for consistency
668 # of the new datasets with existing ones.
669 self._validateImport(tmp_tags, run)
671 # Before we merge temporary table into dataset/tags we need to
672 # drop datasets which are already there (and do not conflict).
673 self._db.deleteWhere(
674 tmp_tags,
675 tmp_tags.columns.dataset_id.in_(sqlalchemy.sql.select(self._static.dataset.columns.id)),
676 )
678 # Copy it into dataset table, need to re-label some columns.
679 self._db.insert(
680 self._static.dataset,
681 select=sqlalchemy.sql.select(
682 tmp_tags.columns.dataset_id.label("id"),
683 tmp_tags.columns.dataset_type_id,
684 tmp_tags.columns[collFkName].label(self._runKeyColumn),
685 timestamp.label("ingest_date"),
686 ),
687 )
689 # Update the summary tables for this collection in case this
690 # is the first time this dataset type or these governor values
691 # will be inserted there.
692 self._summaries.update(run, [self._dataset_type_id], summary)
694 # Copy it into tags table.
695 self._db.insert(self._tags, select=tmp_tags.select())
697 # Return refs in the same order as in the input list.
698 for dataset_id, dataId in dataIds.items():
699 yield DatasetRef(
700 datasetType=self.datasetType,
701 id=dataset_id,
702 dataId=dataId,
703 run=run.name,
704 )
706 def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) -> None:
707 """Validate imported refs against existing datasets.
709 Parameters
710 ----------
711 tmp_tags : `sqlalchemy.schema.Table`
712 Temporary table with new datasets and the same schema as tags
713 table.
714 run : `RunRecord`
715 The record object describing the `~CollectionType.RUN` collection.
717 Raises
718 ------
719 ConflictingDefinitionError
720 Raise if new datasets conflict with existing ones.
721 """
722 dataset = self._static.dataset
723 tags = self._tags
724 collFkName = self._collections.getCollectionForeignKeyName()
726 # Check that existing datasets have the same dataset type and
727 # run.
728 query = (
729 sqlalchemy.sql.select(
730 dataset.columns.id.label("dataset_id"),
731 dataset.columns.dataset_type_id.label("dataset_type_id"),
732 tmp_tags.columns.dataset_type_id.label("new dataset_type_id"),
733 dataset.columns[self._runKeyColumn].label("run"),
734 tmp_tags.columns[collFkName].label("new run"),
735 )
736 .select_from(dataset.join(tmp_tags, dataset.columns.id == tmp_tags.columns.dataset_id))
737 .where(
738 sqlalchemy.sql.or_(
739 dataset.columns.dataset_type_id != tmp_tags.columns.dataset_type_id,
740 dataset.columns[self._runKeyColumn] != tmp_tags.columns[collFkName],
741 )
742 )
743 .limit(1)
744 )
745 with self._db.query(query) as result:
746 if (row := result.first()) is not None:
747 # Only include the first one in the exception message
748 raise ConflictingDefinitionError(
749 f"Existing dataset type or run do not match new dataset: {row._asdict()}"
750 )
752 # Check that matching dataset in tags table has the same DataId.
753 query = (
754 sqlalchemy.sql.select(
755 tags.columns.dataset_id,
756 tags.columns.dataset_type_id.label("type_id"),
757 tmp_tags.columns.dataset_type_id.label("new type_id"),
758 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names],
759 *[
760 tmp_tags.columns[dim].label(f"new {dim}")
761 for dim in self.datasetType.dimensions.required.names
762 ],
763 )
764 .select_from(tags.join(tmp_tags, tags.columns.dataset_id == tmp_tags.columns.dataset_id))
765 .where(
766 sqlalchemy.sql.or_(
767 tags.columns.dataset_type_id != tmp_tags.columns.dataset_type_id,
768 *[
769 tags.columns[dim] != tmp_tags.columns[dim]
770 for dim in self.datasetType.dimensions.required.names
771 ],
772 )
773 )
774 .limit(1)
775 )
777 with self._db.query(query) as result:
778 if (row := result.first()) is not None:
779 # Only include the first one in the exception message
780 raise ConflictingDefinitionError(
781 f"Existing dataset type or dataId do not match new dataset: {row._asdict()}"
782 )
784 # Check that matching run+dataId have the same dataset ID.
785 query = (
786 sqlalchemy.sql.select(
787 tags.columns.dataset_type_id.label("dataset_type_id"),
788 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names],
789 tags.columns.dataset_id,
790 tmp_tags.columns.dataset_id.label("new dataset_id"),
791 tags.columns[collFkName],
792 tmp_tags.columns[collFkName].label(f"new {collFkName}"),
793 )
794 .select_from(
795 tags.join(
796 tmp_tags,
797 sqlalchemy.sql.and_(
798 tags.columns.dataset_type_id == tmp_tags.columns.dataset_type_id,
799 tags.columns[collFkName] == tmp_tags.columns[collFkName],
800 *[
801 tags.columns[dim] == tmp_tags.columns[dim]
802 for dim in self.datasetType.dimensions.required.names
803 ],
804 ),
805 )
806 )
807 .where(tags.columns.dataset_id != tmp_tags.columns.dataset_id)
808 .limit(1)
809 )
810 with self._db.query(query) as result:
811 if (row := result.first()) is not None:
812 # only include the first one in the exception message
813 raise ConflictingDefinitionError(
814 f"Existing dataset type and dataId does not match new dataset: {row._asdict()}"
815 )