Coverage for python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py: 95%

260 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-01 10:59 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28 

29from __future__ import annotations 

30 

31from .... import ddl 

32 

33__all__ = ("ByDimensionsDatasetRecordStorage",) 

34 

35import datetime 

36from collections.abc import Callable, Iterable, Iterator, Sequence, Set 

37from typing import TYPE_CHECKING 

38 

39import astropy.time 

40import sqlalchemy 

41from lsst.daf.relation import Relation, sql 

42 

43from ...._column_tags import DatasetColumnTag, DimensionKeyColumnTag 

44from ...._column_type_info import LogicalColumn 

45from ...._dataset_ref import DatasetId, DatasetIdFactory, DatasetIdGenEnum, DatasetRef 

46from ...._dataset_type import DatasetType 

47from ...._timespan import Timespan 

48from ....dimensions import DataCoordinate 

49from ..._collection_summary import CollectionSummary 

50from ..._collection_type import CollectionType 

51from ..._exceptions import CollectionTypeError, ConflictingDefinitionError 

52from ...interfaces import DatasetRecordStorage 

53from ...queries import SqlQueryContext 

54from .tables import makeTagTableSpec 

55 

56if TYPE_CHECKING: 

57 from ...interfaces import CollectionManager, CollectionRecord, Database, RunRecord 

58 from .summaries import CollectionSummaryManager 

59 from .tables import StaticDatasetTablesTuple 

60 

61 

62class ByDimensionsDatasetRecordStorage(DatasetRecordStorage): 

63 """Dataset record storage implementation paired with 

64 `ByDimensionsDatasetRecordStorageManagerUUID`; see that class for more 

65 information. 

66 

67 Instances of this class should never be constructed directly; use 

68 `DatasetRecordStorageManager.register` instead. 

69 """ 

70 

71 def __init__( 

72 self, 

73 *, 

74 datasetType: DatasetType, 

75 db: Database, 

76 dataset_type_id: int, 

77 collections: CollectionManager, 

78 static: StaticDatasetTablesTuple, 

79 summaries: CollectionSummaryManager, 

80 tags_table_factory: Callable[[], sqlalchemy.schema.Table], 

81 use_astropy_ingest_date: bool, 

82 calibs_table_factory: Callable[[], sqlalchemy.schema.Table] | None, 

83 ): 

84 super().__init__(datasetType=datasetType) 

85 self._dataset_type_id = dataset_type_id 

86 self._db = db 

87 self._collections = collections 

88 self._static = static 

89 self._summaries = summaries 

90 self._tags_table_factory = tags_table_factory 

91 self._calibs_table_factory = calibs_table_factory 

92 self._runKeyColumn = collections.getRunForeignKeyName() 

93 self._use_astropy = use_astropy_ingest_date 

94 self._tags_table: sqlalchemy.schema.Table | None = None 

95 self._calibs_table: sqlalchemy.schema.Table | None = None 

96 

97 @property 

98 def _tags(self) -> sqlalchemy.schema.Table: 

99 if self._tags_table is None: 

100 self._tags_table = self._tags_table_factory() 

101 return self._tags_table 

102 

103 @property 

104 def _calibs(self) -> sqlalchemy.schema.Table | None: 

105 if self._calibs_table is None: 

106 if self._calibs_table_factory is None: 106 ↛ 107line 106 didn't jump to line 107, because the condition on line 106 was never true

107 return None 

108 self._calibs_table = self._calibs_table_factory() 

109 return self._calibs_table 

110 

111 def delete(self, datasets: Iterable[DatasetRef]) -> None: 

112 # Docstring inherited from DatasetRecordStorage. 

113 # Only delete from common dataset table; ON DELETE foreign key clauses 

114 # will handle the rest. 

115 self._db.delete( 

116 self._static.dataset, 

117 ["id"], 

118 *[{"id": dataset.id} for dataset in datasets], 

119 ) 

120 

121 def associate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None: 

122 # Docstring inherited from DatasetRecordStorage. 

123 if collection.type is not CollectionType.TAGGED: 123 ↛ 124line 123 didn't jump to line 124, because the condition on line 123 was never true

124 raise TypeError( 

125 f"Cannot associate into collection '{collection.name}' " 

126 f"of type {collection.type.name}; must be TAGGED." 

127 ) 

128 protoRow = { 

129 self._collections.getCollectionForeignKeyName(): collection.key, 

130 "dataset_type_id": self._dataset_type_id, 

131 } 

132 rows = [] 

133 summary = CollectionSummary() 

134 for dataset in summary.add_datasets_generator(datasets): 

135 rows.append(dict(protoRow, dataset_id=dataset.id, **dataset.dataId.required)) 

