Coverage for python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py: 95%

232 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-04-04 02:05 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22 

23from __future__ import annotations 

24 

25__all__ = ("ByDimensionsDatasetRecordStorage",) 

26 

27import uuid 

28from collections.abc import Iterable, Iterator, Sequence, Set 

29from typing import TYPE_CHECKING 

30 

31import sqlalchemy 

32from lsst.daf.relation import Relation, sql 

33 

34from ....core import ( 

35 DataCoordinate, 

36 DatasetColumnTag, 

37 DatasetId, 

38 DatasetRef, 

39 DatasetType, 

40 DimensionKeyColumnTag, 

41 LogicalColumn, 

42 Timespan, 

43 ddl, 

44) 

45from ..._collection_summary import CollectionSummary 

46from ..._collectionType import CollectionType 

47from ..._exceptions import CollectionTypeError, ConflictingDefinitionError 

48from ...interfaces import DatasetIdFactory, DatasetIdGenEnum, DatasetRecordStorage 

49from ...queries import SqlQueryContext 

50from .tables import makeTagTableSpec 

51 

52if TYPE_CHECKING: 

53 from ...interfaces import CollectionManager, CollectionRecord, Database, RunRecord 

54 from .summaries import CollectionSummaryManager 

55 from .tables import StaticDatasetTablesTuple 

56 

57 

58class ByDimensionsDatasetRecordStorage(DatasetRecordStorage): 

59 """Dataset record storage implementation paired with 

60 `ByDimensionsDatasetRecordStorageManagerUUID`; see that class for more 

61 information. 

62 

63 Instances of this class should never be constructed directly; use 

64 `DatasetRecordStorageManager.register` instead. 

65 """ 

66 

67 def __init__( 

68 self, 

69 *, 

70 datasetType: DatasetType, 

71 db: Database, 

72 dataset_type_id: int, 

73 collections: CollectionManager, 

74 static: StaticDatasetTablesTuple, 

75 summaries: CollectionSummaryManager, 

76 tags: sqlalchemy.schema.Table, 

77 calibs: sqlalchemy.schema.Table | None, 

78 ): 

79 super().__init__(datasetType=datasetType) 

80 self._dataset_type_id = dataset_type_id 

81 self._db = db 

82 self._collections = collections 

83 self._static = static 

84 self._summaries = summaries 

85 self._tags = tags 

86 self._calibs = calibs 

87 self._runKeyColumn = collections.getRunForeignKeyName() 

88 

89 def delete(self, datasets: Iterable[DatasetRef]) -> None: 

90 # Docstring inherited from DatasetRecordStorage. 

91 # Only delete from common dataset table; ON DELETE foreign key clauses 

92 # will handle the rest. 

93 self._db.delete( 

94 self._static.dataset, 

95 ["id"], 

96 *[{"id": dataset.getCheckedId()} for dataset in datasets], 

97 ) 

98 

99 def associate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None: 

100 # Docstring inherited from DatasetRecordStorage. 

101 if collection.type is not CollectionType.TAGGED: 101 ↛ 102line 101 didn't jump to line 102, because the condition on line 101 was never true

102 raise TypeError( 

103 f"Cannot associate into collection '{collection.name}' " 

104 f"of type {collection.type.name}; must be TAGGED." 

105 ) 

106 protoRow = { 

107 self._collections.getCollectionForeignKeyName(): collection.key, 

108 "dataset_type_id": self._dataset_type_id, 

109 } 

110 rows = [] 

111 summary = CollectionSummary() 

112 for dataset in summary.add_datasets_generator(datasets): 

113 row = dict(protoRow, dataset_id=dataset.getCheckedId()) 

114 for dimension, value in dataset.dataId.items(): 

115 row[dimension.name] = value 

116 rows.append(row) 

117 # Update the summary tables for this collection in case this is the 

118 # first time this dataset type or these governor values will be 

119 # inserted there. 

120 self._summaries.update(collection, [self._dataset_type_id], summary) 

121 # Update the tag table itself. 

122 self._db.replace(self._tags, *rows) 

123 

124 def disassociate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None: 

125 # Docstring inherited from DatasetRecordStorage. 

126 if collection.type is not CollectionType.TAGGED: 126 ↛ 127line 126 didn't jump to line 127, because the condition on line 126 was never true

