Coverage for python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py: 89%

298 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-09-22 02:04 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22 

23from __future__ import annotations 

24 

25__all__ = ("ByDimensionsDatasetRecordStorage",) 

26 

27import uuid 

28from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set 

29 

30import sqlalchemy 

31from lsst.daf.butler import ( 

32 CollectionType, 

33 DataCoordinate, 

34 DataCoordinateSet, 

35 DatasetId, 

36 DatasetRef, 

37 DatasetType, 

38 SimpleQuery, 

39 Timespan, 

40 ddl, 

41) 

42from lsst.daf.butler.registry import ( 

43 CollectionSummary, 

44 CollectionTypeError, 

45 ConflictingDefinitionError, 

46 UnsupportedIdGeneratorError, 

47) 

48from lsst.daf.butler.registry.interfaces import DatasetIdFactory, DatasetIdGenEnum, DatasetRecordStorage 

49 

50from .tables import makeTagTableSpec 

51 

52if TYPE_CHECKING: 52 ↛ 53line 52 didn't jump to line 53, because the condition on line 52 was never true

53 from ...interfaces import CollectionManager, CollectionRecord, Database, RunRecord 

54 from .summaries import CollectionSummaryManager 

55 from .tables import StaticDatasetTablesTuple 

56 

57 

58class ByDimensionsDatasetRecordStorage(DatasetRecordStorage): 

59 """Dataset record storage implementation paired with 

60 `ByDimensionsDatasetRecordStorageManager`; see that class for more 

61 information. 

62 

63 Instances of this class should never be constructed directly; use 

64 `DatasetRecordStorageManager.register` instead. 

65 """ 

66 

67 def __init__( 

68 self, 

69 *, 

70 datasetType: DatasetType, 

71 db: Database, 

72 dataset_type_id: int, 

73 collections: CollectionManager, 

74 static: StaticDatasetTablesTuple, 

75 summaries: CollectionSummaryManager, 

76 tags: sqlalchemy.schema.Table, 

77 calibs: Optional[sqlalchemy.schema.Table], 

78 ): 

79 super().__init__(datasetType=datasetType) 

80 self._dataset_type_id = dataset_type_id 

81 self._db = db 

82 self._collections = collections 

83 self._static = static 

84 self._summaries = summaries 

85 self._tags = tags 

86 self._calibs = calibs 

87 self._runKeyColumn = collections.getRunForeignKeyName() 

88 

89 def find( 

90 self, collection: CollectionRecord, dataId: DataCoordinate, timespan: Optional[Timespan] = None 

91 ) -> Optional[DatasetRef]: 

92 # Docstring inherited from DatasetRecordStorage. 

93 assert dataId.graph == self.datasetType.dimensions 

94 if collection.type is CollectionType.CALIBRATION and timespan is None: 94 ↛ 95line 94 didn't jump to line 95, because the condition on line 94 was never true

95 raise TypeError( 

96 f"Cannot search for dataset in CALIBRATION collection {collection.name} " 

97 f"without an input timespan." 

98 ) 

99 sql = self.select( 

100 collection, dataId=dataId, id=SimpleQuery.Select, run=SimpleQuery.Select, timespan=timespan 

101 ) 

102 results = self._db.query(sql) 

103 row = results.fetchone() 

104 if row is None: 

105 return None 

106 if collection.type is CollectionType.CALIBRATION: 

107 # For temporal calibration lookups (only!) our invariants do not 

108 # guarantee that the number of result rows is <= 1. 

109 # They would if `select` constrained the given timespan to be 

110 # _contained_ by the validity range in the self._calibs table, 

111 # instead of simply _overlapping_ it, because we do guarantee that 

112 # the validity ranges are disjoint for a particular dataset type, 

113 # collection, and data ID. But using an overlap test and a check 

114 # for multiple result rows here allows us to provide a more useful 

115 # diagnostic, as well as allowing `select` to support more general 

116 # queries where multiple results are not an error. 

117 if results.fetchone() is not None: 

118 raise RuntimeError( 

119 f"Multiple matches found for calibration lookup in {collection.name} for " 

120 f"{self.datasetType.name} with {dataId} overlapping {timespan}. " 

121 ) 

122 return DatasetRef( 

123 datasetType=self.datasetType, 

124 dataId=dataId, 

125 id=row.id, 

126 run=self._collections[row._mapping[self._runKeyColumn]].name, 

127 ) 

128 

129 def delete(self, datasets: Iterable[DatasetRef]) -> None: 

130 # Docstring inherited from DatasetRecordStorage. 

131 # Only delete from common dataset table; ON DELETE foreign key clauses 

132 # will handle the rest. 

133 self._db.delete( 

134 self._static.dataset, 

135 ["id"], 

136 *[{"id": dataset.getCheckedId()} for dataset in datasets], 

137 ) 

138 

139 def associate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None: 

140 # Docstring inherited from DatasetRecordStorage. 

141 if collection.type is not CollectionType.TAGGED: 141 ↛ 142line 141 didn't jump to line 142, because the condition on line 141 was never true

142 raise TypeError( 

143 f"Cannot associate into collection '{collection.name}' " 

144 f"of type {collection.type.name}; must be TAGGED." 

145 ) 

