Coverage for python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py: 89%

301 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-11-29 02:00 -0800

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22 

23from __future__ import annotations 

24 

25__all__ = ("ByDimensionsDatasetRecordStorage",) 

26 

27import uuid 

28from collections.abc import Iterable, Iterator, Sequence 

29from typing import TYPE_CHECKING, Any 

30 

31import sqlalchemy 

32 

33from ....core import ( 

34 DataCoordinate, 

35 DataCoordinateSet, 

36 DatasetId, 

37 DatasetRef, 

38 DatasetType, 

39 SimpleQuery, 

40 Timespan, 

41 ddl, 

42) 

43from ..._collection_summary import CollectionSummary 

44from ..._collectionType import CollectionType 

45from ..._exceptions import CollectionTypeError, ConflictingDefinitionError, UnsupportedIdGeneratorError 

46from ...interfaces import DatasetIdFactory, DatasetIdGenEnum, DatasetRecordStorage 

47from .tables import makeTagTableSpec 

48 

49if TYPE_CHECKING: 49 ↛ 50line 49 didn't jump to line 50, because the condition on line 49 was never true

50 from ...interfaces import CollectionManager, CollectionRecord, Database, RunRecord 

51 from .summaries import CollectionSummaryManager 

52 from .tables import StaticDatasetTablesTuple 

53 

54 

55class ByDimensionsDatasetRecordStorage(DatasetRecordStorage): 

56 """Dataset record storage implementation paired with 

57 `ByDimensionsDatasetRecordStorageManager`; see that class for more 

58 information. 

59 

60 Instances of this class should never be constructed directly; use 

61 `DatasetRecordStorageManager.register` instead. 

62 """ 

63 

64 def __init__( 

65 self, 

66 *, 

67 datasetType: DatasetType, 

68 db: Database, 

69 dataset_type_id: int, 

70 collections: CollectionManager, 

71 static: StaticDatasetTablesTuple, 

72 summaries: CollectionSummaryManager, 

73 tags: sqlalchemy.schema.Table, 

74 calibs: sqlalchemy.schema.Table | None, 

75 ): 

76 super().__init__(datasetType=datasetType) 

77 self._dataset_type_id = dataset_type_id 

78 self._db = db 

79 self._collections = collections 

80 self._static = static 

81 self._summaries = summaries 

82 self._tags = tags 

83 self._calibs = calibs 

84 self._runKeyColumn = collections.getRunForeignKeyName() 

85 

86 def find( 

87 self, collection: CollectionRecord, dataId: DataCoordinate, timespan: Timespan | None = None 

88 ) -> DatasetRef | None: 

89 # Docstring inherited from DatasetRecordStorage. 

90 assert dataId.graph == self.datasetType.dimensions 

91 if collection.type is CollectionType.CALIBRATION and timespan is None: 91 ↛ 92line 91 didn't jump to line 92, because the condition on line 91 was never true

92 raise TypeError( 

93 f"Cannot search for dataset in CALIBRATION collection {collection.name} " 

94 f"without an input timespan." 

95 ) 

96 sql = self.select( 

97 collection, dataId=dataId, id=SimpleQuery.Select, run=SimpleQuery.Select, timespan=timespan 

98 ) 

99 results = self._db.query(sql) 

100 row = results.fetchone() 

101 if row is None: 

102 return None 

103 if collection.type is CollectionType.CALIBRATION: 

104 # For temporal calibration lookups (only!) our invariants do not 

105 # guarantee that the number of result rows is <= 1. 

106 # They would if `select` constrained the given timespan to be 

107 # _contained_ by the validity range in the self._calibs table, 

108 # instead of simply _overlapping_ it, because we do guarantee that 

109 # the validity ranges are disjoint for a particular dataset type, 

110 # collection, and data ID. But using an overlap test and a check 

111 # for multiple result rows here allows us to provide a more useful 

112 # diagnostic, as well as allowing `select` to support more general 

113 # queries where multiple results are not an error. 

114 if results.fetchone() is not None: 

115 raise RuntimeError( 

116 f"Multiple matches found for calibration lookup in {collection.name} for " 

117 f"{self.datasetType.name} with {dataId} overlapping {timespan}. " 

118 ) 

119 return DatasetRef( 

120 datasetType=self.datasetType, 

121 dataId=dataId, 

122 id=row.id, 

123 run=self._collections[row._mapping[self._runKeyColumn]].name, 

124 ) 

125 

126 def delete(self, datasets: Iterable[DatasetRef]) -> None: 

127 # Docstring inherited from DatasetRecordStorage. 

128 # Only delete from common dataset table; ON DELETE foreign key clauses 

129 # will handle the rest. 

130 self._db.delete( 

131 self._static.dataset, 

132 ["id"], 

133 *[{"id": dataset.getCheckedId()} for dataset in datasets], 

134 ) 

135 

136 def associate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None: 

137 # Docstring inherited from DatasetRecordStorage. 

138 if collection.type is not CollectionType.TAGGED: 138 ↛ 139line 138 didn't jump to line 139, because the condition on line 138 was never true

139 raise TypeError( 

140 f"Cannot associate into collection '{collection.name}' " 

141 f"of type {collection.type.name}; must be TAGGED." 

142 ) 

