Coverage for python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py: 88%

301 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-08-26 09:23 +0000

1from __future__ import annotations 

2 

3__all__ = ("ByDimensionsDatasetRecordStorage",) 

4 

5import uuid 

6from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set 

7 

8import sqlalchemy 

9from lsst.daf.butler import ( 

10 CollectionType, 

11 DataCoordinate, 

12 DataCoordinateSet, 

13 DatasetId, 

14 DatasetRef, 

15 DatasetType, 

16 SimpleQuery, 

17 Timespan, 

18 ddl, 

19) 

20from lsst.daf.butler.registry import ( 

21 CollectionTypeError, 

22 ConflictingDefinitionError, 

23 UnsupportedIdGeneratorError, 

24) 

25from lsst.daf.butler.registry.interfaces import DatasetIdFactory, DatasetIdGenEnum, DatasetRecordStorage 

26 

27from ...summaries import GovernorDimensionRestriction 

28from .tables import makeTagTableSpec 

29 

30if TYPE_CHECKING: 30 ↛ 31line 30 didn't jump to line 31, because the condition on line 30 was never true

31 from ...interfaces import CollectionManager, CollectionRecord, Database, RunRecord 

32 from .summaries import CollectionSummaryManager 

33 from .tables import StaticDatasetTablesTuple 

34 

35 

36class ByDimensionsDatasetRecordStorage(DatasetRecordStorage): 

37 """Dataset record storage implementation paired with 

38 `ByDimensionsDatasetRecordStorageManager`; see that class for more 

39 information. 

40 

41 Instances of this class should never be constructed directly; use 

42 `DatasetRecordStorageManager.register` instead. 

43 """ 

44 

45 def __init__( 

46 self, 

47 *, 

48 datasetType: DatasetType, 

49 db: Database, 

50 dataset_type_id: int, 

51 collections: CollectionManager, 

52 static: StaticDatasetTablesTuple, 

53 summaries: CollectionSummaryManager, 

54 tags: sqlalchemy.schema.Table, 

55 calibs: Optional[sqlalchemy.schema.Table], 

56 ): 

57 super().__init__(datasetType=datasetType) 

58 self._dataset_type_id = dataset_type_id 

59 self._db = db 

60 self._collections = collections 

61 self._static = static 

62 self._summaries = summaries 

63 self._tags = tags 

64 self._calibs = calibs 

65 self._runKeyColumn = collections.getRunForeignKeyName() 

66 

67 def find( 

68 self, collection: CollectionRecord, dataId: DataCoordinate, timespan: Optional[Timespan] = None 

69 ) -> Optional[DatasetRef]: 

70 # Docstring inherited from DatasetRecordStorage. 

71 assert dataId.graph == self.datasetType.dimensions 

72 if collection.type is CollectionType.CALIBRATION and timespan is None: 72 ↛ 73line 72 didn't jump to line 73, because the condition on line 72 was never true

73 raise TypeError( 

74 f"Cannot search for dataset in CALIBRATION collection {collection.name} " 

75 f"without an input timespan." 

76 ) 

77 sql = self.select( 

78 collection, dataId=dataId, id=SimpleQuery.Select, run=SimpleQuery.Select, timespan=timespan 

79 ) 

80 results = self._db.query(sql) 

81 row = results.fetchone() 

82 if row is None: 

83 return None 

84 if collection.type is CollectionType.CALIBRATION: 

85 # For temporal calibration lookups (only!) our invariants do not 

86 # guarantee that the number of result rows is <= 1. 

87 # They would if `select` constrained the given timespan to be 

88 # _contained_ by the validity range in the self._calibs table, 

89 # instead of simply _overlapping_ it, because we do guarantee that 

90 # the validity ranges are disjoint for a particular dataset type, 

91 # collection, and data ID. But using an overlap test and a check 

92 # for multiple result rows here allows us to provide a more useful 

93 # diagnostic, as well as allowing `select` to support more general 

94 # queries where multiple results are not an error. 

95 if results.fetchone() is not None: 

96 raise RuntimeError( 

97 f"Multiple matches found for calibration lookup in {collection.name} for " 

98 f"{self.datasetType.name} with {dataId} overlapping {timespan}. " 

99 ) 

100 return DatasetRef( 

101 datasetType=self.datasetType, 

102 dataId=dataId, 

103 id=row.id, 

104 run=self._collections[row._mapping[self._runKeyColumn]].name, 

105 ) 

106 

107 def delete(self, datasets: Iterable[DatasetRef]) -> None: 

108 # Docstring inherited from DatasetRecordStorage. 

109 # Only delete from common dataset table; ON DELETE foreign key clauses 

110 # will handle the rest. 

111 self._db.delete( 

112 self._static.dataset, 

113 ["id"], 

114 *[{"id": dataset.getCheckedId()} for dataset in datasets], 

115 ) 

116 

117 def associate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None: 

118 # Docstring inherited from DatasetRecordStorage. 

119 if collection.type is not CollectionType.TAGGED: 119 ↛ 120line 119 didn't jump to line 120, because the condition on line 119 was never true

120 raise TypeError( 

121 f"Cannot associate into collection '{collection.name}' " 

122 f"of type {collection.type.name}; must be TAGGED." 

123 ) 

