Coverage for python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py: 95%

260 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-07 11:02 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28 

29from __future__ import annotations 

30 

31from .... import ddl 

32 

33__all__ = ("ByDimensionsDatasetRecordStorage",) 

34 

35import datetime 

36from collections.abc import Callable, Iterable, Iterator, Sequence, Set 

37from typing import TYPE_CHECKING 

38 

39import astropy.time 

40import sqlalchemy 

41from lsst.daf.relation import Relation, sql 

42 

43from ...._column_tags import DatasetColumnTag, DimensionKeyColumnTag 

44from ...._column_type_info import LogicalColumn 

45from ...._dataset_ref import DatasetId, DatasetIdFactory, DatasetIdGenEnum, DatasetRef 

46from ...._dataset_type import DatasetType 

47from ...._timespan import Timespan 

48from ....dimensions import DataCoordinate 

49from ..._collection_summary import CollectionSummary 

50from ..._collection_type import CollectionType 

51from ..._exceptions import CollectionTypeError, ConflictingDefinitionError 

52from ...interfaces import DatasetRecordStorage 

53from ...queries import SqlQueryContext 

54from .tables import makeTagTableSpec 

55 

56if TYPE_CHECKING: 

57 from ...interfaces import CollectionManager, CollectionRecord, Database, RunRecord 

58 from .summaries import CollectionSummaryManager 

59 from .tables import StaticDatasetTablesTuple 

60 

61 

62class ByDimensionsDatasetRecordStorage(DatasetRecordStorage): 

63 """Dataset record storage implementation paired with 

64 `ByDimensionsDatasetRecordStorageManagerUUID`; see that class for more 

65 information. 

66 

67 Instances of this class should never be constructed directly; use 

68 `DatasetRecordStorageManager.register` instead. 

69 

70 Parameters 

71 ---------- 

72 datasetType : `DatasetType` 

73 The dataset type to use. 

74 db : `Database` 

75 Database connection. 

76 dataset_type_id : `int` 

77 Dataset type identifier. 

78 collections : `CollectionManager` 

79 The collection manager. 

80 static : `StaticDatasetTablesTuple` 

81 Unknown. 

82 summaries : `CollectionSummaryManager` 

83 Collection summary manager. 

84 tags_table_factory : `~collections.abc.Callable` 

85 Factory for creating tags tables. 

86 use_astropy_ingest_date : `bool` 

87 Whether to use Astropy for ingest date. 

88 calibs_table_factory : `~collections.abc.Callable` 

89 Factory for creating calibration tables. 

90 """ 

91 

92 def __init__( 

93 self, 

94 *, 

95 datasetType: DatasetType, 

96 db: Database, 

97 dataset_type_id: int, 

98 collections: CollectionManager, 

99 static: StaticDatasetTablesTuple, 

100 summaries: CollectionSummaryManager, 

101 tags_table_factory: Callable[[], sqlalchemy.schema.Table], 

102 use_astropy_ingest_date: bool, 

103 calibs_table_factory: Callable[[], sqlalchemy.schema.Table] | None, 

104 ): 

105 super().__init__(datasetType=datasetType) 

106 self._dataset_type_id = dataset_type_id 

107 self._db = db 

108 self._collections = collections 

109 self._static = static 

110 self._summaries = summaries 

111 self._tags_table_factory = tags_table_factory 

112 self._calibs_table_factory = calibs_table_factory 

113 self._runKeyColumn = collections.getRunForeignKeyName() 

114 self._use_astropy = use_astropy_ingest_date 

115 self._tags_table: sqlalchemy.schema.Table | None = None 

116 self._calibs_table: sqlalchemy.schema.Table | None = None 

117 

118 @property 

119 def _tags(self) -> sqlalchemy.schema.Table: 

120 if self._tags_table is None: 

121 self._tags_table = self._tags_table_factory() 

122 return self._tags_table 

123 

124 @property 

125 def _calibs(self) -> sqlalchemy.schema.Table | None: 

126 if self._calibs_table is None: 

127 if self._calibs_table_factory is None: 127 ↛ 128line 127 didn't jump to line 128, because the condition on line 127 was never true

128 return None 

129 self._calibs_table = self._calibs_table_factory() 

130 return self._calibs_table 

131 

132 def delete(self, datasets: Iterable[DatasetRef]) -> None: 

133 # Docstring inherited from DatasetRecordStorage. 

134 # Only delete from common dataset table; ON DELETE foreign key clauses 

135 # will handle the rest. 

136 self._db.delete( 

137 self._static.dataset, 

138 ["id"], 

139 *[{"id": dataset.id} for dataset in datasets], 

140 ) 

141 

142 def associate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None: 

143 # Docstring inherited from DatasetRecordStorage. 

