Coverage for python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py: 94%

237 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-03-23 02:06 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22 

23from __future__ import annotations 

24 

25__all__ = ("ByDimensionsDatasetRecordStorage",) 

26 

27import uuid 

28from collections.abc import Iterable, Iterator, Sequence, Set 

29from typing import TYPE_CHECKING 

30 

31import sqlalchemy 

32from lsst.daf.relation import Relation, sql 

33 

34from ....core import ( 

35 DataCoordinate, 

36 DatasetColumnTag, 

37 DatasetId, 

38 DatasetRef, 

39 DatasetType, 

40 DimensionKeyColumnTag, 

41 LogicalColumn, 

42 Timespan, 

43 ddl, 

44) 

45from ..._collection_summary import CollectionSummary 

46from ..._collectionType import CollectionType 

47from ..._exceptions import CollectionTypeError, ConflictingDefinitionError 

48from ...interfaces import DatasetIdFactory, DatasetIdGenEnum, DatasetRecordStorage 

49from ...queries import SqlQueryContext 

50from .tables import makeTagTableSpec 

51 

52if TYPE_CHECKING: 52 ↛ 53line 52 didn't jump to line 53, because the condition on line 52 was never true

53 from ...interfaces import CollectionManager, CollectionRecord, Database, RunRecord 

54 from .summaries import CollectionSummaryManager 

55 from .tables import StaticDatasetTablesTuple 

56 

57 

58class ByDimensionsDatasetRecordStorage(DatasetRecordStorage): 

59 """Dataset record storage implementation paired with 

60 `ByDimensionsDatasetRecordStorageManagerUUID`; see that class for more 

61 information. 

62 

63 Instances of this class should never be constructed directly; use 

64 `DatasetRecordStorageManager.register` instead. 

65 """ 

66 

67 def __init__( 

68 self, 

69 *, 

70 datasetType: DatasetType, 

71 db: Database, 

72 dataset_type_id: int, 

73 collections: CollectionManager, 

74 static: StaticDatasetTablesTuple, 

75 summaries: CollectionSummaryManager, 

76 tags: sqlalchemy.schema.Table, 

77 calibs: sqlalchemy.schema.Table | None, 

78 ): 

79 super().__init__(datasetType=datasetType) 

80 self._dataset_type_id = dataset_type_id 

81 self._db = db 

82 self._collections = collections 

83 self._static = static 

84 self._summaries = summaries 

85 self._tags = tags 

86 self._calibs = calibs 

87 self._runKeyColumn = collections.getRunForeignKeyName() 

88 

89 def delete(self, datasets: Iterable[DatasetRef]) -> None: 

90 # Docstring inherited from DatasetRecordStorage. 

91 # Only delete from common dataset table; ON DELETE foreign key clauses 

92 # will handle the rest. 

93 self._db.delete( 

94 self._static.dataset, 

95 ["id"], 

96 *[{"id": dataset.getCheckedId()} for dataset in datasets], 

97 ) 

98 

99 def associate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None: 

100 # Docstring inherited from DatasetRecordStorage. 

101 if collection.type is not CollectionType.TAGGED: 101 ↛ 102line 101 didn't jump to line 102, because the condition on line 101 was never true

102 raise TypeError( 

103 f"Cannot associate into collection '{collection.name}' " 

104 f"of type {collection.type.name}; must be TAGGED." 

105 ) 

106 protoRow = { 

107 self._collections.getCollectionForeignKeyName(): collection.key, 

108 "dataset_type_id": self._dataset_type_id, 

109 } 

110 rows = [] 

111 summary = CollectionSummary() 

112 for dataset in summary.add_datasets_generator(datasets): 

113 row = dict(protoRow, dataset_id=dataset.getCheckedId()) 

114 for dimension, value in dataset.dataId.items(): 

115 row[dimension.name] = value 

116 rows.append(row) 

117 # Update the summary tables for this collection in case this is the 

118 # first time this dataset type or these governor values will be 

119 # inserted there. 

120 self._summaries.update(collection, [self._dataset_type_id], summary) 

121 # Update the tag table itself. 

122 self._db.replace(self._tags, *rows) 

123 

124 def disassociate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None: 

125 # Docstring inherited from DatasetRecordStorage. 

126 if collection.type is not CollectionType.TAGGED: 126 ↛ 127line 126 didn't jump to line 127, because the condition on line 126 was never true

