Coverage for python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py: 95%
251 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-27 09:43 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-27 09:43 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
29from __future__ import annotations
31from .... import ddl
33__all__ = ("ByDimensionsDatasetRecordStorage",)
35from collections.abc import Iterable, Iterator, Sequence, Set
36from datetime import datetime
37from typing import TYPE_CHECKING
39import astropy.time
40import sqlalchemy
41from lsst.daf.relation import Relation, sql
43from ...._column_tags import DatasetColumnTag, DimensionKeyColumnTag
44from ...._column_type_info import LogicalColumn
45from ...._dataset_ref import DatasetId, DatasetIdFactory, DatasetIdGenEnum, DatasetRef
46from ...._dataset_type import DatasetType
47from ...._timespan import Timespan
48from ....dimensions import DataCoordinate
49from ..._collection_summary import CollectionSummary
50from ..._collection_type import CollectionType
51from ..._exceptions import CollectionTypeError, ConflictingDefinitionError
52from ...interfaces import DatasetRecordStorage
53from ...queries import SqlQueryContext
54from .tables import makeTagTableSpec
56if TYPE_CHECKING:
57 from ...interfaces import CollectionManager, CollectionRecord, Database, RunRecord
58 from .summaries import CollectionSummaryManager
59 from .tables import StaticDatasetTablesTuple
62class ByDimensionsDatasetRecordStorage(DatasetRecordStorage):
63 """Dataset record storage implementation paired with
64 `ByDimensionsDatasetRecordStorageManagerUUID`; see that class for more
65 information.
67 Instances of this class should never be constructed directly; use
68 `DatasetRecordStorageManager.register` instead.
69 """
71 def __init__(
72 self,
73 *,
74 datasetType: DatasetType,
75 db: Database,
76 dataset_type_id: int,
77 collections: CollectionManager,
78 static: StaticDatasetTablesTuple,
79 summaries: CollectionSummaryManager,
80 tags: sqlalchemy.schema.Table,
81 use_astropy_ingest_date: bool,
82 calibs: sqlalchemy.schema.Table | None,
83 ):
84 super().__init__(datasetType=datasetType)
85 self._dataset_type_id = dataset_type_id
86 self._db = db
87 self._collections = collections
88 self._static = static
89 self._summaries = summaries
90 self._tags = tags
91 self._calibs = calibs
92 self._runKeyColumn = collections.getRunForeignKeyName()
93 self._use_astropy = use_astropy_ingest_date
95 def delete(self, datasets: Iterable[DatasetRef]) -> None:
96 # Docstring inherited from DatasetRecordStorage.
97 # Only delete from common dataset table; ON DELETE foreign key clauses
98 # will handle the rest.
99 self._db.delete(
100 self._static.dataset,
101 ["id"],
102 *[{"id": dataset.id} for dataset in datasets],
103 )
105 def associate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None:
106 # Docstring inherited from DatasetRecordStorage.
107 if collection.type is not CollectionType.TAGGED: 107 ↛ 108line 107 didn't jump to line 108, because the condition on line 107 was never true
108 raise TypeError(
109 f"Cannot associate into collection '{collection.name}' "
110 f"of type {collection.type.name}; must be TAGGED."
111 )
112 protoRow = {
113 self._collections.getCollectionForeignKeyName(): collection.key,
114 "dataset_type_id": self._dataset_type_id,
115 }
116 rows = []
117 summary = CollectionSummary()
118 for dataset in summary.add_datasets_generator(datasets):
119 row = dict(protoRow, dataset_id=dataset.id)
120 for dimension, value in dataset.dataId.items():
121 row[dimension.name] = value
122 rows.append(row)
123 # Update the summary tables for this collection in case this is the
124 # first time this dataset type or these governor values will be
125 # inserted there.
126 self._summaries.update(collection, [self._dataset_type_id], summary)
127 # Update the tag table itself.
128 self._db.replace(self._tags, *rows)
130 def disassociate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None:
131 # Docstring inherited from DatasetRecordStorage.
132 if collection.type is not CollectionType.TAGGED: 132 ↛ 133line 132 didn't jump to line 133, because the condition on line 132 was never true
133 raise TypeError(
134 f"Cannot disassociate from collection '{collection.name}' "
135 f"of type {collection.type.name}; must be TAGGED."
136 )
137 rows = [
138 {
139 "dataset_id": dataset.id,
140 self._collections.getCollectionForeignKeyName(): collection.key,
141 }
142 for dataset in datasets
143 ]
144 self._db.delete(self._tags, ["dataset_id", self._collections.getCollectionForeignKeyName()], *rows)
146 def _buildCalibOverlapQuery(
147 self,
148 collection: CollectionRecord,
149 data_ids: set[DataCoordinate] | None,
150 timespan: Timespan,
151 context: SqlQueryContext,
152 ) -> Relation:
153 relation = self.make_relation(
154 collection, columns={"timespan", "dataset_id", "calib_pkey"}, context=context
155 ).with_rows_satisfying(
156 context.make_timespan_overlap_predicate(
157 DatasetColumnTag(self.datasetType.name, "timespan"), timespan
158 ),
159 )
160 if data_ids is not None:
161 relation = relation.join(
162 context.make_data_id_relation(
163 data_ids, self.datasetType.dimensions.required.names
164 ).transferred_to(context.sql_engine),
165 )
166 return relation
168 def certify(
169 self,
170 collection: CollectionRecord,
171 datasets: Iterable[DatasetRef],
172 timespan: Timespan,
173 context: SqlQueryContext,
174 ) -> None:
175 # Docstring inherited from DatasetRecordStorage.
176 if self._calibs is None: 176 ↛ 177line 176 didn't jump to line 177, because the condition on line 176 was never true
177 raise CollectionTypeError(
178 f"Cannot certify datasets of type {self.datasetType.name}, for which "
179 "DatasetType.isCalibration() is False."
180 )
181 if collection.type is not CollectionType.CALIBRATION: 181 ↛ 182line 181 didn't jump to line 182, because the condition on line 181 was never true
182 raise CollectionTypeError(
183 f"Cannot certify into collection '{collection.name}' "
184 f"of type {collection.type.name}; must be CALIBRATION."
185 )
186 TimespanReprClass = self._db.getTimespanRepresentation()
187 protoRow = {
188 self._collections.getCollectionForeignKeyName(): collection.key,
189 "dataset_type_id": self._dataset_type_id,
190 }
191 rows = []
192 dataIds: set[DataCoordinate] | None = (
193 set() if not TimespanReprClass.hasExclusionConstraint() else None
194 )
195 summary = CollectionSummary()
196 for dataset in summary.add_datasets_generator(datasets):
197 row = dict(protoRow, dataset_id=dataset.id)
198 for dimension, value in dataset.dataId.items():
199 row[dimension.name] = value
200 TimespanReprClass.update(timespan, result=row)
201 rows.append(row)
202 if dataIds is not None: 202 ↛ 196line 202 didn't jump to line 196, because the condition on line 202 was never false
203 dataIds.add(dataset.dataId)
204 # Update the summary tables for this collection in case this is the
205 # first time this dataset type or these governor values will be
206 # inserted there.
207 self._summaries.update(collection, [self._dataset_type_id], summary)
208 # Update the association table itself.
209 if TimespanReprClass.hasExclusionConstraint(): 209 ↛ 212line 209 didn't jump to line 212, because the condition on line 209 was never true
210 # Rely on database constraint to enforce invariants; we just
211 # reraise the exception for consistency across DB engines.
212 try:
213 self._db.insert(self._calibs, *rows)
214 except sqlalchemy.exc.IntegrityError as err:
215 raise ConflictingDefinitionError(
216 f"Validity range conflict certifying datasets of type {self.datasetType.name} "
217 f"into {collection.name} for range [{timespan.begin}, {timespan.end})."
218 ) from err
219 else:
220 # Have to implement exclusion constraint ourselves.
221 # Start by building a SELECT query for any rows that would overlap
222 # this one.
223 relation = self._buildCalibOverlapQuery(collection, dataIds, timespan, context)
224 # Acquire a table lock to ensure there are no concurrent writes
225 # could invalidate our checking before we finish the inserts. We
226 # use a SAVEPOINT in case there is an outer transaction that a
227 # failure here should not roll back.
228 with self._db.transaction(lock=[self._calibs], savepoint=True):
229 # Enter SqlQueryContext in case we need to use a temporary
230 # table to include the give data IDs in the query. Note that
231 # by doing this inside the transaction, we make sure it doesn't
232 # attempt to close the session when its done, since it just
233 # sees an already-open session that it knows it shouldn't
234 # manage.
235 with context:
236 # Run the check SELECT query.
237 conflicting = context.count(context.process(relation))
238 if conflicting > 0:
239 raise ConflictingDefinitionError(
240 f"{conflicting} validity range conflicts certifying datasets of type "
241 f"{self.datasetType.name} into {collection.name} for range "
242 f"[{timespan.begin}, {timespan.end})."
243 )
244 # Proceed with the insert.
245 self._db.insert(self._calibs, *rows)
247 def decertify(
248 self,
249 collection: CollectionRecord,
250 timespan: Timespan,
251 *,
252 dataIds: Iterable[DataCoordinate] | None = None,
253 context: SqlQueryContext,
254 ) -> None:
255 # Docstring inherited from DatasetRecordStorage.
256 if self._calibs is None: 256 ↛ 257line 256 didn't jump to line 257, because the condition on line 256 was never true
257 raise CollectionTypeError(
258 f"Cannot decertify datasets of type {self.datasetType.name}, for which "
259 "DatasetType.isCalibration() is False."
260 )
261 if collection.type is not CollectionType.CALIBRATION: 261 ↛ 262line 261 didn't jump to line 262, because the condition on line 261 was never true
262 raise CollectionTypeError(
263 f"Cannot decertify from collection '{collection.name}' "
264 f"of type {collection.type.name}; must be CALIBRATION."
265 )
266 TimespanReprClass = self._db.getTimespanRepresentation()
267 # Construct a SELECT query to find all rows that overlap our inputs.
268 dataIdSet: set[DataCoordinate] | None
269 if dataIds is not None:
270 dataIdSet = set(dataIds)
271 else:
272 dataIdSet = None
273 relation = self._buildCalibOverlapQuery(collection, dataIdSet, timespan, context)
274 calib_pkey_tag = DatasetColumnTag(self.datasetType.name, "calib_pkey")
275 dataset_id_tag = DatasetColumnTag(self.datasetType.name, "dataset_id")
276 timespan_tag = DatasetColumnTag(self.datasetType.name, "timespan")
277 data_id_tags = [
278 (name, DimensionKeyColumnTag(name)) for name in self.datasetType.dimensions.required.names
279 ]
280 # Set up collections to populate with the rows we'll want to modify.
281 # The insert rows will have the same values for collection and
282 # dataset type.
283 protoInsertRow = {
284 self._collections.getCollectionForeignKeyName(): collection.key,
285 "dataset_type_id": self._dataset_type_id,
286 }
287 rowsToDelete = []
288 rowsToInsert = []
289 # Acquire a table lock to ensure there are no concurrent writes
290 # between the SELECT and the DELETE and INSERT queries based on it.
291 with self._db.transaction(lock=[self._calibs], savepoint=True):
292 # Enter SqlQueryContext in case we need to use a temporary table to
293 # include the give data IDs in the query (see similar block in
294 # certify for details).
295 with context:
296 for row in context.fetch_iterable(relation):
297 rowsToDelete.append({"id": row[calib_pkey_tag]})
298 # Construct the insert row(s) by copying the prototype row,
299 # then adding the dimension column values, then adding
300 # what's left of the timespan from that row after we
301 # subtract the given timespan.
302 newInsertRow = protoInsertRow.copy()
303 newInsertRow["dataset_id"] = row[dataset_id_tag]
304 for name, tag in data_id_tags:
305 newInsertRow[name] = row[tag]
306 rowTimespan = row[timespan_tag]
307 assert rowTimespan is not None, "Field should have a NOT NULL constraint."
308 for diffTimespan in rowTimespan.difference(timespan):
309 rowsToInsert.append(
310 TimespanReprClass.update(diffTimespan, result=newInsertRow.copy())
311 )
312 # Run the DELETE and INSERT queries.
313 self._db.delete(self._calibs, ["id"], *rowsToDelete)
314 self._db.insert(self._calibs, *rowsToInsert)
316 def make_relation(
317 self,
318 *collections: CollectionRecord,
319 columns: Set[str],
320 context: SqlQueryContext,
321 ) -> Relation:
322 # Docstring inherited from DatasetRecordStorage.
323 collection_types = {collection.type for collection in collections}
324 assert CollectionType.CHAINED not in collection_types, "CHAINED collections must be flattened."
325 TimespanReprClass = self._db.getTimespanRepresentation()
326 #
327 # There are two kinds of table in play here:
328 #
329 # - the static dataset table (with the dataset ID, dataset type ID,
330 # run ID/name, and ingest date);
331 #
332 # - the dynamic tags/calibs table (with the dataset ID, dataset type
333 # type ID, collection ID/name, data ID, and possibly validity
334 # range).
335 #
336 # That means that we might want to return a query against either table
337 # or a JOIN of both, depending on which quantities the caller wants.
338 # But the data ID is always included, which means we'll always include
339 # the tags/calibs table and join in the static dataset table only if we
340 # need things from it that we can't get from the tags/calibs table.
341 #
342 # Note that it's important that we include a WHERE constraint on both
343 # tables for any column (e.g. dataset_type_id) that is in both when
344 # it's given explicitly; not doing can prevent the query planner from
345 # using very important indexes. At present, we don't include those
346 # redundant columns in the JOIN ON expression, however, because the
347 # FOREIGN KEY (and its index) are defined only on dataset_id.
348 tag_relation: Relation | None = None
349 calib_relation: Relation | None = None
350 if collection_types != {CollectionType.CALIBRATION}:
351 # We'll need a subquery for the tags table if any of the given
352 # collections are not a CALIBRATION collection. This intentionally
353 # also fires when the list of collections is empty as a way to
354 # create a dummy subquery that we know will fail.
355 # We give the table an alias because it might appear multiple times
356 # in the same query, for different dataset types.
357 tags_parts = sql.Payload[LogicalColumn](self._tags.alias(f"{self.datasetType.name}_tags"))
358 if "timespan" in columns:
359 tags_parts.columns_available[
360 DatasetColumnTag(self.datasetType.name, "timespan")
361 ] = TimespanReprClass.fromLiteral(Timespan(None, None))
362 tag_relation = self._finish_single_relation(
363 tags_parts,
364 columns,
365 [
366 (record, rank)
367 for rank, record in enumerate(collections)
368 if record.type is not CollectionType.CALIBRATION
369 ],
370 context,
371 )
372 assert "calib_pkey" not in columns, "For internal use only, and only for pure-calib queries."
373 if CollectionType.CALIBRATION in collection_types:
374 # If at least one collection is a CALIBRATION collection, we'll
375 # need a subquery for the calibs table, and could include the
376 # timespan as a result or constraint.
377 assert (
378 self._calibs is not None
379 ), "DatasetTypes with isCalibration() == False can never be found in a CALIBRATION collection."
380 calibs_parts = sql.Payload[LogicalColumn](self._calibs.alias(f"{self.datasetType.name}_calibs"))
381 if "timespan" in columns:
382 calibs_parts.columns_available[
383 DatasetColumnTag(self.datasetType.name, "timespan")
384 ] = TimespanReprClass.from_columns(calibs_parts.from_clause.columns)
385 if "calib_pkey" in columns:
386 # This is a private extension not included in the base class
387 # interface, for internal use only in _buildCalibOverlapQuery,
388 # which needs access to the autoincrement primary key for the
389 # calib association table.
390 calibs_parts.columns_available[
391 DatasetColumnTag(self.datasetType.name, "calib_pkey")
392 ] = calibs_parts.from_clause.columns.id
393 calib_relation = self._finish_single_relation(
394 calibs_parts,
395 columns,
396 [
397 (record, rank)
398 for rank, record in enumerate(collections)
399 if record.type is CollectionType.CALIBRATION
400 ],
401 context,
402 )
403 if tag_relation is not None:
404 if calib_relation is not None:
405 # daf_relation's chain operation does not automatically
406 # deduplicate; it's more like SQL's UNION ALL. To get UNION
407 # in SQL here, we add an explicit deduplication.
408 return tag_relation.chain(calib_relation).without_duplicates()
409 else:
410 return tag_relation
411 elif calib_relation is not None:
412 return calib_relation
413 else:
414 raise AssertionError("Branch should be unreachable.")
416 def _finish_single_relation(
417 self,
418 payload: sql.Payload[LogicalColumn],
419 requested_columns: Set[str],
420 collections: Sequence[tuple[CollectionRecord, int]],
421 context: SqlQueryContext,
422 ) -> Relation:
423 """Handle adding columns and WHERE terms that are not specific to
424 either the tags or calibs tables.
426 Helper method for `make_relation`.
428 Parameters
429 ----------
430 payload : `lsst.daf.relation.sql.Payload`
431 SQL query parts under construction, to be modified in-place and
432 used to construct the new relation.
433 requested_columns : `~collections.abc.Set` [ `str` ]
434 Columns the relation should include.
435 collections : `~collections.abc.Sequence` [ `tuple` \
436 [ `CollectionRecord`, `int` ] ]
437 Collections to search for the dataset and their ranks.
438 context : `SqlQueryContext`
439 Context that manages engines and state for the query.
441 Returns
442 -------
443 relation : `lsst.daf.relation.Relation`
444 New dataset query relation.
445 """
446 payload.where.append(payload.from_clause.columns.dataset_type_id == self._dataset_type_id)
447 dataset_id_col = payload.from_clause.columns.dataset_id
448 collection_col = payload.from_clause.columns[self._collections.getCollectionForeignKeyName()]
449 # We always constrain and optionally retrieve the collection(s) via the
450 # tags/calibs table.
451 if len(collections) == 1:
452 payload.where.append(collection_col == collections[0][0].key)
453 if "collection" in requested_columns:
454 payload.columns_available[
455 DatasetColumnTag(self.datasetType.name, "collection")
456 ] = sqlalchemy.sql.literal(collections[0][0].key)
457 else:
458 assert collections, "The no-collections case should be in calling code for better diagnostics."
459 payload.where.append(collection_col.in_([collection.key for collection, _ in collections]))
460 if "collection" in requested_columns:
461 payload.columns_available[
462 DatasetColumnTag(self.datasetType.name, "collection")
463 ] = collection_col
464 # Add rank if requested as a CASE-based calculation the collection
465 # column.
466 if "rank" in requested_columns:
467 payload.columns_available[DatasetColumnTag(self.datasetType.name, "rank")] = sqlalchemy.sql.case(
468 {record.key: rank for record, rank in collections},
469 value=collection_col,
470 )
471 # Add more column definitions, starting with the data ID.
472 for dimension_name in self.datasetType.dimensions.required.names:
473 payload.columns_available[DimensionKeyColumnTag(dimension_name)] = payload.from_clause.columns[
474 dimension_name
475 ]
476 # We can always get the dataset_id from the tags/calibs table.
477 if "dataset_id" in requested_columns:
478 payload.columns_available[DatasetColumnTag(self.datasetType.name, "dataset_id")] = dataset_id_col
479 # It's possible we now have everything we need, from just the
480 # tags/calibs table. The things we might need to get from the static
481 # dataset table are the run key and the ingest date.
482 need_static_table = False
483 if "run" in requested_columns:
484 if len(collections) == 1 and collections[0][0].type is CollectionType.RUN:
485 # If we are searching exactly one RUN collection, we
486 # know that if we find the dataset in that collection,
487 # then that's the datasets's run; we don't need to
488 # query for it.
489 payload.columns_available[
490 DatasetColumnTag(self.datasetType.name, "run")
491 ] = sqlalchemy.sql.literal(collections[0][0].key)
492 else:
493 payload.columns_available[
494 DatasetColumnTag(self.datasetType.name, "run")
495 ] = self._static.dataset.columns[self._runKeyColumn]
496 need_static_table = True
497 # Ingest date can only come from the static table.
498 if "ingest_date" in requested_columns:
499 need_static_table = True
500 payload.columns_available[
501 DatasetColumnTag(self.datasetType.name, "ingest_date")
502 ] = self._static.dataset.columns.ingest_date
503 # If we need the static table, join it in via dataset_id and
504 # dataset_type_id
505 if need_static_table:
506 payload.from_clause = payload.from_clause.join(
507 self._static.dataset, onclause=(dataset_id_col == self._static.dataset.columns.id)
508 )
509 # Also constrain dataset_type_id in static table in case that helps
510 # generate a better plan.
511 # We could also include this in the JOIN ON clause, but my guess is
512 # that that's a good idea IFF it's in the foreign key, and right
513 # now it isn't.
514 payload.where.append(self._static.dataset.columns.dataset_type_id == self._dataset_type_id)
515 leaf = context.sql_engine.make_leaf(
516 payload.columns_available.keys(),
517 payload=payload,
518 name=self.datasetType.name,
519 parameters={record.name: rank for record, rank in collections},
520 )
521 return leaf
523 def getDataId(self, id: DatasetId) -> DataCoordinate:
524 """Return DataId for a dataset.
526 Parameters
527 ----------
528 id : `DatasetId`
529 Unique dataset identifier.
531 Returns
532 -------
533 dataId : `DataCoordinate`
534 DataId for the dataset.
535 """
536 # This query could return multiple rows (one for each tagged collection
537 # the dataset is in, plus one for its run collection), and we don't
538 # care which of those we get.
539 sql = (
540 self._tags.select()
541 .where(
542 sqlalchemy.sql.and_(
543 self._tags.columns.dataset_id == id,
544 self._tags.columns.dataset_type_id == self._dataset_type_id,
545 )
546 )
547 .limit(1)
548 )
549 with self._db.query(sql) as sql_result:
550 row = sql_result.mappings().fetchone()
551 assert row is not None, "Should be guaranteed by caller and foreign key constraints."
552 return DataCoordinate.standardize(
553 {dimension.name: row[dimension.name] for dimension in self.datasetType.dimensions.required},
554 graph=self.datasetType.dimensions,
555 )
558class ByDimensionsDatasetRecordStorageUUID(ByDimensionsDatasetRecordStorage):
559 """Implementation of ByDimensionsDatasetRecordStorage which uses UUID for
560 dataset IDs.
561 """
563 idMaker = DatasetIdFactory()
564 """Factory for dataset IDs. In the future this factory may be shared with
565 other classes (e.g. Registry)."""
567 def insert(
568 self,
569 run: RunRecord,
570 dataIds: Iterable[DataCoordinate],
571 idMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE,
572 ) -> Iterator[DatasetRef]:
573 # Docstring inherited from DatasetRecordStorage.
575 # Current timestamp, type depends on schema version. Use microsecond
576 # precision for astropy time to keep things consistent with
577 # TIMESTAMP(6) SQL type.
578 timestamp: datetime | astropy.time.Time
579 if self._use_astropy:
580 # Astropy `now()` precision should be the same as `utcnow()` which
581 # should mean microsecond.
582 timestamp = astropy.time.Time.now()
583 else:
584 timestamp = datetime.utcnow()
586 # Iterate over data IDs, transforming a possibly-single-pass iterable
587 # into a list.
588 dataIdList = []
589 rows = []
590 summary = CollectionSummary()
591 for dataId in summary.add_data_ids_generator(self.datasetType, dataIds):
592 dataIdList.append(dataId)
593 rows.append(
594 {
595 "id": self.idMaker.makeDatasetId(run.name, self.datasetType, dataId, idMode),
596 "dataset_type_id": self._dataset_type_id,
597 self._runKeyColumn: run.key,
598 "ingest_date": timestamp,
599 }
600 )
602 with self._db.transaction():
603 # Insert into the static dataset table.
604 self._db.insert(self._static.dataset, *rows)
605 # Update the summary tables for this collection in case this is the
606 # first time this dataset type or these governor values will be
607 # inserted there.
608 self._summaries.update(run, [self._dataset_type_id], summary)
609 # Combine the generated dataset_id values and data ID fields to
610 # form rows to be inserted into the tags table.
611 protoTagsRow = {
612 "dataset_type_id": self._dataset_type_id,
613 self._collections.getCollectionForeignKeyName(): run.key,
614 }
615 tagsRows = [
616 dict(protoTagsRow, dataset_id=row["id"], **dataId.byName())
617 for dataId, row in zip(dataIdList, rows, strict=True)
618 ]
619 # Insert those rows into the tags table.
620 self._db.insert(self._tags, *tagsRows)
622 for dataId, row in zip(dataIdList, rows, strict=True):
623 yield DatasetRef(
624 datasetType=self.datasetType,
625 dataId=dataId,
626 id=row["id"],
627 run=run.name,
628 )
630 def import_(
631 self,
632 run: RunRecord,
633 datasets: Iterable[DatasetRef],
634 ) -> Iterator[DatasetRef]:
635 # Docstring inherited from DatasetRecordStorage.
637 # Current timestamp, type depends on schema version.
638 if self._use_astropy:
639 # Astropy `now()` precision should be the same as `utcnow()` which
640 # should mean microsecond.
641 timestamp = sqlalchemy.sql.literal(astropy.time.Time.now(), type_=ddl.AstropyTimeNsecTai)
642 else:
643 timestamp = sqlalchemy.sql.literal(datetime.utcnow())
645 # Iterate over data IDs, transforming a possibly-single-pass iterable
646 # into a list.
647 dataIds = {}
648 summary = CollectionSummary()
649 for dataset in summary.add_datasets_generator(datasets):
650 dataIds[dataset.id] = dataset.dataId
652 # We'll insert all new rows into a temporary table
653 tableSpec = makeTagTableSpec(self.datasetType, type(self._collections), ddl.GUID, constraints=False)
654 collFkName = self._collections.getCollectionForeignKeyName()
655 protoTagsRow = {
656 "dataset_type_id": self._dataset_type_id,
657 collFkName: run.key,
658 }
659 tmpRows = [
660 dict(protoTagsRow, dataset_id=dataset_id, **dataId.byName())
661 for dataset_id, dataId in dataIds.items()
662 ]
663 with self._db.transaction(for_temp_tables=True), self._db.temporary_table(tableSpec) as tmp_tags:
664 # store all incoming data in a temporary table
665 self._db.insert(tmp_tags, *tmpRows)
667 # There are some checks that we want to make for consistency
668 # of the new datasets with existing ones.
669 self._validateImport(tmp_tags, run)
671 # Before we merge temporary table into dataset/tags we need to
672 # drop datasets which are already there (and do not conflict).
673 self._db.deleteWhere(
674 tmp_tags,
675 tmp_tags.columns.dataset_id.in_(sqlalchemy.sql.select(self._static.dataset.columns.id)),
676 )
678 # Copy it into dataset table, need to re-label some columns.
679 self._db.insert(
680 self._static.dataset,
681 select=sqlalchemy.sql.select(
682 tmp_tags.columns.dataset_id.label("id"),
683 tmp_tags.columns.dataset_type_id,
684 tmp_tags.columns[collFkName].label(self._runKeyColumn),
685 timestamp.label("ingest_date"),
686 ),
687 )
689 # Update the summary tables for this collection in case this
690 # is the first time this dataset type or these governor values
691 # will be inserted there.
692 self._summaries.update(run, [self._dataset_type_id], summary)
694 # Copy it into tags table.
695 self._db.insert(self._tags, select=tmp_tags.select())
697 # Return refs in the same order as in the input list.
698 for dataset_id, dataId in dataIds.items():
699 yield DatasetRef(
700 datasetType=self.datasetType,
701 id=dataset_id,
702 dataId=dataId,
703 run=run.name,
704 )
706 def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) -> None:
707 """Validate imported refs against existing datasets.
709 Parameters
710 ----------
711 tmp_tags : `sqlalchemy.schema.Table`
712 Temporary table with new datasets and the same schema as tags
713 table.
714 run : `RunRecord`
715 The record object describing the `~CollectionType.RUN` collection.
717 Raises
718 ------
719 ConflictingDefinitionError
720 Raise if new datasets conflict with existing ones.
721 """
722 dataset = self._static.dataset
723 tags = self._tags
724 collFkName = self._collections.getCollectionForeignKeyName()
726 # Check that existing datasets have the same dataset type and
727 # run.
728 query = (
729 sqlalchemy.sql.select(
730 dataset.columns.id.label("dataset_id"),
731 dataset.columns.dataset_type_id.label("dataset_type_id"),
732 tmp_tags.columns.dataset_type_id.label("new_dataset_type_id"),
733 dataset.columns[self._runKeyColumn].label("run"),
734 tmp_tags.columns[collFkName].label("new_run"),
735 )
736 .select_from(dataset.join(tmp_tags, dataset.columns.id == tmp_tags.columns.dataset_id))
737 .where(
738 sqlalchemy.sql.or_(
739 dataset.columns.dataset_type_id != tmp_tags.columns.dataset_type_id,
740 dataset.columns[self._runKeyColumn] != tmp_tags.columns[collFkName],
741 )
742 )
743 .limit(1)
744 )
745 with self._db.query(query) as result:
746 # Only include the first one in the exception message
747 if (row := result.first()) is not None:
748 existing_run = self._collections[row.run].name
749 new_run = self._collections[row.new_run].name
750 if row.dataset_type_id == self._dataset_type_id:
751 if row.new_dataset_type_id == self._dataset_type_id: 751 ↛ 757line 751 didn't jump to line 757, because the condition on line 751 was never false
752 raise ConflictingDefinitionError(
753 f"Current run {existing_run!r} and new run {new_run!r} do not agree for "
754 f"dataset {row.dataset_id}."
755 )
756 else:
757 raise ConflictingDefinitionError(
758 f"Dataset {row.dataset_id} was provided with type {self.datasetType.name!r} "
759 f"in run {new_run!r}, but was already defined with type ID {row.dataset_type_id} "
760 f"in run {run!r}."
761 )
762 else:
763 raise ConflictingDefinitionError(
764 f"Dataset {row.dataset_id} was provided with type ID {row.new_dataset_type_id} "
765 f"in run {new_run!r}, but was already defined with type {self.datasetType.name!r} "
766 f"in run {run!r}."
767 )
769 # Check that matching dataset in tags table has the same DataId.
770 query = (
771 sqlalchemy.sql.select(
772 tags.columns.dataset_id,
773 tags.columns.dataset_type_id.label("type_id"),
774 tmp_tags.columns.dataset_type_id.label("new_type_id"),
775 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names],
776 *[
777 tmp_tags.columns[dim].label(f"new_{dim}")
778 for dim in self.datasetType.dimensions.required.names
779 ],
780 )
781 .select_from(tags.join(tmp_tags, tags.columns.dataset_id == tmp_tags.columns.dataset_id))
782 .where(
783 sqlalchemy.sql.or_(
784 tags.columns.dataset_type_id != tmp_tags.columns.dataset_type_id,
785 *[
786 tags.columns[dim] != tmp_tags.columns[dim]
787 for dim in self.datasetType.dimensions.required.names
788 ],
789 )
790 )
791 .limit(1)
792 )
794 with self._db.query(query) as result:
795 if (row := result.first()) is not None:
796 # Only include the first one in the exception message
797 raise ConflictingDefinitionError(
798 f"Existing dataset type or dataId do not match new dataset: {row._asdict()}"
799 )
801 # Check that matching run+dataId have the same dataset ID.
802 query = (
803 sqlalchemy.sql.select(
804 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names],
805 tags.columns.dataset_id,
806 tmp_tags.columns.dataset_id.label("new_dataset_id"),
807 tags.columns[collFkName],
808 tmp_tags.columns[collFkName].label(f"new_{collFkName}"),
809 )
810 .select_from(
811 tags.join(
812 tmp_tags,
813 sqlalchemy.sql.and_(
814 tags.columns.dataset_type_id == tmp_tags.columns.dataset_type_id,
815 tags.columns[collFkName] == tmp_tags.columns[collFkName],
816 *[
817 tags.columns[dim] == tmp_tags.columns[dim]
818 for dim in self.datasetType.dimensions.required.names
819 ],
820 ),
821 )
822 )
823 .where(tags.columns.dataset_id != tmp_tags.columns.dataset_id)
824 .limit(1)
825 )
826 with self._db.query(query) as result:
827 # only include the first one in the exception message
828 if (row := result.first()) is not None:
829 data_id = {dim: getattr(row, dim) for dim in self.datasetType.dimensions.required.names}
830 existing_collection = self._collections[getattr(row, collFkName)].name
831 new_collection = self._collections[getattr(row, f"new_{collFkName}")].name
832 raise ConflictingDefinitionError(
833 f"Dataset with type {self.datasetType.name!r} and data ID {data_id} "
834 f"has ID {row.dataset_id} in existing collection {existing_collection!r} "
835 f"but ID {row.new_dataset_id} in new collection {new_collection!r}."
836 )