144 if collection.type is not CollectionType.TAGGED: 144 ↛ 145line 144 didn't jump to line 145, because the condition on line 144 was never true

145 raise TypeError( 

146 f"Cannot associate into collection '{collection.name}' " 

147 f"of type {collection.type.name}; must be TAGGED." 

148 ) 

149 protoRow = { 

150 self._collections.getCollectionForeignKeyName(): collection.key, 

151 "dataset_type_id": self._dataset_type_id, 

152 } 

153 rows = [] 

154 summary = CollectionSummary() 

155 for dataset in summary.add_datasets_generator(datasets): 

156 rows.append(dict(protoRow, dataset_id=dataset.id, **dataset.dataId.required)) 

157 # Update the summary tables for this collection in case this is the 

158 # first time this dataset type or these governor values will be 

159 # inserted there. 

160 self._summaries.update(collection, [self._dataset_type_id], summary) 

161 # Update the tag table itself. 

162 self._db.replace(self._tags, *rows) 

163 

164 def disassociate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None: 

165 # Docstring inherited from DatasetRecordStorage. 

166 if collection.type is not CollectionType.TAGGED: 166 ↛ 167line 166 didn't jump to line 167, because the condition on line 166 was never true

167 raise TypeError( 

168 f"Cannot disassociate from collection '{collection.name}' " 

169 f"of type {collection.type.name}; must be TAGGED." 

170 ) 

171 rows = [ 

172 { 

173 "dataset_id": dataset.id, 

174 self._collections.getCollectionForeignKeyName(): collection.key, 

175 } 

176 for dataset in datasets 

177 ] 

178 self._db.delete(self._tags, ["dataset_id", self._collections.getCollectionForeignKeyName()], *rows) 

179 

180 def _buildCalibOverlapQuery( 

181 self, 

182 collection: CollectionRecord, 

183 data_ids: set[DataCoordinate] | None, 

184 timespan: Timespan, 

185 context: SqlQueryContext, 

186 ) -> Relation: 

187 relation = self.make_relation( 

188 collection, columns={"timespan", "dataset_id", "calib_pkey"}, context=context 

189 ).with_rows_satisfying( 

190 context.make_timespan_overlap_predicate( 

191 DatasetColumnTag(self.datasetType.name, "timespan"), timespan 

192 ), 

193 ) 

194 if data_ids is not None: 

195 relation = relation.join( 

196 context.make_data_id_relation( 

197 data_ids, self.datasetType.dimensions.required.names 

198 ).transferred_to(context.sql_engine), 

199 ) 

200 return relation 

201 

202 def certify( 

203 self, 

204 collection: CollectionRecord, 

205 datasets: Iterable[DatasetRef], 

206 timespan: Timespan, 

207 context: SqlQueryContext, 

208 ) -> None: 

209 # Docstring inherited from DatasetRecordStorage. 

210 if self._calibs is None: 210 ↛ 211line 210 didn't jump to line 211, because the condition on line 210 was never true

211 raise CollectionTypeError( 

212 f"Cannot certify datasets of type {self.datasetType.name}, for which " 

213 "DatasetType.isCalibration() is False." 

214 ) 

215 if collection.type is not CollectionType.CALIBRATION: 215 ↛ 216line 215 didn't jump to line 216, because the condition on line 215 was never true

216 raise CollectionTypeError( 

217 f"Cannot certify into collection '{collection.name}' " 

218 f"of type {collection.type.name}; must be CALIBRATION." 

219 ) 

220 TimespanReprClass = self._db.getTimespanRepresentation() 

221 protoRow = { 

222 self._collections.getCollectionForeignKeyName(): collection.key, 

223 "dataset_type_id": self._dataset_type_id, 

224 } 

225 rows = [] 

226 dataIds: set[DataCoordinate] | None = ( 

227 set() if not TimespanReprClass.hasExclusionConstraint() else None 

228 ) 

229 summary = CollectionSummary() 

230 for dataset in summary.add_datasets_generator(datasets): 

231 row = dict(protoRow, dataset_id=dataset.id, **dataset.dataId.required) 

232 TimespanReprClass.update(timespan, result=row) 

233 rows.append(row) 

234 if dataIds is not None: 234 ↛ 230line 234 didn't jump to line 230, because the condition on line 234 was never false

235 dataIds.add(dataset.dataId) 

236 # Update the summary tables for this collection in case this is the 

237 # first time this dataset type or these governor values will be 

238 # inserted there. 

239 self._summaries.update(collection, [self._dataset_type_id], summary) 

240 # Update the association table itself. 

241 if TimespanReprClass.hasExclusionConstraint(): 241 ↛ 244line 241 didn't jump to line 244, because the condition on line 241 was never true

242 # Rely on database constraint to enforce invariants; we just 