136 # Update the summary tables for this collection in case this is the 

137 # first time this dataset type or these governor values will be 

138 # inserted there. 

139 self._summaries.update(collection, [self._dataset_type_id], summary) 

140 # Update the tag table itself. 

141 self._db.replace(self._tags, *rows) 

142 

143 def disassociate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None: 

144 # Docstring inherited from DatasetRecordStorage. 

145 if collection.type is not CollectionType.TAGGED: 145 ↛ 146line 145 didn't jump to line 146, because the condition on line 145 was never true

146 raise TypeError( 

147 f"Cannot disassociate from collection '{collection.name}' " 

148 f"of type {collection.type.name}; must be TAGGED." 

149 ) 

150 rows = [ 

151 { 

152 "dataset_id": dataset.id, 

153 self._collections.getCollectionForeignKeyName(): collection.key, 

154 } 

155 for dataset in datasets 

156 ] 

157 self._db.delete(self._tags, ["dataset_id", self._collections.getCollectionForeignKeyName()], *rows) 

158 

159 def _buildCalibOverlapQuery( 

160 self, 

161 collection: CollectionRecord, 

162 data_ids: set[DataCoordinate] | None, 

163 timespan: Timespan, 

164 context: SqlQueryContext, 

165 ) -> Relation: 

166 relation = self.make_relation( 

167 collection, columns={"timespan", "dataset_id", "calib_pkey"}, context=context 

168 ).with_rows_satisfying( 

169 context.make_timespan_overlap_predicate( 

170 DatasetColumnTag(self.datasetType.name, "timespan"), timespan 

171 ), 

172 ) 

173 if data_ids is not None: 

174 relation = relation.join( 

175 context.make_data_id_relation( 

176 data_ids, self.datasetType.dimensions.required.names 

177 ).transferred_to(context.sql_engine), 

178 ) 

179 return relation 

180 

181 def certify( 

182 self, 

183 collection: CollectionRecord, 

184 datasets: Iterable[DatasetRef], 

185 timespan: Timespan, 

186 context: SqlQueryContext, 

187 ) -> None: 

188 # Docstring inherited from DatasetRecordStorage. 

189 if self._calibs is None: 189 ↛ 190line 189 didn't jump to line 190, because the condition on line 189 was never true

190 raise CollectionTypeError( 

191 f"Cannot certify datasets of type {self.datasetType.name}, for which " 

192 "DatasetType.isCalibration() is False." 

193 ) 

194 if collection.type is not CollectionType.CALIBRATION: 194 ↛ 195line 194 didn't jump to line 195, because the condition on line 194 was never true

195 raise CollectionTypeError( 

196 f"Cannot certify into collection '{collection.name}' " 

197 f"of type {collection.type.name}; must be CALIBRATION." 

198 ) 

199 TimespanReprClass = self._db.getTimespanRepresentation() 

200 protoRow = { 

201 self._collections.getCollectionForeignKeyName(): collection.key, 

202 "dataset_type_id": self._dataset_type_id, 

203 } 

204 rows = [] 

205 dataIds: set[DataCoordinate] | None = ( 

206 set() if not TimespanReprClass.hasExclusionConstraint() else None 

207 ) 

208 summary = CollectionSummary() 

209 for dataset in summary.add_datasets_generator(datasets): 

210 row = dict(protoRow, dataset_id=dataset.id, **dataset.dataId.required) 

211 TimespanReprClass.update(timespan, result=row) 

212 rows.append(row) 

213 if dataIds is not None: 213 ↛ 209line 213 didn't jump to line 209, because the condition on line 213 was never false

214 dataIds.add(dataset.dataId) 

215 # Update the summary tables for this collection in case this is the 

216 # first time this dataset type or these governor values will be 

217 # inserted there. 

218 self._summaries.update(collection, [self._dataset_type_id], summary) 

219 # Update the association table itself. 

220 if TimespanReprClass.hasExclusionConstraint(): 220 ↛ 223line 220 didn't jump to line 223, because the condition on line 220 was never true

221 # Rely on database constraint to enforce invariants; we just 

222 # reraise the exception for consistency across DB engines. 

223 try: 

224 self._db.insert(self._calibs, *rows) 

225 except sqlalchemy.exc.IntegrityError as err: 

226 raise ConflictingDefinitionError( 

227 f"Validity range conflict certifying datasets of type {self.datasetType.name} " 

228 f"into {collection.name} for range [{timespan.begin}, {timespan.end})." 

229 ) from err 

230 else: 

231 # Have to implement exclusion constraint ourselves. 

232 # Start by building a SELECT query for any rows that would overlap 

233 # this one. 