124 protoRow = { 

125 self._collections.getCollectionForeignKeyName(): collection.key, 

126 "dataset_type_id": self._dataset_type_id, 

127 } 

128 rows = [] 

129 governorValues = GovernorDimensionRestriction.makeEmpty(self.datasetType.dimensions.universe) 

130 for dataset in datasets: 

131 row = dict(protoRow, dataset_id=dataset.getCheckedId()) 

132 for dimension, value in dataset.dataId.items(): 

133 row[dimension.name] = value 

134 governorValues.update_extract(dataset.dataId) 

135 rows.append(row) 

136 # Update the summary tables for this collection in case this is the 

137 # first time this dataset type or these governor values will be 

138 # inserted there. 

139 self._summaries.update(collection, self.datasetType, self._dataset_type_id, governorValues) 

140 # Update the tag table itself. 

141 self._db.replace(self._tags, *rows) 

142 

143 def disassociate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None: 

144 # Docstring inherited from DatasetRecordStorage. 

145 if collection.type is not CollectionType.TAGGED: 145 ↛ 146line 145 didn't jump to line 146, because the condition on line 145 was never true

146 raise TypeError( 

147 f"Cannot disassociate from collection '{collection.name}' " 

148 f"of type {collection.type.name}; must be TAGGED." 

149 ) 

150 rows = [ 

151 { 

152 "dataset_id": dataset.getCheckedId(), 

153 self._collections.getCollectionForeignKeyName(): collection.key, 

154 } 

155 for dataset in datasets 

156 ] 

157 self._db.delete(self._tags, ["dataset_id", self._collections.getCollectionForeignKeyName()], *rows) 

158 

159 def _buildCalibOverlapQuery( 

160 self, collection: CollectionRecord, dataIds: Optional[DataCoordinateSet], timespan: Timespan 

161 ) -> SimpleQuery: 

162 assert self._calibs is not None 

163 # Start by building a SELECT query for any rows that would overlap 

164 # this one. 

165 query = SimpleQuery() 

166 query.join(self._calibs) 

167 # Add a WHERE clause matching the dataset type and collection. 

168 query.where.append(self._calibs.columns.dataset_type_id == self._dataset_type_id) 

169 query.where.append( 

170 self._calibs.columns[self._collections.getCollectionForeignKeyName()] == collection.key 

171 ) 

172 # Add a WHERE clause matching any of the given data IDs. 

173 if dataIds is not None: 

174 dataIds.constrain( 

175 query, 

176 lambda name: self._calibs.columns[name], # type: ignore 

177 ) 

178 # Add WHERE clause for timespan overlaps. 

179 TimespanReprClass = self._db.getTimespanRepresentation() 

180 query.where.append( 

181 TimespanReprClass.fromSelectable(self._calibs).overlaps(TimespanReprClass.fromLiteral(timespan)) 

182 ) 

183 return query 

184 

185 def certify( 

186 self, collection: CollectionRecord, datasets: Iterable[DatasetRef], timespan: Timespan 

187 ) -> None: 

188 # Docstring inherited from DatasetRecordStorage. 

189 if self._calibs is None: 189 ↛ 190line 189 didn't jump to line 190, because the condition on line 189 was never true

190 raise CollectionTypeError( 

191 f"Cannot certify datasets of type {self.datasetType.name}, for which " 

192 f"DatasetType.isCalibration() is False." 

193 ) 

194 if collection.type is not CollectionType.CALIBRATION: 194 ↛ 195line 194 didn't jump to line 195, because the condition on line 194 was never true

195 raise CollectionTypeError( 

196 f"Cannot certify into collection '{collection.name}' " 

197 f"of type {collection.type.name}; must be CALIBRATION." 

198 ) 

199 TimespanReprClass = self._db.getTimespanRepresentation() 

200 protoRow = { 

201 self._collections.getCollectionForeignKeyName(): collection.key, 

202 "dataset_type_id": self._dataset_type_id, 

203 } 

204 rows = [] 

205 governorValues = GovernorDimensionRestriction.makeEmpty(self.datasetType.dimensions.universe) 

206 dataIds: Optional[Set[DataCoordinate]] = ( 

207 set() if not TimespanReprClass.hasExclusionConstraint() else None 

208 ) 

209 for dataset in datasets: 

210 row = dict(protoRow, dataset_id=dataset.getCheckedId()) 

211 for dimension, value in dataset.dataId.items(): 

212 row[dimension.name] = value 

213 TimespanReprClass.update(timespan, result=row) 

214 governorValues.update_extract(dataset.dataId) 

215 rows.append(row) 

216 if dataIds is not None: 216 ↛ 209line 216 didn't jump to line 209, because the condition on line 216 was never false

217 dataIds.add(dataset.dataId) 

218 # Update the summary tables for this collection in case this is the 

219 # first time this dataset type or these governor values will be 

220 # inserted there. 

221 self._summaries.update(collection, self.datasetType, self._dataset_type_id, governorValues) 

222 # Update the association table itself. 

223 if TimespanReprClass.hasExclusionConstraint(): 223 ↛ 226line 223 didn't jump to line 226, because the condition on line 223 was never true