143 protoRow = { 

144 self._collections.getCollectionForeignKeyName(): collection.key, 

145 "dataset_type_id": self._dataset_type_id, 

146 } 

147 rows = [] 

148 summary = CollectionSummary() 

149 for dataset in summary.add_datasets_generator(datasets): 

150 row = dict(protoRow, dataset_id=dataset.getCheckedId()) 

151 for dimension, value in dataset.dataId.items(): 

152 row[dimension.name] = value 

153 rows.append(row) 

154 # Update the summary tables for this collection in case this is the 

155 # first time this dataset type or these governor values will be 

156 # inserted there. 

157 self._summaries.update(collection, [self._dataset_type_id], summary) 

158 # Update the tag table itself. 

159 self._db.replace(self._tags, *rows) 

160 

161 def disassociate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None: 

162 # Docstring inherited from DatasetRecordStorage. 

163 if collection.type is not CollectionType.TAGGED: 163 ↛ 164line 163 didn't jump to line 164, because the condition on line 163 was never true

164 raise TypeError( 

165 f"Cannot disassociate from collection '{collection.name}' " 

166 f"of type {collection.type.name}; must be TAGGED." 

167 ) 

168 rows = [ 

169 { 

170 "dataset_id": dataset.getCheckedId(), 

171 self._collections.getCollectionForeignKeyName(): collection.key, 

172 } 

173 for dataset in datasets 

174 ] 

175 self._db.delete(self._tags, ["dataset_id", self._collections.getCollectionForeignKeyName()], *rows) 

176 

177 def _buildCalibOverlapQuery( 

178 self, collection: CollectionRecord, dataIds: DataCoordinateSet | None, timespan: Timespan 

179 ) -> SimpleQuery: 

180 assert self._calibs is not None 

181 # Start by building a SELECT query for any rows that would overlap 

182 # this one. 

183 query = SimpleQuery() 

184 query.join(self._calibs) 

185 # Add a WHERE clause matching the dataset type and collection. 

186 query.where.append(self._calibs.columns.dataset_type_id == self._dataset_type_id) 

187 query.where.append( 

188 self._calibs.columns[self._collections.getCollectionForeignKeyName()] == collection.key 

189 ) 

190 # Add a WHERE clause matching any of the given data IDs. 

191 if dataIds is not None: 

192 dataIds.constrain( 

193 query, 

194 lambda name: self._calibs.columns[name], # type: ignore 

195 ) 

196 # Add WHERE clause for timespan overlaps. 

197 TimespanReprClass = self._db.getTimespanRepresentation() 

198 query.where.append( 

199 TimespanReprClass.from_columns(self._calibs.columns).overlaps( 

200 TimespanReprClass.fromLiteral(timespan) 

201 ) 

202 ) 

203 return query 

204 

205 def certify( 

206 self, collection: CollectionRecord, datasets: Iterable[DatasetRef], timespan: Timespan 

207 ) -> None: 

208 # Docstring inherited from DatasetRecordStorage. 

209 if self._calibs is None: 209 ↛ 210line 209 didn't jump to line 210, because the condition on line 209 was never true

210 raise CollectionTypeError( 

211 f"Cannot certify datasets of type {self.datasetType.name}, for which " 

212 f"DatasetType.isCalibration() is False." 

213 ) 

214 if collection.type is not CollectionType.CALIBRATION: 214 ↛ 215line 214 didn't jump to line 215, because the condition on line 214 was never true

215 raise CollectionTypeError( 

216 f"Cannot certify into collection '{collection.name}' " 

217 f"of type {collection.type.name}; must be CALIBRATION." 

218 ) 

219 TimespanReprClass = self._db.getTimespanRepresentation() 

220 protoRow = { 

221 self._collections.getCollectionForeignKeyName(): collection.key, 

222 "dataset_type_id": self._dataset_type_id, 

223 } 

224 rows = [] 

225 dataIds: set[DataCoordinate] | None = ( 

226 set() if not TimespanReprClass.hasExclusionConstraint() else None 

227 ) 

228 summary = CollectionSummary() 

229 for dataset in summary.add_datasets_generator(datasets): 

230 row = dict(protoRow, dataset_id=dataset.getCheckedId()) 

231 for dimension, value in dataset.dataId.items(): 

232 row[dimension.name] = value 

233 TimespanReprClass.update(timespan, result=row) 

234 rows.append(row) 

235 if dataIds is not None: 235 ↛ 229line 235 didn't jump to line 229, because the condition on line 235 was never false

236 dataIds.add(dataset.dataId) 

237 # Update the summary tables for this collection in case this is the 

238 # first time this dataset type or these governor values will be 

239 # inserted there. 

240 self._summaries.update(collection, [self._dataset_type_id], summary) 

241 # Update the association table itself. 

242 if TimespanReprClass.hasExclusionConstraint(): 242 ↛ 245line 242 didn't jump to line 245, because the condition on line 242 was never true

243 # Rely on database constraint to enforce invariants; we just 

244 # reraise the exception for consistency across DB engines. 

245 try: 

246 self._db.insert(self._calibs, *rows) 

247 except sqlalchemy.exc.IntegrityError as err: 