146 protoRow = { 

147 self._collections.getCollectionForeignKeyName(): collection.key, 

148 "dataset_type_id": self._dataset_type_id, 

149 } 

150 rows = [] 

151 summary = CollectionSummary() 

152 for dataset in summary.add_datasets_generator(datasets): 

153 row = dict(protoRow, dataset_id=dataset.getCheckedId()) 

154 for dimension, value in dataset.dataId.items(): 

155 row[dimension.name] = value 

156 rows.append(row) 

157 # Update the summary tables for this collection in case this is the 

158 # first time this dataset type or these governor values will be 

159 # inserted there. 

160 self._summaries.update(collection, [self._dataset_type_id], summary) 

161 # Update the tag table itself. 

162 self._db.replace(self._tags, *rows) 

163 

164 def disassociate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None: 

165 # Docstring inherited from DatasetRecordStorage. 

166 if collection.type is not CollectionType.TAGGED: 166 ↛ 167line 166 didn't jump to line 167, because the condition on line 166 was never true

167 raise TypeError( 

168 f"Cannot disassociate from collection '{collection.name}' " 

169 f"of type {collection.type.name}; must be TAGGED." 

170 ) 

171 rows = [ 

172 { 

173 "dataset_id": dataset.getCheckedId(), 

174 self._collections.getCollectionForeignKeyName(): collection.key, 

175 } 

176 for dataset in datasets 

177 ] 

178 self._db.delete(self._tags, ["dataset_id", self._collections.getCollectionForeignKeyName()], *rows) 

179 

180 def _buildCalibOverlapQuery( 

181 self, collection: CollectionRecord, dataIds: Optional[DataCoordinateSet], timespan: Timespan 

182 ) -> SimpleQuery: 

183 assert self._calibs is not None 

184 # Start by building a SELECT query for any rows that would overlap 

185 # this one. 

186 query = SimpleQuery() 

187 query.join(self._calibs) 

188 # Add a WHERE clause matching the dataset type and collection. 

189 query.where.append(self._calibs.columns.dataset_type_id == self._dataset_type_id) 

190 query.where.append( 

191 self._calibs.columns[self._collections.getCollectionForeignKeyName()] == collection.key 

192 ) 

193 # Add a WHERE clause matching any of the given data IDs. 

194 if dataIds is not None: 

195 dataIds.constrain( 

196 query, 

197 lambda name: self._calibs.columns[name], # type: ignore 

198 ) 

199 # Add WHERE clause for timespan overlaps. 

200 TimespanReprClass = self._db.getTimespanRepresentation() 

201 query.where.append( 

202 TimespanReprClass.from_columns(self._calibs.columns).overlaps( 

203 TimespanReprClass.fromLiteral(timespan) 

204 ) 

205 ) 

206 return query 

207 

208 def certify( 

209 self, collection: CollectionRecord, datasets: Iterable[DatasetRef], timespan: Timespan 

210 ) -> None: 

211 # Docstring inherited from DatasetRecordStorage. 

212 if self._calibs is None: 212 ↛ 213line 212 didn't jump to line 213, because the condition on line 212 was never true

213 raise CollectionTypeError( 

214 f"Cannot certify datasets of type {self.datasetType.name}, for which " 

215 f"DatasetType.isCalibration() is False." 

216 ) 

217 if collection.type is not CollectionType.CALIBRATION: 217 ↛ 218line 217 didn't jump to line 218, because the condition on line 217 was never true

218 raise CollectionTypeError( 

219 f"Cannot certify into collection '{collection.name}' " 

220 f"of type {collection.type.name}; must be CALIBRATION." 

221 ) 

222 TimespanReprClass = self._db.getTimespanRepresentation() 

223 protoRow = { 

224 self._collections.getCollectionForeignKeyName(): collection.key, 

225 "dataset_type_id": self._dataset_type_id, 

226 } 

227 rows = [] 

228 dataIds: Optional[Set[DataCoordinate]] = ( 

229 set() if not TimespanReprClass.hasExclusionConstraint() else None 

230 ) 

231 summary = CollectionSummary() 

232 for dataset in summary.add_datasets_generator(datasets): 

233 row = dict(protoRow, dataset_id=dataset.getCheckedId()) 

234 for dimension, value in dataset.dataId.items(): 

235 row[dimension.name] = value 

236 TimespanReprClass.update(timespan, result=row) 

237 rows.append(row) 

238 if dataIds is not None: 238 ↛ 232line 238 didn't jump to line 232, because the condition on line 238 was never false

239 dataIds.add(dataset.dataId) 

240 # Update the summary tables for this collection in case this is the 

241 # first time this dataset type or these governor values will be 

242 # inserted there. 

243 self._summaries.update(collection, [self._dataset_type_id], summary) 

244 # Update the association table itself. 

245 if TimespanReprClass.hasExclusionConstraint(): 245 ↛ 248line 245 didn't jump to line 248, because the condition on line 245 was never true

246 # Rely on database constraint to enforce invariants; we just 

247 # reraise the exception for consistency across DB engines. 

248 try: 

249 self._db.insert(self._calibs, *rows) 

250 except sqlalchemy.exc.IntegrityError as err: 