234 relation = self._buildCalibOverlapQuery(collection, dataIds, timespan, context) 

235 # Acquire a table lock to ensure there are no concurrent writes 

236 # could invalidate our checking before we finish the inserts. We 

237 # use a SAVEPOINT in case there is an outer transaction that a 

238 # failure here should not roll back. 

239 with self._db.transaction(lock=[self._calibs], savepoint=True): 

240 # Enter SqlQueryContext in case we need to use a temporary 

241 # table to include the give data IDs in the query. Note that 

242 # by doing this inside the transaction, we make sure it doesn't 

243 # attempt to close the session when its done, since it just 

244 # sees an already-open session that it knows it shouldn't 

245 # manage. 

246 with context: 

247 # Run the check SELECT query. 

248 conflicting = context.count(context.process(relation)) 

249 if conflicting > 0: 

250 raise ConflictingDefinitionError( 

251 f"{conflicting} validity range conflicts certifying datasets of type " 

252 f"{self.datasetType.name} into {collection.name} for range " 

253 f"[{timespan.begin}, {timespan.end})." 

254 ) 

255 # Proceed with the insert. 

256 self._db.insert(self._calibs, *rows) 

257 

258 def decertify( 

259 self, 

260 collection: CollectionRecord, 

261 timespan: Timespan, 

262 *, 

263 dataIds: Iterable[DataCoordinate] | None = None, 

264 context: SqlQueryContext, 

265 ) -> None: 

266 # Docstring inherited from DatasetRecordStorage. 

267 if self._calibs is None: 267 ↛ 268line 267 didn't jump to line 268, because the condition on line 267 was never true

268 raise CollectionTypeError( 

269 f"Cannot decertify datasets of type {self.datasetType.name}, for which " 

270 "DatasetType.isCalibration() is False." 

271 ) 

272 if collection.type is not CollectionType.CALIBRATION: 272 ↛ 273line 272 didn't jump to line 273, because the condition on line 272 was never true

273 raise CollectionTypeError( 

274 f"Cannot decertify from collection '{collection.name}' " 

275 f"of type {collection.type.name}; must be CALIBRATION." 

276 ) 

277 TimespanReprClass = self._db.getTimespanRepresentation() 

278 # Construct a SELECT query to find all rows that overlap our inputs. 

279 dataIdSet: set[DataCoordinate] | None 

280 if dataIds is not None: 

281 dataIdSet = set(dataIds) 

282 else: 

283 dataIdSet = None 

284 relation = self._buildCalibOverlapQuery(collection, dataIdSet, timespan, context) 

285 calib_pkey_tag = DatasetColumnTag(self.datasetType.name, "calib_pkey") 

286 dataset_id_tag = DatasetColumnTag(self.datasetType.name, "dataset_id") 

287 timespan_tag = DatasetColumnTag(self.datasetType.name, "timespan") 

288 data_id_tags = [ 

289 (name, DimensionKeyColumnTag(name)) for name in self.datasetType.dimensions.required.names 

290 ] 

291 # Set up collections to populate with the rows we'll want to modify. 

292 # The insert rows will have the same values for collection and 

293 # dataset type. 

294 protoInsertRow = { 

295 self._collections.getCollectionForeignKeyName(): collection.key, 

296 "dataset_type_id": self._dataset_type_id, 

297 } 

298 rowsToDelete = [] 

299 rowsToInsert = [] 

300 # Acquire a table lock to ensure there are no concurrent writes 

301 # between the SELECT and the DELETE and INSERT queries based on it. 

302 with self._db.transaction(lock=[self._calibs], savepoint=True): 

303 # Enter SqlQueryContext in case we need to use a temporary table to 

304 # include the give data IDs in the query (see similar block in 

305 # certify for details). 

306 with context: 

307 for row in context.fetch_iterable(relation): 

308 rowsToDelete.append({"id": row[calib_pkey_tag]}) 

309 # Construct the insert row(s) by copying the prototype row, 

310 # then adding the dimension column values, then adding 

311 # what's left of the timespan from that row after we 

312 # subtract the given timespan. 

313 newInsertRow = protoInsertRow.copy() 

314 newInsertRow["dataset_id"] = row[dataset_id_tag] 

315 for name, tag in data_id_tags: 

316 newInsertRow[name] = row[tag] 

317 rowTimespan = row[timespan_tag] 

318 assert rowTimespan is not None, "Field should have a NOT NULL constraint." 

319 for diffTimespan in rowTimespan.difference(timespan): 

320 rowsToInsert.append( 

321 TimespanReprClass.update(diffTimespan, result=newInsertRow.copy()) 

322 ) 