127 raise TypeError( 

128 f"Cannot disassociate from collection '{collection.name}' " 

129 f"of type {collection.type.name}; must be TAGGED." 

130 ) 

131 rows = [ 

132 { 

133 "dataset_id": dataset.getCheckedId(), 

134 self._collections.getCollectionForeignKeyName(): collection.key, 

135 } 

136 for dataset in datasets 

137 ] 

138 self._db.delete(self._tags, ["dataset_id", self._collections.getCollectionForeignKeyName()], *rows) 

139 

140 def _buildCalibOverlapQuery( 

141 self, 

142 collection: CollectionRecord, 

143 data_ids: set[DataCoordinate] | None, 

144 timespan: Timespan, 

145 context: SqlQueryContext, 

146 ) -> Relation: 

147 relation = self.make_relation( 

148 collection, columns={"timespan", "dataset_id", "calib_pkey"}, context=context 

149 ).with_rows_satisfying( 

150 context.make_timespan_overlap_predicate( 

151 DatasetColumnTag(self.datasetType.name, "timespan"), timespan 

152 ), 

153 ) 

154 if data_ids is not None: 

155 relation = relation.join( 

156 context.make_data_id_relation( 

157 data_ids, self.datasetType.dimensions.required.names 

158 ).transferred_to(context.sql_engine), 

159 ) 

160 return relation 

161 

162 def certify( 

163 self, 

164 collection: CollectionRecord, 

165 datasets: Iterable[DatasetRef], 

166 timespan: Timespan, 

167 context: SqlQueryContext, 

168 ) -> None: 

169 # Docstring inherited from DatasetRecordStorage. 

170 if self._calibs is None: 170 ↛ 171line 170 didn't jump to line 171, because the condition on line 170 was never true

171 raise CollectionTypeError( 

172 f"Cannot certify datasets of type {self.datasetType.name}, for which " 

173 "DatasetType.isCalibration() is False." 

174 ) 

175 if collection.type is not CollectionType.CALIBRATION: 175 ↛ 176line 175 didn't jump to line 176, because the condition on line 175 was never true

176 raise CollectionTypeError( 

177 f"Cannot certify into collection '{collection.name}' " 

178 f"of type {collection.type.name}; must be CALIBRATION." 

179 ) 

180 TimespanReprClass = self._db.getTimespanRepresentation() 

181 protoRow = { 

182 self._collections.getCollectionForeignKeyName(): collection.key, 

183 "dataset_type_id": self._dataset_type_id, 

184 } 

185 rows = [] 

186 dataIds: set[DataCoordinate] | None = ( 

187 set() if not TimespanReprClass.hasExclusionConstraint() else None 

188 ) 

189 summary = CollectionSummary() 

190 for dataset in summary.add_datasets_generator(datasets): 

191 row = dict(protoRow, dataset_id=dataset.getCheckedId()) 

192 for dimension, value in dataset.dataId.items(): 

193 row[dimension.name] = value 

194 TimespanReprClass.update(timespan, result=row) 

195 rows.append(row) 

196 if dataIds is not None: 196 ↛ 190line 196 didn't jump to line 190, because the condition on line 196 was never false

197 dataIds.add(dataset.dataId) 

198 # Update the summary tables for this collection in case this is the 

199 # first time this dataset type or these governor values will be 

200 # inserted there. 

201 self._summaries.update(collection, [self._dataset_type_id], summary) 

202 # Update the association table itself. 

203 if TimespanReprClass.hasExclusionConstraint(): 203 ↛ 206line 203 didn't jump to line 206, because the condition on line 203 was never true

204 # Rely on database constraint to enforce invariants; we just 

205 # reraise the exception for consistency across DB engines. 

206 try: 

207 self._db.insert(self._calibs, *rows) 

208 except sqlalchemy.exc.IntegrityError as err: 

209 raise ConflictingDefinitionError( 

210 f"Validity range conflict certifying datasets of type {self.datasetType.name} " 

211 f"into {collection.name} for range [{timespan.begin}, {timespan.end})." 

212 ) from err 

213 else: 

214 # Have to implement exclusion constraint ourselves. 

215 # Start by building a SELECT query for any rows that would overlap 

216 # this one. 

217 relation = self._buildCalibOverlapQuery(collection, dataIds, timespan, context) 