248 raise ConflictingDefinitionError( 

249 f"Validity range conflict certifying datasets of type {self.datasetType.name} " 

250 f"into {collection.name} for range [{timespan.begin}, {timespan.end})." 

251 ) from err 

252 else: 

253 # Have to implement exclusion constraint ourselves. 

254 # Start by building a SELECT query for any rows that would overlap 

255 # this one. 

256 query = self._buildCalibOverlapQuery( 

257 collection, 

258 DataCoordinateSet(dataIds, graph=self.datasetType.dimensions), # type: ignore 

259 timespan, 

260 ) 

261 query.columns.append(sqlalchemy.sql.func.count()) 

262 sql = query.combine() 

263 # Acquire a table lock to ensure there are no concurrent writes 

264 # could invalidate our checking before we finish the inserts. We 

265 # use a SAVEPOINT in case there is an outer transaction that a 

266 # failure here should not roll back. 

267 with self._db.transaction(lock=[self._calibs], savepoint=True): 

268 # Run the check SELECT query. 

269 conflicting = self._db.query(sql).scalar() 

270 if conflicting > 0: 

271 raise ConflictingDefinitionError( 

272 f"{conflicting} validity range conflicts certifying datasets of type " 

273 f"{self.datasetType.name} into {collection.name} for range " 

274 f"[{timespan.begin}, {timespan.end})." 

275 ) 

276 # Proceed with the insert. 

277 self._db.insert(self._calibs, *rows) 

278 

279 def decertify( 

280 self, 

281 collection: CollectionRecord, 

282 timespan: Timespan, 

283 *, 

284 dataIds: Iterable[DataCoordinate] | None = None, 

285 ) -> None: 

286 # Docstring inherited from DatasetRecordStorage. 

287 if self._calibs is None: 287 ↛ 288line 287 didn't jump to line 288, because the condition on line 287 was never true

288 raise CollectionTypeError( 

289 f"Cannot decertify datasets of type {self.datasetType.name}, for which " 

290 f"DatasetType.isCalibration() is False." 

291 ) 

292 if collection.type is not CollectionType.CALIBRATION: 292 ↛ 293line 292 didn't jump to line 293, because the condition on line 292 was never true

293 raise CollectionTypeError( 

294 f"Cannot decertify from collection '{collection.name}' " 

295 f"of type {collection.type.name}; must be CALIBRATION." 

296 ) 

297 TimespanReprClass = self._db.getTimespanRepresentation() 

298 # Construct a SELECT query to find all rows that overlap our inputs. 

299 dataIdSet: DataCoordinateSet | None 

300 if dataIds is not None: 

301 dataIdSet = DataCoordinateSet(set(dataIds), graph=self.datasetType.dimensions) 

302 else: 

303 dataIdSet = None 

304 query = self._buildCalibOverlapQuery(collection, dataIdSet, timespan) 

305 query.columns.extend(self._calibs.columns) 

306 sql = query.combine() 

307 # Set up collections to populate with the rows we'll want to modify. 

308 # The insert rows will have the same values for collection and 

309 # dataset type. 

310 protoInsertRow = { 

311 self._collections.getCollectionForeignKeyName(): collection.key, 

312 "dataset_type_id": self._dataset_type_id, 

313 } 

314 rowsToDelete = [] 

315 rowsToInsert = [] 

316 # Acquire a table lock to ensure there are no concurrent writes 

317 # between the SELECT and the DELETE and INSERT queries based on it. 

318 with self._db.transaction(lock=[self._calibs], savepoint=True): 

319 for row in self._db.query(sql).mappings(): 

320 rowsToDelete.append({"id": row["id"]}) 

321 # Construct the insert row(s) by copying the prototype row, 

322 # then adding the dimension column values, then adding what's 

323 # left of the timespan from that row after we subtract the 

324 # given timespan. 

325 newInsertRow = protoInsertRow.copy() 

326 newInsertRow["dataset_id"] = row["dataset_id"] 

327 for name in self.datasetType.dimensions.required.names: 

328 newInsertRow[name] = row[name] 

329 rowTimespan = TimespanReprClass.extract(row) 

330 assert rowTimespan is not None, "Field should have a NOT NULL constraint." 

331 for diffTimespan in rowTimespan.difference(timespan): 

332 rowsToInsert.append(TimespanReprClass.update(diffTimespan, result=newInsertRow.copy())) 

333 # Run the DELETE and INSERT queries. 

334 self._db.delete(self._calibs, ["id"], *rowsToDelete) 

335 self._db.insert(self._calibs, *rowsToInsert) 

336 

337 def select( 

338 self, 

339 *collections: CollectionRecord, 

340 dataId: SimpleQuery.Select.Or[DataCoordinate] = SimpleQuery.Select, 

341 id: SimpleQuery.Select.Or[int | None] = SimpleQuery.Select, 

342 run: SimpleQuery.Select.Or[None] = SimpleQuery.Select, 

343 timespan: SimpleQuery.Select.Or[Timespan | None] = SimpleQuery.Select, 

344 ingestDate: SimpleQuery.Select.Or[Timespan | None] = None, 

345 rank: SimpleQuery.Select.Or[None] = None, 

346 ) -> sqlalchemy.sql.Selectable: 

