Coverage for python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py: 95%

236 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-08-12 09:19 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22 

23from __future__ import annotations 

24 

25__all__ = ("ByDimensionsDatasetRecordStorage",) 

26 

27from collections.abc import Iterable, Iterator, Sequence, Set 

28from datetime import datetime 

29from typing import TYPE_CHECKING 

30 

31import astropy.time 

32import sqlalchemy 

33from lsst.daf.relation import Relation, sql 

34 

35from ....core import ( 

36 DataCoordinate, 

37 DatasetColumnTag, 

38 DatasetId, 

39 DatasetIdFactory, 

40 DatasetIdGenEnum, 

41 DatasetRef, 

42 DatasetType, 

43 DimensionKeyColumnTag, 

44 LogicalColumn, 

45 Timespan, 

46 ddl, 

47) 

48from ..._collection_summary import CollectionSummary 

49from ..._collectionType import CollectionType 

50from ..._exceptions import CollectionTypeError, ConflictingDefinitionError 

51from ...interfaces import DatasetRecordStorage 

52from ...queries import SqlQueryContext 

53from .tables import makeTagTableSpec 

54 

55if TYPE_CHECKING: 

56 from ...interfaces import CollectionManager, CollectionRecord, Database, RunRecord 

57 from .summaries import CollectionSummaryManager 

58 from .tables import StaticDatasetTablesTuple 

59 

60 

61class ByDimensionsDatasetRecordStorage(DatasetRecordStorage): 

62 """Dataset record storage implementation paired with 

63 `ByDimensionsDatasetRecordStorageManagerUUID`; see that class for more 

64 information. 

65 

66 Instances of this class should never be constructed directly; use 

67 `DatasetRecordStorageManager.register` instead. 

68 """ 

69 

70 def __init__( 

71 self, 

72 *, 

73 datasetType: DatasetType, 

74 db: Database, 

75 dataset_type_id: int, 

76 collections: CollectionManager, 

77 static: StaticDatasetTablesTuple, 

78 summaries: CollectionSummaryManager, 

79 tags: sqlalchemy.schema.Table, 

80 use_astropy_ingest_date: bool, 

81 calibs: sqlalchemy.schema.Table | None, 

82 ): 

83 super().__init__(datasetType=datasetType) 

84 self._dataset_type_id = dataset_type_id 

85 self._db = db 

86 self._collections = collections 

87 self._static = static 

88 self._summaries = summaries 

89 self._tags = tags 

90 self._calibs = calibs 

91 self._runKeyColumn = collections.getRunForeignKeyName() 

92 self._use_astropy = use_astropy_ingest_date 

93 

94 def delete(self, datasets: Iterable[DatasetRef]) -> None: 

95 # Docstring inherited from DatasetRecordStorage. 

96 # Only delete from common dataset table; ON DELETE foreign key clauses 

97 # will handle the rest. 

98 self._db.delete( 

99 self._static.dataset, 

100 ["id"], 

101 *[{"id": dataset.id} for dataset in datasets], 

102 ) 

103 

104 def associate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None: 

105 # Docstring inherited from DatasetRecordStorage. 

106 if collection.type is not CollectionType.TAGGED: 106 ↛ 107line 106 didn't jump to line 107, because the condition on line 106 was never true

107 raise TypeError( 

108 f"Cannot associate into collection '{collection.name}' " 

109 f"of type {collection.type.name}; must be TAGGED." 

110 ) 

111 protoRow = { 

112 self._collections.getCollectionForeignKeyName(): collection.key, 

113 "dataset_type_id": self._dataset_type_id, 

114 } 

115 rows = [] 

116 summary = CollectionSummary() 

117 for dataset in summary.add_datasets_generator(datasets): 

118 row = dict(protoRow, dataset_id=dataset.id) 

119 for dimension, value in dataset.dataId.items(): 

120 row[dimension.name] = value 

121 rows.append(row) 

122 # Update the summary tables for this collection in case this is the 

123 # first time this dataset type or these governor values will be 

124 # inserted there. 

125 self._summaries.update(collection, [self._dataset_type_id], summary) 

126 # Update the tag table itself. 

127 self._db.replace(self._tags, *rows) 

128 

129 def disassociate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None: 

130 # Docstring inherited from DatasetRecordStorage. 

131 if collection.type is not CollectionType.TAGGED: 131 ↛ 132line 131 didn't jump to line 132, because the condition on line 131 was never true