243 # reraise the exception for consistency across DB engines. 

244 try: 

245 self._db.insert(self._calibs, *rows) 

246 except sqlalchemy.exc.IntegrityError as err: 

247 raise ConflictingDefinitionError( 

248 f"Validity range conflict certifying datasets of type {self.datasetType.name} " 

249 f"into {collection.name} for range [{timespan.begin}, {timespan.end})." 

250 ) from err 

251 else: 

252 # Have to implement exclusion constraint ourselves. 

253 # Start by building a SELECT query for any rows that would overlap 

254 # this one. 

255 relation = self._buildCalibOverlapQuery(collection, dataIds, timespan, context) 

256 # Acquire a table lock to ensure there are no concurrent writes 

257 # could invalidate our checking before we finish the inserts. We 

258 # use a SAVEPOINT in case there is an outer transaction that a 

259 # failure here should not roll back. 

260 with self._db.transaction(lock=[self._calibs], savepoint=True): 

261 # Enter SqlQueryContext in case we need to use a temporary 

262 # table to include the give data IDs in the query. Note that 

263 # by doing this inside the transaction, we make sure it doesn't 

264 # attempt to close the session when its done, since it just 

265 # sees an already-open session that it knows it shouldn't 

266 # manage. 

267 with context: 

268 # Run the check SELECT query. 

269 conflicting = context.count(context.process(relation)) 

270 if conflicting > 0: 

271 raise ConflictingDefinitionError( 

272 f"{conflicting} validity range conflicts certifying datasets of type " 

273 f"{self.datasetType.name} into {collection.name} for range " 

274 f"[{timespan.begin}, {timespan.end})." 

275 ) 

276 # Proceed with the insert. 

277 self._db.insert(self._calibs, *rows) 

278 

279 def decertify( 

280 self, 

281 collection: CollectionRecord, 

282 timespan: Timespan, 

283 *, 

284 dataIds: Iterable[DataCoordinate] | None = None, 

285 context: SqlQueryContext, 

286 ) -> None: 

287 # Docstring inherited from DatasetRecordStorage. 

288 if self._calibs is None: 288 ↛ 289line 288 didn't jump to line 289, because the condition on line 288 was never true

289 raise CollectionTypeError( 

290 f"Cannot decertify datasets of type {self.datasetType.name}, for which " 

291 "DatasetType.isCalibration() is False." 

292 ) 

293 if collection.type is not CollectionType.CALIBRATION: 293 ↛ 294line 293 didn't jump to line 294, because the condition on line 293 was never true

294 raise CollectionTypeError( 

295 f"Cannot decertify from collection '{collection.name}' " 

296 f"of type {collection.type.name}; must be CALIBRATION." 

297 ) 

298 TimespanReprClass = self._db.getTimespanRepresentation() 

299 # Construct a SELECT query to find all rows that overlap our inputs. 

300 dataIdSet: set[DataCoordinate] | None 

301 if dataIds is not None: 

302 dataIdSet = set(dataIds) 

303 else: 

304 dataIdSet = None 

305 relation = self._buildCalibOverlapQuery(collection, dataIdSet, timespan, context) 

306 calib_pkey_tag = DatasetColumnTag(self.datasetType.name, "calib_pkey") 

307 dataset_id_tag = DatasetColumnTag(self.datasetType.name, "dataset_id") 

308 timespan_tag = DatasetColumnTag(self.datasetType.name, "timespan") 

309 data_id_tags = [ 

310 (name, DimensionKeyColumnTag(name)) for name in self.datasetType.dimensions.required.names 

311 ] 

312 # Set up collections to populate with the rows we'll want to modify. 

313 # The insert rows will have the same values for collection and 

314 # dataset type. 

315 protoInsertRow = { 

316 self._collections.getCollectionForeignKeyName(): collection.key, 

317 "dataset_type_id": self._dataset_type_id, 

318 } 

319 rowsToDelete = [] 

320 rowsToInsert = [] 

321 # Acquire a table lock to ensure there are no concurrent writes 

322 # between the SELECT and the DELETE and INSERT queries based on it. 

323 with self._db.transaction(lock=[self._calibs], savepoint=True): 

324 # Enter SqlQueryContext in case we need to use a temporary table to 

325 # include the give data IDs in the query (see similar block in 

326 # certify for details). 

327 with context: 

328 for row in context.fetch_iterable(relation): 

329 rowsToDelete.append({"id": row[calib_pkey_tag]}) 

330 # Construct the insert row(s) by copying the prototype row, 

331 # then adding the dimension column values, then adding 

332 # what's left of the timespan from that row after we 

333 # subtract the given timespan. 

334 newInsertRow = protoInsertRow.copy() 

335 newInsertRow["dataset_id"] = row[dataset_id_tag] 