251 raise ConflictingDefinitionError( 

252 f"Validity range conflict certifying datasets of type {self.datasetType.name} " 

253 f"into {collection.name} for range [{timespan.begin}, {timespan.end})." 

254 ) from err 

255 else: 

256 # Have to implement exclusion constraint ourselves. 

257 # Start by building a SELECT query for any rows that would overlap 

258 # this one. 

259 query = self._buildCalibOverlapQuery( 

260 collection, 

261 DataCoordinateSet(dataIds, graph=self.datasetType.dimensions), # type: ignore 

262 timespan, 

263 ) 

264 query.columns.append(sqlalchemy.sql.func.count()) 

265 sql = query.combine() 

266 # Acquire a table lock to ensure there are no concurrent writes 

267 # could invalidate our checking before we finish the inserts. We 

268 # use a SAVEPOINT in case there is an outer transaction that a 

269 # failure here should not roll back. 

270 with self._db.transaction(lock=[self._calibs], savepoint=True): 

271 # Run the check SELECT query. 

272 conflicting = self._db.query(sql).scalar() 

273 if conflicting > 0: 

274 raise ConflictingDefinitionError( 

275 f"{conflicting} validity range conflicts certifying datasets of type " 

276 f"{self.datasetType.name} into {collection.name} for range " 

277 f"[{timespan.begin}, {timespan.end})." 

278 ) 

279 # Proceed with the insert. 

280 self._db.insert(self._calibs, *rows) 

281 

282 def decertify( 

283 self, 

284 collection: CollectionRecord, 

285 timespan: Timespan, 

286 *, 

287 dataIds: Optional[Iterable[DataCoordinate]] = None, 

288 ) -> None: 

289 # Docstring inherited from DatasetRecordStorage. 

290 if self._calibs is None: 290 ↛ 291line 290 didn't jump to line 291, because the condition on line 290 was never true

291 raise CollectionTypeError( 

292 f"Cannot decertify datasets of type {self.datasetType.name}, for which " 

293 f"DatasetType.isCalibration() is False." 

294 ) 

295 if collection.type is not CollectionType.CALIBRATION: 295 ↛ 296line 295 didn't jump to line 296, because the condition on line 295 was never true

296 raise CollectionTypeError( 

297 f"Cannot decertify from collection '{collection.name}' " 

298 f"of type {collection.type.name}; must be CALIBRATION." 

299 ) 

300 TimespanReprClass = self._db.getTimespanRepresentation() 

301 # Construct a SELECT query to find all rows that overlap our inputs. 

302 dataIdSet: Optional[DataCoordinateSet] 

303 if dataIds is not None: 

304 dataIdSet = DataCoordinateSet(set(dataIds), graph=self.datasetType.dimensions) 

305 else: 

306 dataIdSet = None 

307 query = self._buildCalibOverlapQuery(collection, dataIdSet, timespan) 

308 query.columns.extend(self._calibs.columns) 

309 sql = query.combine() 

310 # Set up collections to populate with the rows we'll want to modify. 

311 # The insert rows will have the same values for collection and 

312 # dataset type. 

313 protoInsertRow = { 

314 self._collections.getCollectionForeignKeyName(): collection.key, 

315 "dataset_type_id": self._dataset_type_id, 

316 } 

317 rowsToDelete = [] 

318 rowsToInsert = [] 

319 # Acquire a table lock to ensure there are no concurrent writes 

320 # between the SELECT and the DELETE and INSERT queries based on it. 

321 with self._db.transaction(lock=[self._calibs], savepoint=True): 

322 for row in self._db.query(sql).mappings(): 

323 rowsToDelete.append({"id": row["id"]}) 

324 # Construct the insert row(s) by copying the prototype row, 

325 # then adding the dimension column values, then adding what's 

326 # left of the timespan from that row after we subtract the 

327 # given timespan. 

328 newInsertRow = protoInsertRow.copy() 

329 newInsertRow["dataset_id"] = row["dataset_id"] 

330 for name in self.datasetType.dimensions.required.names: 

331 newInsertRow[name] = row[name] 

332 rowTimespan = TimespanReprClass.extract(row) 

333 assert rowTimespan is not None, "Field should have a NOT NULL constraint." 

334 for diffTimespan in rowTimespan.difference(timespan): 

335 rowsToInsert.append(TimespanReprClass.update(diffTimespan, result=newInsertRow.copy())) 

336 # Run the DELETE and INSERT queries. 

337 self._db.delete(self._calibs, ["id"], *rowsToDelete) 

338 self._db.insert(self._calibs, *rowsToInsert) 

339 

340 def select( 

341 self, 

342 *collections: CollectionRecord, 

343 dataId: SimpleQuery.Select.Or[DataCoordinate] = SimpleQuery.Select, 

344 id: SimpleQuery.Select.Or[Optional[int]] = SimpleQuery.Select, 

345 run: SimpleQuery.Select.Or[None] = SimpleQuery.Select, 

346 timespan: SimpleQuery.Select.Or[Optional[Timespan]] = SimpleQuery.Select, 

347 ingestDate: SimpleQuery.Select.Or[Optional[Timespan]] = None, 

348 rank: SimpleQuery.Select.Or[None] = None, 

349 ) -> sqlalchemy.sql.Selectable: 