127 raise TypeError( 

128 f"Cannot disassociate from collection '{collection.name}' " 

129 f"of type {collection.type.name}; must be TAGGED." 

130 ) 

131 rows = [ 

132 { 

133 "dataset_id": dataset.getCheckedId(), 

134 self._collections.getCollectionForeignKeyName(): collection.key, 

135 } 

136 for dataset in datasets 

137 ] 

138 self._db.delete(self._tags, ["dataset_id", self._collections.getCollectionForeignKeyName()], *rows) 

139 

140 def _buildCalibOverlapQuery( 

141 self, 

142 collection: CollectionRecord, 

143 data_ids: set[DataCoordinate] | None, 

144 timespan: Timespan, 

145 context: SqlQueryContext, 

146 ) -> Relation: 

147 relation = self.make_relation( 

148 collection, columns={"timespan", "dataset_id", "calib_pkey"}, context=context 

149 ).with_rows_satisfying( 

150 context.make_timespan_overlap_predicate( 

151 DatasetColumnTag(self.datasetType.name, "timespan"), timespan 

152 ), 

153 ) 

154 if data_ids is not None: 

155 relation = relation.join( 

156 context.make_data_id_relation( 

157 data_ids, self.datasetType.dimensions.required.names 

158 ).transferred_to(context.sql_engine), 

159 ) 

160 return relation 

161 

162 def certify( 

163 self, 

164 collection: CollectionRecord, 

165 datasets: Iterable[DatasetRef], 

166 timespan: Timespan, 

167 context: SqlQueryContext, 

168 ) -> None: 

169 # Docstring inherited from DatasetRecordStorage. 

170 if self._calibs is None: 170 ↛ 171line 170 didn't jump to line 171, because the condition on line 170 was never true

171 raise CollectionTypeError( 

172 f"Cannot certify datasets of type {self.datasetType.name}, for which " 

173 "DatasetType.isCalibration() is False." 

174 ) 

175 if collection.type is not CollectionType.CALIBRATION: 175 ↛ 176line 175 didn't jump to line 176, because the condition on line 175 was never true

176 raise CollectionTypeError( 

177 f"Cannot certify into collection '{collection.name}' " 

178 f"of type {collection.type.name}; must be CALIBRATION." 

179 ) 

180 TimespanReprClass = self._db.getTimespanRepresentation() 

181 protoRow = { 

182 self._collections.getCollectionForeignKeyName(): collection.key, 

183 "dataset_type_id": self._dataset_type_id, 

184 } 

185 rows = [] 

186 dataIds: set[DataCoordinate] | None = ( 

187 set() if not TimespanReprClass.hasExclusionConstraint() else None 

188 ) 

189 summary = CollectionSummary() 

190 for dataset in summary.add_datasets_generator(datasets): 

191 row = dict(protoRow, dataset_id=dataset.getCheckedId()) 

192 for dimension, value in dataset.dataId.items(): 

193 row[dimension.name] = value 

194 TimespanReprClass.update(timespan, result=row) 

195 rows.append(row) 

196 if dataIds is not None: 196 ↛ 190line 196 didn't jump to line 190, because the condition on line 196 was never false

197 dataIds.add(dataset.dataId) 

198 # Update the summary tables for this collection in case this is the 

199 # first time this dataset type or these governor values will be 

200 # inserted there. 

201 self._summaries.update(collection, [self._dataset_type_id], summary) 

202 # Update the association table itself. 

203 if TimespanReprClass.hasExclusionConstraint(): 203 ↛ 206line 203 didn't jump to line 206, because the condition on line 203 was never true

204 # Rely on database constraint to enforce invariants; we just 

205 # reraise the exception for consistency across DB engines. 

206 try: 

207 self._db.insert(self._calibs, *rows) 

208 except sqlalchemy.exc.IntegrityError as err: 

209 raise ConflictingDefinitionError( 

210 f"Validity range conflict certifying datasets of type {self.datasetType.name} " 

211 f"into {collection.name} for range [{timespan.begin}, {timespan.end})." 

212 ) from err 

213 else: 

214 # Have to implement exclusion constraint ourselves. 

215 # Start by building a SELECT query for any rows that would overlap 

216 # this one. 

217 relation = self._buildCalibOverlapQuery(collection, dataIds, timespan, context) 