347 # Docstring inherited from DatasetRecordStorage. 

348 collection_types = {collection.type for collection in collections} 

349 assert CollectionType.CHAINED not in collection_types, "CHAINED collections must be flattened." 

350 TimespanReprClass = self._db.getTimespanRepresentation() 

351 # 

352 # There are two kinds of table in play here: 

353 # 

354 # - the static dataset table (with the dataset ID, dataset type ID, 

355 # run ID/name, and ingest date); 

356 # 

357 # - the dynamic tags/calibs table (with the dataset ID, dataset type 

358 # type ID, collection ID/name, data ID, and possibly validity 

359 # range). 

360 # 

361 # That means that we might want to return a query against either table 

362 # or a JOIN of both, depending on which quantities the caller wants. 

363 # But this method is documented/typed such that ``dataId`` is never 

364 # `None` - i.e. we always constrain or retreive the data ID. That 

365 # means we'll always include the tags/calibs table and join in the 

366 # static dataset table only if we need things from it that we can't get 

367 # from the tags/calibs table. 

368 # 

369 # Note that it's important that we include a WHERE constraint on both 

370 # tables for any column (e.g. dataset_type_id) that is in both when 

371 # it's given explicitly; not doing can prevent the query planner from 

372 # using very important indexes. At present, we don't include those 

373 # redundant columns in the JOIN ON expression, however, because the 

374 # FOREIGN KEY (and its index) are defined only on dataset_id. 

375 # 

376 # We'll start by accumulating kwargs to pass to SimpleQuery.join when 

377 # we bring in the tags/calibs table. We get the data ID or constrain 

378 # it in the tags/calibs table(s), but that's multiple columns, not one, 

379 # so we need to transform the one Select.Or argument into a dictionary 

380 # of them. 

381 kwargs: dict[str, Any] 

382 if dataId is SimpleQuery.Select: 

383 kwargs = {dim.name: SimpleQuery.Select for dim in self.datasetType.dimensions.required} 

384 else: 

385 kwargs = dict(dataId.byName()) 

386 # We always constrain (never retrieve) the dataset type in at least the 

387 # tags/calibs table. 

388 kwargs["dataset_type_id"] = self._dataset_type_id 

389 # Join in the tags and/or calibs tables, turning those 'kwargs' entries 

390 # into WHERE constraints or SELECT columns as appropriate. 

391 if collection_types != {CollectionType.CALIBRATION}: 

392 # We'll need a subquery for the tags table if any of the given 

393 # collections are not a CALIBRATION collection. This intentionally 

394 # also fires when the list of collections is empty as a way to 

395 # create a dummy subquery that we know will fail. 

396 tags_query = SimpleQuery() 

397 tags_query.join(self._tags, **kwargs) 

398 # If the timespan is requested, simulate a potentially compound 

399 # column whose values are the maximum and minimum timespan 

400 # bounds. 

401 # If the timespan is constrained, ignore the constraint, since 

402 # it'd be guaranteed to evaluate to True. 

403 if timespan is SimpleQuery.Select: 

404 tags_query.columns.extend(TimespanReprClass.fromLiteral(Timespan(None, None)).flatten()) 

405 self._finish_single_select( 

406 tags_query, 

407 self._tags, 

408 collections, 

409 id=id, 

410 run=run, 

411 ingestDate=ingestDate, 

412 rank=rank, 

413 ) 

414 else: 

415 tags_query = None 

416 if CollectionType.CALIBRATION in collection_types: 

417 # If at least one collection is a CALIBRATION collection, we'll 

418 # need a subquery for the calibs table, and could include the 

419 # timespan as a result or constraint. 

420 calibs_query = SimpleQuery() 

421 assert ( 

422 self._calibs is not None 

423 ), "DatasetTypes with isCalibration() == False can never be found in a CALIBRATION collection." 

424 calibs_query.join(self._calibs, **kwargs) 

425 # Add the timespan column(s) to the result columns, or constrain 

426 # the timespan via an overlap condition. 

427 if timespan is SimpleQuery.Select: 

428 calibs_query.columns.extend(TimespanReprClass.from_columns(self._calibs.columns).flatten()) 

429 elif timespan is not None: 

430 calibs_query.where.append( 

431 TimespanReprClass.from_columns(self._calibs.columns).overlaps( 

432 TimespanReprClass.fromLiteral(timespan) 

433 ) 

434 ) 

435 self._finish_single_select( 

436 calibs_query, 

437 self._calibs, 

438 collections, 

439 id=id, 

440 run=run, 

441 ingestDate=ingestDate, 

442 rank=rank, 

443 ) 

444 else: 

445 calibs_query = None 

446 if calibs_query is not None: 

447 if tags_query is not None: 

448 return tags_query.combine().union(calibs_query.combine()) 

449 else: 

450 return calibs_query.combine() 

451 else: 

452 assert tags_query is not None, "Earlier logic should guaranteed at least one is not None." 

453 return tags_query.combine() 

454 

