Coverage for python/lsst/pipe/base/tests/simpleQGraph.py: 27%

161 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-27 02:40 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28"""Bunch of common classes and methods for use in unit tests. 

29""" 

30from __future__ import annotations 

31 

32__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"] 

33 

34import itertools 

35import logging 

36from collections.abc import Iterable, Mapping, MutableMapping 

37from typing import TYPE_CHECKING, Any, cast 

38 

39import lsst.daf.butler.tests as butlerTests 

40import lsst.pex.config as pexConfig 

41import numpy 

42from lsst.daf.butler import Butler, Config, DataId, DatasetRef, DatasetType, Formatter, LimitedButler 

43from lsst.daf.butler.logging import ButlerLogRecords 

44from lsst.resources import ResourcePath 

45from lsst.utils import doImportType 

46from lsst.utils.introspection import get_full_type_name 

47 

48from .. import connectionTypes as cT 

49from .._instrument import Instrument 

50from ..config import PipelineTaskConfig 

51from ..connections import PipelineTaskConnections 

52from ..graph import QuantumGraph 

53from ..graphBuilder import DatasetQueryConstraintVariant as DSQVariant 

54from ..graphBuilder import GraphBuilder 

55from ..pipeline import Pipeline, TaskDatasetTypes, TaskDef 

56from ..pipelineTask import PipelineTask 

57from ..struct import Struct 

58from ..task import _TASK_FULL_METADATA_TYPE 

59from ..taskFactory import TaskFactory 

60 

61if TYPE_CHECKING: 

62 from lsst.daf.butler import Registry 

63 

64_LOG = logging.getLogger(__name__) 

65 

66 

67class SimpleInstrument(Instrument): 

68 """Simple instrument class suitable for testing. 

69 

70 Parameters 

71 ---------- 

72 *args : `~typing.Any` 

73 Ignore parameters. 

74 **kwargs : `~typing.Any` 

75 Unused keyword arguments. 

76 """ 

77 

78 def __init__(self, *args: Any, **kwargs: Any): 

79 pass 

80 

81 @staticmethod 

82 def getName() -> str: 

83 return "INSTRU" 

84 

85 def getRawFormatter(self, dataId: DataId) -> type[Formatter]: 

86 return Formatter 

87 

88 def register(self, registry: Registry, *, update: bool = False) -> None: 

89 pass 

90 

91 

92class AddTaskConnections( 

93 PipelineTaskConnections, 

94 dimensions=("instrument", "detector"), 

95 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"}, 

96): 

97 """Connections for AddTask, has one input and two outputs, 

98 plus one init output. 

99 """ 

100 

101 input = cT.Input( 

102 name="add_dataset{in_tmpl}", 

103 dimensions=["instrument", "detector"], 

104 storageClass="NumpyArray", 

105 doc="Input dataset type for this task", 

106 ) 

107 output = cT.Output( 

108 name="add_dataset{out_tmpl}", 

109 dimensions=["instrument", "detector"], 

110 storageClass="NumpyArray", 

111 doc="Output dataset type for this task", 

112 ) 

113 output2 = cT.Output( 

114 name="add2_dataset{out_tmpl}", 

115 dimensions=["instrument", "detector"], 

116 storageClass="NumpyArray", 

117 doc="Output dataset type for this task", 

118 ) 

119 initout = cT.InitOutput( 

120 name="add_init_output{out_tmpl}", 

121 storageClass="NumpyArray", 

122 doc="Init Output dataset type for this task", 

123 ) 

124 

125 

126class AddTaskConfig(PipelineTaskConfig, pipelineConnections=AddTaskConnections): 

127 """Config for AddTask.""" 

128 

129 addend = pexConfig.Field[int](doc="amount to add", default=3) 

130 

131 

132class AddTask(PipelineTask): 

133 """Trivial PipelineTask for testing, has some extras useful for specific 

134 unit tests. 

135 """ 

136 

137 ConfigClass = AddTaskConfig 

138 _DefaultName = "add_task" 

139 

140 initout = numpy.array([999]) 

141 """InitOutputs for this task""" 

142 

143 taskFactory: AddTaskFactoryMock | None = None 

144 """Factory that makes instances""" 

145 

146 def run(self, input: int) -> Struct: 

147 if self.taskFactory: 

148 # do some bookkeeping 

149 if self.taskFactory.stopAt == self.taskFactory.countExec: 

150 raise RuntimeError("pretend something bad happened") 

151 self.taskFactory.countExec += 1 

152 

153 self.config = cast(AddTaskConfig, self.config) 

154 self.metadata.add("add", self.config.addend) 

155 output = input + self.config.addend 

156 output2 = output + self.config.addend 

157 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2) 

158 return Struct(output=output, output2=output2) 

159 

160 