218 # Acquire a table lock to ensure there are no concurrent writes 

219 # could invalidate our checking before we finish the inserts. We 

220 # use a SAVEPOINT in case there is an outer transaction that a 

221 # failure here should not roll back. 

222 with self._db.transaction(lock=[self._calibs], savepoint=True): 

223 # Enter SqlQueryContext in case we need to use a temporary 

224 # table to include the give data IDs in the query. Note that 

225 # by doing this inside the transaction, we make sure it doesn't 

226 # attempt to close the session when its done, since it just 

227 # sees an already-open session that it knows it shouldn't 

228 # manage. 

229 with context: 

230 # Run the check SELECT query. 

231 conflicting = context.count(context.process(relation)) 

232 if conflicting > 0: 

233 raise ConflictingDefinitionError( 

234 f"{conflicting} validity range conflicts certifying datasets of type " 

235 f"{self.datasetType.name} into {collection.name} for range " 

236 f"[{timespan.begin}, {timespan.end})." 

237 ) 

238 # Proceed with the insert. 

239 self._db.insert(self._calibs, *rows) 

240 

241 def decertify( 

242 self, 

243 collection: CollectionRecord, 

244 timespan: Timespan, 

245 *, 

246 dataIds: Iterable[DataCoordinate] | None = None, 

247 context: SqlQueryContext, 

248 ) -> None: 

249 # Docstring inherited from DatasetRecordStorage. 

250 if self._calibs is None: 250 ↛ 251line 250 didn't jump to line 251, because the condition on line 250 was never true

251 raise CollectionTypeError( 

252 f"Cannot decertify datasets of type {self.datasetType.name}, for which " 

253 "DatasetType.isCalibration() is False." 

254 ) 

255 if collection.type is not CollectionType.CALIBRATION: 255 ↛ 256line 255 didn't jump to line 256, because the condition on line 255 was never true

256 raise CollectionTypeError( 

257 f"Cannot decertify from collection '{collection.name}' " 

258 f"of type {collection.type.name}; must be CALIBRATION." 

259 ) 

260 TimespanReprClass = self._db.getTimespanRepresentation() 

261 # Construct a SELECT query to find all rows that overlap our inputs. 

262 dataIdSet: set[DataCoordinate] | None 

263 if dataIds is not None: 

264 dataIdSet = set(dataIds) 

265 else: 

266 dataIdSet = None 

267 relation = self._buildCalibOverlapQuery(collection, dataIdSet, timespan, context) 

268 calib_pkey_tag = DatasetColumnTag(self.datasetType.name, "calib_pkey") 

269 dataset_id_tag = DatasetColumnTag(self.datasetType.name, "dataset_id") 

270 timespan_tag = DatasetColumnTag(self.datasetType.name, "timespan") 

271 data_id_tags = [ 

272 (name, DimensionKeyColumnTag(name)) for name in self.datasetType.dimensions.required.names 

273 ] 

274 # Set up collections to populate with the rows we'll want to modify. 

275 # The insert rows will have the same values for collection and 

276 # dataset type. 

277 protoInsertRow = { 

278 self._collections.getCollectionForeignKeyName(): collection.key, 

279 "dataset_type_id": self._dataset_type_id, 

280 } 

281 rowsToDelete = [] 

282 rowsToInsert = [] 

283 # Acquire a table lock to ensure there are no concurrent writes 

284 # between the SELECT and the DELETE and INSERT queries based on it. 

285 with self._db.transaction(lock=[self._calibs], savepoint=True): 

286 # Enter SqlQueryContext in case we need to use a temporary table to 

287 # include the give data IDs in the query (see similar block in 

288 # certify for details). 

289 with context: 

290 for row in context.fetch_iterable(relation): 

291 rowsToDelete.append({"id": row[calib_pkey_tag]}) 

292 # Construct the insert row(s) by copying the prototype row, 

293 # then adding the dimension column values, then adding 

294 # what's left of the timespan from that row after we 

295 # subtract the given timespan. 

296 newInsertRow = protoInsertRow.copy() 

297 newInsertRow["dataset_id"] = row[dataset_id_tag] 

298 for name, tag in data_id_tags: 

299 newInsertRow[name] = row[tag] 

300 rowTimespan = row[timespan_tag] 

301 assert rowTimespan is not None, "Field should have a NOT NULL constraint." 