224 # Rely on database constraint to enforce invariants; we just 

225 # reraise the exception for consistency across DB engines. 

226 try: 

227 self._db.insert(self._calibs, *rows) 

228 except sqlalchemy.exc.IntegrityError as err: 

229 raise ConflictingDefinitionError( 

230 f"Validity range conflict certifying datasets of type {self.datasetType.name} " 

231 f"into {collection.name} for range [{timespan.begin}, {timespan.end})." 

232 ) from err 

233 else: 

234 # Have to implement exclusion constraint ourselves. 

235 # Start by building a SELECT query for any rows that would overlap 

236 # this one. 

237 query = self._buildCalibOverlapQuery( 

238 collection, 

239 DataCoordinateSet(dataIds, graph=self.datasetType.dimensions), # type: ignore 

240 timespan, 

241 ) 

242 query.columns.append(sqlalchemy.sql.func.count()) 

243 sql = query.combine() 

244 # Acquire a table lock to ensure there are no concurrent writes 

245 # could invalidate our checking before we finish the inserts. We 

246 # use a SAVEPOINT in case there is an outer transaction that a 

247 # failure here should not roll back. 

248 with self._db.transaction(lock=[self._calibs], savepoint=True): 

249 # Run the check SELECT query. 

250 conflicting = self._db.query(sql).scalar() 

251 if conflicting > 0: 

252 raise ConflictingDefinitionError( 

253 f"{conflicting} validity range conflicts certifying datasets of type " 

254 f"{self.datasetType.name} into {collection.name} for range " 

255 f"[{timespan.begin}, {timespan.end})." 

256 ) 

257 # Proceed with the insert. 

258 self._db.insert(self._calibs, *rows) 

259 

260 def decertify( 

261 self, 

262 collection: CollectionRecord, 

263 timespan: Timespan, 

264 *, 

265 dataIds: Optional[Iterable[DataCoordinate]] = None, 

266 ) -> None: 

267 # Docstring inherited from DatasetRecordStorage. 

268 if self._calibs is None: 268 ↛ 269line 268 didn't jump to line 269, because the condition on line 268 was never true

269 raise CollectionTypeError( 

270 f"Cannot decertify datasets of type {self.datasetType.name}, for which " 

271 f"DatasetType.isCalibration() is False." 

272 ) 

273 if collection.type is not CollectionType.CALIBRATION: 273 ↛ 274line 273 didn't jump to line 274, because the condition on line 273 was never true

274 raise CollectionTypeError( 

275 f"Cannot decertify from collection '{collection.name}' " 

276 f"of type {collection.type.name}; must be CALIBRATION." 

277 ) 

278 TimespanReprClass = self._db.getTimespanRepresentation() 

279 # Construct a SELECT query to find all rows that overlap our inputs. 

280 dataIdSet: Optional[DataCoordinateSet] 

281 if dataIds is not None: 

282 dataIdSet = DataCoordinateSet(set(dataIds), graph=self.datasetType.dimensions) 

283 else: 

284 dataIdSet = None 

285 query = self._buildCalibOverlapQuery(collection, dataIdSet, timespan) 

286 query.columns.extend(self._calibs.columns) 

287 sql = query.combine() 

288 # Set up collections to populate with the rows we'll want to modify. 

289 # The insert rows will have the same values for collection and 

290 # dataset type. 

291 protoInsertRow = { 

292 self._collections.getCollectionForeignKeyName(): collection.key, 

293 "dataset_type_id": self._dataset_type_id, 

294 } 

295 rowsToDelete = [] 

296 rowsToInsert = [] 

297 # Acquire a table lock to ensure there are no concurrent writes 

298 # between the SELECT and the DELETE and INSERT queries based on it. 

299 with self._db.transaction(lock=[self._calibs], savepoint=True): 

300 for row in self._db.query(sql).mappings(): 

301 rowsToDelete.append({"id": row["id"]}) 

302 # Construct the insert row(s) by copying the prototype row, 

303 # then adding the dimension column values, then adding what's 

304 # left of the timespan from that row after we subtract the 

305 # given timespan. 

306 newInsertRow = protoInsertRow.copy() 

307 newInsertRow["dataset_id"] = row["dataset_id"] 

308 for name in self.datasetType.dimensions.required.names: 

309 newInsertRow[name] = row[name] 

310 rowTimespan = TimespanReprClass.extract(row) 

311 assert rowTimespan is not None, "Field should have a NOT NULL constraint." 

312 for diffTimespan in rowTimespan.difference(timespan): 

313 rowsToInsert.append(TimespanReprClass.update(diffTimespan, result=newInsertRow.copy())) 

314 # Run the DELETE and INSERT queries. 

315 self._db.delete(self._calibs, ["id"], *rowsToDelete) 

316 self._db.insert(self._calibs, *rowsToInsert) 

317 

