Coverage for python/lsst/ctrl/mpexec/preExecInit.py: 18%

203 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-02-22 03:27 -0800

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["PreExecInit"] 

25 

26# ------------------------------- 

27# Imports of standard modules -- 

28# ------------------------------- 

29import abc 

30import logging 

31from collections.abc import Iterable, Iterator 

32from contextlib import contextmanager 

33from typing import TYPE_CHECKING, Any 

34 

35# ----------------------------- 

36# Imports for other modules -- 

37# ----------------------------- 

38from lsst.daf.butler import DataCoordinate, DatasetIdFactory, DatasetRef, DatasetType 

39from lsst.daf.butler.registry import ConflictingDefinitionError 

40from lsst.pipe.base import PipelineDatasetTypes 

41from lsst.utils.packages import Packages 

42 

43from .mock_task import MockButlerQuantumContext 

44 

45if TYPE_CHECKING: 45 ↛ 46line 45 didn't jump to line 46, because the condition on line 45 was never true

46 from lsst.daf.butler import Butler, LimitedButler 

47 from lsst.pipe.base import QuantumGraph, TaskDef, TaskFactory 

48 

49_LOG = logging.getLogger(__name__) 

50 

51 

52class MissingReferenceError(Exception): 

53 """Exception raised when resolved reference is missing from graph.""" 

54 

55 pass 

56 

57 

58def _compare_packages(old_packages: Packages, new_packages: Packages) -> None: 

59 """Compare two versions of Packages. 

60 

61 Parameters 

62 ---------- 

63 old_packages : `Packages` 

64 Previously recorded package versions. 

65 new_packages : `Packages` 

66 New set of package versions. 

67 

68 Raises 

69 ------ 

70 TypeError 

71 Raised if parameters are inconsistent. 

72 """ 

73 diff = new_packages.difference(old_packages) 

74 if diff: 

75 versions_str = "; ".join(f"{pkg}: {diff[pkg][1]} vs {diff[pkg][0]}" for pkg in diff) 

76 raise TypeError(f"Package versions mismatch: ({versions_str})") 

77 else: 

78 _LOG.debug("new packages are consistent with old") 

79 

80 

81class PreExecInitBase(abc.ABC): 

82 """Common part of the implementation of PreExecInit classes that does not 

83 depend on Butler type. 

84 """ 

85 

86 def __init__(self, butler: LimitedButler, taskFactory: TaskFactory): 

87 self.butler = butler 

88 self.taskFactory = taskFactory 

89 

90 def initialize( 

91 self, 

92 graph: QuantumGraph, 

93 saveInitOutputs: bool = True, 

94 registerDatasetTypes: bool = False, 

95 saveVersions: bool = True, 

96 ) -> None: 

97 """Perform all initialization steps. 

98 

99 Convenience method to execute all initialization steps. Instead of 

100 calling this method and providing all options it is also possible to 

101 call methods individually. 

102 

103 Parameters 

104 ---------- 

105 graph : `~lsst.pipe.base.QuantumGraph` 

106 Execution graph. 

107 saveInitOutputs : `bool`, optional 

108 If ``True`` (default) then save "init outputs", configurations, 

109 and package versions to butler. 

110 registerDatasetTypes : `bool`, optional 

111 If ``True`` then register dataset types in registry, otherwise 

112 they must be already registered. 

113 saveVersions : `bool`, optional 

114 If ``False`` then do not save package versions even if 

115 ``saveInitOutputs`` is set to ``True``. 

116 """ 

117 # register dataset types or check consistency 

118 self.initializeDatasetTypes(graph, registerDatasetTypes) 

119 

120 # Save task initialization data or check that saved data 

121 # is consistent with what tasks would save 

122 if saveInitOutputs: 

123 self.saveInitOutputs(graph) 

124 self.saveConfigs(graph) 

125 if saveVersions: 

126 self.savePackageVersions(graph) 

127 