302 for diffTimespan in rowTimespan.difference(timespan): 

303 rowsToInsert.append( 

304 TimespanReprClass.update(diffTimespan, result=newInsertRow.copy()) 

305 ) 

306 # Run the DELETE and INSERT queries. 

307 self._db.delete(self._calibs, ["id"], *rowsToDelete) 

308 self._db.insert(self._calibs, *rowsToInsert) 

309 

310 def make_relation( 

311 self, 

312 *collections: CollectionRecord, 

313 columns: Set[str], 

314 context: SqlQueryContext, 

315 ) -> Relation: 

316 # Docstring inherited from DatasetRecordStorage. 

317 collection_types = {collection.type for collection in collections} 

318 assert CollectionType.CHAINED not in collection_types, "CHAINED collections must be flattened." 

319 TimespanReprClass = self._db.getTimespanRepresentation() 

320 # 

321 # There are two kinds of table in play here: 

322 # 

323 # - the static dataset table (with the dataset ID, dataset type ID, 

324 # run ID/name, and ingest date); 

325 # 

326 # - the dynamic tags/calibs table (with the dataset ID, dataset type 

327 # type ID, collection ID/name, data ID, and possibly validity 

328 # range). 

329 # 

330 # That means that we might want to return a query against either table 

331 # or a JOIN of both, depending on which quantities the caller wants. 

332 # But the data ID is always included, which means we'll always include 

333 # the tags/calibs table and join in the static dataset table only if we 

334 # need things from it that we can't get from the tags/calibs table. 

335 # 

336 # Note that it's important that we include a WHERE constraint on both 

337 # tables for any column (e.g. dataset_type_id) that is in both when 

338 # it's given explicitly; not doing can prevent the query planner from 

339 # using very important indexes. At present, we don't include those 

340 # redundant columns in the JOIN ON expression, however, because the 

341 # FOREIGN KEY (and its index) are defined only on dataset_id. 

342 tag_relation: Relation | None = None 

343 calib_relation: Relation | None = None 

344 if collection_types != {CollectionType.CALIBRATION}: 

345 # We'll need a subquery for the tags table if any of the given 

346 # collections are not a CALIBRATION collection. This intentionally 

347 # also fires when the list of collections is empty as a way to 

348 # create a dummy subquery that we know will fail. 

349 # We give the table an alias because it might appear multiple times 

350 # in the same query, for different dataset types. 

351 tags_parts = sql.Payload[LogicalColumn](self._tags.alias(f"{self.datasetType.name}_tags")) 

352 if "timespan" in columns: 

353 tags_parts.columns_available[ 

354 DatasetColumnTag(self.datasetType.name, "timespan") 

355 ] = TimespanReprClass.fromLiteral(Timespan(None, None)) 

356 tag_relation = self._finish_single_relation( 

357 tags_parts, 

358 columns, 

359 [ 

360 (record, rank) 

361 for rank, record in enumerate(collections) 

362 if record.type is not CollectionType.CALIBRATION 

363 ], 

364 context, 

365 ) 

366 assert "calib_pkey" not in columns, "For internal use only, and only for pure-calib queries." 

367 if CollectionType.CALIBRATION in collection_types: 

368 # If at least one collection is a CALIBRATION collection, we'll 

369 # need a subquery for the calibs table, and could include the 

370 # timespan as a result or constraint. 

371 assert ( 

372 self._calibs is not None 

373 ), "DatasetTypes with isCalibration() == False can never be found in a CALIBRATION collection." 

374 calibs_parts = sql.Payload[LogicalColumn](self._calibs.alias(f"{self.datasetType.name}_calibs")) 

375 if "timespan" in columns: 

376 calibs_parts.columns_available[ 

377 DatasetColumnTag(self.datasetType.name, "timespan") 

378 ] = TimespanReprClass.from_columns(calibs_parts.from_clause.columns) 

379 if "calib_pkey" in columns: 

380 # This is a private extension not included in the base class 

381 # interface, for internal use only in _buildCalibOverlapQuery, 

382 # which needs access to the autoincrement primary key for the 

383 # calib association table. 

384 calibs_parts.columns_available[ 

385 DatasetColumnTag(self.datasetType.name, "calib_pkey") 

386 ] = calibs_parts.from_clause.columns.id 