350 # Docstring inherited from DatasetRecordStorage. 

351 collection_types = {collection.type for collection in collections} 

352 assert CollectionType.CHAINED not in collection_types, "CHAINED collections must be flattened." 

353 TimespanReprClass = self._db.getTimespanRepresentation() 

354 # 

355 # There are two kinds of table in play here: 

356 # 

357 # - the static dataset table (with the dataset ID, dataset type ID, 

358 # run ID/name, and ingest date); 

359 # 

360 # - the dynamic tags/calibs table (with the dataset ID, dataset type 

361 # type ID, collection ID/name, data ID, and possibly validity 

362 # range). 

363 # 

364 # That means that we might want to return a query against either table 

365 # or a JOIN of both, depending on which quantities the caller wants. 

366 # But this method is documented/typed such that ``dataId`` is never 

367 # `None` - i.e. we always constrain or retreive the data ID. That 

368 # means we'll always include the tags/calibs table and join in the 

369 # static dataset table only if we need things from it that we can't get 

370 # from the tags/calibs table. 

371 # 

372 # Note that it's important that we include a WHERE constraint on both 

373 # tables for any column (e.g. dataset_type_id) that is in both when 

374 # it's given explicitly; not doing can prevent the query planner from 

375 # using very important indexes. At present, we don't include those 

376 # redundant columns in the JOIN ON expression, however, because the 

377 # FOREIGN KEY (and its index) are defined only on dataset_id. 

378 # 

379 # We'll start by accumulating kwargs to pass to SimpleQuery.join when 

380 # we bring in the tags/calibs table. We get the data ID or constrain 

381 # it in the tags/calibs table(s), but that's multiple columns, not one, 

382 # so we need to transform the one Select.Or argument into a dictionary 

383 # of them. 

384 kwargs: Dict[str, Any] 

385 if dataId is SimpleQuery.Select: 

386 kwargs = {dim.name: SimpleQuery.Select for dim in self.datasetType.dimensions.required} 

387 else: 

388 kwargs = dict(dataId.byName()) 

389 # We always constrain (never retrieve) the dataset type in at least the 

390 # tags/calibs table. 

391 kwargs["dataset_type_id"] = self._dataset_type_id 

392 # Join in the tags and/or calibs tables, turning those 'kwargs' entries 

393 # into WHERE constraints or SELECT columns as appropriate. 

394 if collection_types != {CollectionType.CALIBRATION}: 

395 # We'll need a subquery for the tags table if any of the given 

396 # collections are not a CALIBRATION collection. This intentionally 

397 # also fires when the list of collections is empty as a way to 

398 # create a dummy subquery that we know will fail. 

399 tags_query = SimpleQuery() 

400 tags_query.join(self._tags, **kwargs) 

401 # If the timespan is requested, simulate a potentially compound 

402 # column whose values are the maximum and minimum timespan 

403 # bounds. 

404 # If the timespan is constrained, ignore the constraint, since 

405 # it'd be guaranteed to evaluate to True. 

406 if timespan is SimpleQuery.Select: 

407 tags_query.columns.extend(TimespanReprClass.fromLiteral(Timespan(None, None)).flatten()) 

408 self._finish_single_select( 

409 tags_query, 

410 self._tags, 

411 collections, 

412 id=id, 

413 run=run, 

414 ingestDate=ingestDate, 

415 rank=rank, 

416 ) 

417 else: 

418 tags_query = None 

419 if CollectionType.CALIBRATION in collection_types: 

420 # If at least one collection is a CALIBRATION collection, we'll 

421 # need a subquery for the calibs table, and could include the 

422 # timespan as a result or constraint. 

423 calibs_query = SimpleQuery() 

424 assert ( 

425 self._calibs is not None 

426 ), "DatasetTypes with isCalibration() == False can never be found in a CALIBRATION collection." 

427 calibs_query.join(self._calibs, **kwargs) 

428 # Add the timespan column(s) to the result columns, or constrain 

429 # the timespan via an overlap condition. 

430 if timespan is SimpleQuery.Select: 

431 calibs_query.columns.extend(TimespanReprClass.from_columns(self._calibs.columns).flatten()) 

432 elif timespan is not None: 

433 calibs_query.where.append( 

434 TimespanReprClass.from_columns(self._calibs.columns).overlaps( 

435 TimespanReprClass.fromLiteral(timespan) 

436 ) 

437 ) 

438 self._finish_single_select( 

439 calibs_query, 

440 self._calibs, 

441 collections, 

442 id=id, 

443 run=run, 

444 ingestDate=ingestDate, 

445 rank=rank, 

446 ) 

447 else: 

448 calibs_query = None 

449 if calibs_query is not None: 

450 if tags_query is not None: 

451 return tags_query.combine().union(calibs_query.combine()) 

452 else: 

453 return calibs_query.combine() 

454 else: 

455 assert tags_query is not None, "Earlier logic should guaranteed at least one is not None." 

456 return tags_query.combine() 

457 