218 # Acquire a table lock to ensure there are no concurrent writes 

219 # could invalidate our checking before we finish the inserts. We 

220 # use a SAVEPOINT in case there is an outer transaction that a 

221 # failure here should not roll back. 

222 with self._db.transaction(lock=[self._calibs], savepoint=True): 

223 # Enter SqlQueryContext in case we need to use a temporary 

224 # table to include the give data IDs in the query. Note that 

225 # by doing this inside the transaction, we make sure it doesn't 

226 # attempt to close the session when its done, since it just 

227 # sees an already-open session that it knows it shouldn't 

228 # manage. 

229 with context: 

230 # Run the check SELECT query. 

231 conflicting = context.count(context.process(relation)) 

232 if conflicting > 0: 

233 raise ConflictingDefinitionError( 

234 f"{conflicting} validity range conflicts certifying datasets of type " 

235 f"{self.datasetType.name} into {collection.name} for range " 

236 f"[{timespan.begin}, {timespan.end})." 

237 ) 

238 # Proceed with the insert. 

239 self._db.insert(self._calibs, *rows) 

240 

241 def decertify( 

242 self, 

243 collection: CollectionRecord, 

244 timespan: Timespan, 

245 *, 

246 dataIds: Iterable[DataCoordinate] | None = None, 

247 context: SqlQueryContext, 

248 ) -> None: 

249 # Docstring inherited from DatasetRecordStorage. 

250 if self._calibs is None: 250 ↛ 251line 250 didn't jump to line 251, because the condition on line 250 was never true

251 raise CollectionTypeError( 

252 f"Cannot decertify datasets of type {self.datasetType.name}, for which " 

253 "DatasetType.isCalibration() is False." 

254 ) 

255 if collection.type is not CollectionType.CALIBRATION: 255 ↛ 256line 255 didn't jump to line 256, because the condition on line 255 was never true

256 raise CollectionTypeError( 

257 f"Cannot decertify from collection '{collection.name}' " 

258 f"of type {collection.type.name}; must be CALIBRATION." 

259 ) 

260 TimespanReprClass = self._db.getTimespanRepresentation() 

261 # Construct a SELECT query to find all rows that overlap our inputs. 

262 dataIdSet: set[DataCoordinate] | None 

263 if dataIds is not None: 

264 dataIdSet = set(dataIds) 

265 else: 

266 dataIdSet = None 

267 relation = self._buildCalibOverlapQuery(collection, dataIdSet, timespan, context) 

268 calib_pkey_tag = DatasetColumnTag(self.datasetType.name, "calib_pkey") 

269 dataset_id_tag = DatasetColumnTag(self.datasetType.name, "dataset_id") 

270 timespan_tag = DatasetColumnTag(self.datasetType.name, "timespan") 

271 data_id_tags = [ 

272 (name, DimensionKeyColumnTag(name)) for name in self.datasetType.dimensions.required.names 

273 ] 

274 # Set up collections to populate with the rows we'll want to modify. 

275 # The insert rows will have the same values for collection and 

276 # dataset type. 

277 protoInsertRow = { 

278 self._collections.getCollectionForeignKeyName(): collection.key, 

279 "dataset_type_id": self._dataset_type_id, 

280 } 

281 rowsToDelete = [] 

282 rowsToInsert = [] 

283 # Acquire a table lock to ensure there are no concurrent writes 

284 # between the SELECT and the DELETE and INSERT queries based on it. 

285 with self._db.transaction(lock=[self._calibs], savepoint=True): 

286 # Enter SqlQueryContext in case we need to use a temporary table to 

287 # include the give data IDs in the query (see similar block in 

288 # certify for details). 

289 with context: 

290 for row in context.fetch_iterable(relation): 

291 rowsToDelete.append({"id": row[calib_pkey_tag]}) 

292 # Construct the insert row(s) by copying the prototype row, 

293 # then adding the dimension column values, then adding 

294 # what's left of the timespan from that row after we 

295 # subtract the given timespan. 

296 newInsertRow = protoInsertRow.copy() 

297 newInsertRow["dataset_id"] = row[dataset_id_tag] 

298 for name, tag in data_id_tags: 

299 newInsertRow[name] = row[tag] 

300 rowTimespan = row[timespan_tag] 

301 assert rowTimespan is not None, "Field should have a NOT NULL constraint." 