323 # Run the DELETE and INSERT queries. 

324 self._db.delete(self._calibs, ["id"], *rowsToDelete) 

325 self._db.insert(self._calibs, *rowsToInsert) 

326 

327 def make_relation( 

328 self, 

329 *collections: CollectionRecord, 

330 columns: Set[str], 

331 context: SqlQueryContext, 

332 ) -> Relation: 

333 # Docstring inherited from DatasetRecordStorage. 

334 collection_types = {collection.type for collection in collections} 

335 assert CollectionType.CHAINED not in collection_types, "CHAINED collections must be flattened." 

336 TimespanReprClass = self._db.getTimespanRepresentation() 

337 # 

338 # There are two kinds of table in play here: 

339 # 

340 # - the static dataset table (with the dataset ID, dataset type ID, 

341 # run ID/name, and ingest date); 

342 # 

343 # - the dynamic tags/calibs table (with the dataset ID, dataset type 

344 # type ID, collection ID/name, data ID, and possibly validity 

345 # range). 

346 # 

347 # That means that we might want to return a query against either table 

348 # or a JOIN of both, depending on which quantities the caller wants. 

349 # But the data ID is always included, which means we'll always include 

350 # the tags/calibs table and join in the static dataset table only if we 

351 # need things from it that we can't get from the tags/calibs table. 

352 # 

353 # Note that it's important that we include a WHERE constraint on both 

354 # tables for any column (e.g. dataset_type_id) that is in both when 

355 # it's given explicitly; not doing can prevent the query planner from 

356 # using very important indexes. At present, we don't include those 

357 # redundant columns in the JOIN ON expression, however, because the 

358 # FOREIGN KEY (and its index) are defined only on dataset_id. 

359 tag_relation: Relation | None = None 

360 calib_relation: Relation | None = None 

361 if collection_types != {CollectionType.CALIBRATION}: 

362 # We'll need a subquery for the tags table if any of the given 

363 # collections are not a CALIBRATION collection. This intentionally 

364 # also fires when the list of collections is empty as a way to 

365 # create a dummy subquery that we know will fail. 

366 # We give the table an alias because it might appear multiple times 

367 # in the same query, for different dataset types. 

368 tags_parts = sql.Payload[LogicalColumn](self._tags.alias(f"{self.datasetType.name}_tags")) 

369 if "timespan" in columns: 

370 tags_parts.columns_available[ 

371 DatasetColumnTag(self.datasetType.name, "timespan") 

372 ] = TimespanReprClass.fromLiteral(Timespan(None, None)) 

373 tag_relation = self._finish_single_relation( 

374 tags_parts, 

375 columns, 

376 [ 

377 (record, rank) 

378 for rank, record in enumerate(collections) 

379 if record.type is not CollectionType.CALIBRATION 

380 ], 

381 context, 

382 ) 

383 assert "calib_pkey" not in columns, "For internal use only, and only for pure-calib queries." 

384 if CollectionType.CALIBRATION in collection_types: 

385 # If at least one collection is a CALIBRATION collection, we'll 

386 # need a subquery for the calibs table, and could include the 

387 # timespan as a result or constraint. 

388 assert ( 

389 self._calibs is not None 

390 ), "DatasetTypes with isCalibration() == False can never be found in a CALIBRATION collection." 

391 calibs_parts = sql.Payload[LogicalColumn](self._calibs.alias(f"{self.datasetType.name}_calibs")) 

392 if "timespan" in columns: 

393 calibs_parts.columns_available[ 

394 DatasetColumnTag(self.datasetType.name, "timespan") 

395 ] = TimespanReprClass.from_columns(calibs_parts.from_clause.columns) 

396 if "calib_pkey" in columns: 

397 # This is a private extension not included in the base class 

398 # interface, for internal use only in _buildCalibOverlapQuery, 

399 # which needs access to the autoincrement primary key for the 

400 # calib association table. 

401 calibs_parts.columns_available[ 

402 DatasetColumnTag(self.datasetType.name, "calib_pkey") 

403 ] = calibs_parts.from_clause.columns.id 

404 calib_relation = self._finish_single_relation( 

405 calibs_parts, 

406 columns, 

407 [ 

408 (record, rank) 

409 for rank, record in enumerate(collections) 

410 if record.type is CollectionType.CALIBRATION 

411 ], 

412 context, 

413 ) 

414 if tag_relation is not None: 

415 if calib_relation is not None: 

416 # daf_relation's chain operation does not automatically 

417 # deduplicate; it's more like SQL's UNION ALL. To get UNION 

418 # in SQL here, we add an explicit deduplication. 

419 return tag_relation.chain(calib_relation).without_duplicates() 