336 for name, tag in data_id_tags: 

337 newInsertRow[name] = row[tag] 

338 rowTimespan = row[timespan_tag] 

339 assert rowTimespan is not None, "Field should have a NOT NULL constraint." 

340 for diffTimespan in rowTimespan.difference(timespan): 

341 rowsToInsert.append( 

342 TimespanReprClass.update(diffTimespan, result=newInsertRow.copy()) 

343 ) 

344 # Run the DELETE and INSERT queries. 

345 self._db.delete(self._calibs, ["id"], *rowsToDelete) 

346 self._db.insert(self._calibs, *rowsToInsert) 

347 

348 def make_relation( 

349 self, 

350 *collections: CollectionRecord, 

351 columns: Set[str], 

352 context: SqlQueryContext, 

353 ) -> Relation: 

354 # Docstring inherited from DatasetRecordStorage. 

355 collection_types = {collection.type for collection in collections} 

356 assert CollectionType.CHAINED not in collection_types, "CHAINED collections must be flattened." 

357 TimespanReprClass = self._db.getTimespanRepresentation() 

358 # 

359 # There are two kinds of table in play here: 

360 # 

361 # - the static dataset table (with the dataset ID, dataset type ID, 

362 # run ID/name, and ingest date); 

363 # 

364 # - the dynamic tags/calibs table (with the dataset ID, dataset type 

365 # type ID, collection ID/name, data ID, and possibly validity 

366 # range). 

367 # 

368 # That means that we might want to return a query against either table 

369 # or a JOIN of both, depending on which quantities the caller wants. 

370 # But the data ID is always included, which means we'll always include 

371 # the tags/calibs table and join in the static dataset table only if we 

372 # need things from it that we can't get from the tags/calibs table. 

373 # 

374 # Note that it's important that we include a WHERE constraint on both 

375 # tables for any column (e.g. dataset_type_id) that is in both when 

376 # it's given explicitly; not doing can prevent the query planner from 

377 # using very important indexes. At present, we don't include those 

378 # redundant columns in the JOIN ON expression, however, because the 

379 # FOREIGN KEY (and its index) are defined only on dataset_id. 

380 tag_relation: Relation | None = None 

381 calib_relation: Relation | None = None 

382 if collection_types != {CollectionType.CALIBRATION}: 

383 # We'll need a subquery for the tags table if any of the given 

384 # collections are not a CALIBRATION collection. This intentionally 

385 # also fires when the list of collections is empty as a way to 

386 # create a dummy subquery that we know will fail. 

387 # We give the table an alias because it might appear multiple times 

388 # in the same query, for different dataset types. 

389 tags_parts = sql.Payload[LogicalColumn](self._tags.alias(f"{self.datasetType.name}_tags")) 

390 if "timespan" in columns: 

391 tags_parts.columns_available[DatasetColumnTag(self.datasetType.name, "timespan")] = ( 

392 TimespanReprClass.fromLiteral(Timespan(None, None)) 

393 ) 

394 tag_relation = self._finish_single_relation( 

395 tags_parts, 

396 columns, 

397 [ 

398 (record, rank) 

399 for rank, record in enumerate(collections) 

400 if record.type is not CollectionType.CALIBRATION 

401 ], 

402 context, 

403 ) 

404 assert "calib_pkey" not in columns, "For internal use only, and only for pure-calib queries." 

405 if CollectionType.CALIBRATION in collection_types: 

406 # If at least one collection is a CALIBRATION collection, we'll 

407 # need a subquery for the calibs table, and could include the 

408 # timespan as a result or constraint. 

409 assert ( 

410 self._calibs is not None 

411 ), "DatasetTypes with isCalibration() == False can never be found in a CALIBRATION collection." 

412 calibs_parts = sql.Payload[LogicalColumn](self._calibs.alias(f"{self.datasetType.name}_calibs")) 

413 if "timespan" in columns: 

414 calibs_parts.columns_available[DatasetColumnTag(self.datasetType.name, "timespan")] = ( 

415 TimespanReprClass.from_columns(calibs_parts.from_clause.columns) 

416 ) 

417 if "calib_pkey" in columns: 

418 # This is a private extension not included in the base class 

419 # interface, for internal use only in _buildCalibOverlapQuery, 

420 # which needs access to the autoincrement primary key for the 

421 # calib association table. 

422 calibs_parts.columns_available[DatasetColumnTag(self.datasetType.name, "calib_pkey")] = ( 

423 calibs_parts.from_clause.columns.id 

424 ) 

425 calib_relation = self._finish_single_relation( 

426 calibs_parts, 

427 columns, 

428 [ 

429 (record, rank) 

430 for rank, record in enumerate(collections) 

431 if record.type is CollectionType.CALIBRATION 

432 ], 

433 context, 

434 ) 