302 for diffTimespan in rowTimespan.difference(timespan): 

303 rowsToInsert.append( 

304 TimespanReprClass.update(diffTimespan, result=newInsertRow.copy()) 

305 ) 

306 # Run the DELETE and INSERT queries. 

307 self._db.delete(self._calibs, ["id"], *rowsToDelete) 

308 self._db.insert(self._calibs, *rowsToInsert) 

309 

310 def make_relation( 

311 self, 

312 *collections: CollectionRecord, 

313 columns: Set[str], 

314 context: SqlQueryContext, 

315 ) -> Relation: 

316 # Docstring inherited from DatasetRecordStorage. 

317 collection_types = {collection.type for collection in collections} 

318 assert CollectionType.CHAINED not in collection_types, "CHAINED collections must be flattened." 

319 TimespanReprClass = self._db.getTimespanRepresentation() 

320 # 

321 # There are two kinds of table in play here: 

322 # 

323 # - the static dataset table (with the dataset ID, dataset type ID, 

324 # run ID/name, and ingest date); 

325 # 

326 # - the dynamic tags/calibs table (with the dataset ID, dataset type 

327 # type ID, collection ID/name, data ID, and possibly validity 

328 # range). 

329 # 

330 # That means that we might want to return a query against either table 

331 # or a JOIN of both, depending on which quantities the caller wants. 

332 # But the data ID is always included, which means we'll always include 

333 # the tags/calibs table and join in the static dataset table only if we 

334 # need things from it that we can't get from the tags/calibs table. 

335 # 

336 # Note that it's important that we include a WHERE constraint on both 

337 # tables for any column (e.g. dataset_type_id) that is in both when 

338 # it's given explicitly; not doing can prevent the query planner from 

339 # using very important indexes. At present, we don't include those 

340 # redundant columns in the JOIN ON expression, however, because the 

341 # FOREIGN KEY (and its index) are defined only on dataset_id. 

342 tag_relation: Relation | None = None 

343 calib_relation: Relation | None = None 

344 if collection_types != {CollectionType.CALIBRATION}: 

345 # We'll need a subquery for the tags table if any of the given 

346 # collections are not a CALIBRATION collection. This intentionally 

347 # also fires when the list of collections is empty as a way to 

348 # create a dummy subquery that we know will fail. 

349 # We give the table an alias because it might appear multiple times 

350 # in the same query, for different dataset types. 

351 tags_parts = sql.Payload[LogicalColumn](self._tags.alias(f"{self.datasetType.name}_tags")) 

352 if "timespan" in columns: 

353 tags_parts.columns_available[ 

354 DatasetColumnTag(self.datasetType.name, "timespan") 

355 ] = TimespanReprClass.fromLiteral(Timespan(None, None)) 

356 tag_relation = self._finish_single_relation( 

357 tags_parts, 

358 columns, 

359 [ 

360 (record, rank) 

361 for rank, record in enumerate(collections) 

362 if record.type is not CollectionType.CALIBRATION 

363 ], 

364 context, 

365 ) 

366 assert "calib_pkey" not in columns, "For internal use only, and only for pure-calib queries." 

367 if CollectionType.CALIBRATION in collection_types: 

368 # If at least one collection is a CALIBRATION collection, we'll 

369 # need a subquery for the calibs table, and could include the 

370 # timespan as a result or constraint. 

371 assert ( 

372 self._calibs is not None 

373 ), "DatasetTypes with isCalibration() == False can never be found in a CALIBRATION collection." 

374 calibs_parts = sql.Payload[LogicalColumn](self._calibs.alias(f"{self.datasetType.name}_calibs")) 

375 if "timespan" in columns: 

376 calibs_parts.columns_available[ 

377 DatasetColumnTag(self.datasetType.name, "timespan") 

378 ] = TimespanReprClass.from_columns(calibs_parts.from_clause.columns) 

379 if "calib_pkey" in columns: 

380 # This is a private extension not included in the base class 

381 # interface, for internal use only in _buildCalibOverlapQuery, 

382 # which needs access to the autoincrement primary key for the 

383 # calib association table. 

384 calibs_parts.columns_available[ 

385 DatasetColumnTag(self.datasetType.name, "calib_pkey") 

386 ] = calibs_parts.from_clause.columns.id 