387 calib_relation = self._finish_single_relation( 

388 calibs_parts, 

389 columns, 

390 [ 

391 (record, rank) 

392 for rank, record in enumerate(collections) 

393 if record.type is CollectionType.CALIBRATION 

394 ], 

395 context, 

396 ) 

397 if tag_relation is not None: 

398 if calib_relation is not None: 

399 # daf_relation's chain operation does not automatically 

400 # deduplicate; it's more like SQL's UNION ALL. To get UNION 

401 # in SQL here, we add an explicit deduplication. 

402 return tag_relation.chain(calib_relation).without_duplicates() 

403 else: 

404 return tag_relation 

405 elif calib_relation is not None: 405 ↛ 408line 405 didn't jump to line 408, because the condition on line 405 was never false

406 return calib_relation 

407 else: 

408 raise AssertionError("Branch should be unreachable.") 

409 

410 def _finish_single_relation( 

411 self, 

412 payload: sql.Payload[LogicalColumn], 

413 requested_columns: Set[str], 

414 collections: Sequence[tuple[CollectionRecord, int]], 

415 context: SqlQueryContext, 

416 ) -> Relation: 

417 """Helper method for `make_relation`. 

418 

419 This handles adding columns and WHERE terms that are not specific to 

420 either the tags or calibs tables. 

421 

422 Parameters 

423 ---------- 

424 payload : `lsst.daf.relation.sql.Payload` 

425 SQL query parts under construction, to be modified in-place and 

426 used to construct the new relation. 

427 requested_columns : `~collections.abc.Set` [ `str` ] 

428 Columns the relation should include. 

429 collections : `Sequence` [ `tuple` [ `CollectionRecord`, `int` ] ] 

430 Collections to search for the dataset and their ranks. 

431 context : `SqlQueryContext` 

432 Context that manages engines and state for the query. 

433 

434 Returns 

435 ------- 

436 relation : `lsst.daf.relation.Relation` 

437 New dataset query relation. 

438 """ 

439 payload.where.append(payload.from_clause.columns.dataset_type_id == self._dataset_type_id) 

440 dataset_id_col = payload.from_clause.columns.dataset_id 

441 collection_col = payload.from_clause.columns[self._collections.getCollectionForeignKeyName()] 

442 # We always constrain and optionally retrieve the collection(s) via the 

443 # tags/calibs table. 

444 if len(collections) == 1: 

445 payload.where.append(collection_col == collections[0][0].key) 

446 if "collection" in requested_columns: 

447 payload.columns_available[ 

448 DatasetColumnTag(self.datasetType.name, "collection") 

449 ] = sqlalchemy.sql.literal(collections[0][0].key) 

450 else: 

451 assert collections, "The no-collections case should be in calling code for better diagnostics." 

452 payload.where.append(collection_col.in_([collection.key for collection, _ in collections])) 

453 if "collection" in requested_columns: 

454 payload.columns_available[ 

455 DatasetColumnTag(self.datasetType.name, "collection") 

456 ] = collection_col 

457 # Add rank if requested as a CASE-based calculation the collection 

458 # column. 

459 if "rank" in requested_columns: 

460 payload.columns_available[DatasetColumnTag(self.datasetType.name, "rank")] = sqlalchemy.sql.case( 

461 {record.key: rank for record, rank in collections}, 

462 value=collection_col, 

463 ) 

464 # Add more column definitions, starting with the data ID. 

465 for dimension_name in self.datasetType.dimensions.required.names: 

466 payload.columns_available[DimensionKeyColumnTag(dimension_name)] = payload.from_clause.columns[ 

467 dimension_name 

468 ] 

469 # We can always get the dataset_id from the tags/calibs table. 

470 if "dataset_id" in requested_columns: 

471 payload.columns_available[DatasetColumnTag(self.datasetType.name, "dataset_id")] = dataset_id_col 

472 # It's possible we now have everything we need, from just the 

473 # tags/calibs table. The things we might need to get from the static 

474 # dataset table are the run key and the ingest date. 

475 need_static_table = False 

476 if "run" in requested_columns: 

477 if len(collections) == 1 and collections[0][0].type is CollectionType.RUN: 

478 # If we are searching exactly one RUN collection, we 

479 # know that if we find the dataset in that collection, 

480 # then that's the datasets's run; we don't need to 

481 # query for it. 

