Coverage for python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py: 89%
298 statements
« prev ^ index » next coverage.py v6.4.4, created at 2022-09-22 02:04 -0700
« prev ^ index » next coverage.py v6.4.4, created at 2022-09-22 02:04 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
23from __future__ import annotations
25__all__ = ("ByDimensionsDatasetRecordStorage",)
27import uuid
28from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set
30import sqlalchemy
31from lsst.daf.butler import (
32 CollectionType,
33 DataCoordinate,
34 DataCoordinateSet,
35 DatasetId,
36 DatasetRef,
37 DatasetType,
38 SimpleQuery,
39 Timespan,
40 ddl,
41)
42from lsst.daf.butler.registry import (
43 CollectionSummary,
44 CollectionTypeError,
45 ConflictingDefinitionError,
46 UnsupportedIdGeneratorError,
47)
48from lsst.daf.butler.registry.interfaces import DatasetIdFactory, DatasetIdGenEnum, DatasetRecordStorage
50from .tables import makeTagTableSpec
52if TYPE_CHECKING: 52 ↛ 53line 52 didn't jump to line 53, because the condition on line 52 was never true
53 from ...interfaces import CollectionManager, CollectionRecord, Database, RunRecord
54 from .summaries import CollectionSummaryManager
55 from .tables import StaticDatasetTablesTuple
58class ByDimensionsDatasetRecordStorage(DatasetRecordStorage):
59 """Dataset record storage implementation paired with
60 `ByDimensionsDatasetRecordStorageManager`; see that class for more
61 information.
63 Instances of this class should never be constructed directly; use
64 `DatasetRecordStorageManager.register` instead.
65 """
67 def __init__(
68 self,
69 *,
70 datasetType: DatasetType,
71 db: Database,
72 dataset_type_id: int,
73 collections: CollectionManager,
74 static: StaticDatasetTablesTuple,
75 summaries: CollectionSummaryManager,
76 tags: sqlalchemy.schema.Table,
77 calibs: Optional[sqlalchemy.schema.Table],
78 ):
79 super().__init__(datasetType=datasetType)
80 self._dataset_type_id = dataset_type_id
81 self._db = db
82 self._collections = collections
83 self._static = static
84 self._summaries = summaries
85 self._tags = tags
86 self._calibs = calibs
87 self._runKeyColumn = collections.getRunForeignKeyName()
89 def find(
90 self, collection: CollectionRecord, dataId: DataCoordinate, timespan: Optional[Timespan] = None
91 ) -> Optional[DatasetRef]:
92 # Docstring inherited from DatasetRecordStorage.
93 assert dataId.graph == self.datasetType.dimensions
94 if collection.type is CollectionType.CALIBRATION and timespan is None: 94 ↛ 95line 94 didn't jump to line 95, because the condition on line 94 was never true
95 raise TypeError(
96 f"Cannot search for dataset in CALIBRATION collection {collection.name} "
97 f"without an input timespan."
98 )
99 sql = self.select(
100 collection, dataId=dataId, id=SimpleQuery.Select, run=SimpleQuery.Select, timespan=timespan
101 )
102 results = self._db.query(sql)
103 row = results.fetchone()
104 if row is None:
105 return None
106 if collection.type is CollectionType.CALIBRATION:
107 # For temporal calibration lookups (only!) our invariants do not
108 # guarantee that the number of result rows is <= 1.
109 # They would if `select` constrained the given timespan to be
110 # _contained_ by the validity range in the self._calibs table,
111 # instead of simply _overlapping_ it, because we do guarantee that
112 # the validity ranges are disjoint for a particular dataset type,
113 # collection, and data ID. But using an overlap test and a check
114 # for multiple result rows here allows us to provide a more useful
115 # diagnostic, as well as allowing `select` to support more general
116 # queries where multiple results are not an error.
117 if results.fetchone() is not None:
118 raise RuntimeError(
119 f"Multiple matches found for calibration lookup in {collection.name} for "
120 f"{self.datasetType.name} with {dataId} overlapping {timespan}. "
121 )
122 return DatasetRef(
123 datasetType=self.datasetType,
124 dataId=dataId,
125 id=row.id,
126 run=self._collections[row._mapping[self._runKeyColumn]].name,
127 )
129 def delete(self, datasets: Iterable[DatasetRef]) -> None:
130 # Docstring inherited from DatasetRecordStorage.
131 # Only delete from common dataset table; ON DELETE foreign key clauses
132 # will handle the rest.
133 self._db.delete(
134 self._static.dataset,
135 ["id"],
136 *[{"id": dataset.getCheckedId()} for dataset in datasets],
137 )
139 def associate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None:
140 # Docstring inherited from DatasetRecordStorage.
141 if collection.type is not CollectionType.TAGGED: 141 ↛ 142line 141 didn't jump to line 142, because the condition on line 141 was never true
142 raise TypeError(
143 f"Cannot associate into collection '{collection.name}' "
144 f"of type {collection.type.name}; must be TAGGED."
145 )
146 protoRow = {
147 self._collections.getCollectionForeignKeyName(): collection.key,
148 "dataset_type_id": self._dataset_type_id,
149 }
150 rows = []
151 summary = CollectionSummary()
152 for dataset in summary.add_datasets_generator(datasets):
153 row = dict(protoRow, dataset_id=dataset.getCheckedId())
154 for dimension, value in dataset.dataId.items():
155 row[dimension.name] = value
156 rows.append(row)
157 # Update the summary tables for this collection in case this is the
158 # first time this dataset type or these governor values will be
159 # inserted there.
160 self._summaries.update(collection, [self._dataset_type_id], summary)
161 # Update the tag table itself.
162 self._db.replace(self._tags, *rows)
164 def disassociate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None:
165 # Docstring inherited from DatasetRecordStorage.
166 if collection.type is not CollectionType.TAGGED: 166 ↛ 167line 166 didn't jump to line 167, because the condition on line 166 was never true
167 raise TypeError(
168 f"Cannot disassociate from collection '{collection.name}' "
169 f"of type {collection.type.name}; must be TAGGED."
170 )
171 rows = [
172 {
173 "dataset_id": dataset.getCheckedId(),
174 self._collections.getCollectionForeignKeyName(): collection.key,
175 }
176 for dataset in datasets
177 ]
178 self._db.delete(self._tags, ["dataset_id", self._collections.getCollectionForeignKeyName()], *rows)
180 def _buildCalibOverlapQuery(
181 self, collection: CollectionRecord, dataIds: Optional[DataCoordinateSet], timespan: Timespan
182 ) -> SimpleQuery:
183 assert self._calibs is not None
184 # Start by building a SELECT query for any rows that would overlap
185 # this one.
186 query = SimpleQuery()
187 query.join(self._calibs)
188 # Add a WHERE clause matching the dataset type and collection.
189 query.where.append(self._calibs.columns.dataset_type_id == self._dataset_type_id)
190 query.where.append(
191 self._calibs.columns[self._collections.getCollectionForeignKeyName()] == collection.key
192 )
193 # Add a WHERE clause matching any of the given data IDs.
194 if dataIds is not None:
195 dataIds.constrain(
196 query,
197 lambda name: self._calibs.columns[name], # type: ignore
198 )
199 # Add WHERE clause for timespan overlaps.
200 TimespanReprClass = self._db.getTimespanRepresentation()
201 query.where.append(
202 TimespanReprClass.from_columns(self._calibs.columns).overlaps(
203 TimespanReprClass.fromLiteral(timespan)
204 )
205 )
206 return query
208 def certify(
209 self, collection: CollectionRecord, datasets: Iterable[DatasetRef], timespan: Timespan
210 ) -> None:
211 # Docstring inherited from DatasetRecordStorage.
212 if self._calibs is None: 212 ↛ 213line 212 didn't jump to line 213, because the condition on line 212 was never true
213 raise CollectionTypeError(
214 f"Cannot certify datasets of type {self.datasetType.name}, for which "
215 f"DatasetType.isCalibration() is False."
216 )
217 if collection.type is not CollectionType.CALIBRATION: 217 ↛ 218line 217 didn't jump to line 218, because the condition on line 217 was never true
218 raise CollectionTypeError(
219 f"Cannot certify into collection '{collection.name}' "
220 f"of type {collection.type.name}; must be CALIBRATION."
221 )
222 TimespanReprClass = self._db.getTimespanRepresentation()
223 protoRow = {
224 self._collections.getCollectionForeignKeyName(): collection.key,
225 "dataset_type_id": self._dataset_type_id,
226 }
227 rows = []
228 dataIds: Optional[Set[DataCoordinate]] = (
229 set() if not TimespanReprClass.hasExclusionConstraint() else None
230 )
231 summary = CollectionSummary()
232 for dataset in summary.add_datasets_generator(datasets):
233 row = dict(protoRow, dataset_id=dataset.getCheckedId())
234 for dimension, value in dataset.dataId.items():
235 row[dimension.name] = value
236 TimespanReprClass.update(timespan, result=row)
237 rows.append(row)
238 if dataIds is not None: 238 ↛ 232line 238 didn't jump to line 232, because the condition on line 238 was never false
239 dataIds.add(dataset.dataId)
240 # Update the summary tables for this collection in case this is the
241 # first time this dataset type or these governor values will be
242 # inserted there.
243 self._summaries.update(collection, [self._dataset_type_id], summary)
244 # Update the association table itself.
245 if TimespanReprClass.hasExclusionConstraint(): 245 ↛ 248line 245 didn't jump to line 248, because the condition on line 245 was never true
246 # Rely on database constraint to enforce invariants; we just
247 # reraise the exception for consistency across DB engines.
248 try:
249 self._db.insert(self._calibs, *rows)
250 except sqlalchemy.exc.IntegrityError as err:
251 raise ConflictingDefinitionError(
252 f"Validity range conflict certifying datasets of type {self.datasetType.name} "
253 f"into {collection.name} for range [{timespan.begin}, {timespan.end})."
254 ) from err
255 else:
256 # Have to implement exclusion constraint ourselves.
257 # Start by building a SELECT query for any rows that would overlap
258 # this one.
259 query = self._buildCalibOverlapQuery(
260 collection,
261 DataCoordinateSet(dataIds, graph=self.datasetType.dimensions), # type: ignore
262 timespan,
263 )
264 query.columns.append(sqlalchemy.sql.func.count())
265 sql = query.combine()
266 # Acquire a table lock to ensure there are no concurrent writes
267 # could invalidate our checking before we finish the inserts. We
268 # use a SAVEPOINT in case there is an outer transaction that a
269 # failure here should not roll back.
270 with self._db.transaction(lock=[self._calibs], savepoint=True):
271 # Run the check SELECT query.
272 conflicting = self._db.query(sql).scalar()
273 if conflicting > 0:
274 raise ConflictingDefinitionError(
275 f"{conflicting} validity range conflicts certifying datasets of type "
276 f"{self.datasetType.name} into {collection.name} for range "
277 f"[{timespan.begin}, {timespan.end})."
278 )
279 # Proceed with the insert.
280 self._db.insert(self._calibs, *rows)
282 def decertify(
283 self,
284 collection: CollectionRecord,
285 timespan: Timespan,
286 *,
287 dataIds: Optional[Iterable[DataCoordinate]] = None,
288 ) -> None:
289 # Docstring inherited from DatasetRecordStorage.
290 if self._calibs is None: 290 ↛ 291line 290 didn't jump to line 291, because the condition on line 290 was never true
291 raise CollectionTypeError(
292 f"Cannot decertify datasets of type {self.datasetType.name}, for which "
293 f"DatasetType.isCalibration() is False."
294 )
295 if collection.type is not CollectionType.CALIBRATION: 295 ↛ 296line 295 didn't jump to line 296, because the condition on line 295 was never true
296 raise CollectionTypeError(
297 f"Cannot decertify from collection '{collection.name}' "
298 f"of type {collection.type.name}; must be CALIBRATION."
299 )
300 TimespanReprClass = self._db.getTimespanRepresentation()
301 # Construct a SELECT query to find all rows that overlap our inputs.
302 dataIdSet: Optional[DataCoordinateSet]
303 if dataIds is not None:
304 dataIdSet = DataCoordinateSet(set(dataIds), graph=self.datasetType.dimensions)
305 else:
306 dataIdSet = None
307 query = self._buildCalibOverlapQuery(collection, dataIdSet, timespan)
308 query.columns.extend(self._calibs.columns)
309 sql = query.combine()
310 # Set up collections to populate with the rows we'll want to modify.
311 # The insert rows will have the same values for collection and
312 # dataset type.
313 protoInsertRow = {
314 self._collections.getCollectionForeignKeyName(): collection.key,
315 "dataset_type_id": self._dataset_type_id,
316 }
317 rowsToDelete = []
318 rowsToInsert = []
319 # Acquire a table lock to ensure there are no concurrent writes
320 # between the SELECT and the DELETE and INSERT queries based on it.
321 with self._db.transaction(lock=[self._calibs], savepoint=True):
322 for row in self._db.query(sql).mappings():
323 rowsToDelete.append({"id": row["id"]})
324 # Construct the insert row(s) by copying the prototype row,
325 # then adding the dimension column values, then adding what's
326 # left of the timespan from that row after we subtract the
327 # given timespan.
328 newInsertRow = protoInsertRow.copy()
329 newInsertRow["dataset_id"] = row["dataset_id"]
330 for name in self.datasetType.dimensions.required.names:
331 newInsertRow[name] = row[name]
332 rowTimespan = TimespanReprClass.extract(row)
333 assert rowTimespan is not None, "Field should have a NOT NULL constraint."
334 for diffTimespan in rowTimespan.difference(timespan):
335 rowsToInsert.append(TimespanReprClass.update(diffTimespan, result=newInsertRow.copy()))
336 # Run the DELETE and INSERT queries.
337 self._db.delete(self._calibs, ["id"], *rowsToDelete)
338 self._db.insert(self._calibs, *rowsToInsert)
340 def select(
341 self,
342 *collections: CollectionRecord,
343 dataId: SimpleQuery.Select.Or[DataCoordinate] = SimpleQuery.Select,
344 id: SimpleQuery.Select.Or[Optional[int]] = SimpleQuery.Select,
345 run: SimpleQuery.Select.Or[None] = SimpleQuery.Select,
346 timespan: SimpleQuery.Select.Or[Optional[Timespan]] = SimpleQuery.Select,
347 ingestDate: SimpleQuery.Select.Or[Optional[Timespan]] = None,
348 rank: SimpleQuery.Select.Or[None] = None,
349 ) -> sqlalchemy.sql.Selectable:
350 # Docstring inherited from DatasetRecordStorage.
351 collection_types = {collection.type for collection in collections}
352 assert CollectionType.CHAINED not in collection_types, "CHAINED collections must be flattened."
353 TimespanReprClass = self._db.getTimespanRepresentation()
354 #
355 # There are two kinds of table in play here:
356 #
357 # - the static dataset table (with the dataset ID, dataset type ID,
358 # run ID/name, and ingest date);
359 #
360 # - the dynamic tags/calibs table (with the dataset ID, dataset type
361 # type ID, collection ID/name, data ID, and possibly validity
362 # range).
363 #
364 # That means that we might want to return a query against either table
365 # or a JOIN of both, depending on which quantities the caller wants.
366 # But this method is documented/typed such that ``dataId`` is never
367 # `None` - i.e. we always constrain or retreive the data ID. That
368 # means we'll always include the tags/calibs table and join in the
369 # static dataset table only if we need things from it that we can't get
370 # from the tags/calibs table.
371 #
372 # Note that it's important that we include a WHERE constraint on both
373 # tables for any column (e.g. dataset_type_id) that is in both when
374 # it's given explicitly; not doing can prevent the query planner from
375 # using very important indexes. At present, we don't include those
376 # redundant columns in the JOIN ON expression, however, because the
377 # FOREIGN KEY (and its index) are defined only on dataset_id.
378 #
379 # We'll start by accumulating kwargs to pass to SimpleQuery.join when
380 # we bring in the tags/calibs table. We get the data ID or constrain
381 # it in the tags/calibs table(s), but that's multiple columns, not one,
382 # so we need to transform the one Select.Or argument into a dictionary
383 # of them.
384 kwargs: Dict[str, Any]
385 if dataId is SimpleQuery.Select:
386 kwargs = {dim.name: SimpleQuery.Select for dim in self.datasetType.dimensions.required}
387 else:
388 kwargs = dict(dataId.byName())
389 # We always constrain (never retrieve) the dataset type in at least the
390 # tags/calibs table.
391 kwargs["dataset_type_id"] = self._dataset_type_id
392 # Join in the tags and/or calibs tables, turning those 'kwargs' entries
393 # into WHERE constraints or SELECT columns as appropriate.
394 if collection_types != {CollectionType.CALIBRATION}:
395 # We'll need a subquery for the tags table if any of the given
396 # collections are not a CALIBRATION collection. This intentionally
397 # also fires when the list of collections is empty as a way to
398 # create a dummy subquery that we know will fail.
399 tags_query = SimpleQuery()
400 tags_query.join(self._tags, **kwargs)
401 # If the timespan is requested, simulate a potentially compound
402 # column whose values are the maximum and minimum timespan
403 # bounds.
404 # If the timespan is constrained, ignore the constraint, since
405 # it'd be guaranteed to evaluate to True.
406 if timespan is SimpleQuery.Select:
407 tags_query.columns.extend(TimespanReprClass.fromLiteral(Timespan(None, None)).flatten())
408 self._finish_single_select(
409 tags_query,
410 self._tags,
411 collections,
412 id=id,
413 run=run,
414 ingestDate=ingestDate,
415 rank=rank,
416 )
417 else:
418 tags_query = None
419 if CollectionType.CALIBRATION in collection_types:
420 # If at least one collection is a CALIBRATION collection, we'll
421 # need a subquery for the calibs table, and could include the
422 # timespan as a result or constraint.
423 calibs_query = SimpleQuery()
424 assert (
425 self._calibs is not None
426 ), "DatasetTypes with isCalibration() == False can never be found in a CALIBRATION collection."
427 calibs_query.join(self._calibs, **kwargs)
428 # Add the timespan column(s) to the result columns, or constrain
429 # the timespan via an overlap condition.
430 if timespan is SimpleQuery.Select:
431 calibs_query.columns.extend(TimespanReprClass.from_columns(self._calibs.columns).flatten())
432 elif timespan is not None:
433 calibs_query.where.append(
434 TimespanReprClass.from_columns(self._calibs.columns).overlaps(
435 TimespanReprClass.fromLiteral(timespan)
436 )
437 )
438 self._finish_single_select(
439 calibs_query,
440 self._calibs,
441 collections,
442 id=id,
443 run=run,
444 ingestDate=ingestDate,
445 rank=rank,
446 )
447 else:
448 calibs_query = None
449 if calibs_query is not None:
450 if tags_query is not None:
451 return tags_query.combine().union(calibs_query.combine())
452 else:
453 return calibs_query.combine()
454 else:
455 assert tags_query is not None, "Earlier logic should guaranteed at least one is not None."
456 return tags_query.combine()
458 def _finish_single_select(
459 self,
460 query: SimpleQuery,
461 table: sqlalchemy.schema.Table,
462 collections: Sequence[CollectionRecord],
463 id: SimpleQuery.Select.Or[Optional[int]],
464 run: SimpleQuery.Select.Or[None],
465 ingestDate: SimpleQuery.Select.Or[Optional[Timespan]],
466 rank: SimpleQuery.Select.Or[None],
467 ) -> None:
468 dataset_id_col = table.columns.dataset_id
469 collection_col = table.columns[self._collections.getCollectionForeignKeyName()]
470 # We always constrain (never retrieve) the collection(s) in the
471 # tags/calibs table.
472 if len(collections) == 1:
473 query.where.append(collection_col == collections[0].key)
474 elif len(collections) == 0:
475 # We support the case where there are no collections as a way to
476 # generate a valid SQL query that can't yield results. This should
477 # never get executed, but lots of downstream code will still try
478 # to access the SQLAlchemy objects representing the columns in the
479 # subquery. That's not ideal, but it'd take a lot of refactoring
480 # to fix it (DM-31725).
481 query.where.append(sqlalchemy.sql.literal(False))
482 else:
483 query.where.append(collection_col.in_([collection.key for collection in collections]))
484 # Add rank if requested as a CASE-based calculation the collection
485 # column.
486 if rank is not None:
487 assert rank is SimpleQuery.Select, "Cannot constraint rank, only select it."
488 query.columns.append(
489 sqlalchemy.sql.case(
490 {record.key: n for n, record in enumerate(collections)},
491 value=collection_col,
492 ).label("rank")
493 )
494 # We can always get the dataset_id from the tags/calibs table or
495 # constrain it there. Can't use kwargs for that because we need to
496 # alias it to 'id'.
497 if id is SimpleQuery.Select:
498 query.columns.append(dataset_id_col.label("id"))
499 elif id is not None: 499 ↛ 500line 499 didn't jump to line 500, because the condition on line 499 was never true
500 query.where.append(dataset_id_col == id)
501 # It's possible we now have everything we need, from just the
502 # tags/calibs table. The things we might need to get from the static
503 # dataset table are the run key and the ingest date.
504 need_static_table = False
505 static_kwargs: Dict[str, Any] = {}
506 if run is not None:
507 assert run is SimpleQuery.Select, "To constrain the run name, pass a RunRecord as a collection."
508 if len(collections) == 1 and collections[0].type is CollectionType.RUN:
509 # If we are searching exactly one RUN collection, we
510 # know that if we find the dataset in that collection,
511 # then that's the datasets's run; we don't need to
512 # query for it.
513 query.columns.append(sqlalchemy.sql.literal(collections[0].key).label(self._runKeyColumn))
514 else:
515 static_kwargs[self._runKeyColumn] = SimpleQuery.Select
516 need_static_table = True
517 # Ingest date can only come from the static table.
518 if ingestDate is not None:
519 need_static_table = True
520 if ingestDate is SimpleQuery.Select: 520 ↛ 523line 520 didn't jump to line 523, because the condition on line 520 was never false
521 static_kwargs["ingest_date"] = SimpleQuery.Select
522 else:
523 assert isinstance(ingestDate, Timespan)
524 # Timespan is astropy Time (usually in TAI) and ingest_date is
525 # TIMESTAMP, convert values to Python datetime for sqlalchemy.
526 if ingestDate.isEmpty():
527 raise RuntimeError("Empty timespan constraint provided for ingest_date.")
528 if ingestDate.begin is not None:
529 begin = ingestDate.begin.utc.datetime # type: ignore
530 query.where.append(self._static.dataset.columns.ingest_date >= begin)
531 if ingestDate.end is not None:
532 end = ingestDate.end.utc.datetime # type: ignore
533 query.where.append(self._static.dataset.columns.ingest_date < end)
534 # If we need the static table, join it in via dataset_id and
535 # dataset_type_id
536 if need_static_table:
537 query.join(
538 self._static.dataset,
539 onclause=(dataset_id_col == self._static.dataset.columns.id),
540 **static_kwargs,
541 )
542 # Also constrain dataset_type_id in static table in case that helps
543 # generate a better plan.
544 # We could also include this in the JOIN ON clause, but my guess is
545 # that that's a good idea IFF it's in the foreign key, and right
546 # now it isn't.
547 query.where.append(self._static.dataset.columns.dataset_type_id == self._dataset_type_id)
549 def getDataId(self, id: DatasetId) -> DataCoordinate:
550 """Return DataId for a dataset.
552 Parameters
553 ----------
554 id : `DatasetId`
555 Unique dataset identifier.
557 Returns
558 -------
559 dataId : `DataCoordinate`
560 DataId for the dataset.
561 """
562 # This query could return multiple rows (one for each tagged collection
563 # the dataset is in, plus one for its run collection), and we don't
564 # care which of those we get.
565 sql = (
566 self._tags.select()
567 .where(
568 sqlalchemy.sql.and_(
569 self._tags.columns.dataset_id == id,
570 self._tags.columns.dataset_type_id == self._dataset_type_id,
571 )
572 )
573 .limit(1)
574 )
575 row = self._db.query(sql).mappings().fetchone()
576 assert row is not None, "Should be guaranteed by caller and foreign key constraints."
577 return DataCoordinate.standardize(
578 {dimension.name: row[dimension.name] for dimension in self.datasetType.dimensions.required},
579 graph=self.datasetType.dimensions,
580 )
583class ByDimensionsDatasetRecordStorageInt(ByDimensionsDatasetRecordStorage):
584 """Implementation of ByDimensionsDatasetRecordStorage which uses integer
585 auto-incremented column for dataset IDs.
586 """
588 def insert(
589 self,
590 run: RunRecord,
591 dataIds: Iterable[DataCoordinate],
592 idMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE,
593 ) -> Iterator[DatasetRef]:
594 # Docstring inherited from DatasetRecordStorage.
596 # We only support UNIQUE mode for integer dataset IDs
597 if idMode != DatasetIdGenEnum.UNIQUE: 597 ↛ 598line 597 didn't jump to line 598, because the condition on line 597 was never true
598 raise UnsupportedIdGeneratorError("Only UNIQUE mode can be used with integer dataset IDs.")
600 # Transform a possibly-single-pass iterable into a list.
601 dataIdList = list(dataIds)
602 yield from self._insert(run, dataIdList)
604 def import_(
605 self,
606 run: RunRecord,
607 datasets: Iterable[DatasetRef],
608 idGenerationMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE,
609 reuseIds: bool = False,
610 ) -> Iterator[DatasetRef]:
611 # Docstring inherited from DatasetRecordStorage.
613 # We only support UNIQUE mode for integer dataset IDs
614 if idGenerationMode != DatasetIdGenEnum.UNIQUE: 614 ↛ 615line 614 didn't jump to line 615, because the condition on line 614 was never true
615 raise UnsupportedIdGeneratorError("Only UNIQUE mode can be used with integer dataset IDs.")
617 # Make a list of dataIds and optionally dataset IDs.
618 dataIdList: List[DataCoordinate] = []
619 datasetIdList: List[int] = []
620 for dataset in datasets:
621 dataIdList.append(dataset.dataId)
623 # We only accept integer dataset IDs, but also allow None.
624 datasetId = dataset.id
625 if datasetId is None: 625 ↛ 627line 625 didn't jump to line 627, because the condition on line 625 was never true
626 # if reuseIds is set then all IDs must be known
627 if reuseIds:
628 raise TypeError("All dataset IDs must be known if `reuseIds` is set")
629 elif isinstance(datasetId, int): 629 ↛ 633line 629 didn't jump to line 633, because the condition on line 629 was never false
630 if reuseIds:
631 datasetIdList.append(datasetId)
632 else:
633 raise TypeError(f"Unsupported type of dataset ID: {type(datasetId)}")
635 yield from self._insert(run, dataIdList, datasetIdList)
637 def _insert(
638 self, run: RunRecord, dataIdList: List[DataCoordinate], datasetIdList: Optional[List[int]] = None
639 ) -> Iterator[DatasetRef]:
640 """Common part of implementation of `insert` and `import_` methods."""
642 # Remember any governor dimension values we see.
643 summary = CollectionSummary()
644 summary.add_data_ids(self.datasetType, dataIdList)
646 staticRow = {
647 "dataset_type_id": self._dataset_type_id,
648 self._runKeyColumn: run.key,
649 }
650 with self._db.transaction():
651 # Insert into the static dataset table, generating autoincrement
652 # dataset_id values.
653 if datasetIdList:
654 # reuse existing IDs
655 rows = [dict(staticRow, id=datasetId) for datasetId in datasetIdList]
656 self._db.insert(self._static.dataset, *rows)
657 else:
658 # use auto-incremented IDs
659 datasetIdList = self._db.insert(
660 self._static.dataset, *([staticRow] * len(dataIdList)), returnIds=True
661 )
662 assert datasetIdList is not None
663 # Update the summary tables for this collection in case this is the
664 # first time this dataset type or these governor values will be
665 # inserted there.
666 self._summaries.update(run, [self._dataset_type_id], summary)
667 # Combine the generated dataset_id values and data ID fields to
668 # form rows to be inserted into the tags table.
669 protoTagsRow = {
670 "dataset_type_id": self._dataset_type_id,
671 self._collections.getCollectionForeignKeyName(): run.key,
672 }
673 tagsRows = [
674 dict(protoTagsRow, dataset_id=dataset_id, **dataId.byName())
675 for dataId, dataset_id in zip(dataIdList, datasetIdList)
676 ]
677 # Insert those rows into the tags table. This is where we'll
678 # get any unique constraint violations.
679 self._db.insert(self._tags, *tagsRows)
681 for dataId, datasetId in zip(dataIdList, datasetIdList):
682 yield DatasetRef(
683 datasetType=self.datasetType,
684 dataId=dataId,
685 id=datasetId,
686 run=run.name,
687 )
690class ByDimensionsDatasetRecordStorageUUID(ByDimensionsDatasetRecordStorage):
691 """Implementation of ByDimensionsDatasetRecordStorage which uses UUID for
692 dataset IDs.
693 """
695 idMaker = DatasetIdFactory()
696 """Factory for dataset IDs. In the future this factory may be shared with
697 other classes (e.g. Registry)."""
699 def insert(
700 self,
701 run: RunRecord,
702 dataIds: Iterable[DataCoordinate],
703 idMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE,
704 ) -> Iterator[DatasetRef]:
705 # Docstring inherited from DatasetRecordStorage.
707 # Iterate over data IDs, transforming a possibly-single-pass iterable
708 # into a list.
709 dataIdList = []
710 rows = []
711 summary = CollectionSummary()
712 for dataId in summary.add_data_ids_generator(self.datasetType, dataIds):
713 dataIdList.append(dataId)
714 rows.append(
715 {
716 "id": self.idMaker.makeDatasetId(run.name, self.datasetType, dataId, idMode),
717 "dataset_type_id": self._dataset_type_id,
718 self._runKeyColumn: run.key,
719 }
720 )
722 with self._db.transaction():
723 # Insert into the static dataset table.
724 self._db.insert(self._static.dataset, *rows)
725 # Update the summary tables for this collection in case this is the
726 # first time this dataset type or these governor values will be
727 # inserted there.
728 self._summaries.update(run, [self._dataset_type_id], summary)
729 # Combine the generated dataset_id values and data ID fields to
730 # form rows to be inserted into the tags table.
731 protoTagsRow = {
732 "dataset_type_id": self._dataset_type_id,
733 self._collections.getCollectionForeignKeyName(): run.key,
734 }
735 tagsRows = [
736 dict(protoTagsRow, dataset_id=row["id"], **dataId.byName())
737 for dataId, row in zip(dataIdList, rows)
738 ]
739 # Insert those rows into the tags table.
740 self._db.insert(self._tags, *tagsRows)
742 for dataId, row in zip(dataIdList, rows):
743 yield DatasetRef(
744 datasetType=self.datasetType,
745 dataId=dataId,
746 id=row["id"],
747 run=run.name,
748 )
750 def import_(
751 self,
752 run: RunRecord,
753 datasets: Iterable[DatasetRef],
754 idGenerationMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE,
755 reuseIds: bool = False,
756 ) -> Iterator[DatasetRef]:
757 # Docstring inherited from DatasetRecordStorage.
759 # Iterate over data IDs, transforming a possibly-single-pass iterable
760 # into a list.
761 dataIds = {}
762 summary = CollectionSummary()
763 for dataset in summary.add_datasets_generator(datasets):
764 # Ignore unknown ID types, normally all IDs have the same type but
765 # this code supports mixed types or missing IDs.
766 datasetId = dataset.id if isinstance(dataset.id, uuid.UUID) else None
767 if datasetId is None:
768 datasetId = self.idMaker.makeDatasetId(
769 run.name, self.datasetType, dataset.dataId, idGenerationMode
770 )
771 dataIds[datasetId] = dataset.dataId
773 with self._db.session() as session:
775 # insert all new rows into a temporary table
776 tableSpec = makeTagTableSpec(
777 self.datasetType, type(self._collections), ddl.GUID, constraints=False
778 )
779 tmp_tags = session.makeTemporaryTable(tableSpec)
781 collFkName = self._collections.getCollectionForeignKeyName()
782 protoTagsRow = {
783 "dataset_type_id": self._dataset_type_id,
784 collFkName: run.key,
785 }
786 tmpRows = [
787 dict(protoTagsRow, dataset_id=dataset_id, **dataId.byName())
788 for dataset_id, dataId in dataIds.items()
789 ]
791 with self._db.transaction():
793 # store all incoming data in a temporary table
794 self._db.insert(tmp_tags, *tmpRows)
796 # There are some checks that we want to make for consistency
797 # of the new datasets with existing ones.
798 self._validateImport(tmp_tags, run)
800 # Before we merge temporary table into dataset/tags we need to
801 # drop datasets which are already there (and do not conflict).
802 self._db.deleteWhere(
803 tmp_tags,
804 tmp_tags.columns.dataset_id.in_(sqlalchemy.sql.select(self._static.dataset.columns.id)),
805 )
807 # Copy it into dataset table, need to re-label some columns.
808 self._db.insert(
809 self._static.dataset,
810 select=sqlalchemy.sql.select(
811 tmp_tags.columns.dataset_id.label("id"),
812 tmp_tags.columns.dataset_type_id,
813 tmp_tags.columns[collFkName].label(self._runKeyColumn),
814 ),
815 )
817 # Update the summary tables for this collection in case this
818 # is the first time this dataset type or these governor values
819 # will be inserted there.
820 self._summaries.update(run, [self._dataset_type_id], summary)
822 # Copy it into tags table.
823 self._db.insert(self._tags, select=tmp_tags.select())
825 # Return refs in the same order as in the input list.
826 for dataset_id, dataId in dataIds.items():
827 yield DatasetRef(
828 datasetType=self.datasetType,
829 id=dataset_id,
830 dataId=dataId,
831 run=run.name,
832 )
834 def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) -> None:
835 """Validate imported refs against existing datasets.
837 Parameters
838 ----------
839 tmp_tags : `sqlalchemy.schema.Table`
840 Temporary table with new datasets and the same schema as tags
841 table.
842 run : `RunRecord`
843 The record object describing the `~CollectionType.RUN` collection.
845 Raises
846 ------
847 ConflictingDefinitionError
848 Raise if new datasets conflict with existing ones.
849 """
850 dataset = self._static.dataset
851 tags = self._tags
852 collFkName = self._collections.getCollectionForeignKeyName()
854 # Check that existing datasets have the same dataset type and
855 # run.
856 query = (
857 sqlalchemy.sql.select(
858 dataset.columns.id.label("dataset_id"),
859 dataset.columns.dataset_type_id.label("dataset_type_id"),
860 tmp_tags.columns.dataset_type_id.label("new dataset_type_id"),
861 dataset.columns[self._runKeyColumn].label("run"),
862 tmp_tags.columns[collFkName].label("new run"),
863 )
864 .select_from(dataset.join(tmp_tags, dataset.columns.id == tmp_tags.columns.dataset_id))
865 .where(
866 sqlalchemy.sql.or_(
867 dataset.columns.dataset_type_id != tmp_tags.columns.dataset_type_id,
868 dataset.columns[self._runKeyColumn] != tmp_tags.columns[collFkName],
869 )
870 )
871 )
872 result = self._db.query(query)
873 if (row := result.first()) is not None:
874 # Only include the first one in the exception message
875 raise ConflictingDefinitionError(
876 f"Existing dataset type or run do not match new dataset: {row._asdict()}"
877 )
879 # Check that matching dataset in tags table has the same DataId.
880 query = (
881 sqlalchemy.sql.select(
882 tags.columns.dataset_id,
883 tags.columns.dataset_type_id.label("type_id"),
884 tmp_tags.columns.dataset_type_id.label("new type_id"),
885 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names],
886 *[
887 tmp_tags.columns[dim].label(f"new {dim}")
888 for dim in self.datasetType.dimensions.required.names
889 ],
890 )
891 .select_from(tags.join(tmp_tags, tags.columns.dataset_id == tmp_tags.columns.dataset_id))
892 .where(
893 sqlalchemy.sql.or_(
894 tags.columns.dataset_type_id != tmp_tags.columns.dataset_type_id,
895 *[
896 tags.columns[dim] != tmp_tags.columns[dim]
897 for dim in self.datasetType.dimensions.required.names
898 ],
899 )
900 )
901 )
902 result = self._db.query(query)
903 if (row := result.first()) is not None:
904 # Only include the first one in the exception message
905 raise ConflictingDefinitionError(
906 f"Existing dataset type or dataId do not match new dataset: {row._asdict()}"
907 )
909 # Check that matching run+dataId have the same dataset ID.
910 query = (
911 sqlalchemy.sql.select(
912 tags.columns.dataset_type_id.label("dataset_type_id"),
913 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names],
914 tags.columns.dataset_id,
915 tmp_tags.columns.dataset_id.label("new dataset_id"),
916 tags.columns[collFkName],
917 tmp_tags.columns[collFkName].label(f"new {collFkName}"),
918 )
919 .select_from(
920 tags.join(
921 tmp_tags,
922 sqlalchemy.sql.and_(
923 tags.columns.dataset_type_id == tmp_tags.columns.dataset_type_id,
924 tags.columns[collFkName] == tmp_tags.columns[collFkName],
925 *[
926 tags.columns[dim] == tmp_tags.columns[dim]
927 for dim in self.datasetType.dimensions.required.names
928 ],
929 ),
930 )
931 )
932 .where(tags.columns.dataset_id != tmp_tags.columns.dataset_id)
933 )
934 result = self._db.query(query)
935 if (row := result.first()) is not None:
936 # only include the first one in the exception message
937 raise ConflictingDefinitionError(
938 f"Existing dataset type and dataId does not match new dataset: {row._asdict()}"
939 )