387 calib_relation = self._finish_single_relation( 

388 calibs_parts, 

389 columns, 

390 [ 

391 (record, rank) 

392 for rank, record in enumerate(collections) 

393 if record.type is CollectionType.CALIBRATION 

394 ], 

395 context, 

396 ) 

397 if tag_relation is not None: 

398 if calib_relation is not None: 

399 # daf_relation's chain operation does not automatically 

400 # deduplicate; it's more like SQL's UNION ALL. To get UNION 

401 # in SQL here, we add an explicit deduplication. 

402 return tag_relation.chain(calib_relation).without_duplicates() 

403 else: 

404 return tag_relation 

405 elif calib_relation is not None: 

406 return calib_relation 

407 else: 

408 raise AssertionError("Branch should be unreachable.") 

409 

410 def _finish_single_relation( 

411 self, 

412 payload: sql.Payload[LogicalColumn], 

413 requested_columns: Set[str], 

414 collections: Sequence[tuple[CollectionRecord, int]], 

415 context: SqlQueryContext, 

416 ) -> Relation: 

417 """Helper method for `make_relation`. 

418 

419 This handles adding columns and WHERE terms that are not specific to 

420 either the tags or calibs tables. 

421 

422 Parameters 

423 ---------- 

424 payload : `lsst.daf.relation.sql.Payload` 

425 SQL query parts under construction, to be modified in-place and 

426 used to construct the new relation. 

427 requested_columns : `~collections.abc.Set` [ `str` ] 

428 Columns the relation should include. 

429 collections : `Sequence` [ `tuple` [ `CollectionRecord`, `int` ] ] 

430 Collections to search for the dataset and their ranks. 

431 context : `SqlQueryContext` 

432 Context that manages engines and state for the query. 

433 

434 Returns 

435 ------- 

436 relation : `lsst.daf.relation.Relation` 

437 New dataset query relation. 

438 """ 

439 payload.where.append(payload.from_clause.columns.dataset_type_id == self._dataset_type_id) 

440 dataset_id_col = payload.from_clause.columns.dataset_id 

441 collection_col = payload.from_clause.columns[self._collections.getCollectionForeignKeyName()] 

442 # We always constrain and optionally retrieve the collection(s) via the 

443 # tags/calibs table. 

444 if len(collections) == 1: 

445 payload.where.append(collection_col == collections[0][0].key) 

446 if "collection" in requested_columns: 

447 payload.columns_available[ 

448 DatasetColumnTag(self.datasetType.name, "collection") 

449 ] = sqlalchemy.sql.literal(collections[0][0].key) 

450 else: 

451 assert collections, "The no-collections case should be in calling code for better diagnostics." 

452 payload.where.append(collection_col.in_([collection.key for collection, _ in collections])) 

453 if "collection" in requested_columns: 

454 payload.columns_available[ 

455 DatasetColumnTag(self.datasetType.name, "collection") 

456 ] = collection_col 

457 # Add rank if requested as a CASE-based calculation the collection 

458 # column. 

459 if "rank" in requested_columns: 

460 payload.columns_available[DatasetColumnTag(self.datasetType.name, "rank")] = sqlalchemy.sql.case( 

461 {record.key: rank for record, rank in collections}, 

462 value=collection_col, 

463 ) 

464 # Add more column definitions, starting with the data ID. 

465 for dimension_name in self.datasetType.dimensions.required.names: 

466 payload.columns_available[DimensionKeyColumnTag(dimension_name)] = payload.from_clause.columns[ 

467 dimension_name 

468 ] 

469 # We can always get the dataset_id from the tags/calibs table. 

470 if "dataset_id" in requested_columns: 

471 payload.columns_available[DatasetColumnTag(self.datasetType.name, "dataset_id")] = dataset_id_col 

472 # It's possible we now have everything we need, from just the 

473 # tags/calibs table. The things we might need to get from the static 

474 # dataset table are the run key and the ingest date. 

475 need_static_table = False 

476 if "run" in requested_columns: 

477 if len(collections) == 1 and collections[0][0].type is CollectionType.RUN: 

478 # If we are searching exactly one RUN collection, we 

479 # know that if we find the dataset in that collection, 

480 # then that's the datasets's run; we don't need to 

481 # query for it. 