455 def _finish_single_select( 

456 self, 

457 query: SimpleQuery, 

458 table: sqlalchemy.schema.Table, 

459 collections: Sequence[CollectionRecord], 

460 id: SimpleQuery.Select.Or[int | None], 

461 run: SimpleQuery.Select.Or[None], 

462 ingestDate: SimpleQuery.Select.Or[Timespan | None], 

463 rank: SimpleQuery.Select.Or[None], 

464 ) -> None: 

465 dataset_id_col = table.columns.dataset_id 

466 collection_col = table.columns[self._collections.getCollectionForeignKeyName()] 

467 # We always constrain (never retrieve) the collection(s) in the 

468 # tags/calibs table. 

469 if len(collections) == 1: 

470 query.where.append(collection_col == collections[0].key) 

471 elif len(collections) == 0: 

472 # We support the case where there are no collections as a way to 

473 # generate a valid SQL query that can't yield results. This should 

474 # never get executed, but lots of downstream code will still try 

475 # to access the SQLAlchemy objects representing the columns in the 

476 # subquery. That's not ideal, but it'd take a lot of refactoring 

477 # to fix it (DM-31725). 

478 query.where.append(sqlalchemy.sql.literal(False)) 

479 else: 

480 query.where.append(collection_col.in_([collection.key for collection in collections])) 

481 # Add rank if requested as a CASE-based calculation the collection 

482 # column. 

483 if rank is not None: 

484 assert rank is SimpleQuery.Select, "Cannot constraint rank, only select it." 

485 query.columns.append( 

486 sqlalchemy.sql.case( 

487 {record.key: n for n, record in enumerate(collections)}, 

488 value=collection_col, 

489 ).label("rank") 

490 ) 

491 # We can always get the dataset_id from the tags/calibs table or 

492 # constrain it there. Can't use kwargs for that because we need to 

493 # alias it to 'id'. 

494 if id is SimpleQuery.Select: 

495 query.columns.append(dataset_id_col.label("id")) 

496 elif id is not None: 496 ↛ 497line 496 didn't jump to line 497, because the condition on line 496 was never true

497 query.where.append(dataset_id_col == id) 

498 # It's possible we now have everything we need, from just the 

499 # tags/calibs table. The things we might need to get from the static 

500 # dataset table are the run key and the ingest date. 

501 need_static_table = False 

502 static_kwargs: dict[str, Any] = {} 

503 if run is not None: 

504 assert run is SimpleQuery.Select, "To constrain the run name, pass a RunRecord as a collection." 

505 if len(collections) == 1 and collections[0].type is CollectionType.RUN: 

506 # If we are searching exactly one RUN collection, we 

507 # know that if we find the dataset in that collection, 

508 # then that's the datasets's run; we don't need to 

509 # query for it. 

510 query.columns.append(sqlalchemy.sql.literal(collections[0].key).label(self._runKeyColumn)) 

511 else: 

512 static_kwargs[self._runKeyColumn] = SimpleQuery.Select 

513 need_static_table = True 

514 # Ingest date can only come from the static table. 

515 if ingestDate is not None: 

516 need_static_table = True 

517 if ingestDate is SimpleQuery.Select: 517 ↛ 520line 517 didn't jump to line 520, because the condition on line 517 was never false

518 static_kwargs["ingest_date"] = SimpleQuery.Select 

519 else: 

520 assert isinstance(ingestDate, Timespan) 

521 # Timespan is astropy Time (usually in TAI) and ingest_date is 

522 # TIMESTAMP, convert values to Python datetime for sqlalchemy. 

523 if ingestDate.isEmpty(): 

524 raise RuntimeError("Empty timespan constraint provided for ingest_date.") 

525 if ingestDate.begin is not None: 

526 begin = ingestDate.begin.utc.datetime # type: ignore 

527 query.where.append(self._static.dataset.columns.ingest_date >= begin) 

528 if ingestDate.end is not None: 

529 end = ingestDate.end.utc.datetime # type: ignore 

530 query.where.append(self._static.dataset.columns.ingest_date < end) 

531 # If we need the static table, join it in via dataset_id and 

532 # dataset_type_id 

533 if need_static_table: 

534 query.join( 

535 self._static.dataset, 

536 onclause=(dataset_id_col == self._static.dataset.columns.id), 

537 **static_kwargs, 

538 ) 

539 # Also constrain dataset_type_id in static table in case that helps 

540 # generate a better plan. 

541 # We could also include this in the JOIN ON clause, but my guess is 

542 # that that's a good idea IFF it's in the foreign key, and right 

543 # now it isn't. 

544 query.where.append(self._static.dataset.columns.dataset_type_id == self._dataset_type_id) 

545 

546 def getDataId(self, id: DatasetId) -> DataCoordinate: 

547 """Return DataId for a dataset. 

548 

549 Parameters 

550 ---------- 

551 id : `DatasetId` 

552 Unique dataset identifier. 

553 

554 Returns 

555 ------- 

556 dataId : `DataCoordinate` 

557 DataId for the dataset. 

558 """ 

559 # This query could return multiple rows (one for each tagged collection 

560 # the dataset is in, plus one for its run collection), and we don't 

561 # care which of those we get. 

562 sql = ( 

563 self._tags.select() 

564 .where( 

565 sqlalchemy.sql.and_( 

566 self._tags.columns.dataset_id == id, 

567 self._tags.columns.dataset_type_id == self._dataset_type_id, 

568 ) 

569 ) 

570 .limit(1) 

571 ) 