435 if tag_relation is not None: 

436 if calib_relation is not None: 

437 # daf_relation's chain operation does not automatically 

438 # deduplicate; it's more like SQL's UNION ALL. To get UNION 

439 # in SQL here, we add an explicit deduplication. 

440 return tag_relation.chain(calib_relation).without_duplicates() 

441 else: 

442 return tag_relation 

443 elif calib_relation is not None: 

444 return calib_relation 

445 else: 

446 raise AssertionError("Branch should be unreachable.") 

447 

448 def _finish_single_relation( 

449 self, 

450 payload: sql.Payload[LogicalColumn], 

451 requested_columns: Set[str], 

452 collections: Sequence[tuple[CollectionRecord, int]], 

453 context: SqlQueryContext, 

454 ) -> Relation: 

455 """Handle adding columns and WHERE terms that are not specific to 

456 either the tags or calibs tables. 

457 

458 Helper method for `make_relation`. 

459 

460 Parameters 

461 ---------- 

462 payload : `lsst.daf.relation.sql.Payload` 

463 SQL query parts under construction, to be modified in-place and 

464 used to construct the new relation. 

465 requested_columns : `~collections.abc.Set` [ `str` ] 

466 Columns the relation should include. 

467 collections : `~collections.abc.Sequence` [ `tuple` \ 

468 [ `CollectionRecord`, `int` ] ] 

469 Collections to search for the dataset and their ranks. 

470 context : `SqlQueryContext` 

471 Context that manages engines and state for the query. 

472 

473 Returns 

474 ------- 

475 relation : `lsst.daf.relation.Relation` 

476 New dataset query relation. 

477 """ 

478 payload.where.append(payload.from_clause.columns.dataset_type_id == self._dataset_type_id) 

479 dataset_id_col = payload.from_clause.columns.dataset_id 

480 collection_col = payload.from_clause.columns[self._collections.getCollectionForeignKeyName()] 

481 # We always constrain and optionally retrieve the collection(s) via the 

482 # tags/calibs table. 

483 if len(collections) == 1: 

484 payload.where.append(collection_col == collections[0][0].key) 

485 if "collection" in requested_columns: 

486 payload.columns_available[DatasetColumnTag(self.datasetType.name, "collection")] = ( 

487 sqlalchemy.sql.literal(collections[0][0].key) 

488 ) 

489 else: 

490 assert collections, "The no-collections case should be in calling code for better diagnostics." 

491 payload.where.append(collection_col.in_([collection.key for collection, _ in collections])) 

492 if "collection" in requested_columns: 

493 payload.columns_available[DatasetColumnTag(self.datasetType.name, "collection")] = ( 

494 collection_col 

495 ) 

496 # Add rank if requested as a CASE-based calculation the collection 

497 # column. 

498 if "rank" in requested_columns: 

499 payload.columns_available[DatasetColumnTag(self.datasetType.name, "rank")] = sqlalchemy.sql.case( 

500 {record.key: rank for record, rank in collections}, 

501 value=collection_col, 

502 ) 

503 # Add more column definitions, starting with the data ID. 

504 for dimension_name in self.datasetType.dimensions.required.names: 

505 payload.columns_available[DimensionKeyColumnTag(dimension_name)] = payload.from_clause.columns[ 

506 dimension_name 

507 ] 

508 # We can always get the dataset_id from the tags/calibs table. 

509 if "dataset_id" in requested_columns: 

510 payload.columns_available[DatasetColumnTag(self.datasetType.name, "dataset_id")] = dataset_id_col 

511 # It's possible we now have everything we need, from just the 

512 # tags/calibs table. The things we might need to get from the static 

513 # dataset table are the run key and the ingest date. 

514 need_static_table = False 

515 if "run" in requested_columns: 

516 if len(collections) == 1 and collections[0][0].type is CollectionType.RUN: 

517 # If we are searching exactly one RUN collection, we 

518 # know that if we find the dataset in that collection, 

519 # then that's the datasets's run; we don't need to 

520 # query for it. 

521 payload.columns_available[DatasetColumnTag(self.datasetType.name, "run")] = ( 

522 sqlalchemy.sql.literal(collections[0][0].key) 

523 ) 

524 else: 

525 payload.columns_available[DatasetColumnTag(self.datasetType.name, "run")] = ( 

526 self._static.dataset.columns[self._runKeyColumn] 

527 ) 

528 need_static_table = True 

529 # Ingest date can only come from the static table. 

530 if "ingest_date" in requested_columns: 

531 need_static_table = True 

532 payload.columns_available[DatasetColumnTag(self.datasetType.name, "ingest_date")] = ( 

533 self._static.dataset.columns.ingest_date 

534 ) 