132 raise TypeError( 

133 f"Cannot disassociate from collection '{collection.name}' " 

134 f"of type {collection.type.name}; must be TAGGED." 

135 ) 

136 rows = [ 

137 { 

138 "dataset_id": dataset.id, 

139 self._collections.getCollectionForeignKeyName(): collection.key, 

140 } 

141 for dataset in datasets 

142 ] 

143 self._db.delete(self._tags, ["dataset_id", self._collections.getCollectionForeignKeyName()], *rows) 

144 

145 def _buildCalibOverlapQuery( 

146 self, 

147 collection: CollectionRecord, 

148 data_ids: set[DataCoordinate] | None, 

149 timespan: Timespan, 

150 context: SqlQueryContext, 

151 ) -> Relation: 

152 relation = self.make_relation( 

153 collection, columns={"timespan", "dataset_id", "calib_pkey"}, context=context 

154 ).with_rows_satisfying( 

155 context.make_timespan_overlap_predicate( 

156 DatasetColumnTag(self.datasetType.name, "timespan"), timespan 

157 ), 

158 ) 

159 if data_ids is not None: 

160 relation = relation.join( 

161 context.make_data_id_relation( 

162 data_ids, self.datasetType.dimensions.required.names 

163 ).transferred_to(context.sql_engine), 

164 ) 

165 return relation 

166 

167 def certify( 

168 self, 

169 collection: CollectionRecord, 

170 datasets: Iterable[DatasetRef], 

171 timespan: Timespan, 

172 context: SqlQueryContext, 

173 ) -> None: 

174 # Docstring inherited from DatasetRecordStorage. 

175 if self._calibs is None: 175 ↛ 176line 175 didn't jump to line 176, because the condition on line 175 was never true

176 raise CollectionTypeError( 

177 f"Cannot certify datasets of type {self.datasetType.name}, for which " 

178 "DatasetType.isCalibration() is False." 

179 ) 

180 if collection.type is not CollectionType.CALIBRATION: 180 ↛ 181line 180 didn't jump to line 181, because the condition on line 180 was never true

181 raise CollectionTypeError( 

182 f"Cannot certify into collection '{collection.name}' " 

183 f"of type {collection.type.name}; must be CALIBRATION." 

184 ) 

185 TimespanReprClass = self._db.getTimespanRepresentation() 

186 protoRow = { 

187 self._collections.getCollectionForeignKeyName(): collection.key, 

188 "dataset_type_id": self._dataset_type_id, 

189 } 

190 rows = [] 

191 dataIds: set[DataCoordinate] | None = ( 

192 set() if not TimespanReprClass.hasExclusionConstraint() else None 

193 ) 

194 summary = CollectionSummary() 

195 for dataset in summary.add_datasets_generator(datasets): 

196 row = dict(protoRow, dataset_id=dataset.id) 

197 for dimension, value in dataset.dataId.items(): 

198 row[dimension.name] = value 

199 TimespanReprClass.update(timespan, result=row) 

200 rows.append(row) 

201 if dataIds is not None: 201 ↛ 195line 201 didn't jump to line 195, because the condition on line 201 was never false

202 dataIds.add(dataset.dataId) 

203 # Update the summary tables for this collection in case this is the 

204 # first time this dataset type or these governor values will be 

205 # inserted there. 

206 self._summaries.update(collection, [self._dataset_type_id], summary) 

207 # Update the association table itself. 

208 if TimespanReprClass.hasExclusionConstraint(): 208 ↛ 211line 208 didn't jump to line 211, because the condition on line 208 was never true

209 # Rely on database constraint to enforce invariants; we just 

210 # reraise the exception for consistency across DB engines. 

211 try: 

212 self._db.insert(self._calibs, *rows) 

213 except sqlalchemy.exc.IntegrityError as err: 

214 raise ConflictingDefinitionError( 

215 f"Validity range conflict certifying datasets of type {self.datasetType.name} " 

216 f"into {collection.name} for range [{timespan.begin}, {timespan.end})." 

217 ) from err 

218 else: 

219 # Have to implement exclusion constraint ourselves. 

220 # Start by building a SELECT query for any rows that would overlap 

221 # this one. 

222 relation = self._buildCalibOverlapQuery(collection, dataIds, timespan, context) 

223 # Acquire a table lock to ensure there are no concurrent writes 

