Coverage for python/lsst/pipe/base/tests/simpleQGraph.py: 23%

188 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-19 11:28 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28"""Bunch of common classes and methods for use in unit tests. 

29""" 

30from __future__ import annotations 

31 

32__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"] 

33 

34import logging 

35import warnings 

36from collections.abc import Iterable, Mapping, MutableMapping 

37from typing import TYPE_CHECKING, Any, cast 

38 

39import lsst.daf.butler.tests as butlerTests 

40import lsst.pex.config as pexConfig 

41import numpy 

42from lsst.daf.butler import Butler, Config, DataId, DatasetRef, DatasetType, Formatter, LimitedButler 

43from lsst.daf.butler.logging import ButlerLogRecords 

44from lsst.resources import ResourcePath 

45from lsst.utils import doImportType 

46from lsst.utils.introspection import find_outside_stacklevel, get_full_type_name 

47 

48from .. import connectionTypes as cT 

49from .._instrument import Instrument 

50from ..all_dimensions_quantum_graph_builder import AllDimensionsQuantumGraphBuilder 

51from ..all_dimensions_quantum_graph_builder import DatasetQueryConstraintVariant as DSQVariant 

52from ..automatic_connection_constants import PACKAGES_INIT_OUTPUT_NAME, PACKAGES_INIT_OUTPUT_STORAGE_CLASS 

53from ..config import PipelineTaskConfig 

54from ..connections import PipelineTaskConnections 

55from ..graph import QuantumGraph 

56from ..pipeline import Pipeline, TaskDef 

57from ..pipeline_graph import PipelineGraph, TaskNode 

58from ..pipelineTask import PipelineTask 

59from ..struct import Struct 

60from ..task import _TASK_FULL_METADATA_TYPE 

61from ..taskFactory import TaskFactory 

62 

63if TYPE_CHECKING: 

64 from lsst.daf.butler import Registry 

65 

66_LOG = logging.getLogger(__name__) 

67 

68 

69class SimpleInstrument(Instrument): 

70 """Simple instrument class suitable for testing. 

71 

72 Parameters 

73 ---------- 

74 *args : `~typing.Any` 

75 Ignore parameters. 

76 **kwargs : `~typing.Any` 

77 Unused keyword arguments. 

78 """ 

79 

80 def __init__(self, *args: Any, **kwargs: Any): 

81 pass 

82 

83 @staticmethod 

84 def getName() -> str: 

85 return "INSTRU" 

86 

87 def getRawFormatter(self, dataId: DataId) -> type[Formatter]: 

88 return Formatter 

89 

90 def register(self, registry: Registry, *, update: bool = False) -> None: 

91 pass 

92 

93 

94class AddTaskConnections( 

95 PipelineTaskConnections, 

96 dimensions=("instrument", "detector"), 

97 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"}, 

98): 

99 """Connections for AddTask, has one input and two outputs, 

100 plus one init output. 

101 """ 

102 

103 input = cT.Input( 

104 name="add_dataset{in_tmpl}", 

105 dimensions=["instrument", "detector"], 

106 storageClass="NumpyArray", 

107 doc="Input dataset type for this task", 

108 ) 

109 output = cT.Output( 

110 name="add_dataset{out_tmpl}", 

111 dimensions=["instrument", "detector"], 

112 storageClass="NumpyArray", 

113 doc="Output dataset type for this task", 

114 ) 

115 output2 = cT.Output( 

116 name="add2_dataset{out_tmpl}", 

117 dimensions=["instrument", "detector"], 

118 storageClass="NumpyArray", 

119 doc="Output dataset type for this task", 

120 ) 

121 initout = cT.InitOutput( 

122 name="add_init_output{out_tmpl}", 

123 storageClass="NumpyArray", 

124 doc="Init Output dataset type for this task", 

125 ) 

126 

127 

128class AddTaskConfig(PipelineTaskConfig, pipelineConnections=AddTaskConnections): 

129 """Config for AddTask.""" 

130 

131 addend = pexConfig.Field[int](doc="amount to add", default=3) 

132 

133 

134class AddTask(PipelineTask): 

135 """Trivial PipelineTask for testing, has some extras useful for specific 

136 unit tests. 

137 """ 

138 

139 ConfigClass = AddTaskConfig 