420 else: 

421 return tag_relation 

422 elif calib_relation is not None: 

423 return calib_relation 

424 else: 

425 raise AssertionError("Branch should be unreachable.") 

426 

427 def _finish_single_relation( 

428 self, 

429 payload: sql.Payload[LogicalColumn], 

430 requested_columns: Set[str], 

431 collections: Sequence[tuple[CollectionRecord, int]], 

432 context: SqlQueryContext, 

433 ) -> Relation: 

434 """Handle adding columns and WHERE terms that are not specific to 

435 either the tags or calibs tables. 

436 

437 Helper method for `make_relation`. 

438 

439 Parameters 

440 ---------- 

441 payload : `lsst.daf.relation.sql.Payload` 

442 SQL query parts under construction, to be modified in-place and 

443 used to construct the new relation. 

444 requested_columns : `~collections.abc.Set` [ `str` ] 

445 Columns the relation should include. 

446 collections : `~collections.abc.Sequence` [ `tuple` \ 

447 [ `CollectionRecord`, `int` ] ] 

448 Collections to search for the dataset and their ranks. 

449 context : `SqlQueryContext` 

450 Context that manages engines and state for the query. 

451 

452 Returns 

453 ------- 

454 relation : `lsst.daf.relation.Relation` 

455 New dataset query relation. 

456 """ 

457 payload.where.append(payload.from_clause.columns.dataset_type_id == self._dataset_type_id) 

458 dataset_id_col = payload.from_clause.columns.dataset_id 

459 collection_col = payload.from_clause.columns[self._collections.getCollectionForeignKeyName()] 

460 # We always constrain and optionally retrieve the collection(s) via the 

461 # tags/calibs table. 

462 if len(collections) == 1: 

463 payload.where.append(collection_col == collections[0][0].key) 

464 if "collection" in requested_columns: 

465 payload.columns_available[ 

466 DatasetColumnTag(self.datasetType.name, "collection") 

467 ] = sqlalchemy.sql.literal(collections[0][0].key) 

468 else: 

469 assert collections, "The no-collections case should be in calling code for better diagnostics." 

470 payload.where.append(collection_col.in_([collection.key for collection, _ in collections])) 

471 if "collection" in requested_columns: 

472 payload.columns_available[ 

473 DatasetColumnTag(self.datasetType.name, "collection") 

474 ] = collection_col 

475 # Add rank if requested as a CASE-based calculation the collection 

476 # column. 

477 if "rank" in requested_columns: 

478 payload.columns_available[DatasetColumnTag(self.datasetType.name, "rank")] = sqlalchemy.sql.case( 

479 {record.key: rank for record, rank in collections}, 

480 value=collection_col, 

481 ) 

482 # Add more column definitions, starting with the data ID. 

483 for dimension_name in self.datasetType.dimensions.required.names: 

484 payload.columns_available[DimensionKeyColumnTag(dimension_name)] = payload.from_clause.columns[ 

485 dimension_name 

486 ] 

487 # We can always get the dataset_id from the tags/calibs table. 

488 if "dataset_id" in requested_columns: 

489 payload.columns_available[DatasetColumnTag(self.datasetType.name, "dataset_id")] = dataset_id_col 

490 # It's possible we now have everything we need, from just the 

491 # tags/calibs table. The things we might need to get from the static 

492 # dataset table are the run key and the ingest date. 

493 need_static_table = False 

494 if "run" in requested_columns: 

495 if len(collections) == 1 and collections[0][0].type is CollectionType.RUN: 

496 # If we are searching exactly one RUN collection, we 

497 # know that if we find the dataset in that collection, 

498 # then that's the datasets's run; we don't need to 

499 # query for it. 

500 payload.columns_available[ 

501 DatasetColumnTag(self.datasetType.name, "run") 

502 ] = sqlalchemy.sql.literal(collections[0][0].key) 

503 else: 

504 payload.columns_available[ 

505 DatasetColumnTag(self.datasetType.name, "run") 

506 ] = self._static.dataset.columns[self._runKeyColumn] 

507 need_static_table = True 

508 # Ingest date can only come from the static table. 

509 if "ingest_date" in requested_columns: 

510 need_static_table = True 

511 payload.columns_available[ 

512 DatasetColumnTag(self.datasetType.name, "ingest_date") 

513 ] = self._static.dataset.columns.ingest_date 

514 # If we need the static table, join it in via dataset_id and 

515 # dataset_type_id 

516 if need_static_table: 

517 payload.from_clause = payload.from_clause.join( 

518 self._static.dataset, onclause=(dataset_id_col == self._static.dataset.columns.id) 

519 ) 