224 # could invalidate our checking before we finish the inserts. We 

225 # use a SAVEPOINT in case there is an outer transaction that a 

226 # failure here should not roll back. 

227 with self._db.transaction(lock=[self._calibs], savepoint=True): 

228 # Enter SqlQueryContext in case we need to use a temporary 

229 # table to include the give data IDs in the query. Note that 

230 # by doing this inside the transaction, we make sure it doesn't 

231 # attempt to close the session when its done, since it just 

232 # sees an already-open session that it knows it shouldn't 

233 # manage. 

234 with context: 

235 # Run the check SELECT query. 

236 conflicting = context.count(context.process(relation)) 

237 if conflicting > 0: 

238 raise ConflictingDefinitionError( 

239 f"{conflicting} validity range conflicts certifying datasets of type " 

240 f"{self.datasetType.name} into {collection.name} for range " 

241 f"[{timespan.begin}, {timespan.end})." 

242 ) 

243 # Proceed with the insert. 

244 self._db.insert(self._calibs, *rows) 

245 

246 def decertify( 

247 self, 

248 collection: CollectionRecord, 

249 timespan: Timespan, 

250 *, 

251 dataIds: Iterable[DataCoordinate] | None = None, 

252 context: SqlQueryContext, 

253 ) -> None: 

254 # Docstring inherited from DatasetRecordStorage. 

255 if self._calibs is None: 255 ↛ 256line 255 didn't jump to line 256, because the condition on line 255 was never true

256 raise CollectionTypeError( 

257 f"Cannot decertify datasets of type {self.datasetType.name}, for which " 

258 "DatasetType.isCalibration() is False." 

259 ) 

260 if collection.type is not CollectionType.CALIBRATION: 260 ↛ 261line 260 didn't jump to line 261, because the condition on line 260 was never true

261 raise CollectionTypeError( 

262 f"Cannot decertify from collection '{collection.name}' " 

263 f"of type {collection.type.name}; must be CALIBRATION." 

264 ) 

265 TimespanReprClass = self._db.getTimespanRepresentation() 

266 # Construct a SELECT query to find all rows that overlap our inputs. 

267 dataIdSet: set[DataCoordinate] | None 

268 if dataIds is not None: 

269 dataIdSet = set(dataIds) 

270 else: 

271 dataIdSet = None 

272 relation = self._buildCalibOverlapQuery(collection, dataIdSet, timespan, context) 

273 calib_pkey_tag = DatasetColumnTag(self.datasetType.name, "calib_pkey") 

274 dataset_id_tag = DatasetColumnTag(self.datasetType.name, "dataset_id") 

275 timespan_tag = DatasetColumnTag(self.datasetType.name, "timespan") 

276 data_id_tags = [ 

277 (name, DimensionKeyColumnTag(name)) for name in self.datasetType.dimensions.required.names 

278 ] 

279 # Set up collections to populate with the rows we'll want to modify. 

280 # The insert rows will have the same values for collection and 

281 # dataset type. 

282 protoInsertRow = { 

283 self._collections.getCollectionForeignKeyName(): collection.key, 

284 "dataset_type_id": self._dataset_type_id, 

285 } 

286 rowsToDelete = [] 

287 rowsToInsert = [] 

288 # Acquire a table lock to ensure there are no concurrent writes 

289 # between the SELECT and the DELETE and INSERT queries based on it. 

290 with self._db.transaction(lock=[self._calibs], savepoint=True): 

291 # Enter SqlQueryContext in case we need to use a temporary table to 

292 # include the give data IDs in the query (see similar block in 

293 # certify for details). 

294 with context: 

295 for row in context.fetch_iterable(relation): 

296 rowsToDelete.append({"id": row[calib_pkey_tag]}) 

297 # Construct the insert row(s) by copying the prototype row, 

298 # then adding the dimension column values, then adding 

299 # what's left of the timespan from that row after we 

300 # subtract the given timespan. 

301 newInsertRow = protoInsertRow.copy() 

302 newInsertRow["dataset_id"] = row[dataset_id_tag] 

303 for name, tag in data_id_tags: 

304 newInsertRow[name] = row[tag] 

305 rowTimespan = row[timespan_tag] 

306 assert rowTimespan is not None, "Field should have a NOT NULL constraint." 

307 for diffTimespan in rowTimespan.difference(timespan): 

308 rowsToInsert.append( 

309 TimespanReprClass.update(diffTimespan, result=newInsertRow.copy()) 

310 ) 