140 _DefaultName = "add_task" 

141 

142 initout = numpy.array([999]) 

143 """InitOutputs for this task""" 

144 

145 taskFactory: AddTaskFactoryMock | None = None 

146 """Factory that makes instances""" 

147 

148 def run(self, input: int) -> Struct: 

149 if self.taskFactory: 

150 # do some bookkeeping 

151 if self.taskFactory.stopAt == self.taskFactory.countExec: 

152 raise RuntimeError("pretend something bad happened") 

153 self.taskFactory.countExec += 1 

154 

155 self.config = cast(AddTaskConfig, self.config) 

156 self.metadata.add("add", self.config.addend) 

157 output = input + self.config.addend 

158 output2 = output + self.config.addend 

159 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2) 

160 return Struct(output=output, output2=output2) 

161 

162 

163class AddTaskFactoryMock(TaskFactory): 

164 """Special task factory that instantiates AddTask. 

165 

166 It also defines some bookkeeping variables used by AddTask to report 

167 progress to unit tests. 

168 

169 Parameters 

170 ---------- 

171 stopAt : `int`, optional 

172 Number of times to call `run` before stopping. 

173 """ 

174 

175 def __init__(self, stopAt: int = -1): 

176 self.countExec = 0 # incremented by AddTask 

177 self.stopAt = stopAt # AddTask raises exception at this call to run() 

178 

179 def makeTask( 

180 self, 

181 task_node: TaskDef | TaskNode, 

182 /, 

183 butler: LimitedButler, 

184 initInputRefs: Iterable[DatasetRef] | None, 

185 ) -> PipelineTask: 

186 if isinstance(task_node, TaskDef): 

187 # TODO: remove support on DM-40443. 

188 warnings.warn( 

189 "Passing TaskDef to TaskFactory is deprecated and will not be supported after v27.", 

190 FutureWarning, 

191 find_outside_stacklevel("lsst.pipe.base"), 

192 ) 

193 task_class = task_node.taskClass 

194 assert task_class is not None 

195 else: 

196 task_class = task_node.task_class 

197 task = task_class(config=task_node.config, initInputs=None, name=task_node.label) 

198 task.taskFactory = self # type: ignore 

199 return task 

200 

201 

202def registerDatasetTypes(registry: Registry, pipeline: Pipeline | Iterable[TaskDef] | PipelineGraph) -> None: 

203 """Register all dataset types used by tasks in a registry. 

204 

205 Copied and modified from `PreExecInit.initializeDatasetTypes`. 

206 

207 Parameters 

208 ---------- 

209 registry : `~lsst.daf.butler.Registry` 

210 Registry instance. 

211 pipeline : `.Pipeline`, `~collections.abc..Iterable` of `.TaskDef`, or \ 

212 `.pipeline_graph.PipelineGraph` 

213 The pipeline whose dataset types should be registered, as a `.Pipeline` 

214 instance, `.PipelineGraph` instance, or iterable of `.TaskDef` 

215 instances. Support for `.TaskDef` is deprecated and will be removed 

216 after v27. 

217 """ 

218 match pipeline: 

219 case PipelineGraph() as pipeline_graph: 

220 pass 

221 case Pipeline(): 

222 pipeline_graph = pipeline.to_graph() 

223 case _: 

224 warnings.warn( 

225 "Passing TaskDefs is deprecated and will not be supported after v27.", 

226 category=FutureWarning, 

227 stacklevel=find_outside_stacklevel("lsst.pipe.base"), 

228 ) 

229 pipeline_graph = PipelineGraph() 

230 for task_def in pipeline: 

231 pipeline_graph.add_task( 

232 task_def.label, task_def.taskClass, task_def.config, connections=task_def.connections 

233 ) 

234 pipeline_graph.resolve(registry) 

235 dataset_types = [node.dataset_type for node in pipeline_graph.dataset_types.values()] 

236 dataset_types.append( 

237 DatasetType( 

238 PACKAGES_INIT_OUTPUT_NAME, 

239 {}, 

240 storageClass=PACKAGES_INIT_OUTPUT_STORAGE_CLASS, 

241 universe=registry.dimensions, 

242 ) 

243 ) 

244 for dataset_type in dataset_types: 

245 _LOG.info("Registering %s with registry", dataset_type) 