520 # Also constrain dataset_type_id in static table in case that helps 

521 # generate a better plan. 

522 # We could also include this in the JOIN ON clause, but my guess is 

523 # that that's a good idea IFF it's in the foreign key, and right 

524 # now it isn't. 

525 payload.where.append(self._static.dataset.columns.dataset_type_id == self._dataset_type_id) 

526 leaf = context.sql_engine.make_leaf( 

527 payload.columns_available.keys(), 

528 payload=payload, 

529 name=self.datasetType.name, 

530 parameters={record.name: rank for record, rank in collections}, 

531 ) 

532 return leaf 

533 

534 def getDataId(self, id: DatasetId) -> DataCoordinate: 

535 """Return DataId for a dataset. 

536 

537 Parameters 

538 ---------- 

539 id : `DatasetId` 

540 Unique dataset identifier. 

541 

542 Returns 

543 ------- 

544 dataId : `DataCoordinate` 

545 DataId for the dataset. 

546 """ 

547 # This query could return multiple rows (one for each tagged collection 

548 # the dataset is in, plus one for its run collection), and we don't 

549 # care which of those we get. 

550 sql = ( 

551 self._tags.select() 

552 .where( 

553 sqlalchemy.sql.and_( 

554 self._tags.columns.dataset_id == id, 

555 self._tags.columns.dataset_type_id == self._dataset_type_id, 

556 ) 

557 ) 

558 .limit(1) 

559 ) 

560 with self._db.query(sql) as sql_result: 

561 row = sql_result.mappings().fetchone() 

562 assert row is not None, "Should be guaranteed by caller and foreign key constraints." 

563 return DataCoordinate.from_required_values( 

564 self.datasetType.dimensions.as_group(), 

565 tuple(row[dimension] for dimension in self.datasetType.dimensions.required.names), 

566 ) 

567 

568 

569class ByDimensionsDatasetRecordStorageUUID(ByDimensionsDatasetRecordStorage): 

570 """Implementation of ByDimensionsDatasetRecordStorage which uses UUID for 

571 dataset IDs. 

572 """ 

573 

574 idMaker = DatasetIdFactory() 

575 """Factory for dataset IDs. In the future this factory may be shared with 

576 other classes (e.g. Registry).""" 

577 

578 def insert( 

579 self, 

580 run: RunRecord, 

581 dataIds: Iterable[DataCoordinate], 

582 idMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE, 

583 ) -> Iterator[DatasetRef]: 

584 # Docstring inherited from DatasetRecordStorage. 

585 

586 # Current timestamp, type depends on schema version. Use microsecond 

587 # precision for astropy time to keep things consistent with 

588 # TIMESTAMP(6) SQL type. 

589 timestamp: datetime.datetime | astropy.time.Time 

590 if self._use_astropy: 

591 # Astropy `now()` precision should be the same as `now()` which 

592 # should mean microsecond. 

593 timestamp = astropy.time.Time.now() 

594 else: 

595 timestamp = datetime.datetime.now(datetime.UTC) 

596 

597 # Iterate over data IDs, transforming a possibly-single-pass iterable 

598 # into a list. 

599 dataIdList: list[DataCoordinate] = [] 

600 rows = [] 

601 summary = CollectionSummary() 

602 for dataId in summary.add_data_ids_generator(self.datasetType, dataIds): 

603 dataIdList.append(dataId) 

604 rows.append( 

605 { 

606 "id": self.idMaker.makeDatasetId(run.name, self.datasetType, dataId, idMode), 

607 "dataset_type_id": self._dataset_type_id, 

608 self._runKeyColumn: run.key, 

609 "ingest_date": timestamp, 

610 } 

611 ) 

612 

613 with self._db.transaction(): 

614 # Insert into the static dataset table. 

615 self._db.insert(self._static.dataset, *rows) 

616 # Update the summary tables for this collection in case this is the 

617 # first time this dataset type or these governor values will be 

618 # inserted there. 

619 self._summaries.update(run, [self._dataset_type_id], summary) 

620 # Combine the generated dataset_id values and data ID fields to 

621 # form rows to be inserted into the tags table. 

622 protoTagsRow = { 

623 "dataset_type_id": self._dataset_type_id, 

624 self._collections.getCollectionForeignKeyName(): run.key, 

625 } 

626 tagsRows = [ 

627 dict(protoTagsRow, dataset_id=row["id"], **dataId.required) 

628 for dataId, row in zip(dataIdList, rows, strict=True) 

629 ] 

630 # Insert those rows into the tags table. 

631 self._db.insert(self._tags, *tagsRows) 

632 

633 for dataId, row in zip(dataIdList, rows, strict=True): 