458 def _finish_single_select( 

459 self, 

460 query: SimpleQuery, 

461 table: sqlalchemy.schema.Table, 

462 collections: Sequence[CollectionRecord], 

463 id: SimpleQuery.Select.Or[Optional[int]], 

464 run: SimpleQuery.Select.Or[None], 

465 ingestDate: SimpleQuery.Select.Or[Optional[Timespan]], 

466 rank: SimpleQuery.Select.Or[None], 

467 ) -> None: 

468 dataset_id_col = table.columns.dataset_id 

469 collection_col = table.columns[self._collections.getCollectionForeignKeyName()] 

470 # We always constrain (never retrieve) the collection(s) in the 

471 # tags/calibs table. 

472 if len(collections) == 1: 

473 query.where.append(collection_col == collections[0].key) 

474 elif len(collections) == 0: 

475 # We support the case where there are no collections as a way to 

476 # generate a valid SQL query that can't yield results. This should 

477 # never get executed, but lots of downstream code will still try 

478 # to access the SQLAlchemy objects representing the columns in the 

479 # subquery. That's not ideal, but it'd take a lot of refactoring 

480 # to fix it (DM-31725). 

481 query.where.append(sqlalchemy.sql.literal(False)) 

482 else: 

483 query.where.append(collection_col.in_([collection.key for collection in collections])) 

484 # Add rank if requested as a CASE-based calculation the collection 

485 # column. 

486 if rank is not None: 

487 assert rank is SimpleQuery.Select, "Cannot constraint rank, only select it." 

488 query.columns.append( 

489 sqlalchemy.sql.case( 

490 {record.key: n for n, record in enumerate(collections)}, 

491 value=collection_col, 

492 ).label("rank") 

493 ) 

494 # We can always get the dataset_id from the tags/calibs table or 

495 # constrain it there. Can't use kwargs for that because we need to 

496 # alias it to 'id'. 

497 if id is SimpleQuery.Select: 

498 query.columns.append(dataset_id_col.label("id")) 

499 elif id is not None: 499 ↛ 500line 499 didn't jump to line 500, because the condition on line 499 was never true

500 query.where.append(dataset_id_col == id) 

501 # It's possible we now have everything we need, from just the 

502 # tags/calibs table. The things we might need to get from the static 

503 # dataset table are the run key and the ingest date. 

504 need_static_table = False 

505 static_kwargs: Dict[str, Any] = {} 

506 if run is not None: 

507 assert run is SimpleQuery.Select, "To constrain the run name, pass a RunRecord as a collection." 

508 if len(collections) == 1 and collections[0].type is CollectionType.RUN: 

509 # If we are searching exactly one RUN collection, we 

510 # know that if we find the dataset in that collection, 

511 # then that's the datasets's run; we don't need to 

512 # query for it. 

513 query.columns.append(sqlalchemy.sql.literal(collections[0].key).label(self._runKeyColumn)) 

514 else: 

515 static_kwargs[self._runKeyColumn] = SimpleQuery.Select 

516 need_static_table = True 

517 # Ingest date can only come from the static table. 

518 if ingestDate is not None: 

519 need_static_table = True 

520 if ingestDate is SimpleQuery.Select: 520 ↛ 523line 520 didn't jump to line 523, because the condition on line 520 was never false

521 static_kwargs["ingest_date"] = SimpleQuery.Select 

522 else: 

523 assert isinstance(ingestDate, Timespan) 

524 # Timespan is astropy Time (usually in TAI) and ingest_date is 

525 # TIMESTAMP, convert values to Python datetime for sqlalchemy. 

526 if ingestDate.isEmpty(): 

527 raise RuntimeError("Empty timespan constraint provided for ingest_date.") 

528 if ingestDate.begin is not None: 

529 begin = ingestDate.begin.utc.datetime # type: ignore 

530 query.where.append(self._static.dataset.columns.ingest_date >= begin) 

531 if ingestDate.end is not None: 

532 end = ingestDate.end.utc.datetime # type: ignore 

533 query.where.append(self._static.dataset.columns.ingest_date < end) 

534 # If we need the static table, join it in via dataset_id and 

535 # dataset_type_id 

536 if need_static_table: 

537 query.join( 

538 self._static.dataset, 

539 onclause=(dataset_id_col == self._static.dataset.columns.id), 

540 **static_kwargs, 

541 ) 

542 # Also constrain dataset_type_id in static table in case that helps 

543 # generate a better plan. 

544 # We could also include this in the JOIN ON clause, but my guess is 

545 # that that's a good idea IFF it's in the foreign key, and right 

546 # now it isn't. 

547 query.where.append(self._static.dataset.columns.dataset_type_id == self._dataset_type_id) 

548 

549 def getDataId(self, id: DatasetId) -> DataCoordinate: 

550 """Return DataId for a dataset. 

551 

552 Parameters 

553 ---------- 

554 id : `DatasetId` 

555 Unique dataset identifier. 

556 

557 Returns 

558 ------- 

559 dataId : `DataCoordinate` 

560 DataId for the dataset. 

561 """ 

562 # This query could return multiple rows (one for each tagged collection 

563 # the dataset is in, plus one for its run collection), and we don't 

564 # care which of those we get. 