128 @abc.abstractmethod 

129 def initializeDatasetTypes(self, graph: QuantumGraph, registerDatasetTypes: bool = False) -> None: 

130 """Save or check DatasetTypes output by the tasks in a graph. 

131 

132 Iterates over all DatasetTypes for all tasks in a graph and either 

133 tries to add them to registry or compares them to existing ones. 

134 

135 Parameters 

136 ---------- 

137 graph : `~lsst.pipe.base.QuantumGraph` 

138 Execution graph. 

139 registerDatasetTypes : `bool`, optional 

140 If ``True`` then register dataset types in registry, otherwise 

141 they must be already registered. 

142 

143 Raises 

144 ------ 

145 ValueError 

146 Raised if existing DatasetType is different from DatasetType 

147 in a graph. 

148 KeyError 

149 Raised if ``registerDatasetTypes`` is ``False`` and DatasetType 

150 does not exist in registry. 

151 """ 

152 raise NotImplementedError() 

153 

154 def saveInitOutputs(self, graph: QuantumGraph) -> None: 

155 """Write any datasets produced by initializing tasks in a graph. 

156 

157 Parameters 

158 ---------- 

159 graph : `~lsst.pipe.base.QuantumGraph` 

160 Execution graph. 

161 

162 Raises 

163 ------ 

164 TypeError 

165 Raised if the type of existing object in butler is different from 

166 new data. 

167 """ 

168 _LOG.debug("Will save InitOutputs for all tasks") 

169 for taskDef in graph.iterTaskGraph(): 

170 init_input_refs = self.find_init_input_refs(taskDef, graph) 

171 task = self.taskFactory.makeTask(taskDef, self.butler, init_input_refs) 

172 for name in taskDef.connections.initOutputs: 

173 attribute = getattr(taskDef.connections, name) 

174 obj_from_store, init_output_ref = self.find_init_output(taskDef, attribute.name, graph) 

175 if init_output_ref is None: 

176 raise ValueError(f"Cannot find or make dataset reference for init output {name}") 

177 init_output_var = getattr(task, name) 

178 

179 if obj_from_store is not None: 

180 _LOG.debug( 

181 "Retrieving InitOutputs for task=%s key=%s dsTypeName=%s", task, name, attribute.name 

182 ) 

183 obj_from_store = self.butler.getDirect(init_output_ref) 

184 # Types are supposed to be identical. 

185 # TODO: Check that object contents is identical too. 

186 if type(obj_from_store) is not type(init_output_var): 

187 raise TypeError( 

188 f"Stored initOutput object type {type(obj_from_store)} " 

189 "is different from task-generated type " 

190 f"{type(init_output_var)} for task {taskDef}" 

191 ) 

192 else: 

193 _LOG.debug("Saving InitOutputs for task=%s key=%s", taskDef.label, name) 

194 # This can still raise if there is a concurrent write. 

195 self.butler.putDirect(init_output_var, init_output_ref) 

196 

197 def saveConfigs(self, graph: QuantumGraph) -> None: 

198 """Write configurations for pipeline tasks to butler or check that 

199 existing configurations are equal to the new ones. 

200 

201 Parameters 

202 ---------- 

203 graph : `~lsst.pipe.base.QuantumGraph` 

204 Execution graph. 

205 

206 Raises 

207 ------ 

208 TypeError 

209 Raised if existing object in butler is different from new data. 

210 Exception 

211 Raised if ``extendRun`` is `False` and datasets already exists. 

212 Content of a butler collection should not be changed if exception 

213 is raised. 

214 """ 

215 

216 def logConfigMismatch(msg: str) -> None: 

217 """Log messages about configuration mismatch.""" 

218 _LOG.fatal("Comparing configuration: %s", msg) 

219 

220 _LOG.debug("Will save Configs for all tasks") 

221 # start transaction to rollback any changes on exceptions 

222 with self.transaction(): 