318 def select( 

319 self, 

320 *collections: CollectionRecord, 

321 dataId: SimpleQuery.Select.Or[DataCoordinate] = SimpleQuery.Select, 

322 id: SimpleQuery.Select.Or[Optional[int]] = SimpleQuery.Select, 

323 run: SimpleQuery.Select.Or[None] = SimpleQuery.Select, 

324 timespan: SimpleQuery.Select.Or[Optional[Timespan]] = SimpleQuery.Select, 

325 ingestDate: SimpleQuery.Select.Or[Optional[Timespan]] = None, 

326 ) -> sqlalchemy.sql.Selectable: 

327 # Docstring inherited from DatasetRecordStorage. 

328 collection_types = {collection.type for collection in collections} 

329 assert CollectionType.CHAINED not in collection_types, "CHAINED collections must be flattened." 

330 # 

331 # There are two kinds of table in play here: 

332 # 

333 # - the static dataset table (with the dataset ID, dataset type ID, 

334 # run ID/name, and ingest date); 

335 # 

336 # - the dynamic tags/calibs table (with the dataset ID, dataset type 

337 # type ID, collection ID/name, data ID, and possibly validity 

338 # range). 

339 # 

340 # That means that we might want to return a query against either table 

341 # or a JOIN of both, depending on which quantities the caller wants. 

342 # But this method is documented/typed such that ``dataId`` is never 

343 # `None` - i.e. we always constrain or retreive the data ID. That 

344 # means we'll always include the tags/calibs table and join in the 

345 # static dataset table only if we need things from it that we can't get 

346 # from the tags/calibs table. 

347 # 

348 # Note that it's important that we include a WHERE constraint on both 

349 # tables for any column (e.g. dataset_type_id) that is in both when 

350 # it's given explicitly; not doing can prevent the query planner from 

351 # using very important indexes. At present, we don't include those 

352 # redundant columns in the JOIN ON expression, however, because the 

353 # FOREIGN KEY (and its index) are defined only on dataset_id. 

354 # 

355 # We'll start by accumulating kwargs to pass to SimpleQuery.join when 

356 # we bring in the tags/calibs table. We get the data ID or constrain 

357 # it in the tags/calibs table(s), but that's multiple columns, not one, 

358 # so we need to transform the one Select.Or argument into a dictionary 

359 # of them. 

360 kwargs: Dict[str, Any] 

361 if dataId is SimpleQuery.Select: 

362 kwargs = {dim.name: SimpleQuery.Select for dim in self.datasetType.dimensions.required} 

363 else: 

364 kwargs = dict(dataId.byName()) 

365 # We always constrain (never retrieve) the dataset type in at least the 

366 # tags/calibs table. 

367 kwargs["dataset_type_id"] = self._dataset_type_id 

368 # Join in the tags and/or calibs tables, turning those 'kwargs' entries 

369 # into WHERE constraints or SELECT columns as appropriate. 

370 if collection_types != {CollectionType.CALIBRATION}: 

371 # We'll need a subquery for the tags table if any of the given 

372 # collections are not a CALIBRATION collection. This intentionally 

373 # also fires when the list of collections is empty as a way to 

374 # create a dummy subquery that we know will fail. 

375 tags_query = SimpleQuery() 

376 tags_query.join(self._tags, **kwargs) 

377 self._finish_single_select( 

378 tags_query, self._tags, collections, id=id, run=run, ingestDate=ingestDate 

379 ) 

380 else: 

381 tags_query = None 

382 if CollectionType.CALIBRATION in collection_types: 

383 # If at least one collection is a CALIBRATION collection, we'll 

384 # need a subquery for the calibs table, and could include the 

385 # timespan as a result or constraint. 

386 calibs_query = SimpleQuery() 

387 assert ( 

388 self._calibs is not None 

389 ), "DatasetTypes with isCalibration() == False can never be found in a CALIBRATION collection." 

390 TimespanReprClass = self._db.getTimespanRepresentation() 

391 # Add the timespan column(s) to the result columns, or constrain 

392 # the timespan via an overlap condition. 

393 if timespan is SimpleQuery.Select: 

394 kwargs.update({k: SimpleQuery.Select for k in TimespanReprClass.getFieldNames()}) 

395 elif timespan is not None: 

396 calibs_query.where.append( 

397 TimespanReprClass.fromSelectable(self._calibs).overlaps( 

398 TimespanReprClass.fromLiteral(timespan) 

399 ) 

400 ) 

401 calibs_query.join(self._calibs, **kwargs) 

402 self._finish_single_select( 

403 calibs_query, self._calibs, collections, id=id, run=run, ingestDate=ingestDate 

404 ) 

405 else: 

406 calibs_query = None 

407 if calibs_query is not None: 

408 if tags_query is not None: 

409 if timespan is not None: 409 ↛ 410line 409 didn't jump to line 410, because the condition on line 409 was never true

410 raise TypeError( 

411 "Cannot query for timespan when the collections include both calibration and " 

412 "non-calibration collections." 

413 ) 

414 return tags_query.combine().union(calibs_query.combine()) 

415 else: 

416 return calibs_query.combine() 

417 else: 

418 assert tags_query is not None, "Earlier logic should guaranteed at least one is not None." 

419 return tags_query.combine() 

420 

421 def _finish_single_select( 

422 self, 

423 query: SimpleQuery, 

424 table: sqlalchemy.schema.Table, 

425 collections: Sequence[CollectionRecord], 

426 id: SimpleQuery.Select.Or[Optional[int]], 

427 run: SimpleQuery.Select.Or[None], 

428 ingestDate: SimpleQuery.Select.Or[Optional[Timespan]], 

429 ) -> None: 