565 sql = ( 

566 self._tags.select() 

567 .where( 

568 sqlalchemy.sql.and_( 

569 self._tags.columns.dataset_id == id, 

570 self._tags.columns.dataset_type_id == self._dataset_type_id, 

571 ) 

572 ) 

573 .limit(1) 

574 ) 

575 row = self._db.query(sql).mappings().fetchone() 

576 assert row is not None, "Should be guaranteed by caller and foreign key constraints." 

577 return DataCoordinate.standardize( 

578 {dimension.name: row[dimension.name] for dimension in self.datasetType.dimensions.required}, 

579 graph=self.datasetType.dimensions, 

580 ) 

581 

582 

583class ByDimensionsDatasetRecordStorageInt(ByDimensionsDatasetRecordStorage): 

584 """Implementation of ByDimensionsDatasetRecordStorage which uses integer 

585 auto-incremented column for dataset IDs. 

586 """ 

587 

588 def insert( 

589 self, 

590 run: RunRecord, 

591 dataIds: Iterable[DataCoordinate], 

592 idMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE, 

593 ) -> Iterator[DatasetRef]: 

594 # Docstring inherited from DatasetRecordStorage. 

595 

596 # We only support UNIQUE mode for integer dataset IDs 

597 if idMode != DatasetIdGenEnum.UNIQUE: 597 ↛ 598line 597 didn't jump to line 598, because the condition on line 597 was never true

598 raise UnsupportedIdGeneratorError("Only UNIQUE mode can be used with integer dataset IDs.") 

599 

600 # Transform a possibly-single-pass iterable into a list. 

601 dataIdList = list(dataIds) 

602 yield from self._insert(run, dataIdList) 

603 

604 def import_( 

605 self, 

606 run: RunRecord, 

607 datasets: Iterable[DatasetRef], 

608 idGenerationMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE, 

609 reuseIds: bool = False, 

610 ) -> Iterator[DatasetRef]: 

611 # Docstring inherited from DatasetRecordStorage. 

612 

613 # We only support UNIQUE mode for integer dataset IDs 

614 if idGenerationMode != DatasetIdGenEnum.UNIQUE: 614 ↛ 615line 614 didn't jump to line 615, because the condition on line 614 was never true

615 raise UnsupportedIdGeneratorError("Only UNIQUE mode can be used with integer dataset IDs.") 

616 

617 # Make a list of dataIds and optionally dataset IDs. 

618 dataIdList: List[DataCoordinate] = [] 

619 datasetIdList: List[int] = [] 

620 for dataset in datasets: 

621 dataIdList.append(dataset.dataId) 

622 

623 # We only accept integer dataset IDs, but also allow None. 

624 datasetId = dataset.id 

625 if datasetId is None: 625 ↛ 627line 625 didn't jump to line 627, because the condition on line 625 was never true

626 # if reuseIds is set then all IDs must be known 

627 if reuseIds: 

628 raise TypeError("All dataset IDs must be known if `reuseIds` is set") 

629 elif isinstance(datasetId, int): 629 ↛ 633line 629 didn't jump to line 633, because the condition on line 629 was never false

630 if reuseIds: 

631 datasetIdList.append(datasetId) 

632 else: 

633 raise TypeError(f"Unsupported type of dataset ID: {type(datasetId)}") 

634 

635 yield from self._insert(run, dataIdList, datasetIdList) 

636 

637 def _insert( 

638 self, run: RunRecord, dataIdList: List[DataCoordinate], datasetIdList: Optional[List[int]] = None 

639 ) -> Iterator[DatasetRef]: 

640 """Common part of implementation of `insert` and `import_` methods.""" 

641 

642 # Remember any governor dimension values we see. 

643 summary = CollectionSummary() 

644 summary.add_data_ids(self.datasetType, dataIdList) 

645 

646 staticRow = { 

647 "dataset_type_id": self._dataset_type_id, 

648 self._runKeyColumn: run.key, 

649 } 

650 with self._db.transaction(): 

651 # Insert into the static dataset table, generating autoincrement 

652 # dataset_id values. 

653 if datasetIdList: 

654 # reuse existing IDs 

655 rows = [dict(staticRow, id=datasetId) for datasetId in datasetIdList] 

656 self._db.insert(self._static.dataset, *rows) 

657 else: 

658 # use auto-incremented IDs 

659 datasetIdList = self._db.insert( 

660 self._static.dataset, *([staticRow] * len(dataIdList)), returnIds=True 

661 ) 

662 assert datasetIdList is not None 

663 # Update the summary tables for this collection in case this is the 

664 # first time this dataset type or these governor values will be 

665 # inserted there. 

666 self._summaries.update(run, [self._dataset_type_id], summary) 

667 # Combine the generated dataset_id values and data ID fields to 

668 # form rows to be inserted into the tags table. 

669 protoTagsRow = { 

670 "dataset_type_id": self._dataset_type_id, 

671 self._collections.getCollectionForeignKeyName(): run.key, 

672 } 

673 tagsRows = [ 

674 dict(protoTagsRow, dataset_id=dataset_id, **dataId.byName()) 

675 for dataId, dataset_id in zip(dataIdList, datasetIdList) 

676 ] 

677 # Insert those rows into the tags table. This is where we'll 

678 # get any unique constraint violations. 

679 self._db.insert(self._tags, *tagsRows) 