223 for taskDef in graph.iterTaskGraph(): 

224 config_name = taskDef.configDatasetName 

225 

226 old_config, dataset_ref = self.find_init_output(taskDef, taskDef.configDatasetName, graph) 

227 

228 if old_config is not None: 

229 if not taskDef.config.compare(old_config, shortcut=False, output=logConfigMismatch): 

230 raise TypeError( 

231 f"Config does not match existing task config {taskDef.configDatasetName!r} in " 

232 "butler; tasks configurations must be consistent within the same run collection" 

233 ) 

234 else: 

235 # butler will raise exception if dataset is already there 

236 _LOG.debug("Saving Config for task=%s dataset type=%s", taskDef.label, config_name) 

237 self.butler.putDirect(taskDef.config, dataset_ref) 

238 

239 def savePackageVersions(self, graph: QuantumGraph) -> None: 

240 """Write versions of software packages to butler. 

241 

242 Parameters 

243 ---------- 

244 graph : `~lsst.pipe.base.QuantumGraph` 

245 Execution graph. 

246 

247 Raises 

248 ------ 

249 TypeError 

250 Raised if existing object in butler is incompatible with new data. 

251 """ 

252 packages = Packages.fromSystem() 

253 _LOG.debug("want to save packages: %s", packages) 

254 

255 # start transaction to rollback any changes on exceptions 

256 with self.transaction(): 

257 old_packages, dataset_ref = self.find_packages(graph) 

258 

259 if old_packages is not None: 

260 # Note that because we can only detect python modules that have 

261 # been imported, the stored list of products may be more or 

262 # less complete than what we have now. What's important is 

263 # that the products that are in common have the same version. 

264 _compare_packages(old_packages, packages) 

265 # Update the old set of packages in case we have more packages 

266 # that haven't been persisted. 

267 extra = packages.extra(old_packages) 

268 if extra: 

269 _LOG.debug("extra packages: %s", extra) 

270 old_packages.update(packages) 

271 # have to remove existing dataset first, butler has no 

272 # replace option. 

273 self.butler.pruneDatasets([dataset_ref], unstore=True, purge=True) 

274 self.butler.putDirect(old_packages, dataset_ref) 

275 else: 

276 self.butler.putDirect(packages, dataset_ref) 

277 

278 @abc.abstractmethod 

279 def find_init_input_refs(self, taskDef: TaskDef, graph: QuantumGraph) -> Iterable[DatasetRef]: 

280 """Return the list of resolved dataset references for task init inputs. 

281 

282 Parameters 

283 ---------- 

284 taskDef : `~lsst.pipe.base.TaskDef` 

285 Pipeline task definition. 

286 graph : `~lsst.pipe.base.QuantumGraph` 

287 Quantum graph. 

288 

289 Returns 

290 ------- 

291 refs : `~collections.abc.Iterable` [`~lsst.daf.butler.DatasetRef`] 

292 Resolved dataset references. 

293 """ 

294 raise NotImplementedError() 

295 

296 @abc.abstractmethod 

297 def find_init_output( 

298 self, taskDef: TaskDef, dataset_type: str, graph: QuantumGraph 

299 ) -> tuple[Any | None, DatasetRef]: 

300 """Find task init output for given dataset type. 

301 

302 Parameters 

303 ---------- 

304 taskDef : `~lsst.pipe.base.TaskDef` 

305 Pipeline task definition. 

306 dataset_type : `str` 

307 Dataset type name. 

308 graph : `~lsst.pipe.base.QuantumGraph` 

309 Quantum graph. 

310 

311 Returns 

312 ------- 

313 data 

314 Existing init output object retrieved from butler, `None` if butler 

315 has no existing object. 

316 ref : `~lsst.daf.butler.DatasetRef` 

317 Resolved reference for init output to be stored in butler. 

318 

319 Raises 

320 ------ 

321 MissingReferenceError 

322 Raised if reference cannot be found or generated. 

323 """ 