535 # If we need the static table, join it in via dataset_id and 

536 # dataset_type_id 

537 if need_static_table: 

538 payload.from_clause = payload.from_clause.join( 

539 self._static.dataset, onclause=(dataset_id_col == self._static.dataset.columns.id) 

540 ) 

541 # Also constrain dataset_type_id in static table in case that helps 

542 # generate a better plan. 

543 # We could also include this in the JOIN ON clause, but my guess is 

544 # that that's a good idea IFF it's in the foreign key, and right 

545 # now it isn't. 

546 payload.where.append(self._static.dataset.columns.dataset_type_id == self._dataset_type_id) 

547 leaf = context.sql_engine.make_leaf( 

548 payload.columns_available.keys(), 

549 payload=payload, 

550 name=self.datasetType.name, 

551 parameters={record.name: rank for record, rank in collections}, 

552 ) 

553 return leaf 

554 

555 def getDataId(self, id: DatasetId) -> DataCoordinate: 

556 """Return DataId for a dataset. 

557 

558 Parameters 

559 ---------- 

560 id : `DatasetId` 

561 Unique dataset identifier. 

562 

563 Returns 

564 ------- 

565 dataId : `DataCoordinate` 

566 DataId for the dataset. 

567 """ 

568 # This query could return multiple rows (one for each tagged collection 

569 # the dataset is in, plus one for its run collection), and we don't 

570 # care which of those we get. 

571 sql = ( 

572 self._tags.select() 

573 .where( 

574 sqlalchemy.sql.and_( 

575 self._tags.columns.dataset_id == id, 

576 self._tags.columns.dataset_type_id == self._dataset_type_id, 

577 ) 

578 ) 

579 .limit(1) 

580 ) 

581 with self._db.query(sql) as sql_result: 

582 row = sql_result.mappings().fetchone() 

583 assert row is not None, "Should be guaranteed by caller and foreign key constraints." 

584 return DataCoordinate.from_required_values( 

585 self.datasetType.dimensions.as_group(), 

586 tuple(row[dimension] for dimension in self.datasetType.dimensions.required.names), 

587 ) 

588 

589 

590class ByDimensionsDatasetRecordStorageUUID(ByDimensionsDatasetRecordStorage): 

591 """Implementation of ByDimensionsDatasetRecordStorage which uses UUID for 

592 dataset IDs. 

593 """ 

594 

595 idMaker = DatasetIdFactory() 

596 """Factory for dataset IDs. In the future this factory may be shared with 

597 other classes (e.g. Registry).""" 

598 

599 def insert( 

600 self, 

601 run: RunRecord, 

602 dataIds: Iterable[DataCoordinate], 

603 idMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE, 

604 ) -> Iterator[DatasetRef]: 

605 # Docstring inherited from DatasetRecordStorage. 

606 

607 # Current timestamp, type depends on schema version. Use microsecond 

608 # precision for astropy time to keep things consistent with 

609 # TIMESTAMP(6) SQL type. 

610 timestamp: datetime.datetime | astropy.time.Time 

611 if self._use_astropy: 

612 # Astropy `now()` precision should be the same as `now()` which 

613 # should mean microsecond. 

614 timestamp = astropy.time.Time.now() 

615 else: 

616 timestamp = datetime.datetime.now(datetime.UTC) 

617 

618 # Iterate over data IDs, transforming a possibly-single-pass iterable 

619 # into a list. 

620 dataIdList: list[DataCoordinate] = [] 

621 rows = [] 

622 summary = CollectionSummary() 

623 for dataId in summary.add_data_ids_generator(self.datasetType, dataIds): 

624 dataIdList.append(dataId) 

625 rows.append( 

626 { 

627 "id": self.idMaker.makeDatasetId(run.name, self.datasetType, dataId, idMode), 

628 "dataset_type_id": self._dataset_type_id, 

629 self._runKeyColumn: run.key, 

630 "ingest_date": timestamp, 

631 } 

632 ) 

633 

634 with self._db.transaction(): 

635 # Insert into the static dataset table. 

636 self._db.insert(self._static.dataset, *rows) 

637 # Update the summary tables for this collection in case this is the 

638 # first time this dataset type or these governor values will be 

639 # inserted there. 

640 self._summaries.update(run, [self._dataset_type_id], summary) 

641 # Combine the generated dataset_id values and data ID fields to 

642 # form rows to be inserted into the tags table. 

643 protoTagsRow = { 

644 "dataset_type_id": self._dataset_type_id, 

645 self._collections.getCollectionForeignKeyName(): run.key, 

646 } 

647 tagsRows = [ 

648 dict(protoTagsRow, dataset_id=row["id"], **dataId.required) 

649 for dataId, row in zip(dataIdList, rows, strict=True) 

650 ] 