311 # Run the DELETE and INSERT queries. 

312 self._db.delete(self._calibs, ["id"], *rowsToDelete) 

313 self._db.insert(self._calibs, *rowsToInsert) 

314 

315 def make_relation( 

316 self, 

317 *collections: CollectionRecord, 

318 columns: Set[str], 

319 context: SqlQueryContext, 

320 ) -> Relation: 

321 # Docstring inherited from DatasetRecordStorage. 

322 collection_types = {collection.type for collection in collections} 

323 assert CollectionType.CHAINED not in collection_types, "CHAINED collections must be flattened." 

324 TimespanReprClass = self._db.getTimespanRepresentation() 

325 # 

326 # There are two kinds of table in play here: 

327 # 

328 # - the static dataset table (with the dataset ID, dataset type ID, 

329 # run ID/name, and ingest date); 

330 # 

331 # - the dynamic tags/calibs table (with the dataset ID, dataset type 

332 # type ID, collection ID/name, data ID, and possibly validity 

333 # range). 

334 # 

335 # That means that we might want to return a query against either table 

336 # or a JOIN of both, depending on which quantities the caller wants. 

337 # But the data ID is always included, which means we'll always include 

338 # the tags/calibs table and join in the static dataset table only if we 

339 # need things from it that we can't get from the tags/calibs table. 

340 # 

341 # Note that it's important that we include a WHERE constraint on both 

342 # tables for any column (e.g. dataset_type_id) that is in both when 

343 # it's given explicitly; not doing can prevent the query planner from 

344 # using very important indexes. At present, we don't include those 

345 # redundant columns in the JOIN ON expression, however, because the 

346 # FOREIGN KEY (and its index) are defined only on dataset_id. 

347 tag_relation: Relation | None = None 

348 calib_relation: Relation | None = None 

349 if collection_types != {CollectionType.CALIBRATION}: 

350 # We'll need a subquery for the tags table if any of the given 

351 # collections are not a CALIBRATION collection. This intentionally 

352 # also fires when the list of collections is empty as a way to 

353 # create a dummy subquery that we know will fail. 

354 # We give the table an alias because it might appear multiple times 

355 # in the same query, for different dataset types. 

356 tags_parts = sql.Payload[LogicalColumn](self._tags.alias(f"{self.datasetType.name}_tags")) 

357 if "timespan" in columns: 

358 tags_parts.columns_available[ 

359 DatasetColumnTag(self.datasetType.name, "timespan") 

360 ] = TimespanReprClass.fromLiteral(Timespan(None, None)) 

361 tag_relation = self._finish_single_relation( 

362 tags_parts, 

363 columns, 

364 [ 

365 (record, rank) 

366 for rank, record in enumerate(collections) 

367 if record.type is not CollectionType.CALIBRATION 

368 ], 

369 context, 

370 ) 

371 assert "calib_pkey" not in columns, "For internal use only, and only for pure-calib queries." 

372 if CollectionType.CALIBRATION in collection_types: 

373 # If at least one collection is a CALIBRATION collection, we'll 

374 # need a subquery for the calibs table, and could include the 

375 # timespan as a result or constraint. 

376 assert ( 

377 self._calibs is not None 

378 ), "DatasetTypes with isCalibration() == False can never be found in a CALIBRATION collection." 

379 calibs_parts = sql.Payload[LogicalColumn](self._calibs.alias(f"{self.datasetType.name}_calibs")) 

380 if "timespan" in columns: 

381 calibs_parts.columns_available[ 

382 DatasetColumnTag(self.datasetType.name, "timespan") 

383 ] = TimespanReprClass.from_columns(calibs_parts.from_clause.columns) 

384 if "calib_pkey" in columns: 

385 # This is a private extension not included in the base class 

386 # interface, for internal use only in _buildCalibOverlapQuery, 

387 # which needs access to the autoincrement primary key for the 

388 # calib association table. 

389 calibs_parts.columns_available[ 

390 DatasetColumnTag(self.datasetType.name, "calib_pkey") 

391 ] = calibs_parts.from_clause.columns.id 

392 calib_relation = self._finish_single_relation( 

393 calibs_parts, 

394 columns, 

395 [ 

396 (record, rank) 

397 for rank, record in enumerate(collections) 

398 if record.type is CollectionType.CALIBRATION 

399 ], 

400 context, 

401 ) 