482 payload.columns_available[ 

483 DatasetColumnTag(self.datasetType.name, "run") 

484 ] = sqlalchemy.sql.literal(collections[0][0].key) 

485 else: 

486 payload.columns_available[ 

487 DatasetColumnTag(self.datasetType.name, "run") 

488 ] = self._static.dataset.columns[self._runKeyColumn] 

489 need_static_table = True 

490 # Ingest date can only come from the static table. 

491 if "ingest_date" in requested_columns: 

492 need_static_table = True 

493 payload.columns_available[ 

494 DatasetColumnTag(self.datasetType.name, "ingest_date") 

495 ] = self._static.dataset.columns.ingest_date 

496 # If we need the static table, join it in via dataset_id and 

497 # dataset_type_id 

498 if need_static_table: 

499 payload.from_clause = payload.from_clause.join( 

500 self._static.dataset, onclause=(dataset_id_col == self._static.dataset.columns.id) 

501 ) 

502 # Also constrain dataset_type_id in static table in case that helps 

503 # generate a better plan. 

504 # We could also include this in the JOIN ON clause, but my guess is 

505 # that that's a good idea IFF it's in the foreign key, and right 

506 # now it isn't. 

507 payload.where.append(self._static.dataset.columns.dataset_type_id == self._dataset_type_id) 

508 leaf = context.sql_engine.make_leaf( 

509 payload.columns_available.keys(), 

510 payload=payload, 

511 name=self.datasetType.name, 

512 parameters={record.name: rank for record, rank in collections}, 

513 ) 

514 return leaf 

515 

516 def getDataId(self, id: DatasetId) -> DataCoordinate: 

517 """Return DataId for a dataset. 

518 

519 Parameters 

520 ---------- 

521 id : `DatasetId` 

522 Unique dataset identifier. 

523 

524 Returns 

525 ------- 

526 dataId : `DataCoordinate` 

527 DataId for the dataset. 

528 """ 

529 # This query could return multiple rows (one for each tagged collection 

530 # the dataset is in, plus one for its run collection), and we don't 

531 # care which of those we get. 

532 sql = ( 

533 self._tags.select() 

534 .where( 

535 sqlalchemy.sql.and_( 

536 self._tags.columns.dataset_id == id, 

537 self._tags.columns.dataset_type_id == self._dataset_type_id, 

538 ) 

539 ) 

540 .limit(1) 

541 ) 

542 with self._db.query(sql) as sql_result: 

543 row = sql_result.mappings().fetchone() 

544 assert row is not None, "Should be guaranteed by caller and foreign key constraints." 

545 return DataCoordinate.standardize( 

546 {dimension.name: row[dimension.name] for dimension in self.datasetType.dimensions.required}, 

547 graph=self.datasetType.dimensions, 

548 ) 

549 

550 

551class ByDimensionsDatasetRecordStorageUUID(ByDimensionsDatasetRecordStorage): 

552 """Implementation of ByDimensionsDatasetRecordStorage which uses UUID for 

553 dataset IDs. 

554 """ 

555 

556 idMaker = DatasetIdFactory() 

557 """Factory for dataset IDs. In the future this factory may be shared with 

558 other classes (e.g. Registry).""" 

559 

560 def insert( 

561 self, 

562 run: RunRecord, 

563 dataIds: Iterable[DataCoordinate], 

564 idMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE, 

565 ) -> Iterator[DatasetRef]: 

566 # Docstring inherited from DatasetRecordStorage. 

567 

568 # Iterate over data IDs, transforming a possibly-single-pass iterable 

569 # into a list. 

570 dataIdList = [] 

571 rows = [] 

572 summary = CollectionSummary() 

573 for dataId in summary.add_data_ids_generator(self.datasetType, dataIds): 

574 dataIdList.append(dataId) 

575 rows.append( 

576 { 

577 "id": self.idMaker.makeDatasetId(run.name, self.datasetType, dataId, idMode), 

578 "dataset_type_id": self._dataset_type_id, 

579 self._runKeyColumn: run.key, 

580 } 

581 ) 

582 

583 with self._db.transaction(): 

584 # Insert into the static dataset table. 

585 self._db.insert(self._static.dataset, *rows) 

586 # Update the summary tables for this collection in case this is the 

587 # first time this dataset type or these governor values will be 

588 # inserted there. 