482 payload.columns_available[ 

483 DatasetColumnTag(self.datasetType.name, "run") 

484 ] = sqlalchemy.sql.literal(collections[0][0].key) 

485 else: 

486 payload.columns_available[ 

487 DatasetColumnTag(self.datasetType.name, "run") 

488 ] = self._static.dataset.columns[self._runKeyColumn] 

489 need_static_table = True 

490 # Ingest date can only come from the static table. 

491 if "ingest_date" in requested_columns: 

492 need_static_table = True 

493 payload.columns_available[ 

494 DatasetColumnTag(self.datasetType.name, "ingest_date") 

495 ] = self._static.dataset.columns.ingest_date 

496 # If we need the static table, join it in via dataset_id and 

497 # dataset_type_id 

498 if need_static_table: 

499 payload.from_clause = payload.from_clause.join( 

500 self._static.dataset, onclause=(dataset_id_col == self._static.dataset.columns.id) 

501 ) 

502 # Also constrain dataset_type_id in static table in case that helps 

503 # generate a better plan. 

504 # We could also include this in the JOIN ON clause, but my guess is 

505 # that that's a good idea IFF it's in the foreign key, and right 

506 # now it isn't. 

507 payload.where.append(self._static.dataset.columns.dataset_type_id == self._dataset_type_id) 

508 leaf = context.sql_engine.make_leaf( 

509 payload.columns_available.keys(), 

510 payload=payload, 

511 name=self.datasetType.name, 

512 parameters={record.name: rank for record, rank in collections}, 

513 ) 

514 return leaf 

515 

516 def getDataId(self, id: DatasetId) -> DataCoordinate: 

517 """Return DataId for a dataset. 

518 

519 Parameters 

520 ---------- 

521 id : `DatasetId` 

522 Unique dataset identifier. 

523 

524 Returns 

525 ------- 

526 dataId : `DataCoordinate` 

527 DataId for the dataset. 

528 """ 

529 # This query could return multiple rows (one for each tagged collection 

530 # the dataset is in, plus one for its run collection), and we don't 

531 # care which of those we get. 

532 sql = ( 

533 self._tags.select() 

534 .where( 

535 sqlalchemy.sql.and_( 

536 self._tags.columns.dataset_id == id, 

537 self._tags.columns.dataset_type_id == self._dataset_type_id, 

538 ) 

539 ) 

540 .limit(1) 

541 ) 

542 with self._db.query(sql) as sql_result: 

543 row = sql_result.mappings().fetchone() 

544 assert row is not None, "Should be guaranteed by caller and foreign key constraints." 

545 return DataCoordinate.standardize( 

546 {dimension.name: row[dimension.name] for dimension in self.datasetType.dimensions.required}, 

547 graph=self.datasetType.dimensions, 

548 ) 

549 

550 

551class ByDimensionsDatasetRecordStorageUUID(ByDimensionsDatasetRecordStorage): 

552 """Implementation of ByDimensionsDatasetRecordStorage which uses UUID for 

553 dataset IDs. 

554 """ 

555 

556 idMaker = DatasetIdFactory() 

557 """Factory for dataset IDs. In the future this factory may be shared with 

558 other classes (e.g. Registry).""" 

559 

560 def insert( 

561 self, 

562 run: RunRecord, 

563 dataIds: Iterable[DataCoordinate], 

564 idMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE, 

565 ) -> Iterator[DatasetRef]: 

566 # Docstring inherited from DatasetRecordStorage. 

567 

568 # Iterate over data IDs, transforming a possibly-single-pass iterable 

569 # into a list. 

570 dataIdList = [] 

571 rows = [] 

572 summary = CollectionSummary() 

573 for dataId in summary.add_data_ids_generator(self.datasetType, dataIds): 

574 dataIdList.append(dataId) 

575 rows.append( 

576 { 

577 "id": self.idMaker.makeDatasetId(run.name, self.datasetType, dataId, idMode), 

578 "dataset_type_id": self._dataset_type_id, 

579 self._runKeyColumn: run.key, 

580 } 

581 ) 

582 

583 with self._db.transaction(): 

584 # Insert into the static dataset table. 

585 self._db.insert(self._static.dataset, *rows) 

586 # Update the summary tables for this collection in case this is the 

587 # first time this dataset type or these governor values will be 

588 # inserted there. 

589 self._summaries.update(run, [self._dataset_type_id], summary) 