324 raise NotImplementedError() 

325 

326 @abc.abstractmethod 

327 def find_packages(self, graph: QuantumGraph) -> tuple[Packages | None, DatasetRef]: 

328 """Find packages information. 

329 

330 Parameters 

331 ---------- 

332 graph : `~lsst.pipe.base.QuantumGraph` 

333 Quantum graph. 

334 

335 Returns 

336 ------- 

337 packages : `lsst.utils.packages.Packages` or `None` 

338 Existing packages data retrieved from butler, or `None`. 

339 ref : `~lsst.daf.butler.DatasetRef` 

340 Resolved reference for packages to be stored in butler. 

341 

342 Raises 

343 ------ 

344 MissingReferenceError 

345 Raised if reference cannot be found or generated. 

346 """ 

347 raise NotImplementedError() 

348 

349 @contextmanager 

350 def transaction(self) -> Iterator[None]: 

351 """Context manager for transaction. 

352 

353 Default implementation has no transaction support. 

354 """ 

355 yield 

356 

357 

358class PreExecInit(PreExecInitBase): 

359 """Initialization of registry for QuantumGraph execution. 

360 

361 This class encapsulates all necessary operations that have to be performed 

362 on butler and registry to prepare them for QuantumGraph execution. 

363 

364 Parameters 

365 ---------- 

366 butler : `~lsst.daf.butler.Butler` 

367 Data butler instance. 

368 taskFactory : `~lsst.pipe.base.TaskFactory` 

369 Task factory. 

370 extendRun : `bool`, optional 

371 If `True` then do not try to overwrite any datasets that might exist 

372 in ``butler.run``; instead compare them when appropriate/possible. If 

373 `False`, then any existing conflicting dataset will cause a butler 

374 exception to be raised. 

375 mock : `bool`, optional 

376 If `True` then also do initialization needed for pipeline mocking. 

377 """ 

378 

379 def __init__(self, butler: Butler, taskFactory: TaskFactory, extendRun: bool = False, mock: bool = False): 

380 super().__init__(butler, taskFactory) 

381 self.full_butler = butler 

382 self.extendRun = extendRun 

383 self.mock = mock 

384 if self.extendRun and self.full_butler.run is None: 

385 raise RuntimeError( 

386 "Cannot perform extendRun logic unless butler is initialized " 

387 "with a default output RUN collection." 

388 ) 

389 

390 @contextmanager 

391 def transaction(self) -> Iterator[None]: 

392 # dosctring inherited 

393 with self.full_butler.transaction(): 

394 yield 

395 

396 def initializeDatasetTypes(self, graph: QuantumGraph, registerDatasetTypes: bool = False) -> None: 

397 # docstring inherited 

398 pipeline = graph.taskGraph 

399 pipelineDatasetTypes = PipelineDatasetTypes.fromPipeline( 

400 pipeline, registry=self.full_butler.registry, include_configs=True, include_packages=True 

401 ) 

402 

403 for datasetTypes, is_input in ( 

404 (pipelineDatasetTypes.initIntermediates, True), 

405 (pipelineDatasetTypes.initOutputs, False), 

406 (pipelineDatasetTypes.intermediates, True), 

407 (pipelineDatasetTypes.outputs, False), 

408 ): 

409 self._register_output_dataset_types(registerDatasetTypes, datasetTypes, is_input) 

410 

411 if self.mock: 

412 # register special mock data types, skip logs and metadata 

413 skipDatasetTypes = {taskDef.metadataDatasetName for taskDef in pipeline} 

414 skipDatasetTypes |= {taskDef.logOutputDatasetName for taskDef in pipeline} 

415 for datasetTypes, is_input in ( 

416 (pipelineDatasetTypes.intermediates, True), 

417 (pipelineDatasetTypes.outputs, False), 

418 ): 

419 mockDatasetTypes = [] 

420 for datasetType in datasetTypes: 