651 # Insert those rows into the tags table. 

652 self._db.insert(self._tags, *tagsRows) 

653 

654 for dataId, row in zip(dataIdList, rows, strict=True): 

655 yield DatasetRef( 

656 datasetType=self.datasetType, 

657 dataId=dataId, 

658 id=row["id"], 

659 run=run.name, 

660 ) 

661 

662 def import_( 

663 self, 

664 run: RunRecord, 

665 datasets: Iterable[DatasetRef], 

666 ) -> Iterator[DatasetRef]: 

667 # Docstring inherited from DatasetRecordStorage. 

668 

669 # Current timestamp, type depends on schema version. 

670 if self._use_astropy: 

671 # Astropy `now()` precision should be the same as `now()` which 

672 # should mean microsecond. 

673 timestamp = sqlalchemy.sql.literal(astropy.time.Time.now(), type_=ddl.AstropyTimeNsecTai) 

674 else: 

675 timestamp = sqlalchemy.sql.literal(datetime.datetime.now(datetime.UTC)) 

676 

677 # Iterate over data IDs, transforming a possibly-single-pass iterable 

678 # into a list. 

679 dataIds: dict[DatasetId, DataCoordinate] = {} 

680 summary = CollectionSummary() 

681 for dataset in summary.add_datasets_generator(datasets): 

682 dataIds[dataset.id] = dataset.dataId 

683 

684 # We'll insert all new rows into a temporary table 

685 tableSpec = makeTagTableSpec(self.datasetType, type(self._collections), ddl.GUID, constraints=False) 

686 collFkName = self._collections.getCollectionForeignKeyName() 

687 protoTagsRow = { 

688 "dataset_type_id": self._dataset_type_id, 

689 collFkName: run.key, 

690 } 

691 tmpRows = [ 

692 dict(protoTagsRow, dataset_id=dataset_id, **dataId.required) 

693 for dataset_id, dataId in dataIds.items() 

694 ] 

695 with self._db.transaction(for_temp_tables=True), self._db.temporary_table(tableSpec) as tmp_tags: 

696 # store all incoming data in a temporary table 

697 self._db.insert(tmp_tags, *tmpRows) 

698 

699 # There are some checks that we want to make for consistency 

700 # of the new datasets with existing ones. 

701 self._validateImport(tmp_tags, run) 

702 

703 # Before we merge temporary table into dataset/tags we need to 

704 # drop datasets which are already there (and do not conflict). 

705 self._db.deleteWhere( 

706 tmp_tags, 

707 tmp_tags.columns.dataset_id.in_(sqlalchemy.sql.select(self._static.dataset.columns.id)), 

708 ) 

709 

710 # Copy it into dataset table, need to re-label some columns. 

711 self._db.insert( 

712 self._static.dataset, 

713 select=sqlalchemy.sql.select( 

714 tmp_tags.columns.dataset_id.label("id"), 

715 tmp_tags.columns.dataset_type_id, 

716 tmp_tags.columns[collFkName].label(self._runKeyColumn), 

717 timestamp.label("ingest_date"), 

718 ), 

719 ) 

720 

721 # Update the summary tables for this collection in case this 

722 # is the first time this dataset type or these governor values 

723 # will be inserted there. 

724 self._summaries.update(run, [self._dataset_type_id], summary) 

725 

726 # Copy it into tags table. 

727 self._db.insert(self._tags, select=tmp_tags.select()) 

728 

729 # Return refs in the same order as in the input list. 

730 for dataset_id, dataId in dataIds.items(): 

731 yield DatasetRef( 

732 datasetType=self.datasetType, 

733 id=dataset_id, 

734 dataId=dataId, 

735 run=run.name, 

736 ) 

737 

738 def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) -> None: 

739 """Validate imported refs against existing datasets. 

740 

741 Parameters 

742 ---------- 

743 tmp_tags : `sqlalchemy.schema.Table` 

744 Temporary table with new datasets and the same schema as tags 

745 table. 

746 run : `RunRecord` 

747 The record object describing the `~CollectionType.RUN` collection. 

748 

749 Raises 

750 ------ 

751 ConflictingDefinitionError 

752 Raise if new datasets conflict with existing ones. 

753 """ 

754 dataset = self._static.dataset 

755 tags = self._tags 

756 collFkName = self._collections.getCollectionForeignKeyName() 

757 

758 # Check that existing datasets have the same dataset type and 

759 # run. 