402 if tag_relation is not None: 

403 if calib_relation is not None: 

404 # daf_relation's chain operation does not automatically 

405 # deduplicate; it's more like SQL's UNION ALL. To get UNION 

406 # in SQL here, we add an explicit deduplication. 

407 return tag_relation.chain(calib_relation).without_duplicates() 

408 else: 

409 return tag_relation 

410 elif calib_relation is not None: 

411 return calib_relation 

412 else: 

413 raise AssertionError("Branch should be unreachable.") 

414 

415 def _finish_single_relation( 

416 self, 

417 payload: sql.Payload[LogicalColumn], 

418 requested_columns: Set[str], 

419 collections: Sequence[tuple[CollectionRecord, int]], 

420 context: SqlQueryContext, 

421 ) -> Relation: 

422 """Handle adding columns and WHERE terms that are not specific to 

423 either the tags or calibs tables. 

424 

425 Helper method for `make_relation`. 

426 

427 Parameters 

428 ---------- 

429 payload : `lsst.daf.relation.sql.Payload` 

430 SQL query parts under construction, to be modified in-place and 

431 used to construct the new relation. 

432 requested_columns : `~collections.abc.Set` [ `str` ] 

433 Columns the relation should include. 

434 collections : `~collections.abc.Sequence` [ `tuple` \ 

435 [ `CollectionRecord`, `int` ] ] 

436 Collections to search for the dataset and their ranks. 

437 context : `SqlQueryContext` 

438 Context that manages engines and state for the query. 

439 

440 Returns 

441 ------- 

442 relation : `lsst.daf.relation.Relation` 

443 New dataset query relation. 

444 """ 

445 payload.where.append(payload.from_clause.columns.dataset_type_id == self._dataset_type_id) 

446 dataset_id_col = payload.from_clause.columns.dataset_id 

447 collection_col = payload.from_clause.columns[self._collections.getCollectionForeignKeyName()] 

448 # We always constrain and optionally retrieve the collection(s) via the 

449 # tags/calibs table. 

450 if len(collections) == 1: 

451 payload.where.append(collection_col == collections[0][0].key) 

452 if "collection" in requested_columns: 

453 payload.columns_available[ 

454 DatasetColumnTag(self.datasetType.name, "collection") 

455 ] = sqlalchemy.sql.literal(collections[0][0].key) 

456 else: 

457 assert collections, "The no-collections case should be in calling code for better diagnostics." 

458 payload.where.append(collection_col.in_([collection.key for collection, _ in collections])) 

459 if "collection" in requested_columns: 

460 payload.columns_available[ 

461 DatasetColumnTag(self.datasetType.name, "collection") 

462 ] = collection_col 

463 # Add rank if requested as a CASE-based calculation the collection 

464 # column. 

465 if "rank" in requested_columns: 

466 payload.columns_available[DatasetColumnTag(self.datasetType.name, "rank")] = sqlalchemy.sql.case( 

467 {record.key: rank for record, rank in collections}, 

468 value=collection_col, 

469 ) 

470 # Add more column definitions, starting with the data ID. 

471 for dimension_name in self.datasetType.dimensions.required.names: 

472 payload.columns_available[DimensionKeyColumnTag(dimension_name)] = payload.from_clause.columns[ 

473 dimension_name 

474 ] 

475 # We can always get the dataset_id from the tags/calibs table. 

476 if "dataset_id" in requested_columns: 

477 payload.columns_available[DatasetColumnTag(self.datasetType.name, "dataset_id")] = dataset_id_col 

478 # It's possible we now have everything we need, from just the 

479 # tags/calibs table. The things we might need to get from the static 

480 # dataset table are the run key and the ingest date. 

481 need_static_table = False 

482 if "run" in requested_columns: 

483 if len(collections) == 1 and collections[0][0].type is CollectionType.RUN: 

484 # If we are searching exactly one RUN collection, we 

485 # know that if we find the dataset in that collection, 

486 # then that's the datasets's run; we don't need to 

487 # query for it. 

488 payload.columns_available[ 

489 DatasetColumnTag(self.datasetType.name, "run") 

490 ] = sqlalchemy.sql.literal(collections[0][0].key) 

491 else: 

492 payload.columns_available[ 

493 DatasetColumnTag(self.datasetType.name, "run") 

494 ] = self._static.dataset.columns[self._runKeyColumn] 