430 dataset_id_col = table.columns.dataset_id 

431 collection_col = table.columns[self._collections.getCollectionForeignKeyName()] 

432 # We always constrain (never retrieve) the collection(s) in the 

433 # tags/calibs table. 

434 if len(collections) == 1: 

435 query.where.append(collection_col == collections[0].key) 

436 elif len(collections) == 0: 

437 # We support the case where there are no collections as a way to 

438 # generate a valid SQL query that can't yield results. This should 

439 # never get executed, but lots of downstream code will still try 

440 # to access the SQLAlchemy objects representing the columns in the 

441 # subquery. That's not ideal, but it'd take a lot of refactoring 

442 # to fix it (DM-31725). 

443 query.where.append(sqlalchemy.sql.literal(False)) 

444 else: 

445 query.where.append(collection_col.in_([collection.key for collection in collections])) 

446 # We can always get the dataset_id from the tags/calibs table or 

447 # constrain it there. Can't use kwargs for that because we need to 

448 # alias it to 'id'. 

449 if id is SimpleQuery.Select: 

450 query.columns.append(dataset_id_col.label("id")) 

451 elif id is not None: 451 ↛ 452line 451 didn't jump to line 452, because the condition on line 451 was never true

452 query.where.append(dataset_id_col == id) 

453 # It's possible we now have everything we need, from just the 

454 # tags/calibs table. The things we might need to get from the static 

455 # dataset table are the run key and the ingest date. 

456 need_static_table = False 

457 static_kwargs: Dict[str, Any] = {} 

458 if run is not None: 

459 assert run is SimpleQuery.Select, "To constrain the run name, pass a RunRecord as a collection." 

460 if len(collections) == 1 and collections[0].type is CollectionType.RUN: 

461 # If we are searching exactly one RUN collection, we 

462 # know that if we find the dataset in that collection, 

463 # then that's the datasets's run; we don't need to 

464 # query for it. 

465 query.columns.append(sqlalchemy.sql.literal(collections[0].key).label(self._runKeyColumn)) 

466 else: 

467 static_kwargs[self._runKeyColumn] = SimpleQuery.Select 

468 need_static_table = True 

469 # Ingest date can only come from the static table. 

470 if ingestDate is not None: 

471 need_static_table = True 

472 if ingestDate is SimpleQuery.Select: 472 ↛ 475line 472 didn't jump to line 475, because the condition on line 472 was never false

473 static_kwargs["ingest_date"] = SimpleQuery.Select 

474 else: 

475 assert isinstance(ingestDate, Timespan) 

476 # Timespan is astropy Time (usually in TAI) and ingest_date is 

477 # TIMESTAMP, convert values to Python datetime for sqlalchemy. 

478 if ingestDate.isEmpty(): 

479 raise RuntimeError("Empty timespan constraint provided for ingest_date.") 

480 if ingestDate.begin is not None: 

481 begin = ingestDate.begin.utc.datetime # type: ignore 

482 query.where.append(self._static.dataset.columns.ingest_date >= begin) 

483 if ingestDate.end is not None: 

484 end = ingestDate.end.utc.datetime # type: ignore 

485 query.where.append(self._static.dataset.columns.ingest_date < end) 

486 # If we need the static table, join it in via dataset_id and 

487 # dataset_type_id 

488 if need_static_table: 

489 query.join( 

490 self._static.dataset, 

491 onclause=(dataset_id_col == self._static.dataset.columns.id), 

492 **static_kwargs, 

493 ) 

494 # Also constrain dataset_type_id in static table in case that helps 

495 # generate a better plan. 

496 # We could also include this in the JOIN ON clause, but my guess is 

497 # that that's a good idea IFF it's in the foreign key, and right 

498 # now it isn't. 

499 query.where.append(self._static.dataset.columns.dataset_type_id == self._dataset_type_id) 

500 

501 def getDataId(self, id: DatasetId) -> DataCoordinate: 

502 """Return DataId for a dataset. 

503 

504 Parameters 

505 ---------- 

506 id : `DatasetId` 

507 Unique dataset identifier. 

508 

509 Returns 

510 ------- 

511 dataId : `DataCoordinate` 

512 DataId for the dataset. 

513 """ 

514 # This query could return multiple rows (one for each tagged collection 

515 # the dataset is in, plus one for its run collection), and we don't 

516 # care which of those we get. 

517 sql = ( 

518 self._tags.select() 

519 .where( 

520 sqlalchemy.sql.and_( 

521 self._tags.columns.dataset_id == id, 

522 self._tags.columns.dataset_type_id == self._dataset_type_id, 

523 ) 

524 ) 

525 .limit(1) 

526 ) 

527 row = self._db.query(sql).mappings().fetchone() 

528 assert row is not None, "Should be guaranteed by caller and foreign key constraints." 

529 return DataCoordinate.standardize( 

530 {dimension.name: row[dimension.name] for dimension in self.datasetType.dimensions.required}, 

531 graph=self.datasetType.dimensions, 

532 ) 