246 registry.registerDatasetType(dataset_type) 

247 

248 

249def makeSimplePipeline(nQuanta: int, instrument: str | None = None) -> Pipeline: 

250 """Make a simple Pipeline for tests. 

251 

252 This is called by `makeSimpleQGraph()` if no pipeline is passed to that 

253 function. It can also be used to customize the pipeline used by 

254 `makeSimpleQGraph()` function by calling this first and passing the result 

255 to it. 

256 

257 Parameters 

258 ---------- 

259 nQuanta : `int` 

260 The number of quanta to add to the pipeline. 

261 instrument : `str` or `None`, optional 

262 The importable name of an instrument to be added to the pipeline or 

263 if no instrument should be added then an empty string or `None`, by 

264 default `None`. 

265 

266 Returns 

267 ------- 

268 pipeline : `~lsst.pipe.base.Pipeline` 

269 The created pipeline object. 

270 """ 

271 pipeline = Pipeline("test pipeline") 

272 # make a bunch of tasks that execute in well defined order (via data 

273 # dependencies) 

274 for lvl in range(nQuanta): 

275 pipeline.addTask(AddTask, f"task{lvl}") 

276 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", lvl) 

277 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", lvl + 1) 

278 if instrument: 

279 pipeline.addInstrument(instrument) 

280 return pipeline 

281 

282 

283def makeSimpleButler( 

284 root: str, run: str = "test", inMemory: bool = True, config: Config | str | None = None 

285) -> Butler: 

286 """Create new data butler instance. 

287 

288 Parameters 

289 ---------- 

290 root : `str` 

291 Path or URI to the root location of the new repository. 

292 run : `str`, optional 

293 Run collection name. 

294 inMemory : `bool`, optional 

295 If true make in-memory repository. 

296 config : `~lsst.daf.butler.Config`, optional 

297 Configuration to use for new Butler, if `None` then default 

298 configuration is used. If ``inMemory`` is `True` then configuration 

299 is updated to use SQLite registry and file datastore in ``root``. 

300 

301 Returns 

302 ------- 

303 butler : `~lsst.daf.butler.Butler` 

304 Data butler instance. 

305 """ 

306 root_path = ResourcePath(root, forceDirectory=True) 

307 if not root_path.isLocal: 

308 raise ValueError(f"Only works with local root not {root_path}") 

309 butler_config = Config() 

310 if config: 

311 butler_config.update(Config(config)) 

312 if not inMemory: 

313 butler_config["registry", "db"] = f"sqlite:///{root_path.ospath}/gen3.sqlite" 

314 butler_config["datastore", "cls"] = "lsst.daf.butler.datastores.fileDatastore.FileDatastore" 

315 repo = butlerTests.makeTestRepo(str(root_path), {}, config=butler_config) 

316 butler = Butler.from_config(butler=repo, run=run) 

317 return butler 

318 

319 

320def populateButler( 

321 pipeline: Pipeline | PipelineGraph, 

322 butler: Butler, 

323 datasetTypes: dict[str | None, list[str]] | None = None, 

324 instrument: str | None = None, 

325) -> None: 

326 """Populate data butler with data needed for test. 

327 

328 Initializes data butler with a bunch of items: 

329 - registers dataset types which are defined by pipeline 

330 - create dimensions data for (instrument, detector) 

331 - adds datasets based on ``datasetTypes`` dictionary, if dictionary is 

332 missing then a single dataset with type "add_dataset0" is added. 

333 

334 All datasets added to butler have ``dataId={instrument=instrument, 

335 detector=0}`` where ``instrument`` is extracted from pipeline, "INSTR" is 

336 used if pipeline is missing instrument definition. Type of the dataset is 

337 guessed from dataset type name (assumes that pipeline is made of `AddTask` 

338 tasks). 

339 

340 Parameters 

341 ---------- 

342 pipeline : `.Pipeline` or `.pipeline_graph.PipelineGraph` 

343 Pipeline or pipeline graph instance. 

344 butler : `~lsst.daf.butler.Butler` 

345 Data butler instance. 

346 datasetTypes : `dict` [ `str`, `list` ], optional 

347 Dictionary whose keys are collection names and values are lists of 

348 dataset type names. By default a single dataset of type "add_dataset0" 

349 is added to a ``butler.run`` collection. 

350 instrument : `str`, optional 

351 Fully-qualified name of the instrumemnt class (as it appears in a 

352 pipeline). This is needed to propagate the instrument from the 

353 original pipeline if a pipeline graph is passed instead. 

354 """ 