495 need_static_table = True 

496 # Ingest date can only come from the static table. 

497 if "ingest_date" in requested_columns: 

498 need_static_table = True 

499 payload.columns_available[ 

500 DatasetColumnTag(self.datasetType.name, "ingest_date") 

501 ] = self._static.dataset.columns.ingest_date 

502 # If we need the static table, join it in via dataset_id and 

503 # dataset_type_id 

504 if need_static_table: 

505 payload.from_clause = payload.from_clause.join( 

506 self._static.dataset, onclause=(dataset_id_col == self._static.dataset.columns.id) 

507 ) 

508 # Also constrain dataset_type_id in static table in case that helps 

509 # generate a better plan. 

510 # We could also include this in the JOIN ON clause, but my guess is 

511 # that that's a good idea IFF it's in the foreign key, and right 

512 # now it isn't. 

513 payload.where.append(self._static.dataset.columns.dataset_type_id == self._dataset_type_id) 

514 leaf = context.sql_engine.make_leaf( 

515 payload.columns_available.keys(), 

516 payload=payload, 

517 name=self.datasetType.name, 

518 parameters={record.name: rank for record, rank in collections}, 

519 ) 

520 return leaf 

521 

522 def getDataId(self, id: DatasetId) -> DataCoordinate: 

523 """Return DataId for a dataset. 

524 

525 Parameters 

526 ---------- 

527 id : `DatasetId` 

528 Unique dataset identifier. 

529 

530 Returns 

531 ------- 

532 dataId : `DataCoordinate` 

533 DataId for the dataset. 

534 """ 

535 # This query could return multiple rows (one for each tagged collection 

536 # the dataset is in, plus one for its run collection), and we don't 

537 # care which of those we get. 

538 sql = ( 

539 self._tags.select() 

540 .where( 

541 sqlalchemy.sql.and_( 

542 self._tags.columns.dataset_id == id, 

543 self._tags.columns.dataset_type_id == self._dataset_type_id, 

544 ) 

545 ) 

546 .limit(1) 

547 ) 

548 with self._db.query(sql) as sql_result: 

549 row = sql_result.mappings().fetchone() 

550 assert row is not None, "Should be guaranteed by caller and foreign key constraints." 

551 return DataCoordinate.standardize( 

552 {dimension.name: row[dimension.name] for dimension in self.datasetType.dimensions.required}, 

553 graph=self.datasetType.dimensions, 

554 ) 

555 

556 

557class ByDimensionsDatasetRecordStorageUUID(ByDimensionsDatasetRecordStorage): 

558 """Implementation of ByDimensionsDatasetRecordStorage which uses UUID for 

559 dataset IDs. 

560 """ 

561 

562 idMaker = DatasetIdFactory() 

563 """Factory for dataset IDs. In the future this factory may be shared with 

564 other classes (e.g. Registry).""" 

565 

566 def insert( 

567 self, 

568 run: RunRecord, 

569 dataIds: Iterable[DataCoordinate], 

570 idMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE, 

571 ) -> Iterator[DatasetRef]: 

572 # Docstring inherited from DatasetRecordStorage. 

573 

574 # Current timestamp, type depends on schema version. Use microsecond 

575 # precision for astropy time to keep things consistent with 

576 # TIMESTAMP(6) SQL type. 

577 timestamp: datetime | astropy.time.Time 

578 if self._use_astropy: 

579 # Astropy `now()` precision should be the same as `utcnow()` which 

580 # should mean microsecond. 

581 timestamp = astropy.time.Time.now() 

582 else: 

583 timestamp = datetime.utcnow() 

584 

585 # Iterate over data IDs, transforming a possibly-single-pass iterable 

586 # into a list. 

587 dataIdList = [] 

588 rows = [] 

589 summary = CollectionSummary() 

590 for dataId in summary.add_data_ids_generator(self.datasetType, dataIds): 

591 dataIdList.append(dataId) 

592 rows.append( 

593 { 

594 "id": self.idMaker.makeDatasetId(run.name, self.datasetType, dataId, idMode), 

595 "dataset_type_id": self._dataset_type_id, 

596 self._runKeyColumn: run.key, 

597 "ingest_date": timestamp, 

598 } 

599 ) 

600 

601 with self._db.transaction(): 

602 # Insert into the static dataset table. 

603 self._db.insert(self._static.dataset, *rows) 

604 # Update the summary tables for this collection in case this is the 