533 

534 

535class ByDimensionsDatasetRecordStorageInt(ByDimensionsDatasetRecordStorage): 

536 """Implementation of ByDimensionsDatasetRecordStorage which uses integer 

537 auto-incremented column for dataset IDs. 

538 """ 

539 

540 def insert( 

541 self, 

542 run: RunRecord, 

543 dataIds: Iterable[DataCoordinate], 

544 idMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE, 

545 ) -> Iterator[DatasetRef]: 

546 # Docstring inherited from DatasetRecordStorage. 

547 

548 # We only support UNIQUE mode for integer dataset IDs 

549 if idMode != DatasetIdGenEnum.UNIQUE: 549 ↛ 550line 549 didn't jump to line 550, because the condition on line 549 was never true

550 raise UnsupportedIdGeneratorError("Only UNIQUE mode can be used with integer dataset IDs.") 

551 

552 # Transform a possibly-single-pass iterable into a list. 

553 dataIdList = list(dataIds) 

554 yield from self._insert(run, dataIdList) 

555 

556 def import_( 

557 self, 

558 run: RunRecord, 

559 datasets: Iterable[DatasetRef], 

560 idGenerationMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE, 

561 reuseIds: bool = False, 

562 ) -> Iterator[DatasetRef]: 

563 # Docstring inherited from DatasetRecordStorage. 

564 

565 # We only support UNIQUE mode for integer dataset IDs 

566 if idGenerationMode != DatasetIdGenEnum.UNIQUE: 566 ↛ 567line 566 didn't jump to line 567, because the condition on line 566 was never true

567 raise UnsupportedIdGeneratorError("Only UNIQUE mode can be used with integer dataset IDs.") 

568 

569 # Make a list of dataIds and optionally dataset IDs. 

570 dataIdList: List[DataCoordinate] = [] 

571 datasetIdList: List[int] = [] 

572 for dataset in datasets: 

573 dataIdList.append(dataset.dataId) 

574 

575 # We only accept integer dataset IDs, but also allow None. 

576 datasetId = dataset.id 

577 if datasetId is None: 577 ↛ 579line 577 didn't jump to line 579, because the condition on line 577 was never true

578 # if reuseIds is set then all IDs must be known 

579 if reuseIds: 

580 raise TypeError("All dataset IDs must be known if `reuseIds` is set") 

581 elif isinstance(datasetId, int): 581 ↛ 585line 581 didn't jump to line 585, because the condition on line 581 was never false

582 if reuseIds: 

583 datasetIdList.append(datasetId) 

584 else: 

585 raise TypeError(f"Unsupported type of dataset ID: {type(datasetId)}") 

586 

587 yield from self._insert(run, dataIdList, datasetIdList) 

588 

589 def _insert( 

590 self, run: RunRecord, dataIdList: List[DataCoordinate], datasetIdList: Optional[List[int]] = None 

591 ) -> Iterator[DatasetRef]: 

592 """Common part of implementation of `insert` and `import_` methods.""" 

593 

594 # Remember any governor dimension values we see. 

595 governorValues = GovernorDimensionRestriction.makeEmpty(self.datasetType.dimensions.universe) 

596 for dataId in dataIdList: 

597 governorValues.update_extract(dataId) 

598 

599 staticRow = { 

600 "dataset_type_id": self._dataset_type_id, 

601 self._runKeyColumn: run.key, 

602 } 

603 with self._db.transaction(): 

604 # Insert into the static dataset table, generating autoincrement 

605 # dataset_id values. 

606 if datasetIdList: 

607 # reuse existing IDs 

608 rows = [dict(staticRow, id=datasetId) for datasetId in datasetIdList] 

609 self._db.insert(self._static.dataset, *rows) 

610 else: 

611 # use auto-incremented IDs 

612 datasetIdList = self._db.insert( 

613 self._static.dataset, *([staticRow] * len(dataIdList)), returnIds=True 

614 ) 

615 assert datasetIdList is not None 

616 # Update the summary tables for this collection in case this is the 

617 # first time this dataset type or these governor values will be 

618 # inserted there. 

619 self._summaries.update(run, self.datasetType, self._dataset_type_id, governorValues) 

620 # Combine the generated dataset_id values and data ID fields to 

621 # form rows to be inserted into the tags table. 

622 protoTagsRow = { 

623 "dataset_type_id": self._dataset_type_id, 

624 self._collections.getCollectionForeignKeyName(): run.key, 

625 } 

626 tagsRows = [ 

627 dict(protoTagsRow, dataset_id=dataset_id, **dataId.byName()) 

628 for dataId, dataset_id in zip(dataIdList, datasetIdList) 

629 ] 

630 # Insert those rows into the tags table. This is where we'll 

631 # get any unique constraint violations. 

632 self._db.insert(self._tags, *tagsRows) 

633 

634 for dataId, datasetId in zip(dataIdList, datasetIdList): 

635 yield DatasetRef( 

636 datasetType=self.datasetType, 

637 dataId=dataId, 

638 id=datasetId, 

639 run=run.name, 

640 ) 

641 

642 

643class ByDimensionsDatasetRecordStorageUUID(ByDimensionsDatasetRecordStorage): 