680 

681 for dataId, datasetId in zip(dataIdList, datasetIdList): 

682 yield DatasetRef( 

683 datasetType=self.datasetType, 

684 dataId=dataId, 

685 id=datasetId, 

686 run=run.name, 

687 ) 

688 

689 

690class ByDimensionsDatasetRecordStorageUUID(ByDimensionsDatasetRecordStorage): 

691 """Implementation of ByDimensionsDatasetRecordStorage which uses UUID for 

692 dataset IDs. 

693 """ 

694 

695 idMaker = DatasetIdFactory() 

696 """Factory for dataset IDs. In the future this factory may be shared with 

697 other classes (e.g. Registry).""" 

698 

699 def insert( 

700 self, 

701 run: RunRecord, 

702 dataIds: Iterable[DataCoordinate], 

703 idMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE, 

704 ) -> Iterator[DatasetRef]: 

705 # Docstring inherited from DatasetRecordStorage. 

706 

707 # Iterate over data IDs, transforming a possibly-single-pass iterable 

708 # into a list. 

709 dataIdList = [] 

710 rows = [] 

711 summary = CollectionSummary() 

712 for dataId in summary.add_data_ids_generator(self.datasetType, dataIds): 

713 dataIdList.append(dataId) 

714 rows.append( 

715 { 

716 "id": self.idMaker.makeDatasetId(run.name, self.datasetType, dataId, idMode), 

717 "dataset_type_id": self._dataset_type_id, 

718 self._runKeyColumn: run.key, 

719 } 

720 ) 

721 

722 with self._db.transaction(): 

723 # Insert into the static dataset table. 

724 self._db.insert(self._static.dataset, *rows) 

725 # Update the summary tables for this collection in case this is the 

726 # first time this dataset type or these governor values will be 

727 # inserted there. 

728 self._summaries.update(run, [self._dataset_type_id], summary) 

729 # Combine the generated dataset_id values and data ID fields to 

730 # form rows to be inserted into the tags table. 

731 protoTagsRow = { 

732 "dataset_type_id": self._dataset_type_id, 

733 self._collections.getCollectionForeignKeyName(): run.key, 

734 } 

735 tagsRows = [ 

736 dict(protoTagsRow, dataset_id=row["id"], **dataId.byName()) 

737 for dataId, row in zip(dataIdList, rows) 

738 ] 

739 # Insert those rows into the tags table. 

740 self._db.insert(self._tags, *tagsRows) 

741 

742 for dataId, row in zip(dataIdList, rows): 

743 yield DatasetRef( 

744 datasetType=self.datasetType, 

745 dataId=dataId, 

746 id=row["id"], 

747 run=run.name, 

748 ) 

749 

750 def import_( 

751 self, 

752 run: RunRecord, 

753 datasets: Iterable[DatasetRef], 

754 idGenerationMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE, 

755 reuseIds: bool = False, 

756 ) -> Iterator[DatasetRef]: 

757 # Docstring inherited from DatasetRecordStorage. 

758 

759 # Iterate over data IDs, transforming a possibly-single-pass iterable 

760 # into a list. 

761 dataIds = {} 

762 summary = CollectionSummary() 

763 for dataset in summary.add_datasets_generator(datasets): 

764 # Ignore unknown ID types, normally all IDs have the same type but 

765 # this code supports mixed types or missing IDs. 

766 datasetId = dataset.id if isinstance(dataset.id, uuid.UUID) else None 

767 if datasetId is None: 

768 datasetId = self.idMaker.makeDatasetId( 

769 run.name, self.datasetType, dataset.dataId, idGenerationMode 

770 ) 

771 dataIds[datasetId] = dataset.dataId 

772 

773 with self._db.session() as session: 

774 

775 # insert all new rows into a temporary table 

776 tableSpec = makeTagTableSpec( 

777 self.datasetType, type(self._collections), ddl.GUID, constraints=False 

778 ) 

779 tmp_tags = session.makeTemporaryTable(tableSpec) 

780 

781 collFkName = self._collections.getCollectionForeignKeyName() 

782 protoTagsRow = { 

783 "dataset_type_id": self._dataset_type_id, 

784 collFkName: run.key, 

785 } 

786 tmpRows = [ 

787 dict(protoTagsRow, dataset_id=dataset_id, **dataId.byName()) 

788 for dataset_id, dataId in dataIds.items() 

789 ] 

790 

791 with self._db.transaction(): 

792 

793 # store all incoming data in a temporary table 

794 self._db.insert(tmp_tags, *tmpRows) 

795 

796 # There are some checks that we want to make for consistency 

797 # of the new datasets with existing ones. 

798 self._validateImport(tmp_tags, run) 

799 

800 # Before we merge temporary table into dataset/tags we need to 

801 # drop datasets which are already there (and do not conflict). 

802 self._db.deleteWhere( 

803 tmp_tags, 

804 tmp_tags.columns.dataset_id.in_(sqlalchemy.sql.select(self._static.dataset.columns.id)), 

805 ) 

806 

807 # Copy it into dataset table, need to re-label some columns. 

