Coverage for python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py: 89%
301 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-07 09:46 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-07 09:46 +0000
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
23from __future__ import annotations
25__all__ = ("ByDimensionsDatasetRecordStorage",)
27import uuid
28from collections.abc import Iterable, Iterator, Sequence
29from typing import TYPE_CHECKING, Any
31import sqlalchemy
33from ....core import (
34 DataCoordinate,
35 DataCoordinateSet,
36 DatasetId,
37 DatasetRef,
38 DatasetType,
39 SimpleQuery,
40 Timespan,
41 ddl,
42)
43from ..._collection_summary import CollectionSummary
44from ..._collectionType import CollectionType
45from ..._exceptions import CollectionTypeError, ConflictingDefinitionError, UnsupportedIdGeneratorError
46from ...interfaces import DatasetIdFactory, DatasetIdGenEnum, DatasetRecordStorage
47from .tables import makeTagTableSpec
49if TYPE_CHECKING: 49 ↛ 50line 49 didn't jump to line 50, because the condition on line 49 was never true
50 from ...interfaces import CollectionManager, CollectionRecord, Database, RunRecord
51 from .summaries import CollectionSummaryManager
52 from .tables import StaticDatasetTablesTuple
55class ByDimensionsDatasetRecordStorage(DatasetRecordStorage):
56 """Dataset record storage implementation paired with
57 `ByDimensionsDatasetRecordStorageManager`; see that class for more
58 information.
60 Instances of this class should never be constructed directly; use
61 `DatasetRecordStorageManager.register` instead.
62 """
64 def __init__(
65 self,
66 *,
67 datasetType: DatasetType,
68 db: Database,
69 dataset_type_id: int,
70 collections: CollectionManager,
71 static: StaticDatasetTablesTuple,
72 summaries: CollectionSummaryManager,
73 tags: sqlalchemy.schema.Table,
74 calibs: sqlalchemy.schema.Table | None,
75 ):
76 super().__init__(datasetType=datasetType)
77 self._dataset_type_id = dataset_type_id
78 self._db = db
79 self._collections = collections
80 self._static = static
81 self._summaries = summaries
82 self._tags = tags
83 self._calibs = calibs
84 self._runKeyColumn = collections.getRunForeignKeyName()
86 def find(
87 self, collection: CollectionRecord, dataId: DataCoordinate, timespan: Timespan | None = None
88 ) -> DatasetRef | None:
89 # Docstring inherited from DatasetRecordStorage.
90 assert dataId.graph == self.datasetType.dimensions
91 if collection.type is CollectionType.CALIBRATION and timespan is None: 91 ↛ 92line 91 didn't jump to line 92, because the condition on line 91 was never true
92 raise TypeError(
93 f"Cannot search for dataset in CALIBRATION collection {collection.name} "
94 f"without an input timespan."
95 )
96 sql = self.select(
97 collection, dataId=dataId, id=SimpleQuery.Select, run=SimpleQuery.Select, timespan=timespan
98 )
99 results = self._db.query(sql)
100 row = results.fetchone()
101 if row is None:
102 return None
103 if collection.type is CollectionType.CALIBRATION:
104 # For temporal calibration lookups (only!) our invariants do not
105 # guarantee that the number of result rows is <= 1.
106 # They would if `select` constrained the given timespan to be
107 # _contained_ by the validity range in the self._calibs table,
108 # instead of simply _overlapping_ it, because we do guarantee that
109 # the validity ranges are disjoint for a particular dataset type,
110 # collection, and data ID. But using an overlap test and a check
111 # for multiple result rows here allows us to provide a more useful
112 # diagnostic, as well as allowing `select` to support more general
113 # queries where multiple results are not an error.
114 if results.fetchone() is not None:
115 raise RuntimeError(
116 f"Multiple matches found for calibration lookup in {collection.name} for "
117 f"{self.datasetType.name} with {dataId} overlapping {timespan}. "
118 )
119 return DatasetRef(
120 datasetType=self.datasetType,
121 dataId=dataId,
122 id=row.id,
123 run=self._collections[row._mapping[self._runKeyColumn]].name,
124 )
126 def delete(self, datasets: Iterable[DatasetRef]) -> None:
127 # Docstring inherited from DatasetRecordStorage.
128 # Only delete from common dataset table; ON DELETE foreign key clauses
129 # will handle the rest.
130 self._db.delete(
131 self._static.dataset,
132 ["id"],
133 *[{"id": dataset.getCheckedId()} for dataset in datasets],
134 )
136 def associate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None:
137 # Docstring inherited from DatasetRecordStorage.
138 if collection.type is not CollectionType.TAGGED: 138 ↛ 139line 138 didn't jump to line 139, because the condition on line 138 was never true
139 raise TypeError(
140 f"Cannot associate into collection '{collection.name}' "
141 f"of type {collection.type.name}; must be TAGGED."
142 )
143 protoRow = {
144 self._collections.getCollectionForeignKeyName(): collection.key,
145 "dataset_type_id": self._dataset_type_id,
146 }
147 rows = []
148 summary = CollectionSummary()
149 for dataset in summary.add_datasets_generator(datasets):
150 row = dict(protoRow, dataset_id=dataset.getCheckedId())
151 for dimension, value in dataset.dataId.items():
152 row[dimension.name] = value
153 rows.append(row)
154 # Update the summary tables for this collection in case this is the
155 # first time this dataset type or these governor values will be
156 # inserted there.
157 self._summaries.update(collection, [self._dataset_type_id], summary)
158 # Update the tag table itself.
159 self._db.replace(self._tags, *rows)
161 def disassociate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None:
162 # Docstring inherited from DatasetRecordStorage.
163 if collection.type is not CollectionType.TAGGED: 163 ↛ 164line 163 didn't jump to line 164, because the condition on line 163 was never true
164 raise TypeError(
165 f"Cannot disassociate from collection '{collection.name}' "
166 f"of type {collection.type.name}; must be TAGGED."
167 )
168 rows = [
169 {
170 "dataset_id": dataset.getCheckedId(),
171 self._collections.getCollectionForeignKeyName(): collection.key,
172 }
173 for dataset in datasets
174 ]
175 self._db.delete(self._tags, ["dataset_id", self._collections.getCollectionForeignKeyName()], *rows)
177 def _buildCalibOverlapQuery(
178 self, collection: CollectionRecord, dataIds: DataCoordinateSet | None, timespan: Timespan
179 ) -> SimpleQuery:
180 assert self._calibs is not None
181 # Start by building a SELECT query for any rows that would overlap
182 # this one.
183 query = SimpleQuery()
184 query.join(self._calibs)
185 # Add a WHERE clause matching the dataset type and collection.
186 query.where.append(self._calibs.columns.dataset_type_id == self._dataset_type_id)
187 query.where.append(
188 self._calibs.columns[self._collections.getCollectionForeignKeyName()] == collection.key
189 )
190 # Add a WHERE clause matching any of the given data IDs.
191 if dataIds is not None:
192 dataIds.constrain(
193 query,
194 lambda name: self._calibs.columns[name], # type: ignore
195 )
196 # Add WHERE clause for timespan overlaps.
197 TimespanReprClass = self._db.getTimespanRepresentation()
198 query.where.append(
199 TimespanReprClass.from_columns(self._calibs.columns).overlaps(
200 TimespanReprClass.fromLiteral(timespan)
201 )
202 )
203 return query
205 def certify(
206 self, collection: CollectionRecord, datasets: Iterable[DatasetRef], timespan: Timespan
207 ) -> None:
208 # Docstring inherited from DatasetRecordStorage.
209 if self._calibs is None: 209 ↛ 210line 209 didn't jump to line 210, because the condition on line 209 was never true
210 raise CollectionTypeError(
211 f"Cannot certify datasets of type {self.datasetType.name}, for which "
212 f"DatasetType.isCalibration() is False."
213 )
214 if collection.type is not CollectionType.CALIBRATION: 214 ↛ 215line 214 didn't jump to line 215, because the condition on line 214 was never true
215 raise CollectionTypeError(
216 f"Cannot certify into collection '{collection.name}' "
217 f"of type {collection.type.name}; must be CALIBRATION."
218 )
219 TimespanReprClass = self._db.getTimespanRepresentation()
220 protoRow = {
221 self._collections.getCollectionForeignKeyName(): collection.key,
222 "dataset_type_id": self._dataset_type_id,
223 }
224 rows = []
225 dataIds: set[DataCoordinate] | None = (
226 set() if not TimespanReprClass.hasExclusionConstraint() else None
227 )
228 summary = CollectionSummary()
229 for dataset in summary.add_datasets_generator(datasets):
230 row = dict(protoRow, dataset_id=dataset.getCheckedId())
231 for dimension, value in dataset.dataId.items():
232 row[dimension.name] = value
233 TimespanReprClass.update(timespan, result=row)
234 rows.append(row)
235 if dataIds is not None: 235 ↛ 229line 235 didn't jump to line 229, because the condition on line 235 was never false
236 dataIds.add(dataset.dataId)
237 # Update the summary tables for this collection in case this is the
238 # first time this dataset type or these governor values will be
239 # inserted there.
240 self._summaries.update(collection, [self._dataset_type_id], summary)
241 # Update the association table itself.
242 if TimespanReprClass.hasExclusionConstraint(): 242 ↛ 245line 242 didn't jump to line 245, because the condition on line 242 was never true
243 # Rely on database constraint to enforce invariants; we just
244 # reraise the exception for consistency across DB engines.
245 try:
246 self._db.insert(self._calibs, *rows)
247 except sqlalchemy.exc.IntegrityError as err:
248 raise ConflictingDefinitionError(
249 f"Validity range conflict certifying datasets of type {self.datasetType.name} "
250 f"into {collection.name} for range [{timespan.begin}, {timespan.end})."
251 ) from err
252 else:
253 # Have to implement exclusion constraint ourselves.
254 # Start by building a SELECT query for any rows that would overlap
255 # this one.
256 query = self._buildCalibOverlapQuery(
257 collection,
258 DataCoordinateSet(dataIds, graph=self.datasetType.dimensions), # type: ignore
259 timespan,
260 )
261 query.columns.append(sqlalchemy.sql.func.count())
262 sql = query.combine()
263 # Acquire a table lock to ensure there are no concurrent writes
264 # could invalidate our checking before we finish the inserts. We
265 # use a SAVEPOINT in case there is an outer transaction that a
266 # failure here should not roll back.
267 with self._db.transaction(lock=[self._calibs], savepoint=True):
268 # Run the check SELECT query.
269 conflicting = self._db.query(sql).scalar()
270 if conflicting > 0:
271 raise ConflictingDefinitionError(
272 f"{conflicting} validity range conflicts certifying datasets of type "
273 f"{self.datasetType.name} into {collection.name} for range "
274 f"[{timespan.begin}, {timespan.end})."
275 )
276 # Proceed with the insert.
277 self._db.insert(self._calibs, *rows)
279 def decertify(
280 self,
281 collection: CollectionRecord,
282 timespan: Timespan,
283 *,
284 dataIds: Iterable[DataCoordinate] | None = None,
285 ) -> None:
286 # Docstring inherited from DatasetRecordStorage.
287 if self._calibs is None: 287 ↛ 288line 287 didn't jump to line 288, because the condition on line 287 was never true
288 raise CollectionTypeError(
289 f"Cannot decertify datasets of type {self.datasetType.name}, for which "
290 f"DatasetType.isCalibration() is False."
291 )
292 if collection.type is not CollectionType.CALIBRATION: 292 ↛ 293line 292 didn't jump to line 293, because the condition on line 292 was never true
293 raise CollectionTypeError(
294 f"Cannot decertify from collection '{collection.name}' "
295 f"of type {collection.type.name}; must be CALIBRATION."
296 )
297 TimespanReprClass = self._db.getTimespanRepresentation()
298 # Construct a SELECT query to find all rows that overlap our inputs.
299 dataIdSet: DataCoordinateSet | None
300 if dataIds is not None:
301 dataIdSet = DataCoordinateSet(set(dataIds), graph=self.datasetType.dimensions)
302 else:
303 dataIdSet = None
304 query = self._buildCalibOverlapQuery(collection, dataIdSet, timespan)
305 query.columns.extend(self._calibs.columns)
306 sql = query.combine()
307 # Set up collections to populate with the rows we'll want to modify.
308 # The insert rows will have the same values for collection and
309 # dataset type.
310 protoInsertRow = {
311 self._collections.getCollectionForeignKeyName(): collection.key,
312 "dataset_type_id": self._dataset_type_id,
313 }
314 rowsToDelete = []
315 rowsToInsert = []
316 # Acquire a table lock to ensure there are no concurrent writes
317 # between the SELECT and the DELETE and INSERT queries based on it.
318 with self._db.transaction(lock=[self._calibs], savepoint=True):
319 for row in self._db.query(sql).mappings():
320 rowsToDelete.append({"id": row["id"]})
321 # Construct the insert row(s) by copying the prototype row,
322 # then adding the dimension column values, then adding what's
323 # left of the timespan from that row after we subtract the
324 # given timespan.
325 newInsertRow = protoInsertRow.copy()
326 newInsertRow["dataset_id"] = row["dataset_id"]
327 for name in self.datasetType.dimensions.required.names:
328 newInsertRow[name] = row[name]
329 rowTimespan = TimespanReprClass.extract(row)
330 assert rowTimespan is not None, "Field should have a NOT NULL constraint."
331 for diffTimespan in rowTimespan.difference(timespan):
332 rowsToInsert.append(TimespanReprClass.update(diffTimespan, result=newInsertRow.copy()))
333 # Run the DELETE and INSERT queries.
334 self._db.delete(self._calibs, ["id"], *rowsToDelete)
335 self._db.insert(self._calibs, *rowsToInsert)
337 def select(
338 self,
339 *collections: CollectionRecord,
340 dataId: SimpleQuery.Select.Or[DataCoordinate] = SimpleQuery.Select,
341 id: SimpleQuery.Select.Or[int | None] = SimpleQuery.Select,
342 run: SimpleQuery.Select.Or[None] = SimpleQuery.Select,
343 timespan: SimpleQuery.Select.Or[Timespan | None] = SimpleQuery.Select,
344 ingestDate: SimpleQuery.Select.Or[Timespan | None] = None,
345 rank: SimpleQuery.Select.Or[None] = None,
346 ) -> sqlalchemy.sql.Selectable:
347 # Docstring inherited from DatasetRecordStorage.
348 collection_types = {collection.type for collection in collections}
349 assert CollectionType.CHAINED not in collection_types, "CHAINED collections must be flattened."
350 TimespanReprClass = self._db.getTimespanRepresentation()
351 #
352 # There are two kinds of table in play here:
353 #
354 # - the static dataset table (with the dataset ID, dataset type ID,
355 # run ID/name, and ingest date);
356 #
357 # - the dynamic tags/calibs table (with the dataset ID, dataset type
358 # type ID, collection ID/name, data ID, and possibly validity
359 # range).
360 #
361 # That means that we might want to return a query against either table
362 # or a JOIN of both, depending on which quantities the caller wants.
363 # But this method is documented/typed such that ``dataId`` is never
364 # `None` - i.e. we always constrain or retreive the data ID. That
365 # means we'll always include the tags/calibs table and join in the
366 # static dataset table only if we need things from it that we can't get
367 # from the tags/calibs table.
368 #
369 # Note that it's important that we include a WHERE constraint on both
370 # tables for any column (e.g. dataset_type_id) that is in both when
371 # it's given explicitly; not doing can prevent the query planner from
372 # using very important indexes. At present, we don't include those
373 # redundant columns in the JOIN ON expression, however, because the
374 # FOREIGN KEY (and its index) are defined only on dataset_id.
375 #
376 # We'll start by accumulating kwargs to pass to SimpleQuery.join when
377 # we bring in the tags/calibs table. We get the data ID or constrain
378 # it in the tags/calibs table(s), but that's multiple columns, not one,
379 # so we need to transform the one Select.Or argument into a dictionary
380 # of them.
381 kwargs: dict[str, Any]
382 if dataId is SimpleQuery.Select:
383 kwargs = {dim.name: SimpleQuery.Select for dim in self.datasetType.dimensions.required}
384 else:
385 kwargs = dict(dataId.byName())
386 # We always constrain (never retrieve) the dataset type in at least the
387 # tags/calibs table.
388 kwargs["dataset_type_id"] = self._dataset_type_id
389 # Join in the tags and/or calibs tables, turning those 'kwargs' entries
390 # into WHERE constraints or SELECT columns as appropriate.
391 if collection_types != {CollectionType.CALIBRATION}:
392 # We'll need a subquery for the tags table if any of the given
393 # collections are not a CALIBRATION collection. This intentionally
394 # also fires when the list of collections is empty as a way to
395 # create a dummy subquery that we know will fail.
396 tags_query = SimpleQuery()
397 tags_query.join(self._tags, **kwargs)
398 # If the timespan is requested, simulate a potentially compound
399 # column whose values are the maximum and minimum timespan
400 # bounds.
401 # If the timespan is constrained, ignore the constraint, since
402 # it'd be guaranteed to evaluate to True.
403 if timespan is SimpleQuery.Select:
404 tags_query.columns.extend(TimespanReprClass.fromLiteral(Timespan(None, None)).flatten())
405 self._finish_single_select(
406 tags_query,
407 self._tags,
408 collections,
409 id=id,
410 run=run,
411 ingestDate=ingestDate,
412 rank=rank,
413 )
414 else:
415 tags_query = None
416 if CollectionType.CALIBRATION in collection_types:
417 # If at least one collection is a CALIBRATION collection, we'll
418 # need a subquery for the calibs table, and could include the
419 # timespan as a result or constraint.
420 calibs_query = SimpleQuery()
421 assert (
422 self._calibs is not None
423 ), "DatasetTypes with isCalibration() == False can never be found in a CALIBRATION collection."
424 calibs_query.join(self._calibs, **kwargs)
425 # Add the timespan column(s) to the result columns, or constrain
426 # the timespan via an overlap condition.
427 if timespan is SimpleQuery.Select:
428 calibs_query.columns.extend(TimespanReprClass.from_columns(self._calibs.columns).flatten())
429 elif timespan is not None:
430 calibs_query.where.append(
431 TimespanReprClass.from_columns(self._calibs.columns).overlaps(
432 TimespanReprClass.fromLiteral(timespan)
433 )
434 )
435 self._finish_single_select(
436 calibs_query,
437 self._calibs,
438 collections,
439 id=id,
440 run=run,
441 ingestDate=ingestDate,
442 rank=rank,
443 )
444 else:
445 calibs_query = None
446 if calibs_query is not None:
447 if tags_query is not None:
448 return tags_query.combine().union(calibs_query.combine())
449 else:
450 return calibs_query.combine()
451 else:
452 assert tags_query is not None, "Earlier logic should guaranteed at least one is not None."
453 return tags_query.combine()
455 def _finish_single_select(
456 self,
457 query: SimpleQuery,
458 table: sqlalchemy.schema.Table,
459 collections: Sequence[CollectionRecord],
460 id: SimpleQuery.Select.Or[int | None],
461 run: SimpleQuery.Select.Or[None],
462 ingestDate: SimpleQuery.Select.Or[Timespan | None],
463 rank: SimpleQuery.Select.Or[None],
464 ) -> None:
465 dataset_id_col = table.columns.dataset_id
466 collection_col = table.columns[self._collections.getCollectionForeignKeyName()]
467 # We always constrain (never retrieve) the collection(s) in the
468 # tags/calibs table.
469 if len(collections) == 1:
470 query.where.append(collection_col == collections[0].key)
471 elif len(collections) == 0:
472 # We support the case where there are no collections as a way to
473 # generate a valid SQL query that can't yield results. This should
474 # never get executed, but lots of downstream code will still try
475 # to access the SQLAlchemy objects representing the columns in the
476 # subquery. That's not ideal, but it'd take a lot of refactoring
477 # to fix it (DM-31725).
478 query.where.append(sqlalchemy.sql.literal(False))
479 else:
480 query.where.append(collection_col.in_([collection.key for collection in collections]))
481 # Add rank if requested as a CASE-based calculation the collection
482 # column.
483 if rank is not None:
484 assert rank is SimpleQuery.Select, "Cannot constraint rank, only select it."
485 query.columns.append(
486 sqlalchemy.sql.case(
487 {record.key: n for n, record in enumerate(collections)},
488 value=collection_col,
489 ).label("rank")
490 )
491 # We can always get the dataset_id from the tags/calibs table or
492 # constrain it there. Can't use kwargs for that because we need to
493 # alias it to 'id'.
494 if id is SimpleQuery.Select:
495 query.columns.append(dataset_id_col.label("id"))
496 elif id is not None: 496 ↛ 497line 496 didn't jump to line 497, because the condition on line 496 was never true
497 query.where.append(dataset_id_col == id)
498 # It's possible we now have everything we need, from just the
499 # tags/calibs table. The things we might need to get from the static
500 # dataset table are the run key and the ingest date.
501 need_static_table = False
502 static_kwargs: dict[str, Any] = {}
503 if run is not None:
504 assert run is SimpleQuery.Select, "To constrain the run name, pass a RunRecord as a collection."
505 if len(collections) == 1 and collections[0].type is CollectionType.RUN:
506 # If we are searching exactly one RUN collection, we
507 # know that if we find the dataset in that collection,
508 # then that's the datasets's run; we don't need to
509 # query for it.
510 query.columns.append(sqlalchemy.sql.literal(collections[0].key).label(self._runKeyColumn))
511 else:
512 static_kwargs[self._runKeyColumn] = SimpleQuery.Select
513 need_static_table = True
514 # Ingest date can only come from the static table.
515 if ingestDate is not None:
516 need_static_table = True
517 if ingestDate is SimpleQuery.Select: 517 ↛ 520line 517 didn't jump to line 520, because the condition on line 517 was never false
518 static_kwargs["ingest_date"] = SimpleQuery.Select
519 else:
520 assert isinstance(ingestDate, Timespan)
521 # Timespan is astropy Time (usually in TAI) and ingest_date is
522 # TIMESTAMP, convert values to Python datetime for sqlalchemy.
523 if ingestDate.isEmpty():
524 raise RuntimeError("Empty timespan constraint provided for ingest_date.")
525 if ingestDate.begin is not None:
526 begin = ingestDate.begin.utc.datetime # type: ignore
527 query.where.append(self._static.dataset.columns.ingest_date >= begin)
528 if ingestDate.end is not None:
529 end = ingestDate.end.utc.datetime # type: ignore
530 query.where.append(self._static.dataset.columns.ingest_date < end)
531 # If we need the static table, join it in via dataset_id and
532 # dataset_type_id
533 if need_static_table:
534 query.join(
535 self._static.dataset,
536 onclause=(dataset_id_col == self._static.dataset.columns.id),
537 **static_kwargs,
538 )
539 # Also constrain dataset_type_id in static table in case that helps
540 # generate a better plan.
541 # We could also include this in the JOIN ON clause, but my guess is
542 # that that's a good idea IFF it's in the foreign key, and right
543 # now it isn't.
544 query.where.append(self._static.dataset.columns.dataset_type_id == self._dataset_type_id)
546 def getDataId(self, id: DatasetId) -> DataCoordinate:
547 """Return DataId for a dataset.
549 Parameters
550 ----------
551 id : `DatasetId`
552 Unique dataset identifier.
554 Returns
555 -------
556 dataId : `DataCoordinate`
557 DataId for the dataset.
558 """
559 # This query could return multiple rows (one for each tagged collection
560 # the dataset is in, plus one for its run collection), and we don't
561 # care which of those we get.
562 sql = (
563 self._tags.select()
564 .where(
565 sqlalchemy.sql.and_(
566 self._tags.columns.dataset_id == id,
567 self._tags.columns.dataset_type_id == self._dataset_type_id,
568 )
569 )
570 .limit(1)
571 )
572 row = self._db.query(sql).mappings().fetchone()
573 assert row is not None, "Should be guaranteed by caller and foreign key constraints."
574 return DataCoordinate.standardize(
575 {dimension.name: row[dimension.name] for dimension in self.datasetType.dimensions.required},
576 graph=self.datasetType.dimensions,
577 )
580class ByDimensionsDatasetRecordStorageInt(ByDimensionsDatasetRecordStorage):
581 """Implementation of ByDimensionsDatasetRecordStorage which uses integer
582 auto-incremented column for dataset IDs.
583 """
585 def insert(
586 self,
587 run: RunRecord,
588 dataIds: Iterable[DataCoordinate],
589 idMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE,
590 ) -> Iterator[DatasetRef]:
591 # Docstring inherited from DatasetRecordStorage.
593 # We only support UNIQUE mode for integer dataset IDs
594 if idMode != DatasetIdGenEnum.UNIQUE: 594 ↛ 595line 594 didn't jump to line 595, because the condition on line 594 was never true
595 raise UnsupportedIdGeneratorError("Only UNIQUE mode can be used with integer dataset IDs.")
597 # Transform a possibly-single-pass iterable into a list.
598 dataIdList = list(dataIds)
599 yield from self._insert(run, dataIdList)
601 def import_(
602 self,
603 run: RunRecord,
604 datasets: Iterable[DatasetRef],
605 idGenerationMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE,
606 reuseIds: bool = False,
607 ) -> Iterator[DatasetRef]:
608 # Docstring inherited from DatasetRecordStorage.
610 # We only support UNIQUE mode for integer dataset IDs
611 if idGenerationMode != DatasetIdGenEnum.UNIQUE: 611 ↛ 612line 611 didn't jump to line 612, because the condition on line 611 was never true
612 raise UnsupportedIdGeneratorError("Only UNIQUE mode can be used with integer dataset IDs.")
614 # Make a list of dataIds and optionally dataset IDs.
615 dataIdList: list[DataCoordinate] = []
616 datasetIdList: list[int] = []
617 for dataset in datasets:
618 dataIdList.append(dataset.dataId)
620 # We only accept integer dataset IDs, but also allow None.
621 datasetId = dataset.id
622 if datasetId is None: 622 ↛ 624line 622 didn't jump to line 624, because the condition on line 622 was never true
623 # if reuseIds is set then all IDs must be known
624 if reuseIds:
625 raise TypeError("All dataset IDs must be known if `reuseIds` is set")
626 elif isinstance(datasetId, int): 626 ↛ 630line 626 didn't jump to line 630, because the condition on line 626 was never false
627 if reuseIds:
628 datasetIdList.append(datasetId)
629 else:
630 raise TypeError(f"Unsupported type of dataset ID: {type(datasetId)}")
632 yield from self._insert(run, dataIdList, datasetIdList)
634 def _insert(
635 self, run: RunRecord, dataIdList: list[DataCoordinate], datasetIdList: list[int] | None = None
636 ) -> Iterator[DatasetRef]:
637 """Common part of implementation of `insert` and `import_` methods."""
639 # Remember any governor dimension values we see.
640 summary = CollectionSummary()
641 summary.add_data_ids(self.datasetType, dataIdList)
643 staticRow = {
644 "dataset_type_id": self._dataset_type_id,
645 self._runKeyColumn: run.key,
646 }
647 with self._db.transaction():
648 # Insert into the static dataset table, generating autoincrement
649 # dataset_id values.
650 if datasetIdList:
651 # reuse existing IDs
652 rows = [dict(staticRow, id=datasetId) for datasetId in datasetIdList]
653 self._db.insert(self._static.dataset, *rows)
654 else:
655 # use auto-incremented IDs
656 datasetIdList = self._db.insert(
657 self._static.dataset, *([staticRow] * len(dataIdList)), returnIds=True
658 )
659 assert datasetIdList is not None
660 # Update the summary tables for this collection in case this is the
661 # first time this dataset type or these governor values will be
662 # inserted there.
663 self._summaries.update(run, [self._dataset_type_id], summary)
664 # Combine the generated dataset_id values and data ID fields to
665 # form rows to be inserted into the tags table.
666 protoTagsRow = {
667 "dataset_type_id": self._dataset_type_id,
668 self._collections.getCollectionForeignKeyName(): run.key,
669 }
670 tagsRows = [
671 dict(protoTagsRow, dataset_id=dataset_id, **dataId.byName())
672 for dataId, dataset_id in zip(dataIdList, datasetIdList)
673 ]
674 # Insert those rows into the tags table. This is where we'll
675 # get any unique constraint violations.
676 self._db.insert(self._tags, *tagsRows)
678 for dataId, datasetId in zip(dataIdList, datasetIdList):
679 yield DatasetRef(
680 datasetType=self.datasetType,
681 dataId=dataId,
682 id=datasetId,
683 run=run.name,
684 )
687class ByDimensionsDatasetRecordStorageUUID(ByDimensionsDatasetRecordStorage):
688 """Implementation of ByDimensionsDatasetRecordStorage which uses UUID for
689 dataset IDs.
690 """
692 idMaker = DatasetIdFactory()
693 """Factory for dataset IDs. In the future this factory may be shared with
694 other classes (e.g. Registry)."""
696 def insert(
697 self,
698 run: RunRecord,
699 dataIds: Iterable[DataCoordinate],
700 idMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE,
701 ) -> Iterator[DatasetRef]:
702 # Docstring inherited from DatasetRecordStorage.
704 # Iterate over data IDs, transforming a possibly-single-pass iterable
705 # into a list.
706 dataIdList = []
707 rows = []
708 summary = CollectionSummary()
709 for dataId in summary.add_data_ids_generator(self.datasetType, dataIds):
710 dataIdList.append(dataId)
711 rows.append(
712 {
713 "id": self.idMaker.makeDatasetId(run.name, self.datasetType, dataId, idMode),
714 "dataset_type_id": self._dataset_type_id,
715 self._runKeyColumn: run.key,
716 }
717 )
719 with self._db.transaction():
720 # Insert into the static dataset table.
721 self._db.insert(self._static.dataset, *rows)
722 # Update the summary tables for this collection in case this is the
723 # first time this dataset type or these governor values will be
724 # inserted there.
725 self._summaries.update(run, [self._dataset_type_id], summary)
726 # Combine the generated dataset_id values and data ID fields to
727 # form rows to be inserted into the tags table.
728 protoTagsRow = {
729 "dataset_type_id": self._dataset_type_id,
730 self._collections.getCollectionForeignKeyName(): run.key,
731 }
732 tagsRows = [
733 dict(protoTagsRow, dataset_id=row["id"], **dataId.byName())
734 for dataId, row in zip(dataIdList, rows)
735 ]
736 # Insert those rows into the tags table.
737 self._db.insert(self._tags, *tagsRows)
739 for dataId, row in zip(dataIdList, rows):
740 yield DatasetRef(
741 datasetType=self.datasetType,
742 dataId=dataId,
743 id=row["id"],
744 run=run.name,
745 )
747 def import_(
748 self,
749 run: RunRecord,
750 datasets: Iterable[DatasetRef],
751 idGenerationMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE,
752 reuseIds: bool = False,
753 ) -> Iterator[DatasetRef]:
754 # Docstring inherited from DatasetRecordStorage.
756 # Iterate over data IDs, transforming a possibly-single-pass iterable
757 # into a list.
758 dataIds = {}
759 summary = CollectionSummary()
760 for dataset in summary.add_datasets_generator(datasets):
761 # Ignore unknown ID types, normally all IDs have the same type but
762 # this code supports mixed types or missing IDs.
763 datasetId = dataset.id if isinstance(dataset.id, uuid.UUID) else None
764 if datasetId is None:
765 datasetId = self.idMaker.makeDatasetId(
766 run.name, self.datasetType, dataset.dataId, idGenerationMode
767 )
768 dataIds[datasetId] = dataset.dataId
770 with self._db.session() as session:
772 # insert all new rows into a temporary table
773 tableSpec = makeTagTableSpec(
774 self.datasetType, type(self._collections), ddl.GUID, constraints=False
775 )
776 tmp_tags = session.makeTemporaryTable(tableSpec)
778 collFkName = self._collections.getCollectionForeignKeyName()
779 protoTagsRow = {
780 "dataset_type_id": self._dataset_type_id,
781 collFkName: run.key,
782 }
783 tmpRows = [
784 dict(protoTagsRow, dataset_id=dataset_id, **dataId.byName())
785 for dataset_id, dataId in dataIds.items()
786 ]
788 with self._db.transaction():
790 # store all incoming data in a temporary table
791 self._db.insert(tmp_tags, *tmpRows)
793 # There are some checks that we want to make for consistency
794 # of the new datasets with existing ones.
795 self._validateImport(tmp_tags, run)
797 # Before we merge temporary table into dataset/tags we need to
798 # drop datasets which are already there (and do not conflict).
799 self._db.deleteWhere(
800 tmp_tags,
801 tmp_tags.columns.dataset_id.in_(sqlalchemy.sql.select(self._static.dataset.columns.id)),
802 )
804 # Copy it into dataset table, need to re-label some columns.
805 self._db.insert(
806 self._static.dataset,
807 select=sqlalchemy.sql.select(
808 tmp_tags.columns.dataset_id.label("id"),
809 tmp_tags.columns.dataset_type_id,
810 tmp_tags.columns[collFkName].label(self._runKeyColumn),
811 ),
812 )
814 # Update the summary tables for this collection in case this
815 # is the first time this dataset type or these governor values
816 # will be inserted there.
817 self._summaries.update(run, [self._dataset_type_id], summary)
819 # Copy it into tags table.
820 self._db.insert(self._tags, select=tmp_tags.select())
822 # Return refs in the same order as in the input list.
823 for dataset_id, dataId in dataIds.items():
824 yield DatasetRef(
825 datasetType=self.datasetType,
826 id=dataset_id,
827 dataId=dataId,
828 run=run.name,
829 )
831 def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) -> None:
832 """Validate imported refs against existing datasets.
834 Parameters
835 ----------
836 tmp_tags : `sqlalchemy.schema.Table`
837 Temporary table with new datasets and the same schema as tags
838 table.
839 run : `RunRecord`
840 The record object describing the `~CollectionType.RUN` collection.
842 Raises
843 ------
844 ConflictingDefinitionError
845 Raise if new datasets conflict with existing ones.
846 """
847 dataset = self._static.dataset
848 tags = self._tags
849 collFkName = self._collections.getCollectionForeignKeyName()
851 # Check that existing datasets have the same dataset type and
852 # run.
853 query = (
854 sqlalchemy.sql.select(
855 dataset.columns.id.label("dataset_id"),
856 dataset.columns.dataset_type_id.label("dataset_type_id"),
857 tmp_tags.columns.dataset_type_id.label("new dataset_type_id"),
858 dataset.columns[self._runKeyColumn].label("run"),
859 tmp_tags.columns[collFkName].label("new run"),
860 )
861 .select_from(dataset.join(tmp_tags, dataset.columns.id == tmp_tags.columns.dataset_id))
862 .where(
863 sqlalchemy.sql.or_(
864 dataset.columns.dataset_type_id != tmp_tags.columns.dataset_type_id,
865 dataset.columns[self._runKeyColumn] != tmp_tags.columns[collFkName],
866 )
867 )
868 )
869 result = self._db.query(query)
870 if (row := result.first()) is not None:
871 # Only include the first one in the exception message
872 raise ConflictingDefinitionError(
873 f"Existing dataset type or run do not match new dataset: {row._asdict()}"
874 )
876 # Check that matching dataset in tags table has the same DataId.
877 query = (
878 sqlalchemy.sql.select(
879 tags.columns.dataset_id,
880 tags.columns.dataset_type_id.label("type_id"),
881 tmp_tags.columns.dataset_type_id.label("new type_id"),
882 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names],
883 *[
884 tmp_tags.columns[dim].label(f"new {dim}")
885 for dim in self.datasetType.dimensions.required.names
886 ],
887 )
888 .select_from(tags.join(tmp_tags, tags.columns.dataset_id == tmp_tags.columns.dataset_id))
889 .where(
890 sqlalchemy.sql.or_(
891 tags.columns.dataset_type_id != tmp_tags.columns.dataset_type_id,
892 *[
893 tags.columns[dim] != tmp_tags.columns[dim]
894 for dim in self.datasetType.dimensions.required.names
895 ],
896 )
897 )
898 )
899 result = self._db.query(query)
900 if (row := result.first()) is not None:
901 # Only include the first one in the exception message
902 raise ConflictingDefinitionError(
903 f"Existing dataset type or dataId do not match new dataset: {row._asdict()}"
904 )
906 # Check that matching run+dataId have the same dataset ID.
907 query = (
908 sqlalchemy.sql.select(
909 tags.columns.dataset_type_id.label("dataset_type_id"),
910 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names],
911 tags.columns.dataset_id,
912 tmp_tags.columns.dataset_id.label("new dataset_id"),
913 tags.columns[collFkName],
914 tmp_tags.columns[collFkName].label(f"new {collFkName}"),
915 )
916 .select_from(
917 tags.join(
918 tmp_tags,
919 sqlalchemy.sql.and_(
920 tags.columns.dataset_type_id == tmp_tags.columns.dataset_type_id,
921 tags.columns[collFkName] == tmp_tags.columns[collFkName],
922 *[
923 tags.columns[dim] == tmp_tags.columns[dim]
924 for dim in self.datasetType.dimensions.required.names
925 ],
926 ),
927 )
928 )
929 .where(tags.columns.dataset_id != tmp_tags.columns.dataset_id)
930 )
931 result = self._db.query(query)
932 if (row := result.first()) is not None:
933 # only include the first one in the exception message
934 raise ConflictingDefinitionError(
935 f"Existing dataset type and dataId does not match new dataset: {row._asdict()}"
936 )