161class AddTaskFactoryMock(TaskFactory): 

162 """Special task factory that instantiates AddTask. 

163 

164 It also defines some bookkeeping variables used by AddTask to report 

165 progress to unit tests. 

166 

167 Parameters 

168 ---------- 

169 stopAt : `int`, optional 

170 Number of times to call `run` before stopping. 

171 """ 

172 

173 def __init__(self, stopAt: int = -1): 

174 self.countExec = 0 # incremented by AddTask 

175 self.stopAt = stopAt # AddTask raises exception at this call to run() 

176 

177 def makeTask( 

178 self, taskDef: TaskDef, butler: LimitedButler, initInputRefs: Iterable[DatasetRef] | None 

179 ) -> PipelineTask: 

180 taskClass = taskDef.taskClass 

181 assert taskClass is not None 

182 task = taskClass(config=taskDef.config, initInputs=None, name=taskDef.label) 

183 task.taskFactory = self # type: ignore 

184 return task 

185 

186 

187def registerDatasetTypes(registry: Registry, pipeline: Pipeline | Iterable[TaskDef]) -> None: 

188 """Register all dataset types used by tasks in a registry. 

189 

190 Copied and modified from `PreExecInit.initializeDatasetTypes`. 

191 

192 Parameters 

193 ---------- 

194 registry : `~lsst.daf.butler.Registry` 

195 Registry instance. 

196 pipeline : `typing.Iterable` of `TaskDef` 

197 Iterable of TaskDef instances, likely the output of the method 

198 `Pipelines.toExpandedPipeline` on a `~lsst.pipe.base.Pipeline` object. 

199 """ 

200 for taskDef in pipeline: 

201 configDatasetType = DatasetType( 

202 taskDef.configDatasetName, {}, storageClass="Config", universe=registry.dimensions 

203 ) 

204 storageClass = "Packages" 

205 packagesDatasetType = DatasetType( 

206 "packages", {}, storageClass=storageClass, universe=registry.dimensions 

207 ) 

208 datasetTypes = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry) 

209 for datasetType in itertools.chain( 

210 datasetTypes.initInputs, 

211 datasetTypes.initOutputs, 

212 datasetTypes.inputs, 

213 datasetTypes.outputs, 

214 datasetTypes.prerequisites, 

215 [configDatasetType, packagesDatasetType], 

216 ): 

217 _LOG.info("Registering %s with registry", datasetType) 

218 # this is a no-op if it already exists and is consistent, 

219 # and it raises if it is inconsistent. But components must be 

220 # skipped 

221 if not datasetType.isComponent(): 

222 registry.registerDatasetType(datasetType) 

223 

224 

225def makeSimplePipeline(nQuanta: int, instrument: str | None = None) -> Pipeline: 

226 """Make a simple Pipeline for tests. 

227 

228 This is called by `makeSimpleQGraph()` if no pipeline is passed to that 

229 function. It can also be used to customize the pipeline used by 

230 `makeSimpleQGraph()` function by calling this first and passing the result 

231 to it. 

232 

233 Parameters 

234 ---------- 

235 nQuanta : `int` 

236 The number of quanta to add to the pipeline. 

237 instrument : `str` or `None`, optional 

238 The importable name of an instrument to be added to the pipeline or 

239 if no instrument should be added then an empty string or `None`, by 

240 default `None`. 

241 

242 Returns 

243 ------- 

244 pipeline : `~lsst.pipe.base.Pipeline` 

245 The created pipeline object. 

246 """ 

247 pipeline = Pipeline("test pipeline") 

248 # make a bunch of tasks that execute in well defined order (via data 

249 # dependencies) 

250 for lvl in range(nQuanta): 

251 pipeline.addTask(AddTask, f"task{lvl}") 

252 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", lvl) 

253 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", lvl + 1) 

254 if instrument: 

255 pipeline.addInstrument(instrument) 

256 return pipeline 

257 

258 

259def makeSimpleButler( 

260 root: str, run: str = "test", inMemory: bool = True, config: Config | str | None = None 

261) -> Butler: 

262 """Create new data butler instance. 

263 

264 Parameters 

265 ---------- 

266 root : `str` 

267 Path or URI to the root location of the new repository. 

268 run : `str`, optional 

269 Run collection name. 

270 inMemory : `bool`, optional 

271 If true make in-memory repository. 

272 config : `~lsst.daf.butler.Config`, optional 

273 Configuration to use for new Butler, if `None` then default 

274 configuration is used. If ``inMemory`` is `True` then configuration 

275 is updated to use SQLite registry and file datastore in ``root``. 

276 

277 Returns 

278 ------- 

279 butler : `~lsst.daf.butler.Butler` 

280 Data butler instance. 

281 """ 

282 root_path = ResourcePath(root, forceDirectory=True) 