644 """Implementation of ByDimensionsDatasetRecordStorage which uses UUID for 

645 dataset IDs. 

646 """ 

647 

648 idMaker = DatasetIdFactory() 

649 """Factory for dataset IDs. In the future this factory may be shared with 

650 other classes (e.g. Registry).""" 

651 

652 def insert( 

653 self, 

654 run: RunRecord, 

655 dataIds: Iterable[DataCoordinate], 

656 idMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE, 

657 ) -> Iterator[DatasetRef]: 

658 # Docstring inherited from DatasetRecordStorage. 

659 

660 # Remember any governor dimension values we see. 

661 governorValues = GovernorDimensionRestriction.makeEmpty(self.datasetType.dimensions.universe) 

662 

663 # Iterate over data IDs, transforming a possibly-single-pass iterable 

664 # into a list. 

665 dataIdList = [] 

666 rows = [] 

667 for dataId in dataIds: 

668 dataIdList.append(dataId) 

669 rows.append( 

670 { 

671 "id": self.idMaker.makeDatasetId(run.name, self.datasetType, dataId, idMode), 

672 "dataset_type_id": self._dataset_type_id, 

673 self._runKeyColumn: run.key, 

674 } 

675 ) 

676 governorValues.update_extract(dataId) 

677 

678 with self._db.transaction(): 

679 # Insert into the static dataset table. 

680 self._db.insert(self._static.dataset, *rows) 

681 # Update the summary tables for this collection in case this is the 

682 # first time this dataset type or these governor values will be 

683 # inserted there. 

684 self._summaries.update(run, self.datasetType, self._dataset_type_id, governorValues) 

685 # Combine the generated dataset_id values and data ID fields to 

686 # form rows to be inserted into the tags table. 

687 protoTagsRow = { 

688 "dataset_type_id": self._dataset_type_id, 

689 self._collections.getCollectionForeignKeyName(): run.key, 

690 } 

691 tagsRows = [ 

692 dict(protoTagsRow, dataset_id=row["id"], **dataId.byName()) 

693 for dataId, row in zip(dataIdList, rows) 

694 ] 

695 # Insert those rows into the tags table. 

696 self._db.insert(self._tags, *tagsRows) 

697 

698 for dataId, row in zip(dataIdList, rows): 

699 yield DatasetRef( 

700 datasetType=self.datasetType, 

701 dataId=dataId, 

702 id=row["id"], 

703 run=run.name, 

704 ) 

705 

706 def import_( 

707 self, 

708 run: RunRecord, 

709 datasets: Iterable[DatasetRef], 

710 idGenerationMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE, 

711 reuseIds: bool = False, 

712 ) -> Iterator[DatasetRef]: 

713 # Docstring inherited from DatasetRecordStorage. 

714 

715 # Remember any governor dimension values we see. 

716 governorValues = GovernorDimensionRestriction.makeEmpty(self.datasetType.dimensions.universe) 

717 

718 # Iterate over data IDs, transforming a possibly-single-pass iterable 

719 # into a list. 

720 dataIds = {} 

721 for dataset in datasets: 

722 # Ignore unknown ID types, normally all IDs have the same type but 

723 # this code supports mixed types or missing IDs. 

724 datasetId = dataset.id if isinstance(dataset.id, uuid.UUID) else None 

725 if datasetId is None: 

726 datasetId = self.idMaker.makeDatasetId( 

727 run.name, self.datasetType, dataset.dataId, idGenerationMode 

728 ) 

729 dataIds[datasetId] = dataset.dataId 

730 governorValues.update_extract(dataset.dataId) 

731 

732 with self._db.session() as session: 

733 

734 # insert all new rows into a temporary table 

735 tableSpec = makeTagTableSpec( 

736 self.datasetType, type(self._collections), ddl.GUID, constraints=False 

737 ) 

738 tmp_tags = session.makeTemporaryTable(tableSpec) 

739 

740 collFkName = self._collections.getCollectionForeignKeyName() 

741 protoTagsRow = { 

742 "dataset_type_id": self._dataset_type_id, 

743 collFkName: run.key, 

744 } 

745 tmpRows = [ 

746 dict(protoTagsRow, dataset_id=dataset_id, **dataId.byName()) 

747 for dataset_id, dataId in dataIds.items() 

748 ] 

749 

750 with self._db.transaction(): 

751 

752 # store all incoming data in a temporary table 

753 self._db.insert(tmp_tags, *tmpRows) 

754 

755 # There are some checks that we want to make for consistency 

756 # of the new datasets with existing ones. 

757 self._validateImport(tmp_tags, run) 

758 

759 # Before we merge temporary table into dataset/tags we need to 

760 # drop datasets which are already there (and do not conflict). 

761 self._db.deleteWhere( 

762 tmp_tags, 

763 tmp_tags.columns.dataset_id.in_(sqlalchemy.sql.select(self._static.dataset.columns.id)), 

764 ) 

765 

766 # Copy it into dataset table, need to re-label some columns. 