355 # Add dataset types to registry 

356 match pipeline: 

357 case PipelineGraph() as pipeline_graph: 

358 pass 

359 case Pipeline(): 

360 pipeline_graph = pipeline.to_graph() 

361 if instrument is None: 

362 instrument = pipeline.getInstrument() 

363 case _: 

364 raise TypeError(f"Unexpected pipeline object: {pipeline!r}.") 

365 

366 registerDatasetTypes(butler.registry, pipeline_graph) 

367 

368 if instrument is not None: 

369 instrument_class = doImportType(instrument) 

370 instrumentName = cast(Instrument, instrument_class).getName() 

371 instrumentClass = get_full_type_name(instrument_class) 

372 else: 

373 instrumentName = "INSTR" 

374 instrumentClass = None 

375 

376 # Add all needed dimensions to registry 

377 butler.registry.insertDimensionData("instrument", dict(name=instrumentName, class_name=instrumentClass)) 

378 butler.registry.insertDimensionData("detector", dict(instrument=instrumentName, id=0, full_name="det0")) 

379 

380 # Add inputs to butler. 

381 if not datasetTypes: 

382 datasetTypes = {None: ["add_dataset0"]} 

383 for run, dsTypes in datasetTypes.items(): 

384 if run is not None: 

385 butler.registry.registerRun(run) 

386 for dsType in dsTypes: 

387 if dsType == "packages": 

388 # Version is intentionally inconsistent. 

389 # Dict is convertible to Packages if Packages is installed. 

390 data: Any = {"python": "9.9.99"} 

391 butler.put(data, dsType, run=run) 

392 else: 

393 if dsType.endswith("_config"): 

394 # find a config from matching task name or make a new one 

395 taskLabel, _, _ = dsType.rpartition("_") 

396 task_node = pipeline_graph.tasks.get(taskLabel) 

397 if task_node is not None: 

398 data = task_node.config 

399 else: 

400 data = AddTaskConfig() 

401 elif dsType.endswith("_metadata"): 

402 data = _TASK_FULL_METADATA_TYPE() 

403 elif dsType.endswith("_log"): 

404 data = ButlerLogRecords.from_records([]) 

405 else: 

406 data = numpy.array([0.0, 1.0, 2.0, 5.0]) 

407 butler.put(data, dsType, run=run, instrument=instrumentName, detector=0) 

408 

409 

410def makeSimpleQGraph( 

411 nQuanta: int = 5, 

412 pipeline: Pipeline | PipelineGraph | None = None, 

413 butler: Butler | None = None, 

414 root: str | None = None, 

415 callPopulateButler: bool = True, 

416 run: str = "test", 

417 instrument: str | None = None, 

418 skipExistingIn: Any = None, 

419 inMemory: bool = True, 

420 userQuery: str = "", 

421 datasetTypes: dict[str | None, list[str]] | None = None, 

422 datasetQueryConstraint: DSQVariant = DSQVariant.ALL, 

423 makeDatastoreRecords: bool = False, 

424 bind: Mapping[str, Any] | None = None, 

425 metadata: MutableMapping[str, Any] | None = None, 

426) -> tuple[Butler, QuantumGraph]: 