589 self._summaries.update(run, [self._dataset_type_id], summary) 

590 # Combine the generated dataset_id values and data ID fields to 

591 # form rows to be inserted into the tags table. 

592 protoTagsRow = { 

593 "dataset_type_id": self._dataset_type_id, 

594 self._collections.getCollectionForeignKeyName(): run.key, 

595 } 

596 tagsRows = [ 

597 dict(protoTagsRow, dataset_id=row["id"], **dataId.byName()) 

598 for dataId, row in zip(dataIdList, rows) 

599 ] 

600 # Insert those rows into the tags table. 

601 self._db.insert(self._tags, *tagsRows) 

602 

603 for dataId, row in zip(dataIdList, rows): 

604 yield DatasetRef( 

605 datasetType=self.datasetType, 

606 dataId=dataId, 

607 id=row["id"], 

608 run=run.name, 

609 ) 

610 

611 def import_( 

612 self, 

613 run: RunRecord, 

614 datasets: Iterable[DatasetRef], 

615 idGenerationMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE, 

616 reuseIds: bool = False, 

617 ) -> Iterator[DatasetRef]: 

618 # Docstring inherited from DatasetRecordStorage. 

619 

620 # Iterate over data IDs, transforming a possibly-single-pass iterable 

621 # into a list. 

622 dataIds = {} 

623 summary = CollectionSummary() 

624 for dataset in summary.add_datasets_generator(datasets): 

625 # Ignore unknown ID types, normally all IDs have the same type but 

626 # this code supports mixed types or missing IDs. 

627 datasetId = dataset.id if isinstance(dataset.id, uuid.UUID) else None 

628 if datasetId is None: 

629 datasetId = self.idMaker.makeDatasetId( 

630 run.name, self.datasetType, dataset.dataId, idGenerationMode 

631 ) 

632 dataIds[datasetId] = dataset.dataId 

633 

634 # We'll insert all new rows into a temporary table 

635 tableSpec = makeTagTableSpec(self.datasetType, type(self._collections), ddl.GUID, constraints=False) 

636 collFkName = self._collections.getCollectionForeignKeyName() 

637 protoTagsRow = { 

638 "dataset_type_id": self._dataset_type_id, 

639 collFkName: run.key, 

640 } 

641 tmpRows = [ 

642 dict(protoTagsRow, dataset_id=dataset_id, **dataId.byName()) 

643 for dataset_id, dataId in dataIds.items() 

644 ] 

645 with self._db.transaction(for_temp_tables=True): 

646 with self._db.temporary_table(tableSpec) as tmp_tags: 

647 # store all incoming data in a temporary table 

648 self._db.insert(tmp_tags, *tmpRows) 

649 

650 # There are some checks that we want to make for consistency 

651 # of the new datasets with existing ones. 

652 self._validateImport(tmp_tags, run) 

653 

654 # Before we merge temporary table into dataset/tags we need to 

655 # drop datasets which are already there (and do not conflict). 

656 self._db.deleteWhere( 

657 tmp_tags, 

658 tmp_tags.columns.dataset_id.in_(sqlalchemy.sql.select(self._static.dataset.columns.id)), 

659 ) 

660 

661 # Copy it into dataset table, need to re-label some columns. 

662 self._db.insert( 

663 self._static.dataset, 

664 select=sqlalchemy.sql.select( 

665 tmp_tags.columns.dataset_id.label("id"), 

666 tmp_tags.columns.dataset_type_id, 

667 tmp_tags.columns[collFkName].label(self._runKeyColumn), 

668 ), 

669 ) 

670 

671 # Update the summary tables for this collection in case this 

672 # is the first time this dataset type or these governor values 

673 # will be inserted there. 

674 self._summaries.update(run, [self._dataset_type_id], summary) 

675 

676 # Copy it into tags table. 

677 self._db.insert(self._tags, select=tmp_tags.select()) 

678 

679 # Return refs in the same order as in the input list. 

680 for dataset_id, dataId in dataIds.items(): 

681 yield DatasetRef( 

682 datasetType=self.datasetType, 

683 id=dataset_id, 

684 dataId=dataId, 

685 run=run.name, 

686 ) 

687 

688 def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) -> None: 