767 self._db.insert( 

768 self._static.dataset, 

769 select=sqlalchemy.sql.select( 

770 tmp_tags.columns.dataset_id.label("id"), 

771 tmp_tags.columns.dataset_type_id, 

772 tmp_tags.columns[collFkName].label(self._runKeyColumn), 

773 ), 

774 ) 

775 

776 # Update the summary tables for this collection in case this 

777 # is the first time this dataset type or these governor values 

778 # will be inserted there. 

779 self._summaries.update(run, self.datasetType, self._dataset_type_id, governorValues) 

780 

781 # Copy it into tags table. 

782 self._db.insert(self._tags, select=tmp_tags.select()) 

783 

784 # Return refs in the same order as in the input list. 

785 for dataset_id, dataId in dataIds.items(): 

786 yield DatasetRef( 

787 datasetType=self.datasetType, 

788 id=dataset_id, 

789 dataId=dataId, 

790 run=run.name, 

791 ) 

792 

793 def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) -> None: 

794 """Validate imported refs against existing datasets. 

795 

796 Parameters 

797 ---------- 

798 tmp_tags : `sqlalchemy.schema.Table` 

799 Temporary table with new datasets and the same schema as tags 

800 table. 

801 run : `RunRecord` 

802 The record object describing the `~CollectionType.RUN` collection. 

803 

804 Raises 

805 ------ 

806 ConflictingDefinitionError 

807 Raise if new datasets conflict with existing ones. 

808 """ 

809 dataset = self._static.dataset 

810 tags = self._tags 

811 collFkName = self._collections.getCollectionForeignKeyName() 

812 

813 # Check that existing datasets have the same dataset type and 

814 # run. 

815 query = ( 

816 sqlalchemy.sql.select( 

817 dataset.columns.id.label("dataset_id"), 

818 dataset.columns.dataset_type_id.label("dataset_type_id"), 

819 tmp_tags.columns.dataset_type_id.label("new dataset_type_id"), 

820 dataset.columns[self._runKeyColumn].label("run"), 

821 tmp_tags.columns[collFkName].label("new run"), 

822 ) 

823 .select_from(dataset.join(tmp_tags, dataset.columns.id == tmp_tags.columns.dataset_id)) 

824 .where( 

825 sqlalchemy.sql.or_( 

826 dataset.columns.dataset_type_id != tmp_tags.columns.dataset_type_id, 

827 dataset.columns[self._runKeyColumn] != tmp_tags.columns[collFkName], 

828 ) 

829 ) 

830 ) 

831 result = self._db.query(query) 

832 if (row := result.first()) is not None: 

833 # Only include the first one in the exception message 

834 raise ConflictingDefinitionError( 

835 f"Existing dataset type or run do not match new dataset: {row._asdict()}" 

836 ) 

837 

838 # Check that matching dataset in tags table has the same DataId. 

839 query = ( 

840 sqlalchemy.sql.select( 

841 tags.columns.dataset_id, 

842 tags.columns.dataset_type_id.label("type_id"), 

843 tmp_tags.columns.dataset_type_id.label("new type_id"), 

844 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names], 

845 *[ 

846 tmp_tags.columns[dim].label(f"new {dim}") 

847 for dim in self.datasetType.dimensions.required.names 

848 ], 

849 ) 

850 .select_from(tags.join(tmp_tags, tags.columns.dataset_id == tmp_tags.columns.dataset_id)) 

851 .where( 

852 sqlalchemy.sql.or_( 

853 tags.columns.dataset_type_id != tmp_tags.columns.dataset_type_id, 

854 *[ 

855 tags.columns[dim] != tmp_tags.columns[dim] 

856 for dim in self.datasetType.dimensions.required.names 

857 ], 

858 ) 

859 ) 

860 ) 

861 result = self._db.query(query) 

862 if (row := result.first()) is not None: 

863 # Only include the first one in the exception message 

864 raise ConflictingDefinitionError( 

865 f"Existing dataset type or dataId do not match new dataset: {row._asdict()}" 

866 ) 

867 

868 # Check that matching run+dataId have the same dataset ID. 

869 query = ( 

870 sqlalchemy.sql.select( 

871 tags.columns.dataset_type_id.label("dataset_type_id"), 

872 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names], 

873 tags.columns.dataset_id, 

874 tmp_tags.columns.dataset_id.label("new dataset_id"), 

875 tags.columns[collFkName], 

876 tmp_tags.columns[collFkName].label(f"new {collFkName}"), 

877 ) 

878 .select_from( 

879 tags.join( 

880 tmp_tags, 

881 sqlalchemy.sql.and_( 

882 tags.columns.dataset_type_id == tmp_tags.columns.dataset_type_id, 

883 tags.columns[collFkName] == tmp_tags.columns[collFkName], 

884 *[ 

885 tags.columns[dim] == tmp_tags.columns[dim] 

886 for dim in self.datasetType.dimensions.required.names 

887 ], 

888 ), 

889 ) 

890 ) 

891 .where(tags.columns.dataset_id != tmp_tags.columns.dataset_id) 

892 ) 

893 result = self._db.query(query) 

894 if (row := result.first()) is not None: 

895 # only include the first one in the exception message 

896 raise ConflictingDefinitionError( 

897 f"Existing dataset type and dataId does not match new dataset: {row._asdict()}" 

898 )