421 if not (datasetType.name in skipDatasetTypes or datasetType.isComponent()): 

422 mockDatasetTypes.append( 

423 DatasetType( 

424 MockButlerQuantumContext.mockDatasetTypeName(datasetType.name), 

425 datasetType.dimensions, 

426 "StructuredDataDict", 

427 ) 

428 ) 

429 if mockDatasetTypes: 

430 self._register_output_dataset_types(registerDatasetTypes, mockDatasetTypes, is_input) 

431 

432 def _register_output_dataset_types( 

433 self, registerDatasetTypes: bool, datasetTypes: Iterable[DatasetType], is_input: bool 

434 ) -> None: 

435 def _check_compatibility(datasetType: DatasetType, expected: DatasetType, is_input: bool) -> bool: 

436 # These are output dataset types so check for compatibility on put. 

437 is_compatible = expected.is_compatible_with(datasetType) 

438 

439 if is_input: 

440 # This dataset type is also used for input so must be 

441 # compatible on get as ell. 

442 is_compatible = is_compatible and datasetType.is_compatible_with(expected) 

443 

444 if is_compatible: 

445 _LOG.debug( 

446 "The dataset type configurations differ (%s from task != %s from registry) " 

447 "but the storage classes are compatible. Can continue.", 

448 datasetType, 

449 expected, 

450 ) 

451 return is_compatible 

452 

453 missing_datasetTypes = set() 

454 for datasetType in datasetTypes: 

455 # Only composites are registered, no components, and by this point 

456 # the composite should already exist. 

457 if registerDatasetTypes and not datasetType.isComponent(): 

458 _LOG.debug("Registering DatasetType %s with registry", datasetType) 

459 # this is a no-op if it already exists and is consistent, 

460 # and it raises if it is inconsistent. 

461 try: 

462 self.full_butler.registry.registerDatasetType(datasetType) 

463 except ConflictingDefinitionError: 

464 if not _check_compatibility( 

465 datasetType, self.full_butler.registry.getDatasetType(datasetType.name), is_input 

466 ): 

467 raise 

468 else: 

469 _LOG.debug("Checking DatasetType %s against registry", datasetType) 

470 try: 

471 expected = self.full_butler.registry.getDatasetType(datasetType.name) 

472 except KeyError: 

473 # Likely means that --register-dataset-types is forgotten. 

474 missing_datasetTypes.add(datasetType.name) 

475 continue 

476 if expected != datasetType: 

477 if not _check_compatibility(datasetType, expected, is_input): 

478 raise ValueError( 

479 f"DatasetType configuration does not match Registry: {datasetType} != {expected}" 

480 ) 

481 

482 if missing_datasetTypes: 

483 plural = "s" if len(missing_datasetTypes) != 1 else "" 

484 raise KeyError( 

485 f"Missing dataset type definition{plural}: {', '.join(missing_datasetTypes)}. " 

486 "Dataset types have to be registered with either `butler register-dataset-type` or " 

487 "passing `--register-dataset-types` option to `pipetask run`." 

488 ) 

489 

490 def find_init_input_refs(self, taskDef: TaskDef, graph: QuantumGraph) -> Iterable[DatasetRef]: 

491 # docstring inherited 

492 refs: list[DatasetRef] = [] 

493 for name in taskDef.connections.initInputs: 

494 attribute = getattr(taskDef.connections, name) 

495 dataId = DataCoordinate.makeEmpty(self.full_butler.dimensions) 

496 dataset_type = DatasetType(attribute.name, graph.universe.empty, attribute.storageClass) 

497 ref = self.full_butler.registry.findDataset(dataset_type, dataId) 

498 if ref is None: 

499 raise ValueError(f"InitInput does not exist in butler for dataset type {dataset_type}") 

500 refs.append(ref) 

501 return refs 

502 