283 if not root_path.isLocal: 

284 raise ValueError(f"Only works with local root not {root_path}") 

285 butler_config = Config() 

286 if config: 

287 butler_config.update(Config(config)) 

288 if not inMemory: 

289 butler_config["registry", "db"] = f"sqlite:///{root_path.ospath}/gen3.sqlite" 

290 butler_config["datastore", "cls"] = "lsst.daf.butler.datastores.fileDatastore.FileDatastore" 

291 repo = butlerTests.makeTestRepo(str(root_path), {}, config=butler_config) 

292 butler = Butler.from_config(butler=repo, run=run) 

293 return butler 

294 

295 

296def populateButler( 

297 pipeline: Pipeline, butler: Butler, datasetTypes: dict[str | None, list[str]] | None = None 

298) -> None: 

299 """Populate data butler with data needed for test. 

300 

301 Initializes data butler with a bunch of items: 

302 - registers dataset types which are defined by pipeline 

303 - create dimensions data for (instrument, detector) 

304 - adds datasets based on ``datasetTypes`` dictionary, if dictionary is 

305 missing then a single dataset with type "add_dataset0" is added 

306 

307 All datasets added to butler have ``dataId={instrument=instrument, 

308 detector=0}`` where ``instrument`` is extracted from pipeline, "INSTR" is 

309 used if pipeline is missing instrument definition. Type of the dataset is 

310 guessed from dataset type name (assumes that pipeline is made of `AddTask` 

311 tasks). 

312 

313 Parameters 

314 ---------- 

315 pipeline : `~lsst.pipe.base.Pipeline` 

316 Pipeline instance. 

317 butler : `~lsst.daf.butler.Butler` 

318 Data butler instance. 

319 datasetTypes : `dict` [ `str`, `list` ], optional 

320 Dictionary whose keys are collection names and values are lists of 

321 dataset type names. By default a single dataset of type "add_dataset0" 

322 is added to a ``butler.run`` collection. 

323 """ 

324 # Add dataset types to registry 

325 taskDefs = list(pipeline.toExpandedPipeline()) 

326 registerDatasetTypes(butler.registry, taskDefs) 

327 

328 instrument = pipeline.getInstrument() 

329 if instrument is not None: 

330 instrument_class = doImportType(instrument) 

331 instrumentName = cast(Instrument, instrument_class).getName() 

332 instrumentClass = get_full_type_name(instrument_class) 

333 else: 

334 instrumentName = "INSTR" 

335 instrumentClass = None 

336 

337 # Add all needed dimensions to registry 

338 butler.registry.insertDimensionData("instrument", dict(name=instrumentName, class_name=instrumentClass)) 

339 butler.registry.insertDimensionData("detector", dict(instrument=instrumentName, id=0, full_name="det0")) 

340 

341 taskDefMap = {taskDef.label: taskDef for taskDef in taskDefs} 

342 # Add inputs to butler 

343 if not datasetTypes: 

344 datasetTypes = {None: ["add_dataset0"]} 

345 for run, dsTypes in datasetTypes.items(): 

346 if run is not None: 

347 butler.registry.registerRun(run) 

348 for dsType in dsTypes: 

349 if dsType == "packages": 

350 # Version is intentionally inconsistent. 

351 # Dict is convertible to Packages if Packages is installed. 

352 data: Any = {"python": "9.9.99"} 

353 butler.put(data, dsType, run=run) 

354 else: 

355 if dsType.endswith("_config"): 

356 # find a config from matching task name or make a new one 

357 taskLabel, _, _ = dsType.rpartition("_") 

358 taskDef = taskDefMap.get(taskLabel) 

359 if taskDef is not None: 

360 data = taskDef.config 

361 else: 

362 data = AddTaskConfig() 

363 elif dsType.endswith("_metadata"): 

364 data = _TASK_FULL_METADATA_TYPE() 

365 elif dsType.endswith("_log"): 

366 data = ButlerLogRecords.from_records([]) 

367 else: 

368 data = numpy.array([0.0, 1.0, 2.0, 5.0]) 

369 butler.put(data, dsType, run=run, instrument=instrumentName, detector=0) 

370 

371 

372def makeSimpleQGraph( 

373 nQuanta: int = 5, 

374 pipeline: Pipeline | None = None, 

375 butler: Butler | None = None, 

376 root: str | None = None, 

377 callPopulateButler: bool = True, 

378 run: str = "test", 

379 instrument: str | None = None, 

380 skipExistingIn: Any = None, 

381 inMemory: bool = True, 

382 userQuery: str = "", 

383 datasetTypes: dict[str | None, list[str]] | None = None, 

384 datasetQueryConstraint: DSQVariant = DSQVariant.ALL, 

385 makeDatastoreRecords: bool = False, 

386 bind: Mapping[str, Any] | None = None, 

387 metadata: MutableMapping[str, Any] | None = None, 

388) -> tuple[Butler, QuantumGraph]: 