634 yield DatasetRef( 

635 datasetType=self.datasetType, 

636 dataId=dataId, 

637 id=row["id"], 

638 run=run.name, 

639 ) 

640 

641 def import_( 

642 self, 

643 run: RunRecord, 

644 datasets: Iterable[DatasetRef], 

645 ) -> Iterator[DatasetRef]: 

646 # Docstring inherited from DatasetRecordStorage. 

647 

648 # Current timestamp, type depends on schema version. 

649 if self._use_astropy: 

650 # Astropy `now()` precision should be the same as `now()` which 

651 # should mean microsecond. 

652 timestamp = sqlalchemy.sql.literal(astropy.time.Time.now(), type_=ddl.AstropyTimeNsecTai) 

653 else: 

654 timestamp = sqlalchemy.sql.literal(datetime.datetime.now(datetime.UTC)) 

655 

656 # Iterate over data IDs, transforming a possibly-single-pass iterable 

657 # into a list. 

658 dataIds: dict[DatasetId, DataCoordinate] = {} 

659 summary = CollectionSummary() 

660 for dataset in summary.add_datasets_generator(datasets): 

661 dataIds[dataset.id] = dataset.dataId 

662 

663 # We'll insert all new rows into a temporary table 

664 tableSpec = makeTagTableSpec(self.datasetType, type(self._collections), ddl.GUID, constraints=False) 

665 collFkName = self._collections.getCollectionForeignKeyName() 

666 protoTagsRow = { 

667 "dataset_type_id": self._dataset_type_id, 

668 collFkName: run.key, 

669 } 

670 tmpRows = [ 

671 dict(protoTagsRow, dataset_id=dataset_id, **dataId.required) 

672 for dataset_id, dataId in dataIds.items() 

673 ] 

674 with self._db.transaction(for_temp_tables=True), self._db.temporary_table(tableSpec) as tmp_tags: 

675 # store all incoming data in a temporary table 

676 self._db.insert(tmp_tags, *tmpRows) 

677 

678 # There are some checks that we want to make for consistency 

679 # of the new datasets with existing ones. 

680 self._validateImport(tmp_tags, run) 

681 

682 # Before we merge temporary table into dataset/tags we need to 

683 # drop datasets which are already there (and do not conflict). 

684 self._db.deleteWhere( 

685 tmp_tags, 

686 tmp_tags.columns.dataset_id.in_(sqlalchemy.sql.select(self._static.dataset.columns.id)), 

687 ) 

688 

689 # Copy it into dataset table, need to re-label some columns. 

690 self._db.insert( 

691 self._static.dataset, 

692 select=sqlalchemy.sql.select( 

693 tmp_tags.columns.dataset_id.label("id"), 

694 tmp_tags.columns.dataset_type_id, 

695 tmp_tags.columns[collFkName].label(self._runKeyColumn), 

696 timestamp.label("ingest_date"), 

697 ), 

698 ) 

699 

700 # Update the summary tables for this collection in case this 

701 # is the first time this dataset type or these governor values 

702 # will be inserted there. 

703 self._summaries.update(run, [self._dataset_type_id], summary) 

704 

705 # Copy it into tags table. 

706 self._db.insert(self._tags, select=tmp_tags.select()) 

707 

708 # Return refs in the same order as in the input list. 

709 for dataset_id, dataId in dataIds.items(): 

710 yield DatasetRef( 

711 datasetType=self.datasetType, 

712 id=dataset_id, 

713 dataId=dataId, 

714 run=run.name, 

715 ) 

716 

717 def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) -> None: 

718 """Validate imported refs against existing datasets. 

719 

720 Parameters 

721 ---------- 

722 tmp_tags : `sqlalchemy.schema.Table` 

723 Temporary table with new datasets and the same schema as tags 

724 table. 

725 run : `RunRecord` 

726 The record object describing the `~CollectionType.RUN` collection. 

727 

728 Raises 

729 ------ 

730 ConflictingDefinitionError 

731 Raise if new datasets conflict with existing ones. 

732 """ 

733 dataset = self._static.dataset 

734 tags = self._tags 

735 collFkName = self._collections.getCollectionForeignKeyName() 

736 

737 # Check that existing datasets have the same dataset type and 

738 # run. 

739 query = ( 

740 sqlalchemy.sql.select( 

741 dataset.columns.id.label("dataset_id"), 

742 dataset.columns.dataset_type_id.label("dataset_type_id"), 

743 tmp_tags.columns.dataset_type_id.label("new_dataset_type_id"), 

744 dataset.columns[self._runKeyColumn].label("run"), 

745 tmp_tags.columns[collFkName].label("new_run"), 

746 ) 

747 .select_from(dataset.join(tmp_tags, dataset.columns.id == tmp_tags.columns.dataset_id)) 

748 .where( 

749 sqlalchemy.sql.or_( 

750 dataset.columns.dataset_type_id != tmp_tags.columns.dataset_type_id, 

751 dataset.columns[self._runKeyColumn] != tmp_tags.columns[collFkName], 

752 ) 

753 ) 

754 .limit(1) 

755 ) 