760 query = ( 

761 sqlalchemy.sql.select( 

762 dataset.columns.id.label("dataset_id"), 

763 dataset.columns.dataset_type_id.label("dataset_type_id"), 

764 tmp_tags.columns.dataset_type_id.label("new_dataset_type_id"), 

765 dataset.columns[self._runKeyColumn].label("run"), 

766 tmp_tags.columns[collFkName].label("new_run"), 

767 ) 

768 .select_from(dataset.join(tmp_tags, dataset.columns.id == tmp_tags.columns.dataset_id)) 

769 .where( 

770 sqlalchemy.sql.or_( 

771 dataset.columns.dataset_type_id != tmp_tags.columns.dataset_type_id, 

772 dataset.columns[self._runKeyColumn] != tmp_tags.columns[collFkName], 

773 ) 

774 ) 

775 .limit(1) 

776 ) 

777 with self._db.query(query) as result: 

778 # Only include the first one in the exception message 

779 if (row := result.first()) is not None: 

780 existing_run = self._collections[row.run].name 

781 new_run = self._collections[row.new_run].name 

782 if row.dataset_type_id == self._dataset_type_id: 

783 if row.new_dataset_type_id == self._dataset_type_id: 783 ↛ 789line 783 didn't jump to line 789, because the condition on line 783 was never false

784 raise ConflictingDefinitionError( 

785 f"Current run {existing_run!r} and new run {new_run!r} do not agree for " 

786 f"dataset {row.dataset_id}." 

787 ) 

788 else: 

789 raise ConflictingDefinitionError( 

790 f"Dataset {row.dataset_id} was provided with type {self.datasetType.name!r} " 

791 f"in run {new_run!r}, but was already defined with type ID {row.dataset_type_id} " 

792 f"in run {run!r}." 

793 ) 

794 else: 

795 raise ConflictingDefinitionError( 

796 f"Dataset {row.dataset_id} was provided with type ID {row.new_dataset_type_id} " 

797 f"in run {new_run!r}, but was already defined with type {self.datasetType.name!r} " 

798 f"in run {run!r}." 

799 ) 

800 

801 # Check that matching dataset in tags table has the same DataId. 

802 query = ( 

803 sqlalchemy.sql.select( 

804 tags.columns.dataset_id, 

805 tags.columns.dataset_type_id.label("type_id"), 

806 tmp_tags.columns.dataset_type_id.label("new_type_id"), 

807 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names], 

808 *[ 

809 tmp_tags.columns[dim].label(f"new_{dim}") 

810 for dim in self.datasetType.dimensions.required.names 

811 ], 

812 ) 

813 .select_from(tags.join(tmp_tags, tags.columns.dataset_id == tmp_tags.columns.dataset_id)) 

814 .where( 

815 sqlalchemy.sql.or_( 

816 tags.columns.dataset_type_id != tmp_tags.columns.dataset_type_id, 

817 *[ 

818 tags.columns[dim] != tmp_tags.columns[dim] 

819 for dim in self.datasetType.dimensions.required.names 

820 ], 

821 ) 

822 ) 

823 .limit(1) 

824 ) 

825 

826 with self._db.query(query) as result: 

827 if (row := result.first()) is not None: 

828 # Only include the first one in the exception message 

829 raise ConflictingDefinitionError( 

830 f"Existing dataset type or dataId do not match new dataset: {row._asdict()}" 

831 ) 

832 

833 # Check that matching run+dataId have the same dataset ID. 

834 query = ( 

835 sqlalchemy.sql.select( 

836 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names], 

837 tags.columns.dataset_id, 

838 tmp_tags.columns.dataset_id.label("new_dataset_id"), 

839 tags.columns[collFkName], 

840 tmp_tags.columns[collFkName].label(f"new_{collFkName}"), 

841 ) 

842 .select_from( 

843 tags.join( 

844 tmp_tags, 

845 sqlalchemy.sql.and_( 

846 tags.columns.dataset_type_id == tmp_tags.columns.dataset_type_id, 

847 tags.columns[collFkName] == tmp_tags.columns[collFkName], 

848 *[ 

849 tags.columns[dim] == tmp_tags.columns[dim] 

850 for dim in self.datasetType.dimensions.required.names 

851 ], 

852 ), 

853 ) 

854 ) 

855 .where(tags.columns.dataset_id != tmp_tags.columns.dataset_id) 

856 .limit(1) 

857 ) 

858 with self._db.query(query) as result: 

859 # only include the first one in the exception message 

860 if (row := result.first()) is not None: 

861 data_id = {dim: getattr(row, dim) for dim in self.datasetType.dimensions.required.names} 

862 existing_collection = self._collections[getattr(row, collFkName)].name 

863 new_collection = self._collections[getattr(row, f"new_{collFkName}")].name 

864 raise ConflictingDefinitionError( 

865 f"Dataset with type {self.datasetType.name!r} and data ID {data_id} " 

866 f"has ID {row.dataset_id} in existing collection {existing_collection!r} " 

867 f"but ID {row.new_dataset_id} in new collection {new_collection!r}." 

868 )