572 row = self._db.query(sql).mappings().fetchone() 

573 assert row is not None, "Should be guaranteed by caller and foreign key constraints." 

574 return DataCoordinate.standardize( 

575 {dimension.name: row[dimension.name] for dimension in self.datasetType.dimensions.required}, 

576 graph=self.datasetType.dimensions, 

577 ) 

578 

579 

580class ByDimensionsDatasetRecordStorageInt(ByDimensionsDatasetRecordStorage): 

581 """Implementation of ByDimensionsDatasetRecordStorage which uses integer 

582 auto-incremented column for dataset IDs. 

583 """ 

584 

585 def insert( 

586 self, 

587 run: RunRecord, 

588 dataIds: Iterable[DataCoordinate], 

589 idMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE, 

590 ) -> Iterator[DatasetRef]: 

591 # Docstring inherited from DatasetRecordStorage. 

592 

593 # We only support UNIQUE mode for integer dataset IDs 

594 if idMode != DatasetIdGenEnum.UNIQUE: 594 ↛ 595line 594 didn't jump to line 595, because the condition on line 594 was never true

595 raise UnsupportedIdGeneratorError("Only UNIQUE mode can be used with integer dataset IDs.") 

596 

597 # Transform a possibly-single-pass iterable into a list. 

598 dataIdList = list(dataIds) 

599 yield from self._insert(run, dataIdList) 

600 

601 def import_( 

602 self, 

603 run: RunRecord, 

604 datasets: Iterable[DatasetRef], 

605 idGenerationMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE, 

606 reuseIds: bool = False, 

607 ) -> Iterator[DatasetRef]: 

608 # Docstring inherited from DatasetRecordStorage. 

609 

610 # We only support UNIQUE mode for integer dataset IDs 

611 if idGenerationMode != DatasetIdGenEnum.UNIQUE: 611 ↛ 612line 611 didn't jump to line 612, because the condition on line 611 was never true

612 raise UnsupportedIdGeneratorError("Only UNIQUE mode can be used with integer dataset IDs.") 

613 

614 # Make a list of dataIds and optionally dataset IDs. 

615 dataIdList: list[DataCoordinate] = [] 

616 datasetIdList: list[int] = [] 

617 for dataset in datasets: 

618 dataIdList.append(dataset.dataId) 

619 

620 # We only accept integer dataset IDs, but also allow None. 

621 datasetId = dataset.id 

622 if datasetId is None: 622 ↛ 624line 622 didn't jump to line 624, because the condition on line 622 was never true

623 # if reuseIds is set then all IDs must be known 

624 if reuseIds: 

625 raise TypeError("All dataset IDs must be known if `reuseIds` is set") 

626 elif isinstance(datasetId, int): 626 ↛ 630line 626 didn't jump to line 630, because the condition on line 626 was never false

627 if reuseIds: 

628 datasetIdList.append(datasetId) 

629 else: 

630 raise TypeError(f"Unsupported type of dataset ID: {type(datasetId)}") 

631 

632 yield from self._insert(run, dataIdList, datasetIdList) 

633 

634 def _insert( 

635 self, run: RunRecord, dataIdList: list[DataCoordinate], datasetIdList: list[int] | None = None 

636 ) -> Iterator[DatasetRef]: 

637 """Common part of implementation of `insert` and `import_` methods.""" 

638 

639 # Remember any governor dimension values we see. 

640 summary = CollectionSummary() 

641 summary.add_data_ids(self.datasetType, dataIdList) 

642 

643 staticRow = { 

644 "dataset_type_id": self._dataset_type_id, 

645 self._runKeyColumn: run.key, 

646 } 

647 with self._db.transaction(): 

648 # Insert into the static dataset table, generating autoincrement 

649 # dataset_id values. 

650 if datasetIdList: 

651 # reuse existing IDs 

652 rows = [dict(staticRow, id=datasetId) for datasetId in datasetIdList] 

653 self._db.insert(self._static.dataset, *rows) 

654 else: 

655 # use auto-incremented IDs 

656 datasetIdList = self._db.insert( 

657 self._static.dataset, *([staticRow] * len(dataIdList)), returnIds=True 

658 ) 

659 assert datasetIdList is not None 

660 # Update the summary tables for this collection in case this is the 

661 # first time this dataset type or these governor values will be 

662 # inserted there. 

663 self._summaries.update(run, [self._dataset_type_id], summary) 

664 # Combine the generated dataset_id values and data ID fields to 

665 # form rows to be inserted into the tags table. 

666 protoTagsRow = { 

667 "dataset_type_id": self._dataset_type_id, 

668 self._collections.getCollectionForeignKeyName(): run.key, 

669 } 

670 tagsRows = [ 

671 dict(protoTagsRow, dataset_id=dataset_id, **dataId.byName()) 

672 for dataId, dataset_id in zip(dataIdList, datasetIdList) 

673 ] 

674 # Insert those rows into the tags table. This is where we'll 

675 # get any unique constraint violations. 

676 self._db.insert(self._tags, *tagsRows) 

677 

678 for dataId, datasetId in zip(dataIdList, datasetIdList): 