427 """Make simple `QuantumGraph` for tests. 

428 

429 Makes simple one-task pipeline with AddTask, sets up in-memory registry 

430 and butler, fills them with minimal data, and generates QuantumGraph with 

431 all of that. 

432 

433 Parameters 

434 ---------- 

435 nQuanta : `int` 

436 Number of quanta in a graph, only used if ``pipeline`` is None. 

437 pipeline : `.Pipeline` or `.pipeline_graph.PipelineGraph` 

438 Pipeline or pipeline graph to build the graph from. If `None`, a 

439 pipeline is made with `AddTask` and default `AddTaskConfig`. 

440 butler : `~lsst.daf.butler.Butler`, optional 

441 Data butler instance, if None then new data butler is created by 

442 calling `makeSimpleButler`. 

443 root : `str` 

444 Path or URI to the root location of the new repository. Only used if 

445 ``butler`` is None. 

446 callPopulateButler : `bool`, optional 

447 If True insert datasets into the butler prior to building a graph. 

448 If False butler argument must not be None, and must be pre-populated. 

449 Defaults to True. 

450 run : `str`, optional 

451 Name of the RUN collection to add to butler, only used if ``butler`` 

452 is None. 

453 instrument : `str` or `None`, optional 

454 The importable name of an instrument to be added to the pipeline or 

455 if no instrument should be added then an empty string or `None`, by 

456 default `None`. Only used if ``pipeline`` is `None`. 

457 skipExistingIn : `~typing.Any` 

458 Expressions representing the collections to search for existing 

459 output datasets that should be skipped. See 

460 :ref:`daf_butler_ordered_collection_searches`. 

461 inMemory : `bool`, optional 

462 If true make in-memory repository, only used if ``butler`` is `None`. 

463 userQuery : `str`, optional 

464 The user query to pass to ``makeGraph``, by default an empty string. 

465 datasetTypes : `dict` [ `str`, `list` ], optional 

466 Dictionary whose keys are collection names and values are lists of 

467 dataset type names. By default a single dataset of type "add_dataset0" 

468 is added to a ``butler.run`` collection. 

469 datasetQueryConstraint : `DatasetQueryConstraintVariant` 

470 The query constraint variant that should be used to constrain the 

471 query based on dataset existence, defaults to 

472 `DatasetQueryConstraintVariant.ALL`. 

473 makeDatastoreRecords : `bool`, optional 

474 If `True` then add datastore records to generated quanta. 

475 bind : `~collections.abc.Mapping`, optional 

476 Mapping containing literal values that should be injected into the 

477 ``userQuery`` expression, keyed by the identifiers they replace. 

478 metadata : `~collections.abc.Mapping`, optional 

479 Optional graph metadata. 

480 

481 Returns 

482 ------- 

483 butler : `~lsst.daf.butler.Butler` 

484 Butler instance. 

485 qgraph : `~lsst.pipe.base.QuantumGraph` 

486 Quantum graph instance. 

487 """ 

488 match pipeline: 

489 case PipelineGraph() as pipeline_graph: 

490 pass 

491 case Pipeline(): 

492 pipeline_graph = pipeline.to_graph() 

493 if instrument is None: 

494 instrument = pipeline.getInstrument() 

495 case None: 

496 pipeline_graph = makeSimplePipeline(nQuanta=nQuanta, instrument=instrument).to_graph() 

497 case _: 

498 raise TypeError(f"Unexpected pipeline object: {pipeline!r}.") 

499 

500 if butler is None: 

501 if root is None: 

502 raise ValueError("Must provide `root` when `butler` is None") 

503 if callPopulateButler is False: 

504 raise ValueError("populateButler can only be False when butler is supplied as an argument") 

505 butler = makeSimpleButler(root, run=run, inMemory=inMemory) 

506 

507 if callPopulateButler: 

508 populateButler(pipeline_graph, butler, datasetTypes=datasetTypes, instrument=instrument) 

509 

510 # Make the graph 

511 _LOG.debug( 

512 "Instantiating QuantumGraphBuilder, " 

513 "skip_existing_in=%s, input_collections=%r, output_run=%r, where=%r, bind=%s.", 

514 skipExistingIn, 

515 butler.collections, 

516 run, 

517 userQuery, 

518 bind, 

519 ) 

520 if not run: 

521 assert butler.run is not None, "Butler must have run defined" 

522 run = butler.run 

523 builder = AllDimensionsQuantumGraphBuilder( 

524 pipeline_graph, 

525 butler, 

526 skip_existing_in=skipExistingIn if skipExistingIn is not None else [], 

527 input_collections=butler.collections if butler.collections is not None else [run], 

528 output_run=run, 

529 where=userQuery, 

530 bind=bind, 

531 dataset_query_constraint=datasetQueryConstraint, 

532 ) 

533 _LOG.debug("Calling QuantumGraphBuilder.build.") 

534 if not metadata: 

535 metadata = {} 

536 metadata["output_run"] = run 

537 

538 qgraph = builder.build(metadata=metadata, attach_datastore_records=makeDatastoreRecords) 

539 

540 return butler, qgraph