689 """Validate imported refs against existing datasets. 

690 

691 Parameters 

692 ---------- 

693 tmp_tags : `sqlalchemy.schema.Table` 

694 Temporary table with new datasets and the same schema as tags 

695 table. 

696 run : `RunRecord` 

697 The record object describing the `~CollectionType.RUN` collection. 

698 

699 Raises 

700 ------ 

701 ConflictingDefinitionError 

702 Raise if new datasets conflict with existing ones. 

703 """ 

704 dataset = self._static.dataset 

705 tags = self._tags 

706 collFkName = self._collections.getCollectionForeignKeyName() 

707 

708 # Check that existing datasets have the same dataset type and 

709 # run. 

710 query = ( 

711 sqlalchemy.sql.select( 

712 dataset.columns.id.label("dataset_id"), 

713 dataset.columns.dataset_type_id.label("dataset_type_id"), 

714 tmp_tags.columns.dataset_type_id.label("new dataset_type_id"), 

715 dataset.columns[self._runKeyColumn].label("run"), 

716 tmp_tags.columns[collFkName].label("new run"), 

717 ) 

718 .select_from(dataset.join(tmp_tags, dataset.columns.id == tmp_tags.columns.dataset_id)) 

719 .where( 

720 sqlalchemy.sql.or_( 

721 dataset.columns.dataset_type_id != tmp_tags.columns.dataset_type_id, 

722 dataset.columns[self._runKeyColumn] != tmp_tags.columns[collFkName], 

723 ) 

724 ) 

725 .limit(1) 

726 ) 

727 with self._db.query(query) as result: 

728 if (row := result.first()) is not None: 

729 # Only include the first one in the exception message 

730 raise ConflictingDefinitionError( 

731 f"Existing dataset type or run do not match new dataset: {row._asdict()}" 

732 ) 

733 

734 # Check that matching dataset in tags table has the same DataId. 

735 query = ( 

736 sqlalchemy.sql.select( 

737 tags.columns.dataset_id, 

738 tags.columns.dataset_type_id.label("type_id"), 

739 tmp_tags.columns.dataset_type_id.label("new type_id"), 

740 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names], 

741 *[ 

742 tmp_tags.columns[dim].label(f"new {dim}") 

743 for dim in self.datasetType.dimensions.required.names 

744 ], 

745 ) 

746 .select_from(tags.join(tmp_tags, tags.columns.dataset_id == tmp_tags.columns.dataset_id)) 

747 .where( 

748 sqlalchemy.sql.or_( 

749 tags.columns.dataset_type_id != tmp_tags.columns.dataset_type_id, 

750 *[ 

751 tags.columns[dim] != tmp_tags.columns[dim] 

752 for dim in self.datasetType.dimensions.required.names 

753 ], 

754 ) 

755 ) 

756 .limit(1) 

757 ) 

758 

759 with self._db.query(query) as result: 

760 if (row := result.first()) is not None: 

761 # Only include the first one in the exception message 

762 raise ConflictingDefinitionError( 

763 f"Existing dataset type or dataId do not match new dataset: {row._asdict()}" 

764 ) 

765 

766 # Check that matching run+dataId have the same dataset ID. 

767 query = ( 

768 sqlalchemy.sql.select( 

769 tags.columns.dataset_type_id.label("dataset_type_id"), 

770 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names], 

771 tags.columns.dataset_id, 

772 tmp_tags.columns.dataset_id.label("new dataset_id"), 

773 tags.columns[collFkName], 

774 tmp_tags.columns[collFkName].label(f"new {collFkName}"), 

775 ) 

776 .select_from( 

777 tags.join( 

778 tmp_tags, 

779 sqlalchemy.sql.and_( 

780 tags.columns.dataset_type_id == tmp_tags.columns.dataset_type_id, 

781 tags.columns[collFkName] == tmp_tags.columns[collFkName], 

782 *[ 

783 tags.columns[dim] == tmp_tags.columns[dim] 

784 for dim in self.datasetType.dimensions.required.names 

785 ], 

786 ), 

787 ) 

788 ) 

789 .where(tags.columns.dataset_id != tmp_tags.columns.dataset_id) 

790 .limit(1) 

791 ) 

792 with self._db.query(query) as result: 

793 if (row := result.first()) is not None: 

794 # only include the first one in the exception message 

795 raise ConflictingDefinitionError( 

796 f"Existing dataset type and dataId does not match new dataset: {row._asdict()}" 

797 )