679 yield DatasetRef( 

680 datasetType=self.datasetType, 

681 dataId=dataId, 

682 id=datasetId, 

683 run=run.name, 

684 ) 

685 

686 

687class ByDimensionsDatasetRecordStorageUUID(ByDimensionsDatasetRecordStorage): 

688 """Implementation of ByDimensionsDatasetRecordStorage which uses UUID for 

689 dataset IDs. 

690 """ 

691 

692 idMaker = DatasetIdFactory() 

693 """Factory for dataset IDs. In the future this factory may be shared with 

694 other classes (e.g. Registry).""" 

695 

696 def insert( 

697 self, 

698 run: RunRecord, 

699 dataIds: Iterable[DataCoordinate], 

700 idMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE, 

701 ) -> Iterator[DatasetRef]: 

702 # Docstring inherited from DatasetRecordStorage. 

703 

704 # Iterate over data IDs, transforming a possibly-single-pass iterable 

705 # into a list. 

706 dataIdList = [] 

707 rows = [] 

708 summary = CollectionSummary() 

709 for dataId in summary.add_data_ids_generator(self.datasetType, dataIds): 

710 dataIdList.append(dataId) 

711 rows.append( 

712 { 

713 "id": self.idMaker.makeDatasetId(run.name, self.datasetType, dataId, idMode), 

714 "dataset_type_id": self._dataset_type_id, 

715 self._runKeyColumn: run.key, 

716 } 

717 ) 

718 

719 with self._db.transaction(): 

720 # Insert into the static dataset table. 

721 self._db.insert(self._static.dataset, *rows) 

722 # Update the summary tables for this collection in case this is the 

723 # first time this dataset type or these governor values will be 

724 # inserted there. 

725 self._summaries.update(run, [self._dataset_type_id], summary) 

726 # Combine the generated dataset_id values and data ID fields to 

727 # form rows to be inserted into the tags table. 

728 protoTagsRow = { 

729 "dataset_type_id": self._dataset_type_id, 

730 self._collections.getCollectionForeignKeyName(): run.key, 

731 } 

732 tagsRows = [ 

733 dict(protoTagsRow, dataset_id=row["id"], **dataId.byName()) 

734 for dataId, row in zip(dataIdList, rows) 

735 ] 

736 # Insert those rows into the tags table. 

737 self._db.insert(self._tags, *tagsRows) 

738 

739 for dataId, row in zip(dataIdList, rows): 

740 yield DatasetRef( 

741 datasetType=self.datasetType, 

742 dataId=dataId, 

743 id=row["id"], 

744 run=run.name, 

745 ) 

746 

747 def import_( 

748 self, 

749 run: RunRecord, 

750 datasets: Iterable[DatasetRef], 

751 idGenerationMode: DatasetIdGenEnum = DatasetIdGenEnum.UNIQUE, 

752 reuseIds: bool = False, 

753 ) -> Iterator[DatasetRef]: 

754 # Docstring inherited from DatasetRecordStorage. 

755 

756 # Iterate over data IDs, transforming a possibly-single-pass iterable 

757 # into a list. 

758 dataIds = {} 

759 summary = CollectionSummary() 

760 for dataset in summary.add_datasets_generator(datasets): 

761 # Ignore unknown ID types, normally all IDs have the same type but 

762 # this code supports mixed types or missing IDs. 

763 datasetId = dataset.id if isinstance(dataset.id, uuid.UUID) else None 

764 if datasetId is None: 

765 datasetId = self.idMaker.makeDatasetId( 

766 run.name, self.datasetType, dataset.dataId, idGenerationMode 

767 ) 

768 dataIds[datasetId] = dataset.dataId 

769 

770 with self._db.session() as session: 

771 

772 # insert all new rows into a temporary table 

773 tableSpec = makeTagTableSpec( 

774 self.datasetType, type(self._collections), ddl.GUID, constraints=False 

775 ) 

776 tmp_tags = session.makeTemporaryTable(tableSpec) 

777 

778 collFkName = self._collections.getCollectionForeignKeyName() 

779 protoTagsRow = { 

780 "dataset_type_id": self._dataset_type_id, 

781 collFkName: run.key, 

782 } 

783 tmpRows = [ 

784 dict(protoTagsRow, dataset_id=dataset_id, **dataId.byName()) 

785 for dataset_id, dataId in dataIds.items() 

786 ] 

787 

788 with self._db.transaction(): 

789 

790 # store all incoming data in a temporary table 

791 self._db.insert(tmp_tags, *tmpRows) 

792 

793 # There are some checks that we want to make for consistency 

794 # of the new datasets with existing ones. 

795 self._validateImport(tmp_tags, run) 

796 

797 # Before we merge temporary table into dataset/tags we need to 

798 # drop datasets which are already there (and do not conflict). 

799 self._db.deleteWhere( 

800 tmp_tags, 

801 tmp_tags.columns.dataset_id.in_(sqlalchemy.sql.select(self._static.dataset.columns.id)), 

802 ) 

803 

804 # Copy it into dataset table, need to re-label some columns. 