590 # Combine the generated dataset_id values and data ID fields to 

591 # form rows to be inserted into the tags table. 

592 protoTagsRow = { 

593 "dataset_type_id": self._dataset_type_id, 

594 self._collections.getCollectionForeignKeyName(): run.key, 

595 } 

596 tagsRows = [ 

597 dict(protoTagsRow, dataset_id=row["id"], **dataId.byName()) 

598 for dataId, row in zip(dataIdList, rows) 

599 ] 

600 # Insert those rows into the tags table. 

601 self._db.insert(self._tags, *tagsRows) 

602 

603 for dataId, row in zip(dataIdList, rows): 

604 yield DatasetRef( 

605 datasetType=self.datasetType, 

606 dataId=dataId, 

607 id=row["id"], 

608 run=run.name, 

609 ) 

610 

611 def import_( 

612 self, 

613 run: RunRecord, 

614 datasets: Iterable[DatasetRef], 

615 idGenerationMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE, 

616 reuseIds: bool = False, 

617 ) -> Iterator[DatasetRef]: 

618 # Docstring inherited from DatasetRecordStorage. 

619 

620 # Iterate over data IDs, transforming a possibly-single-pass iterable 

621 # into a list. 

622 dataIds = {} 

623 summary = CollectionSummary() 

624 for dataset in summary.add_datasets_generator(datasets): 

625 # Ignore unknown ID types, normally all IDs have the same type but 

626 # this code supports mixed types or missing IDs. 

627 datasetId = dataset.id if isinstance(dataset.id, uuid.UUID) else None 

628 if datasetId is None: 

629 datasetId = self.idMaker.makeDatasetId( 

630 run.name, self.datasetType, dataset.dataId, idGenerationMode 

631 ) 

632 dataIds[datasetId] = dataset.dataId 

633 

634 # We'll insert all new rows into a temporary table 

635 tableSpec = makeTagTableSpec(self.datasetType, type(self._collections), ddl.GUID, constraints=False) 

636 collFkName = self._collections.getCollectionForeignKeyName() 

637 protoTagsRow = { 

638 "dataset_type_id": self._dataset_type_id, 

639 collFkName: run.key, 

640 } 

641 tmpRows = [ 

642 dict(protoTagsRow, dataset_id=dataset_id, **dataId.byName()) 

643 for dataset_id, dataId in dataIds.items() 

644 ] 

645 with self._db.transaction(for_temp_tables=True): 

646 with self._db.temporary_table(tableSpec) as tmp_tags: 

647 # store all incoming data in a temporary table 

648 self._db.insert(tmp_tags, *tmpRows) 

649 

650 # There are some checks that we want to make for consistency 

651 # of the new datasets with existing ones. 

652 self._validateImport(tmp_tags, run) 

653 

654 # Before we merge temporary table into dataset/tags we need to 

655 # drop datasets which are already there (and do not conflict). 

656 self._db.deleteWhere( 

657 tmp_tags, 

658 tmp_tags.columns.dataset_id.in_(sqlalchemy.sql.select(self._static.dataset.columns.id)), 

659 ) 

660 

661 # Copy it into dataset table, need to re-label some columns. 

662 self._db.insert( 

663 self._static.dataset, 

664 select=sqlalchemy.sql.select( 

665 tmp_tags.columns.dataset_id.label("id"), 

666 tmp_tags.columns.dataset_type_id, 

667 tmp_tags.columns[collFkName].label(self._runKeyColumn), 

668 ), 

669 ) 

670 

671 # Update the summary tables for this collection in case this 

672 # is the first time this dataset type or these governor values 

673 # will be inserted there. 

674 self._summaries.update(run, [self._dataset_type_id], summary) 

675 

676 # Copy it into tags table. 

677 self._db.insert(self._tags, select=tmp_tags.select()) 

678 

679 # Return refs in the same order as in the input list. 

680 for dataset_id, dataId in dataIds.items(): 

681 yield DatasetRef( 

682 datasetType=self.datasetType, 

683 id=dataset_id, 

684 dataId=dataId, 

685 run=run.name, 

686 ) 

687 

688 def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) -> None: 