756 with self._db.query(query) as result: 

757 # Only include the first one in the exception message 

758 if (row := result.first()) is not None: 

759 existing_run = self._collections[row.run].name 

760 new_run = self._collections[row.new_run].name 

761 if row.dataset_type_id == self._dataset_type_id: 

762 if row.new_dataset_type_id == self._dataset_type_id: 762 ↛ 768line 762 didn't jump to line 768, because the condition on line 762 was never false

763 raise ConflictingDefinitionError( 

764 f"Current run {existing_run!r} and new run {new_run!r} do not agree for " 

765 f"dataset {row.dataset_id}." 

766 ) 

767 else: 

768 raise ConflictingDefinitionError( 

769 f"Dataset {row.dataset_id} was provided with type {self.datasetType.name!r} " 

770 f"in run {new_run!r}, but was already defined with type ID {row.dataset_type_id} " 

771 f"in run {run!r}." 

772 ) 

773 else: 

774 raise ConflictingDefinitionError( 

775 f"Dataset {row.dataset_id} was provided with type ID {row.new_dataset_type_id} " 

776 f"in run {new_run!r}, but was already defined with type {self.datasetType.name!r} " 

777 f"in run {run!r}." 

778 ) 

779 

780 # Check that matching dataset in tags table has the same DataId. 

781 query = ( 

782 sqlalchemy.sql.select( 

783 tags.columns.dataset_id, 

784 tags.columns.dataset_type_id.label("type_id"), 

785 tmp_tags.columns.dataset_type_id.label("new_type_id"), 

786 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names], 

787 *[ 

788 tmp_tags.columns[dim].label(f"new_{dim}") 

789 for dim in self.datasetType.dimensions.required.names 

790 ], 

791 ) 

792 .select_from(tags.join(tmp_tags, tags.columns.dataset_id == tmp_tags.columns.dataset_id)) 

793 .where( 

794 sqlalchemy.sql.or_( 

795 tags.columns.dataset_type_id != tmp_tags.columns.dataset_type_id, 

796 *[ 

797 tags.columns[dim] != tmp_tags.columns[dim] 

798 for dim in self.datasetType.dimensions.required.names 

799 ], 

800 ) 

801 ) 

802 .limit(1) 

803 ) 

804 

805 with self._db.query(query) as result: 

806 if (row := result.first()) is not None: 

807 # Only include the first one in the exception message 

808 raise ConflictingDefinitionError( 

809 f"Existing dataset type or dataId do not match new dataset: {row._asdict()}" 

810 ) 

811 

812 # Check that matching run+dataId have the same dataset ID. 

813 query = ( 

814 sqlalchemy.sql.select( 

815 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names], 

816 tags.columns.dataset_id, 

817 tmp_tags.columns.dataset_id.label("new_dataset_id"), 

818 tags.columns[collFkName], 

819 tmp_tags.columns[collFkName].label(f"new_{collFkName}"), 

820 ) 

821 .select_from( 

822 tags.join( 

823 tmp_tags, 

824 sqlalchemy.sql.and_( 

825 tags.columns.dataset_type_id == tmp_tags.columns.dataset_type_id, 

826 tags.columns[collFkName] == tmp_tags.columns[collFkName], 

827 *[ 

828 tags.columns[dim] == tmp_tags.columns[dim] 

829 for dim in self.datasetType.dimensions.required.names 

830 ], 

831 ), 

832 ) 

833 ) 

834 .where(tags.columns.dataset_id != tmp_tags.columns.dataset_id) 

835 .limit(1) 

836 ) 

837 with self._db.query(query) as result: 

838 # only include the first one in the exception message 

839 if (row := result.first()) is not None: 

840 data_id = {dim: getattr(row, dim) for dim in self.datasetType.dimensions.required.names} 

841 existing_collection = self._collections[getattr(row, collFkName)].name 

842 new_collection = self._collections[getattr(row, f"new_{collFkName}")].name 

843 raise ConflictingDefinitionError( 

844 f"Dataset with type {self.datasetType.name!r} and data ID {data_id} " 

845 f"has ID {row.dataset_id} in existing collection {existing_collection!r} " 

846 f"but ID {row.new_dataset_id} in new collection {new_collection!r}." 

847 )