808 self._db.insert( 

809 self._static.dataset, 

810 select=sqlalchemy.sql.select( 

811 tmp_tags.columns.dataset_id.label("id"), 

812 tmp_tags.columns.dataset_type_id, 

813 tmp_tags.columns[collFkName].label(self._runKeyColumn), 

814 ), 

815 ) 

816 

817 # Update the summary tables for this collection in case this 

818 # is the first time this dataset type or these governor values 

819 # will be inserted there. 

820 self._summaries.update(run, [self._dataset_type_id], summary) 

821 

822 # Copy it into tags table. 

823 self._db.insert(self._tags, select=tmp_tags.select()) 

824 

825 # Return refs in the same order as in the input list. 

826 for dataset_id, dataId in dataIds.items(): 

827 yield DatasetRef( 

828 datasetType=self.datasetType, 

829 id=dataset_id, 

830 dataId=dataId, 

831 run=run.name, 

832 ) 

833 

834 def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) -> None: 

835 """Validate imported refs against existing datasets. 

836 

837 Parameters 

838 ---------- 

839 tmp_tags : `sqlalchemy.schema.Table` 

840 Temporary table with new datasets and the same schema as tags 

841 table. 

842 run : `RunRecord` 

843 The record object describing the `~CollectionType.RUN` collection. 

844 

845 Raises 

846 ------ 

847 ConflictingDefinitionError 

848 Raise if new datasets conflict with existing ones. 

849 """ 

850 dataset = self._static.dataset 

851 tags = self._tags 

852 collFkName = self._collections.getCollectionForeignKeyName() 

853 

854 # Check that existing datasets have the same dataset type and 

855 # run. 

856 query = ( 

857 sqlalchemy.sql.select( 

858 dataset.columns.id.label("dataset_id"), 

859 dataset.columns.dataset_type_id.label("dataset_type_id"), 

860 tmp_tags.columns.dataset_type_id.label("new dataset_type_id"), 

861 dataset.columns[self._runKeyColumn].label("run"), 

862 tmp_tags.columns[collFkName].label("new run"), 

863 ) 

864 .select_from(dataset.join(tmp_tags, dataset.columns.id == tmp_tags.columns.dataset_id)) 

865 .where( 

866 sqlalchemy.sql.or_( 

867 dataset.columns.dataset_type_id != tmp_tags.columns.dataset_type_id, 

868 dataset.columns[self._runKeyColumn] != tmp_tags.columns[collFkName], 

869 ) 

870 ) 

871 ) 

872 result = self._db.query(query) 

873 if (row := result.first()) is not None: 

874 # Only include the first one in the exception message 

875 raise ConflictingDefinitionError( 

876 f"Existing dataset type or run do not match new dataset: {row._asdict()}" 

877 ) 

878 

879 # Check that matching dataset in tags table has the same DataId. 

880 query = ( 

881 sqlalchemy.sql.select( 

882 tags.columns.dataset_id, 

883 tags.columns.dataset_type_id.label("type_id"), 

884 tmp_tags.columns.dataset_type_id.label("new type_id"), 

885 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names], 

886 *[ 

887 tmp_tags.columns[dim].label(f"new {dim}") 

888 for dim in self.datasetType.dimensions.required.names 

889 ], 

890 ) 

891 .select_from(tags.join(tmp_tags, tags.columns.dataset_id == tmp_tags.columns.dataset_id)) 

892 .where( 

893 sqlalchemy.sql.or_( 

894 tags.columns.dataset_type_id != tmp_tags.columns.dataset_type_id, 

895 *[ 

896 tags.columns[dim] != tmp_tags.columns[dim] 

897 for dim in self.datasetType.dimensions.required.names 

898 ], 

899 ) 

900 ) 

901 ) 

902 result = self._db.query(query) 

903 if (row := result.first()) is not None: 

904 # Only include the first one in the exception message 

905 raise ConflictingDefinitionError( 

906 f"Existing dataset type or dataId do not match new dataset: {row._asdict()}" 

907 ) 

908 

909 # Check that matching run+dataId have the same dataset ID. 

910 query = ( 

911 sqlalchemy.sql.select( 

912 tags.columns.dataset_type_id.label("dataset_type_id"), 

913 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names], 

914 tags.columns.dataset_id, 

915 tmp_tags.columns.dataset_id.label("new dataset_id"), 

916 tags.columns[collFkName], 

917 tmp_tags.columns[collFkName].label(f"new {collFkName}"), 

918 ) 

919 .select_from( 

920 tags.join( 

921 tmp_tags, 

922 sqlalchemy.sql.and_( 

923 tags.columns.dataset_type_id == tmp_tags.columns.dataset_type_id, 

924 tags.columns[collFkName] == tmp_tags.columns[collFkName], 

925 *[ 

926 tags.columns[dim] == tmp_tags.columns[dim] 

927 for dim in self.datasetType.dimensions.required.names 

928 ], 

929 ), 

930 ) 

931 ) 

932 .where(tags.columns.dataset_id != tmp_tags.columns.dataset_id) 

933 ) 

934 result = self._db.query(query) 

935 if (row := result.first()) is not None: 

936 # only include the first one in the exception message 

937 raise ConflictingDefinitionError( 

938 f"Existing dataset type and dataId does not match new dataset: {row._asdict()}" 

939 )