689 """Validate imported refs against existing datasets. 

690 

691 Parameters 

692 ---------- 

693 tmp_tags : `sqlalchemy.schema.Table` 

694 Temporary table with new datasets and the same schema as tags 

695 table. 

696 run : `RunRecord` 

697 The record object describing the `~CollectionType.RUN` collection. 

698 

699 Raises 

700 ------ 

701 ConflictingDefinitionError 

702 Raise if new datasets conflict with existing ones. 

703 """ 

704 dataset = self._static.dataset 

705 tags = self._tags 

706 collFkName = self._collections.getCollectionForeignKeyName() 

707 

708 # Check that existing datasets have the same dataset type and 

709 # run. 

710 query = ( 

711 sqlalchemy.sql.select( 

712 dataset.columns.id.label("dataset_id"), 

713 dataset.columns.dataset_type_id.label("dataset_type_id"), 

714 tmp_tags.columns.dataset_type_id.label("new dataset_type_id"), 

715 dataset.columns[self._runKeyColumn].label("run"), 

716 tmp_tags.columns[collFkName].label("new run"), 

717 ) 

718 .select_from(dataset.join(tmp_tags, dataset.columns.id == tmp_tags.columns.dataset_id)) 

719 .where( 

720 sqlalchemy.sql.or_( 

721 dataset.columns.dataset_type_id != tmp_tags.columns.dataset_type_id, 

722 dataset.columns[self._runKeyColumn] != tmp_tags.columns[collFkName], 

723 ) 

724 ) 

725 .limit(1) 

726 ) 

727 with self._db.query(query) as result: 

728 if (row := result.first()) is not None: 

729 # Only include the first one in the exception message 

730 raise ConflictingDefinitionError( 

731 f"Existing dataset type or run do not match new dataset: {row._asdict()}" 

732 ) 

733 

734 # Check that matching dataset in tags table has the same DataId. 

735 query = ( 

736 sqlalchemy.sql.select( 

737 tags.columns.dataset_id, 

738 tags.columns.dataset_type_id.label("type_id"), 

739 tmp_tags.columns.dataset_type_id.label("new type_id"), 

740 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names], 

741 *[ 

742 tmp_tags.columns[dim].label(f"new {dim}") 

743 for dim in self.datasetType.dimensions.required.names 

744 ], 

745 ) 

746 .select_from(tags.join(tmp_tags, tags.columns.dataset_id == tmp_tags.columns.dataset_id)) 

747 .where( 

748 sqlalchemy.sql.or_( 

749 tags.columns.dataset_type_id != tmp_tags.columns.dataset_type_id, 

750 *[ 

751 tags.columns[dim] != tmp_tags.columns[dim] 

752 for dim in self.datasetType.dimensions.required.names 

753 ], 

754 ) 

755 ) 

756 .limit(1) 

757 ) 

758 

759 with self._db.query(query) as result: 

760 if (row := result.first()) is not None: 

761 # Only include the first one in the exception message 

762 raise ConflictingDefinitionError( 

763 f"Existing dataset type or dataId do not match new dataset: {row._asdict()}" 

764 ) 

765 

766 # Check that matching run+dataId have the same dataset ID. 

767 query = ( 

768 sqlalchemy.sql.select( 

769 tags.columns.dataset_type_id.label("dataset_type_id"), 

770 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names], 

771 tags.columns.dataset_id, 

772 tmp_tags.columns.dataset_id.label("new dataset_id"), 

773 tags.columns[collFkName], 

774 tmp_tags.columns[collFkName].label(f"new {collFkName}"), 

775 ) 

776 .select_from( 

777 tags.join( 

778 tmp_tags, 

779 sqlalchemy.sql.and_( 

780 tags.columns.dataset_type_id == tmp_tags.columns.dataset_type_id, 

781 tags.columns[collFkName] == tmp_tags.columns[collFkName], 

782 *[ 

783 tags.columns[dim] == tmp_tags.columns[dim] 

784 for dim in self.datasetType.dimensions.required.names 

785 ], 

786 ), 

787 ) 

788 ) 

789 .where(tags.columns.dataset_id != tmp_tags.columns.dataset_id) 

790 .limit(1) 

791 ) 

792 with self._db.query(query) as result: 

793 if (row := result.first()) is not None: 

794 # only include the first one in the exception message 

795 raise ConflictingDefinitionError( 

796 f"Existing dataset type and dataId does not match new dataset: {row._asdict()}" 

797 )