805 self._db.insert( 

806 self._static.dataset, 

807 select=sqlalchemy.sql.select( 

808 tmp_tags.columns.dataset_id.label("id"), 

809 tmp_tags.columns.dataset_type_id, 

810 tmp_tags.columns[collFkName].label(self._runKeyColumn), 

811 ), 

812 ) 

813 

814 # Update the summary tables for this collection in case this 

815 # is the first time this dataset type or these governor values 

816 # will be inserted there. 

817 self._summaries.update(run, [self._dataset_type_id], summary) 

818 

819 # Copy it into tags table. 

820 self._db.insert(self._tags, select=tmp_tags.select()) 

821 

822 # Return refs in the same order as in the input list. 

823 for dataset_id, dataId in dataIds.items(): 

824 yield DatasetRef( 

825 datasetType=self.datasetType, 

826 id=dataset_id, 

827 dataId=dataId, 

828 run=run.name, 

829 ) 

830 

831 def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) -> None: 

832 """Validate imported refs against existing datasets. 

833 

834 Parameters 

835 ---------- 

836 tmp_tags : `sqlalchemy.schema.Table` 

837 Temporary table with new datasets and the same schema as tags 

838 table. 

839 run : `RunRecord` 

840 The record object describing the `~CollectionType.RUN` collection. 

841 

842 Raises 

843 ------ 

844 ConflictingDefinitionError 

845 Raise if new datasets conflict with existing ones. 

846 """ 

847 dataset = self._static.dataset 

848 tags = self._tags 

849 collFkName = self._collections.getCollectionForeignKeyName() 

850 

851 # Check that existing datasets have the same dataset type and 

852 # run. 

853 query = ( 

854 sqlalchemy.sql.select( 

855 dataset.columns.id.label("dataset_id"), 

856 dataset.columns.dataset_type_id.label("dataset_type_id"), 

857 tmp_tags.columns.dataset_type_id.label("new dataset_type_id"), 

858 dataset.columns[self._runKeyColumn].label("run"), 

859 tmp_tags.columns[collFkName].label("new run"), 

860 ) 

861 .select_from(dataset.join(tmp_tags, dataset.columns.id == tmp_tags.columns.dataset_id)) 

862 .where( 

863 sqlalchemy.sql.or_( 

864 dataset.columns.dataset_type_id != tmp_tags.columns.dataset_type_id, 

865 dataset.columns[self._runKeyColumn] != tmp_tags.columns[collFkName], 

866 ) 

867 ) 

868 ) 

869 result = self._db.query(query) 

870 if (row := result.first()) is not None: 

871 # Only include the first one in the exception message 

872 raise ConflictingDefinitionError( 

873 f"Existing dataset type or run do not match new dataset: {row._asdict()}" 

874 ) 

875 

876 # Check that matching dataset in tags table has the same DataId. 

877 query = ( 

878 sqlalchemy.sql.select( 

879 tags.columns.dataset_id, 

880 tags.columns.dataset_type_id.label("type_id"), 

881 tmp_tags.columns.dataset_type_id.label("new type_id"), 

882 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names], 

883 *[ 

884 tmp_tags.columns[dim].label(f"new {dim}") 

885 for dim in self.datasetType.dimensions.required.names 

886 ], 

887 ) 

888 .select_from(tags.join(tmp_tags, tags.columns.dataset_id == tmp_tags.columns.dataset_id)) 

889 .where( 

890 sqlalchemy.sql.or_( 

891 tags.columns.dataset_type_id != tmp_tags.columns.dataset_type_id, 

892 *[ 

893 tags.columns[dim] != tmp_tags.columns[dim] 

894 for dim in self.datasetType.dimensions.required.names 

895 ], 

896 ) 

897 ) 

898 ) 

899 result = self._db.query(query) 

900 if (row := result.first()) is not None: 

901 # Only include the first one in the exception message 

902 raise ConflictingDefinitionError( 

903 f"Existing dataset type or dataId do not match new dataset: {row._asdict()}" 

904 ) 

905 

906 # Check that matching run+dataId have the same dataset ID. 

907 query = ( 

908 sqlalchemy.sql.select( 

909 tags.columns.dataset_type_id.label("dataset_type_id"), 

910 *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names], 

911 tags.columns.dataset_id, 

912 tmp_tags.columns.dataset_id.label("new dataset_id"), 

913 tags.columns[collFkName], 

914 tmp_tags.columns[collFkName].label(f"new {collFkName}"), 

915 ) 

916 .select_from( 

917 tags.join( 

918 tmp_tags, 

919 sqlalchemy.sql.and_( 

920 tags.columns.dataset_type_id == tmp_tags.columns.dataset_type_id, 

921 tags.columns[collFkName] == tmp_tags.columns[collFkName], 

922 *[ 

923 tags.columns[dim] == tmp_tags.columns[dim] 

924 for dim in self.datasetType.dimensions.required.names 

925 ], 

926 ), 

927 ) 

928 ) 

929 .where(tags.columns.dataset_id != tmp_tags.columns.dataset_id) 

930 ) 

931 result = self._db.query(query) 

932 if (row := result.first()) is not None: 

933 # only include the first one in the exception message 

934 raise ConflictingDefinitionError( 

935 f"Existing dataset type and dataId does not match new dataset: {row._asdict()}" 

936 )