503 def find_init_output( 

504 self, taskDef: TaskDef, dataset_type_name: str, graph: QuantumGraph 

505 ) -> tuple[Any | None, DatasetRef]: 

506 # docstring inherited 

507 dataset_type = self.full_butler.registry.getDatasetType(dataset_type_name) 

508 dataId = DataCoordinate.makeEmpty(self.full_butler.dimensions) 

509 return self._find_existing(dataset_type, dataId) 

510 

511 def find_packages(self, graph: QuantumGraph) -> tuple[Packages | None, DatasetRef]: 

512 # docstring inherited 

513 dataset_type = self.full_butler.registry.getDatasetType(PipelineDatasetTypes.packagesDatasetName) 

514 dataId = DataCoordinate.makeEmpty(self.full_butler.dimensions) 

515 return self._find_existing(dataset_type, dataId) 

516 

517 def _find_existing( 

518 self, dataset_type: DatasetType, dataId: DataCoordinate 

519 ) -> tuple[Any | None, DatasetRef]: 

520 """Make a reference of a given dataset type and try to retrieve it from 

521 butler. If not found then generate new resolved reference. 

522 """ 

523 run = self.full_butler.run 

524 assert run is not None 

525 

526 ref = self.full_butler.registry.findDataset(dataset_type, dataId, collections=[run]) 

527 if self.extendRun and ref is not None: 

528 try: 

529 config = self.butler.getDirect(ref) 

530 return config, ref 

531 except (LookupError, FileNotFoundError): 

532 return None, ref 

533 else: 

534 # make new resolved dataset ref 

535 ref = DatasetRef(dataset_type, dataId) 

536 ref = DatasetIdFactory().resolveRef(ref, run) 

537 return None, ref 

538 

539 

540class PreExecInitLimited(PreExecInitBase): 

541 """Initialization of registry for QuantumGraph execution. 

542 

543 This class works with LimitedButler and expects that all references in 

544 QuantumGraph are resolved. 

545 

546 Parameters 

547 ---------- 

548 butler : `~lsst.daf.butler.LimitedButler` 

549 Limited data butler instance. 

550 taskFactory : `~lsst.pipe.base.TaskFactory` 

551 Task factory. 

552 """ 

553 

554 def __init__(self, butler: LimitedButler, taskFactory: TaskFactory): 

555 super().__init__(butler, taskFactory) 

556 

557 def initializeDatasetTypes(self, graph: QuantumGraph, registerDatasetTypes: bool = False) -> None: 

558 # docstring inherited 

559 # With LimitedButler we never create or check dataset types. 

560 pass 

561 

562 def find_init_input_refs(self, taskDef: TaskDef, graph: QuantumGraph) -> Iterable[DatasetRef]: 

563 # docstring inherited 

564 return graph.initInputRefs(taskDef) or [] 

565 

566 def find_init_output( 

567 self, taskDef: TaskDef, dataset_type: str, graph: QuantumGraph 

568 ) -> tuple[Any | None, DatasetRef]: 

569 # docstring inherited 

570 return self._find_existing(graph.initOutputRefs(taskDef) or [], dataset_type) 

571 

572 def find_packages(self, graph: QuantumGraph) -> tuple[Packages | None, DatasetRef]: 

573 # docstring inherited 

574 return self._find_existing(graph.globalInitOutputRefs(), PipelineDatasetTypes.packagesDatasetName) 

575 

576 def _find_existing(self, refs: Iterable[DatasetRef], dataset_type: str) -> tuple[Any | None, DatasetRef]: 

577 """Find a reference of a given dataset type in the list of references 

578 and try to retrieve it from butler. 

579 """ 

580 for ref in refs: 

581 if ref.datasetType.name == dataset_type: 

582 try: 

583 data = self.butler.getDirect(ref) 

584 return data, ref 

585 except (LookupError, FileNotFoundError): 

586 return None, ref 

587 raise MissingReferenceError(f"Failed to find reference for dataset type {dataset_type}")