Coverage for python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py: 95%
260 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-01 10:59 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-01 10:59 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
29from __future__ import annotations
31from .... import ddl
33__all__ = ("ByDimensionsDatasetRecordStorage",)
35import datetime
36from collections.abc import Callable, Iterable, Iterator, Sequence, Set
37from typing import TYPE_CHECKING
39import astropy.time
40import sqlalchemy
41from lsst.daf.relation import Relation, sql
43from ...._column_tags import DatasetColumnTag, DimensionKeyColumnTag
44from ...._column_type_info import LogicalColumn
45from ...._dataset_ref import DatasetId, DatasetIdFactory, DatasetIdGenEnum, DatasetRef
46from ...._dataset_type import DatasetType
47from ...._timespan import Timespan
48from ....dimensions import DataCoordinate
49from ..._collection_summary import CollectionSummary
50from ..._collection_type import CollectionType
51from ..._exceptions import CollectionTypeError, ConflictingDefinitionError
52from ...interfaces import DatasetRecordStorage
53from ...queries import SqlQueryContext
54from .tables import makeTagTableSpec
56if TYPE_CHECKING:
57 from ...interfaces import CollectionManager, CollectionRecord, Database, RunRecord
58 from .summaries import CollectionSummaryManager
59 from .tables import StaticDatasetTablesTuple
62class ByDimensionsDatasetRecordStorage(DatasetRecordStorage):
63 """Dataset record storage implementation paired with
64 `ByDimensionsDatasetRecordStorageManagerUUID`; see that class for more
65 information.
67 Instances of this class should never be constructed directly; use
68 `DatasetRecordStorageManager.register` instead.
69 """
71 def __init__(
72 self,
73 *,
74 datasetType: DatasetType,
75 db: Database,
76 dataset_type_id: int,
77 collections: CollectionManager,
78 static: StaticDatasetTablesTuple,
79 summaries: CollectionSummaryManager,
80 tags_table_factory: Callable[[], sqlalchemy.schema.Table],
81 use_astropy_ingest_date: bool,
82 calibs_table_factory: Callable[[], sqlalchemy.schema.Table] | None,
83 ):
84 super().__init__(datasetType=datasetType)
85 self._dataset_type_id = dataset_type_id
86 self._db = db
87 self._collections = collections
88 self._static = static
89 self._summaries = summaries
90 self._tags_table_factory = tags_table_factory
91 self._calibs_table_factory = calibs_table_factory
92 self._runKeyColumn = collections.getRunForeignKeyName()
93 self._use_astropy = use_astropy_ingest_date
94 self._tags_table: sqlalchemy.schema.Table | None = None
95 self._calibs_table: sqlalchemy.schema.Table | None = None
97 @property
98 def _tags(self) -> sqlalchemy.schema.Table:
99 if self._tags_table is None:
100 self._tags_table = self._tags_table_factory()
101 return self._tags_table
103 @property
104 def _calibs(self) -> sqlalchemy.schema.Table | None:
105 if self._calibs_table is None:
106 if self._calibs_table_factory is None: 106 ↛ 107line 106 didn't jump to line 107, because the condition on line 106 was never true
107 return None
108 self._calibs_table = self._calibs_table_factory()
109 return self._calibs_table
111 def delete(self, datasets: Iterable[DatasetRef]) -> None:
112 # Docstring inherited from DatasetRecordStorage.
113 # Only delete from common dataset table; ON DELETE foreign key clauses
114 # will handle the rest.
115 self._db.delete(
116 self._static.dataset,
117 ["id"],
118 *[{"id": dataset.id} for dataset in datasets],
119 )
121 def associate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None:
122 # Docstring inherited from DatasetRecordStorage.
123 if collection.type is not CollectionType.TAGGED: 123 ↛ 124line 123 didn't jump to line 124, because the condition on line 123 was never true
124 raise TypeError(
125 f"Cannot associate into collection '{collection.name}' "
126 f"of type {collection.type.name}; must be TAGGED."
127 )
128 protoRow = {
129 self._collections.getCollectionForeignKeyName(): collection.key,
130 "dataset_type_id": self._dataset_type_id,
131 }
132 rows = []
133 summary = CollectionSummary()
134 for dataset in summary.add_datasets_generator(datasets):
135 rows.append(dict(protoRow, dataset_id=dataset.id, **dataset.dataId.required))
136 # Update the summary tables for this collection in case this is the
137 # first time this dataset type or these governor values will be
138 # inserted there.
139 self._summaries.update(collection, [self._dataset_type_id], summary)
140 # Update the tag table itself.
141 self._db.replace(self._tags, *rows)
143 def disassociate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None:
144 # Docstring inherited from DatasetRecordStorage.
145 if collection.type is not CollectionType.TAGGED: 145 ↛ 146line 145 didn't jump to line 146, because the condition on line 145 was never true
146 raise TypeError(
147 f"Cannot disassociate from collection '{collection.name}' "
148 f"of type {collection.type.name}; must be TAGGED."
149 )
150 rows = [
151 {
152 "dataset_id": dataset.id,
153 self._collections.getCollectionForeignKeyName(): collection.key,
154 }
155 for dataset in datasets
156 ]
157 self._db.delete(self._tags, ["dataset_id", self._collections.getCollectionForeignKeyName()], *rows)
159 def _buildCalibOverlapQuery(
160 self,
161 collection: CollectionRecord,
162 data_ids: set[DataCoordinate] | None,
163 timespan: Timespan,
164 context: SqlQueryContext,
165 ) -> Relation:
166 relation = self.make_relation(
167 collection, columns={"timespan", "dataset_id", "calib_pkey"}, context=context
168 ).with_rows_satisfying(
169 context.make_timespan_overlap_predicate(
170 DatasetColumnTag(self.datasetType.name, "timespan"), timespan
171 ),
172 )
173 if data_ids is not None:
174 relation = relation.join(
175 context.make_data_id_relation(
176 data_ids, self.datasetType.dimensions.required.names
177 ).transferred_to(context.sql_engine),
178 )
179 return relation
181 def certify(
182 self,
183 collection: CollectionRecord,
184 datasets: Iterable[DatasetRef],
185 timespan: Timespan,
186 context: SqlQueryContext,
187 ) -> None:
188 # Docstring inherited from DatasetRecordStorage.
189 if self._calibs is None: 189 ↛ 190line 189 didn't jump to line 190, because the condition on line 189 was never true
190 raise CollectionTypeError(
191 f"Cannot certify datasets of type {self.datasetType.name}, for which "
192 "DatasetType.isCalibration() is False."
193 )
194 if collection.type is not CollectionType.CALIBRATION: 194 ↛ 195line 194 didn't jump to line 195, because the condition on line 194 was never true
195 raise CollectionTypeError(
196 f"Cannot certify into collection '{collection.name}' "
197 f"of type {collection.type.name}; must be CALIBRATION."
198 )
199 TimespanReprClass = self._db.getTimespanRepresentation()
200 protoRow = {
201 self._collections.getCollectionForeignKeyName(): collection.key,
202 "dataset_type_id": self._dataset_type_id,
203 }
204 rows = []
205 dataIds: set[DataCoordinate] | None = (
206 set() if not TimespanReprClass.hasExclusionConstraint() else None
207 )
208 summary = CollectionSummary()
209 for dataset in summary.add_datasets_generator(datasets):
210 row = dict(protoRow, dataset_id=dataset.id, **dataset.dataId.required)
211 TimespanReprClass.update(timespan, result=row)
212 rows.append(row)
213 if dataIds is not None: 213 ↛ 209line 213 didn't jump to line 209, because the condition on line 213 was never false
214 dataIds.add(dataset.dataId)
215 # Update the summary tables for this collection in case this is the
216 # first time this dataset type or these governor values will be
217 # inserted there.
218 self._summaries.update(collection, [self._dataset_type_id], summary)
219 # Update the association table itself.
220 if TimespanReprClass.hasExclusionConstraint(): 220 ↛ 223line 220 didn't jump to line 223, because the condition on line 220 was never true
221 # Rely on database constraint to enforce invariants; we just
222 # reraise the exception for consistency across DB engines.
223 try:
224 self._db.insert(self._calibs, *rows)
225 except sqlalchemy.exc.IntegrityError as err:
226 raise ConflictingDefinitionError(
227 f"Validity range conflict certifying datasets of type {self.datasetType.name} "
228 f"into {collection.name} for range [{timespan.begin}, {timespan.end})."
229 ) from err
230 else:
231 # Have to implement exclusion constraint ourselves.
232 # Start by building a SELECT query for any rows that would overlap
233 # this one.
234 relation = self._buildCalibOverlapQuery(collection, dataIds, timespan, context)
235 # Acquire a table lock to ensure there are no concurrent writes
236 # could invalidate our checking before we finish the inserts. We
237 # use a SAVEPOINT in case there is an outer transaction that a
238 # failure here should not roll back.
239 with self._db.transaction(lock=[self._calibs], savepoint=True):
240 # Enter SqlQueryContext in case we need to use a temporary
241 # table to include the give data IDs in the query. Note that
242 # by doing this inside the transaction, we make sure it doesn't
243 # attempt to close the session when its done, since it just
244 # sees an already-open session that it knows it shouldn't
245 # manage.
246 with context:
247 # Run the check SELECT query.
248 conflicting = context.count(context.process(relation))
249 if conflicting > 0:
250 raise ConflictingDefinitionError(
251 f"{conflicting} validity range conflicts certifying datasets of type "
252 f"{self.datasetType.name} into {collection.name} for range "
253 f"[{timespan.begin}, {timespan.end})."
254 )
255 # Proceed with the insert.
256 self._db.insert(self._calibs, *rows)
258 def decertify(
259 self,
260 collection: CollectionRecord,
261 timespan: Timespan,
262 *,
263 dataIds: Iterable[DataCoordinate] | None = None,
264 context: SqlQueryContext,
265 ) -> None:
266 # Docstring inherited from DatasetRecordStorage.
267 if self._calibs is None: 267 ↛ 268line 267 didn't jump to line 268, because the condition on line 267 was never true
268 raise CollectionTypeError(
269 f"Cannot decertify datasets of type {self.datasetType.name}, for which "
270 "DatasetType.isCalibration() is False."
271 )
272 if collection.type is not CollectionType.CALIBRATION: 272 ↛ 273line 272 didn't jump to line 273, because the condition on line 272 was never true
273 raise CollectionTypeError(
274 f"Cannot decertify from collection '{collection.name}' "
275 f"of type {collection.type.name}; must be CALIBRATION."
276 )
277 TimespanReprClass = self._db.getTimespanRepresentation()
278 # Construct a SELECT query to find all rows that overlap our inputs.
279 dataIdSet: set[DataCoordinate] | None
280 if dataIds is not None:
281 dataIdSet = set(dataIds)
282 else:
283 dataIdSet = None
284 relation = self._buildCalibOverlapQuery(collection, dataIdSet, timespan, context)
285 calib_pkey_tag = DatasetColumnTag(self.datasetType.name, "calib_pkey")
286 dataset_id_tag = DatasetColumnTag(self.datasetType.name, "dataset_id")
287 timespan_tag = DatasetColumnTag(self.datasetType.name, "timespan")
288 data_id_tags = [
289 (name, DimensionKeyColumnTag(name)) for name in self.datasetType.dimensions.required.names
290 ]
291 # Set up collections to populate with the rows we'll want to modify.
292 # The insert rows will have the same values for collection and
293 # dataset type.
294 protoInsertRow = {
295 self._collections.getCollectionForeignKeyName(): collection.key,
296 "dataset_type_id": self._dataset_type_id,
297 }
298 rowsToDelete = []
299 rowsToInsert = []
300 # Acquire a table lock to ensure there are no concurrent writes
301 # between the SELECT and the DELETE and INSERT queries based on it.
302 with self._db.transaction(lock=[self._calibs], savepoint=True):
303 # Enter SqlQueryContext in case we need to use a temporary table to
304 # include the give data IDs in the query (see similar block in
305 # certify for details).
306 with context:
307 for row in context.fetch_iterable(relation):
308 rowsToDelete.append({"id": row[calib_pkey_tag]})
309 # Construct the insert row(s) by copying the prototype row,
310 # then adding the dimension column values, then adding
311 # what's left of the timespan from that row after we
312 # subtract the given timespan.
313 newInsertRow = protoInsertRow.copy()
314 newInsertRow["dataset_id"] = row[dataset_id_tag]
315 for name, tag in data_id_tags:
316 newInsertRow[name] = row[tag]
317 rowTimespan = row[timespan_tag]
318 assert rowTimespan is not None, "Field should have a NOT NULL constraint."
319 for diffTimespan in rowTimespan.difference(timespan):
320 rowsToInsert.append(
321 TimespanReprClass.update(diffTimespan, result=newInsertRow.copy())
322 )
323 # Run the DELETE and INSERT queries.
324 self._db.delete(self._calibs, ["id"], *rowsToDelete)
325 self._db.insert(self._calibs, *rowsToInsert)
327 def make_relation(
328 self,
329 *collections: CollectionRecord,
330 columns: Set[str],
331 context: SqlQueryContext,
332 ) -> Relation:
333 # Docstring inherited from DatasetRecordStorage.
334 collection_types = {collection.type for collection in collections}
335 assert CollectionType.CHAINED not in collection_types, "CHAINED collections must be flattened."
336 TimespanReprClass = self._db.getTimespanRepresentation()
337 #
338 # There are two kinds of table in play here:
339 #
340 # - the static dataset table (with the dataset ID, dataset type ID,
341 # run ID/name, and ingest date);
342 #
343 # - the dynamic tags/calibs table (with the dataset ID, dataset type
344 # type ID, collection ID/name, data ID, and possibly validity
345 # range).
346 #
347 # That means that we might want to return a query against either table
348 # or a JOIN of both, depending on which quantities the caller wants.
349 # But the data ID is always included, which means we'll always include
350 # the tags/calibs table and join in the static dataset table only if we
351 # need things from it that we can't get from the tags/calibs table.
352 #
353 # Note that it's important that we include a WHERE constraint on both
354 # tables for any column (e.g. dataset_type_id) that is in both when
355 # it's given explicitly; not doing can prevent the query planner from
356 # using very important indexes. At present, we don't include those
357 # redundant columns in the JOIN ON expression, however, because the
358 # FOREIGN KEY (and its index) are defined only on dataset_id.
359 tag_relation: Relation | None = None
360 calib_relation: Relation | None = None
361 if collection_types != {CollectionType.CALIBRATION}:
362 # We'll need a subquery for the tags table if any of the given
363 # collections are not a CALIBRATION collection. This intentionally
364 # also fires when the list of collections is empty as a way to
365 # create a dummy subquery that we know will fail.
366 # We give the table an alias because it might appear multiple times
367 # in the same query, for different dataset types.
368 tags_parts = sql.Payload[LogicalColumn](self._tags.alias(f"{self.datasetType.name}_tags"))
369 if "timespan" in columns:
370 tags_parts.columns_available[
371 DatasetColumnTag(self.datasetType.name, "timespan")
372 ] = TimespanReprClass.fromLiteral(Timespan(None, None))
373 tag_relation = self._finish_single_relation(
374 tags_parts,
375 columns,
376 [
377 (record, rank)
378 for rank, record in enumerate(collections)
379 if record.type is not CollectionType.CALIBRATION
380 ],
381 context,
382 )
383 assert "calib_pkey" not in columns, "For internal use only, and only for pure-calib queries."
384 if CollectionType.CALIBRATION in collection_types:
385 # If at least one collection is a CALIBRATION collection, we'll
386 # need a subquery for the calibs table, and could include the
387 # timespan as a result or constraint.
388 assert (
389 self._calibs is not None
390 ), "DatasetTypes with isCalibration() == False can never be found in a CALIBRATION collection."
391 calibs_parts = sql.Payload[LogicalColumn](self._calibs.alias(f"{self.datasetType.name}_calibs"))
392 if "timespan" in columns:
393 calibs_parts.columns_available[
394 DatasetColumnTag(self.datasetType.name, "timespan")
395 ] = TimespanReprClass.from_columns(calibs_parts.from_clause.columns)
396 if "calib_pkey" in columns:
397 # This is a private extension not included in the base class
398 # interface, for internal use only in _buildCalibOverlapQuery,
399 # which needs access to the autoincrement primary key for the
400 # calib association table.
401 calibs_parts.columns_available[
402 DatasetColumnTag(self.datasetType.name, "calib_pkey")
403 ] = calibs_parts.from_clause.columns.id
404 calib_relation = self._finish_single_relation(
405 calibs_parts,
406 columns,
407 [
408 (record, rank)
409 for rank, record in enumerate(collections)
410 if record.type is CollectionType.CALIBRATION
411 ],
412 context,
413 )
414 if tag_relation is not None:
415 if calib_relation is not None:
416 # daf_relation's chain operation does not automatically
417 # deduplicate; it's more like SQL's UNION ALL. To get UNION
418 # in SQL here, we add an explicit deduplication.
419 return tag_relation.chain(calib_relation).without_duplicates()
420 else:
421 return tag_relation
422 elif calib_relation is not None:
423 return calib_relation
424 else:
425 raise AssertionError("Branch should be unreachable.")
427 def _finish_single_relation(
428 self,
429 payload: sql.Payload[LogicalColumn],
430 requested_columns: Set[str],
431 collections: Sequence[tuple[CollectionRecord, int]],
432 context: SqlQueryContext,
433 ) -> Relation:
434 """Handle adding columns and WHERE terms that are not specific to
435 either the tags or calibs tables.
437 Helper method for `make_relation`.
439 Parameters
440 ----------
441 payload : `lsst.daf.relation.sql.Payload`
442 SQL query parts under construction, to be modified in-place and
443 used to construct the new relation.
444 requested_columns : `~collections.abc.Set` [ `str` ]
445 Columns the relation should include.
446 collections : `~collections.abc.Sequence` [ `tuple` \
447 [ `CollectionRecord`, `int` ] ]
448 Collections to search for the dataset and their ranks.
449 context : `SqlQueryContext`
450 Context that manages engines and state for the query.
452 Returns
453 -------
454 relation : `lsst.daf.relation.Relation`
455 New dataset query relation.
456 """
457 payload.where.append(payload.from_clause.columns.dataset_type_id == self._dataset_type_id)
458 dataset_id_col = payload.from_clause.columns.dataset_id
459 collection_col = payload.from_clause.columns[self._collections.getCollectionForeignKeyName()]
460 # We always constrain and optionally retrieve the collection(s) via the
461 # tags/calibs table.
462 if len(collections) == 1:
463 payload.where.append(collection_col == collections[0][0].key)
464 if "collection" in requested_columns:
465 payload.columns_available[
466 DatasetColumnTag(self.datasetType.name, "collection")
467 ] = sqlalchemy.sql.literal(collections[0][0].key)
468 else:
469 assert collections, "The no-collections case should be in calling code for better diagnostics."
470 payload.where.append(collection_col.in_([collection.key for collection, _ in collections]))
471 if "collection" in requested_columns:
472 payload.columns_available[
473 DatasetColumnTag(self.datasetType.name, "collection")
474 ] = collection_col
475 # Add rank if requested as a CASE-based calculation the collection
476 # column.
477 if "rank" in requested_columns:
478 payload.columns_available[DatasetColumnTag(self.datasetType.name, "rank")] = sqlalchemy.sql.case(
479 {record.key: rank for record, rank in collections},
480 value=collection_col,
481 )
482 # Add more column definitions, starting with the data ID.
483 for dimension_name in self.datasetType.dimensions.required.names:
484 payload.columns_available[DimensionKeyColumnTag(dimension_name)] = payload.from_clause.columns[
485 dimension_name
486 ]
487 # We can always get the dataset_id from the tags/calibs table.
488 if "dataset_id" in requested_columns:
489 payload.columns_available[DatasetColumnTag(self.datasetType.name, "dataset_id")] = dataset_id_col
490 # It's possible we now have everything we need, from just the
491 # tags/calibs table. The things we might need to get from the static
492 # dataset table are the run key and the ingest date.
493 need_static_table = False
494 if "run" in requested_columns:
495 if len(collections) == 1 and collections[0][0].type is CollectionType.RUN:
496 # If we are searching exactly one RUN collection, we
497 # know that if we find the dataset in that collection,
498 # then that's the datasets's run; we don't need to
499 # query for it.
500 payload.columns_available[
501 DatasetColumnTag(self.datasetType.name, "run")
502 ] = sqlalchemy.sql.literal(collections[0][0].key)
503 else:
504 payload.columns_available[
505 DatasetColumnTag(self.datasetType.name, "run")
506 ] = self._static.dataset.columns[self._runKeyColumn]
507 need_static_table = True
508 # Ingest date can only come from the static table.
509 if "ingest_date" in requested_columns:
510 need_static_table = True
511 payload.columns_available[
512 DatasetColumnTag(self.datasetType.name, "ingest_date")
513 ] = self._static.dataset.columns.ingest_date
514 # If we need the static table, join it in via dataset_id and
515 # dataset_type_id
516 if need_static_table:
517 payload.from_clause = payload.from_clause.join(
518 self._static.dataset, onclause=(dataset_id_col == self._static.dataset.columns.id)
519 )
520 # Also constrain dataset_type_id in static table in case that helps
521 # generate a better plan.
522 # We could also include this in the JOIN ON clause, but my guess is
523 # that that's a good idea IFF it's in the foreign key, and right
524 # now it isn't.
525 payload.where.append(self._static.dataset.columns.dataset_type_id == self._dataset_type_id)
526 leaf = context.sql_engine.make_leaf(
527 payload.columns_available.keys(),
528 payload=payload,
529 name=self.datasetType.name,
530 parameters={record.name: rank for record, rank in collections},
531 )
532 return leaf
534 def getDataId(self, id: DatasetId) -> DataCoordinate:
535 """Return DataId for a dataset.
537 Parameters
538 ----------
539 id : `DatasetId`
540 Unique dataset identifier.
542 Returns
543 -------
544 dataId : `DataCoordinate`
545 DataId for the dataset.
546 """
547 # This query could return multiple rows (one for each tagged collection
548 # the dataset is in, plus one for its run collection), and we don't
549 # care which of those we get.
550 sql = (
551 self._tags.select()
552 .where(
553 sqlalchemy.sql.and_(
554 self._tags.columns.dataset_id == id,
555 self._tags.columns.dataset_type_id == self._dataset_type_id,
556 )
557 )
558 .limit(1)
559 )
560 with self._db.query(sql) as sql_result:
561 row = sql_result.mappings().fetchone()
562 assert row is not None, "Should be guaranteed by caller and foreign key constraints."
563 return DataCoordinate.from_required_values(
564 self.datasetType.dimensions.as_group(),
565 tuple(row[dimension] for dimension in self.datasetType.dimensions.required.names),
566 )
569class ByDimensionsDatasetRecordStorageUUID(ByDimensionsDatasetRecordStorage):
570 """Implementation of ByDimensionsDatasetRecordStorage which uses UUID for
571 dataset IDs.
572 """
574 idMaker = DatasetIdFactory()
575 """Factory for dataset IDs. In the future this factory may be shared with
576 other classes (e.g. Registry)."""
578 def insert(
579 self,
580 run: RunRecord,
581 dataIds: Iterable[DataCoordinate],
582 idMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE,
583 ) -> Iterator[DatasetRef]:
584 # Docstring inherited from DatasetRecordStorage.
586 # Current timestamp, type depends on schema version. Use microsecond
587 # precision for astropy time to keep things consistent with
588 # TIMESTAMP(6) SQL type.
589 timestamp: datetime.datetime | astropy.time.Time
590 if self._use_astropy:
591 # Astropy `now()` precision should be the same as `now()` which
592 # should mean microsecond.
593 timestamp = astropy.time.Time.now()
594 else:
595 timestamp = datetime.datetime.now(datetime.UTC)
597 # Iterate over data IDs, transforming a possibly-single-pass iterable
598 # into a list.
599 dataIdList: list[DataCoordinate] = []
600 rows = []
601 summary = CollectionSummary()
602 for dataId in summary.add_data_ids_generator(self.datasetType, dataIds):
603 dataIdList.append(dataId)
604 rows.append(
605 {
606 "id": self.idMaker.makeDatasetId(run.name, self.datasetType, dataId, idMode),
607 "dataset_type_id": self._dataset_type_id,
608 self._runKeyColumn: run.key,
609 "ingest_date": timestamp,
610 }
611 )
613 with self._db.transaction():
614 # Insert into the static dataset table.
615 self._db.insert(self._static.dataset, *rows)
616 # Update the summary tables for this collection in case this is the
617 # first time this dataset type or these governor values will be
618 # inserted there.
619 self._summaries.update(run, [self._dataset_type_id], summary)
620 # Combine the generated dataset_id values and data ID fields to
621 # form rows to be inserted into the tags table.
622 protoTagsRow = {
623 "dataset_type_id": self._dataset_type_id,
624 self._collections.getCollectionForeignKeyName(): run.key,
625 }
626 tagsRows = [
627 dict(protoTagsRow, dataset_id=row["id"], **dataId.required)
628 for dataId, row in zip(dataIdList, rows, strict=True)
629 ]
630 # Insert those rows into the tags table.
631 self._db.insert(self._tags, *tagsRows)
633 for dataId, row in zip(dataIdList, rows, strict=True):
634 yield DatasetRef(
635 datasetType=self.datasetType,
636 dataId=dataId,
637 id=row["id"],
638 run=run.name,
639 )
641 def import_(
642 self,
643 run: RunRecord,
644 datasets: Iterable[DatasetRef],
645 ) -> Iterator[DatasetRef]:
646 # Docstring inherited from DatasetRecordStorage.
648 # Current timestamp, type depends on schema version.
649 if self._use_astropy:
650 # Astropy `now()` precision should be the same as `now()` which
651 # should mean microsecond.
652 timestamp = sqlalchemy.sql.literal(astropy.time.Time.now(), type_=ddl.AstropyTimeNsecTai)
653 else:
654 timestamp = sqlalchemy.sql.literal(datetime.datetime.now(datetime.UTC))
656 # Iterate over data IDs, transforming a possibly-single-pass iterable
657 # into a list.
658 dataIds: dict[DatasetId, DataCoordinate] = {}
659 summary = CollectionSummary()
660 for dataset in summary.add_datasets_generator(datasets):
661 dataIds[dataset.id] = dataset.dataId
663 # We'll insert all new rows into a temporary table
664 tableSpec = makeTagTableSpec(self.datasetType, type(self._collections), ddl.GUID, constraints=False)
665 collFkName = self._collections.getCollectionForeignKeyName()
666 protoTagsRow = {
667 "dataset_type_id": self._dataset_type_id,
668 collFkName: run.key,
669 }
670 tmpRows = [
671 dict(protoTagsRow, dataset_id=dataset_id, **dataId.required)
672 for dataset_id, dataId in dataIds.items()
673 ]
674 with self._db.transaction(for_temp_tables=True), self._db.temporary_table(tableSpec) as tmp_tags:
675 # store all incoming data in a temporary table
676 self._db.insert(tmp_tags, *tmpRows)
678 # There are some checks that we want to make for consistency
679 # of the new datasets with existing ones.
680 self._validateImport(tmp_tags, run)
682 # Before we merge temporary table into dataset/tags we need to
683 # drop datasets which are already there (and do not conflict).
684 self._db.deleteWhere(
685 tmp_tags,
686 tmp_tags.columns.dataset_id.in_(sqlalchemy.sql.select(self._static.dataset.columns.id)),
687 )
689 # Copy it into dataset table, need to re-label some columns.
690 self._db.insert(
691 self._static.dataset,
692 select=sqlalchemy.sql.select(
693 tmp_tags.columns.dataset_id.label("id"),
694 tmp_tags.columns.dataset_type_id,
695 tmp_tags.columns[collFkName].label(self._runKeyColumn),
696 timestamp.label("ingest_date"),
697 ),
698 )
700 # Update the summary tables for this collection in case this
701 # is the first time this dataset type or these governor values
702 # will be inserted there.
703 self._summaries.update(run, [self._dataset_type_id], summary)
705 # Copy it into tags table.
706 self._db.insert(self._tags, select=tmp_tags.select())
708 # Return refs in the same order as in the input list.
709 for dataset_id, dataId in dataIds.items():
710 yield DatasetRef(
711 datasetType=self.datasetType,
712 id=dataset_id,
713 dataId=dataId,
714 run=run.name,
715 )
717 def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) -> None:
718 """Validate imported refs against existing datasets.
720 Parameters
721 ----------
722 tmp_tags : `sqlalchemy.schema.Table`
723 Temporary table with new datasets and the same schema as tags
724 table.
725 run : `RunRecord`
726 The record object describing the `~CollectionType.RUN` collection.
728 Raises
729 ------
730 ConflictingDefinitionError
731 Raise if new datasets conflict with existing ones.
732 """
733 dataset = self._static.dataset
734 tags = self._tags
735 collFkName = self._collections.getCollectionForeignKeyName()
737 # Check that existing datasets have the same dataset type and
738 # run.
739 query = (
740 sqlalchemy.sql.select(
741 dataset.columns.id.label("dataset_id"),
742 dataset.columns.dataset_type_id.label("dataset_type_id"),
743 tmp_tags.columns.dataset_type_id.label("new_dataset_type_id"),
744 dataset.columns[self._runKeyColumn].label("run"),
745 tmp_tags.columns[collFkName].label("new_run"),
746 )
747 .select_from(dataset.join(tmp_tags, dataset.columns.id == tmp_tags.columns.dataset_id))
748 .where(
749 sqlalchemy.sql.or_(
750 dataset.columns.dataset_type_id != tmp_tags.columns.dataset_type_id,
751 dataset.columns[self._runKeyColumn] != tmp_tags.columns[collFkName],
752 )
753 )
754 .limit(1)
755 )
756 with self._db.query(query) as result:
757 # Only include the first one in the exception message
758 if (row := result.first()) is not None:
759 existing_run = self._collections[row.run].name
760 new_run = self._collections[row.new_run].name
761 if row.dataset_type_id == self._dataset_type_id:
762 if row.new_dataset_type_id == self._dataset_type_id: 762 ↛ 768line 762 didn't jump to line 768, because the condition on line 762 was never false
763 raise ConflictingDefinitionError(
764 f"Current run {existing_run!r} and new run {new_run!r} do not agree for "
765 f"dataset {row.dataset_id}."
766 )
767 else:
768 raise ConflictingDefinitionError(
769 f"Dataset {row.dataset_id} was provided with type {self.datasetType.name!r} "
770 f"in run {new_run!r}, but was already defined with type ID {row.dataset_type_id} "
771 f"in run {run!r}."
772 )
773 else:
774 raise ConflictingDefinitionError(
775 f"Dataset {row.dataset_id} was provided with type ID {row.new_dataset_type_id} "
776 f"in run {new_run!r}, but was already defined with type {self.datasetType.name!r} "
777 f"in run {run!r}."
778 )
780 # Check that matching dataset in tags table has the same DataId.
781 query = (
782 sqlalchemy.sql.select(
783 tags.columns.dataset_id,
784 tags.columns.dataset_type_id.label("type_id"),
785 tmp_tags.columns.dataset_type_id.label("new_type_id"),
786 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names],
787 *[
788 tmp_tags.columns[dim].label(f"new_{dim}")
789 for dim in self.datasetType.dimensions.required.names
790 ],
791 )
792 .select_from(tags.join(tmp_tags, tags.columns.dataset_id == tmp_tags.columns.dataset_id))
793 .where(
794 sqlalchemy.sql.or_(
795 tags.columns.dataset_type_id != tmp_tags.columns.dataset_type_id,
796 *[
797 tags.columns[dim] != tmp_tags.columns[dim]
798 for dim in self.datasetType.dimensions.required.names
799 ],
800 )
801 )
802 .limit(1)
803 )
805 with self._db.query(query) as result:
806 if (row := result.first()) is not None:
807 # Only include the first one in the exception message
808 raise ConflictingDefinitionError(
809 f"Existing dataset type or dataId do not match new dataset: {row._asdict()}"
810 )
812 # Check that matching run+dataId have the same dataset ID.
813 query = (
814 sqlalchemy.sql.select(
815 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names],
816 tags.columns.dataset_id,
817 tmp_tags.columns.dataset_id.label("new_dataset_id"),
818 tags.columns[collFkName],
819 tmp_tags.columns[collFkName].label(f"new_{collFkName}"),
820 )
821 .select_from(
822 tags.join(
823 tmp_tags,
824 sqlalchemy.sql.and_(
825 tags.columns.dataset_type_id == tmp_tags.columns.dataset_type_id,
826 tags.columns[collFkName] == tmp_tags.columns[collFkName],
827 *[
828 tags.columns[dim] == tmp_tags.columns[dim]
829 for dim in self.datasetType.dimensions.required.names
830 ],
831 ),
832 )
833 )
834 .where(tags.columns.dataset_id != tmp_tags.columns.dataset_id)
835 .limit(1)
836 )
837 with self._db.query(query) as result:
838 # only include the first one in the exception message
839 if (row := result.first()) is not None:
840 data_id = {dim: getattr(row, dim) for dim in self.datasetType.dimensions.required.names}
841 existing_collection = self._collections[getattr(row, collFkName)].name
842 new_collection = self._collections[getattr(row, f"new_{collFkName}")].name
843 raise ConflictingDefinitionError(
844 f"Dataset with type {self.datasetType.name!r} and data ID {data_id} "
845 f"has ID {row.dataset_id} in existing collection {existing_collection!r} "
846 f"but ID {row.new_dataset_id} in new collection {new_collection!r}."
847 )