605 # first time this dataset type or these governor values will be 

606 # inserted there. 

607 self._summaries.update(run, [self._dataset_type_id], summary) 

608 # Combine the generated dataset_id values and data ID fields to 

609 # form rows to be inserted into the tags table. 

610 protoTagsRow = { 

611 "dataset_type_id": self._dataset_type_id, 

612 self._collections.getCollectionForeignKeyName(): run.key, 

613 } 

614 tagsRows = [ 

615 dict(protoTagsRow, dataset_id=row["id"], **dataId.byName()) 

616 for dataId, row in zip(dataIdList, rows, strict=True) 

617 ] 

618 # Insert those rows into the tags table. 

619 self._db.insert(self._tags, *tagsRows) 

620 

621 for dataId, row in zip(dataIdList, rows, strict=True): 

622 yield DatasetRef( 

623 datasetType=self.datasetType, 

624 dataId=dataId, 

625 id=row["id"], 

626 run=run.name, 

627 ) 

628 

629 def import_( 

630 self, 

631 run: RunRecord, 

632 datasets: Iterable[DatasetRef], 

633 ) -> Iterator[DatasetRef]: 

634 # Docstring inherited from DatasetRecordStorage. 

635 

636 # Current timestamp, type depends on schema version. 

637 if self._use_astropy: 

638 # Astropy `now()` precision should be the same as `utcnow()` which 

639 # should mean microsecond. 

640 timestamp = sqlalchemy.sql.literal(astropy.time.Time.now(), type_=ddl.AstropyTimeNsecTai) 

641 else: 

642 timestamp = sqlalchemy.sql.literal(datetime.utcnow()) 

643 

644 # Iterate over data IDs, transforming a possibly-single-pass iterable 

645 # into a list. 

646 dataIds = {} 

647 summary = CollectionSummary() 

648 for dataset in summary.add_datasets_generator(datasets): 

649 dataIds[dataset.id] = dataset.dataId 

650 

651 # We'll insert all new rows into a temporary table 

652 tableSpec = makeTagTableSpec(self.datasetType, type(self._collections), ddl.GUID, constraints=False) 

653 collFkName = self._collections.getCollectionForeignKeyName() 

654 protoTagsRow = { 

655 "dataset_type_id": self._dataset_type_id, 

656 collFkName: run.key, 

657 } 

658 tmpRows = [ 

659 dict(protoTagsRow, dataset_id=dataset_id, **dataId.byName()) 

660 for dataset_id, dataId in dataIds.items() 

661 ] 

662 with self._db.transaction(for_temp_tables=True), self._db.temporary_table(tableSpec) as tmp_tags: 

663 # store all incoming data in a temporary table 

664 self._db.insert(tmp_tags, *tmpRows) 

665 

666 # There are some checks that we want to make for consistency 

667 # of the new datasets with existing ones. 

668 self._validateImport(tmp_tags, run) 

669 

670 # Before we merge temporary table into dataset/tags we need to 

671 # drop datasets which are already there (and do not conflict). 

672 self._db.deleteWhere( 

673 tmp_tags, 

674 tmp_tags.columns.dataset_id.in_(sqlalchemy.sql.select(self._static.dataset.columns.id)), 

675 ) 

676 

677 # Copy it into dataset table, need to re-label some columns. 

678 self._db.insert( 

679 self._static.dataset, 

680 select=sqlalchemy.sql.select( 

681 tmp_tags.columns.dataset_id.label("id"), 

682 tmp_tags.columns.dataset_type_id, 

683 tmp_tags.columns[collFkName].label(self._runKeyColumn), 

684 timestamp.label("ingest_date"), 

685 ), 

686 ) 

687 

688 # Update the summary tables for this collection in case this 

689 # is the first time this dataset type or these governor values 

690 # will be inserted there. 

691 self._summaries.update(run, [self._dataset_type_id], summary) 

692 

693 # Copy it into tags table. 

694 self._db.insert(self._tags, select=tmp_tags.select()) 

695 

696 # Return refs in the same order as in the input list. 

697 for dataset_id, dataId in dataIds.items(): 

698 yield DatasetRef( 

699 datasetType=self.datasetType, 

700 id=dataset_id, 

701 dataId=dataId, 

702 run=run.name, 

703 ) 

704 

705 def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) -> None: 