389 """Make simple `QuantumGraph` for tests. 

390 

391 Makes simple one-task pipeline with AddTask, sets up in-memory registry 

392 and butler, fills them with minimal data, and generates QuantumGraph with 

393 all of that. 

394 

395 Parameters 

396 ---------- 

397 nQuanta : `int` 

398 Number of quanta in a graph, only used if ``pipeline`` is None. 

399 pipeline : `~lsst.pipe.base.Pipeline` 

400 If `None` then a pipeline is made with `AddTask` and default 

401 `AddTaskConfig`. 

402 butler : `~lsst.daf.butler.Butler`, optional 

403 Data butler instance, if None then new data butler is created by 

404 calling `makeSimpleButler`. 

405 root : `str` 

406 Path or URI to the root location of the new repository. Only used if 

407 ``butler`` is None. 

408 callPopulateButler : `bool`, optional 

409 If True insert datasets into the butler prior to building a graph. 

410 If False butler argument must not be None, and must be pre-populated. 

411 Defaults to True. 

412 run : `str`, optional 

413 Name of the RUN collection to add to butler, only used if ``butler`` 

414 is None. 

415 instrument : `str` or `None`, optional 

416 The importable name of an instrument to be added to the pipeline or 

417 if no instrument should be added then an empty string or `None`, by 

418 default `None`. Only used if ``pipeline`` is `None`. 

419 skipExistingIn : `~typing.Any` 

420 Expressions representing the collections to search for existing 

421 output datasets that should be skipped. See 

422 :ref:`daf_butler_ordered_collection_searches`. 

423 inMemory : `bool`, optional 

424 If true make in-memory repository, only used if ``butler`` is `None`. 

425 userQuery : `str`, optional 

426 The user query to pass to ``makeGraph``, by default an empty string. 

427 datasetTypes : `dict` [ `str`, `list` ], optional 

428 Dictionary whose keys are collection names and values are lists of 

429 dataset type names. By default a single dataset of type "add_dataset0" 

430 is added to a ``butler.run`` collection. 

431 datasetQueryConstraint : `DatasetQueryConstraintVariant` 

432 The query constraint variant that should be used to constrain the 

433 query based on dataset existence, defaults to 

434 `DatasetQueryConstraintVariant.ALL`. 

435 makeDatastoreRecords : `bool`, optional 

436 If `True` then add datastore records to generated quanta. 

437 bind : `~collections.abc.Mapping`, optional 

438 Mapping containing literal values that should be injected into the 

439 ``userQuery`` expression, keyed by the identifiers they replace. 

440 metadata : `~collections.abc.Mapping`, optional 

441 Optional graph metadata. 

442 

443 Returns 

444 ------- 

445 butler : `~lsst.daf.butler.Butler` 

446 Butler instance. 

447 qgraph : `~lsst.pipe.base.QuantumGraph` 

448 Quantum graph instance. 

449 """ 

450 if pipeline is None: 

451 pipeline = makeSimplePipeline(nQuanta=nQuanta, instrument=instrument) 

452 

453 if butler is None: 

454 if root is None: 

455 raise ValueError("Must provide `root` when `butler` is None") 

456 if callPopulateButler is False: 

457 raise ValueError("populateButler can only be False when butler is supplied as an argument") 

458 butler = makeSimpleButler(root, run=run, inMemory=inMemory) 

459 

460 if callPopulateButler: 

461 populateButler(pipeline, butler, datasetTypes=datasetTypes) 

462 

463 # Make the graph 

464 _LOG.debug("Instantiating GraphBuilder, skipExistingIn=%s", skipExistingIn) 

465 builder = GraphBuilder( 

466 registry=butler.registry, 

467 skipExistingIn=skipExistingIn, 

468 datastore=butler._datastore if makeDatastoreRecords else None, 

469 ) 

470 if not run: 

471 assert butler.run is not None, "Butler must have run defined" 

472 run = butler.run 

473 _LOG.debug( 

474 "Calling GraphBuilder.makeGraph, collections=%r, run=%r, userQuery=%r bind=%s", 

475 butler.collections, 

476 run, 

477 userQuery, 

478 bind, 

479 ) 

480 if not metadata: 

481 metadata = {} 

482 metadata["output_run"] = run 

483 

484 qgraph = builder.makeGraph( 

485 pipeline, 

486 collections=butler.collections, 

487 run=run, 

488 userQuery=userQuery, 

489 datasetQueryConstraint=datasetQueryConstraint, 

490 bind=bind, 

491 metadata=metadata, 

492 ) 

493 

494 return butler, qgraph