706 """Validate imported refs against existing datasets. 

707 

708 Parameters 

709 ---------- 

710 tmp_tags : `sqlalchemy.schema.Table` 

711 Temporary table with new datasets and the same schema as tags 

712 table. 

713 run : `RunRecord` 

714 The record object describing the `~CollectionType.RUN` collection. 

715 

716 Raises 

717 ------ 

718 ConflictingDefinitionError 

719 Raise if new datasets conflict with existing ones. 

720 """ 

721 dataset = self._static.dataset 

722 tags = self._tags 

723 collFkName = self._collections.getCollectionForeignKeyName() 

724 

725 # Check that existing datasets have the same dataset type and 

726 # run. 

727 query = ( 

728 sqlalchemy.sql.select( 

729 dataset.columns.id.label("dataset_id"), 

730 dataset.columns.dataset_type_id.label("dataset_type_id"), 

731 tmp_tags.columns.dataset_type_id.label("new dataset_type_id"), 

732 dataset.columns[self._runKeyColumn].label("run"), 

733 tmp_tags.columns[collFkName].label("new run"), 

734 ) 

735 .select_from(dataset.join(tmp_tags, dataset.columns.id == tmp_tags.columns.dataset_id)) 

736 .where( 

737 sqlalchemy.sql.or_( 

738 dataset.columns.dataset_type_id != tmp_tags.columns.dataset_type_id, 

739 dataset.columns[self._runKeyColumn] != tmp_tags.columns[collFkName], 

740 ) 

741 ) 

742 .limit(1) 

743 ) 

744 with self._db.query(query) as result: 

745 if (row := result.first()) is not None: 

746 # Only include the first one in the exception message 

747 raise ConflictingDefinitionError( 

748 f"Existing dataset type or run do not match new dataset: {row._asdict()}" 

749 ) 

750 

751 # Check that matching dataset in tags table has the same DataId. 

752 query = ( 

753 sqlalchemy.sql.select( 

754 tags.columns.dataset_id, 

755 tags.columns.dataset_type_id.label("type_id"), 

756 tmp_tags.columns.dataset_type_id.label("new type_id"), 

757 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names], 

758 *[ 

759 tmp_tags.columns[dim].label(f"new {dim}") 

760 for dim in self.datasetType.dimensions.required.names 

761 ], 

762 ) 

763 .select_from(tags.join(tmp_tags, tags.columns.dataset_id == tmp_tags.columns.dataset_id)) 

764 .where( 

765 sqlalchemy.sql.or_( 

766 tags.columns.dataset_type_id != tmp_tags.columns.dataset_type_id, 

767 *[ 

768 tags.columns[dim] != tmp_tags.columns[dim] 

769 for dim in self.datasetType.dimensions.required.names 

770 ], 

771 ) 

772 ) 

773 .limit(1) 

774 ) 

775 

776 with self._db.query(query) as result: 

777 if (row := result.first()) is not None: 

778 # Only include the first one in the exception message 

779 raise ConflictingDefinitionError( 

780 f"Existing dataset type or dataId do not match new dataset: {row._asdict()}" 

781 ) 

782 

783 # Check that matching run+dataId have the same dataset ID. 

784 query = ( 

785 sqlalchemy.sql.select( 

786 tags.columns.dataset_type_id.label("dataset_type_id"), 

787 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names], 

788 tags.columns.dataset_id, 

789 tmp_tags.columns.dataset_id.label("new dataset_id"), 

790 tags.columns[collFkName], 

791 tmp_tags.columns[collFkName].label(f"new {collFkName}"), 

792 ) 

793 .select_from( 

794 tags.join( 

795 tmp_tags, 

796 sqlalchemy.sql.and_( 

797 tags.columns.dataset_type_id == tmp_tags.columns.dataset_type_id, 

798 tags.columns[collFkName] == tmp_tags.columns[collFkName], 

799 *[ 

800 tags.columns[dim] == tmp_tags.columns[dim] 

801 for dim in self.datasetType.dimensions.required.names 

802 ], 

803 ), 

804 ) 

805 ) 

806 .where(tags.columns.dataset_id != tmp_tags.columns.dataset_id) 

807 .limit(1) 

808 ) 

809 with self._db.query(query) as result: 

810 if (row := result.first()) is not None: 

811 # only include the first one in the exception message 

812 raise ConflictingDefinitionError( 

813 f"Existing dataset type and dataId does not match new dataset: {row._asdict()}" 

814 )