Coverage for python/lsst/pipe/base/tests/simpleQGraph.py: 27%

161 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-11 09:32 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28"""Bunch of common classes and methods for use in unit tests. 

29""" 

30from __future__ import annotations 

31 

32__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"] 

33 

34import itertools 

35import logging 

36from collections.abc import Iterable, Mapping, MutableMapping 

37from typing import TYPE_CHECKING, Any, cast 

38 

39import lsst.daf.butler.tests as butlerTests 

40import lsst.pex.config as pexConfig 

41import numpy 

42from lsst.daf.butler import Butler, Config, DataId, DatasetRef, DatasetType, Formatter, LimitedButler 

43from lsst.daf.butler.logging import ButlerLogRecords 

44from lsst.resources import ResourcePath 

45from lsst.utils import doImportType 

46from lsst.utils.introspection import get_full_type_name 

47 

48from .. import connectionTypes as cT 

49from .._instrument import Instrument 

50from ..config import PipelineTaskConfig 

51from ..connections import PipelineTaskConnections 

52from ..graph import QuantumGraph 

53from ..graphBuilder import DatasetQueryConstraintVariant as DSQVariant 

54from ..graphBuilder import GraphBuilder 

55from ..pipeline import Pipeline, TaskDatasetTypes, TaskDef 

56from ..pipelineTask import PipelineTask 

57from ..struct import Struct 

58from ..task import _TASK_FULL_METADATA_TYPE 

59from ..taskFactory import TaskFactory 

60 

61if TYPE_CHECKING: 

62 from lsst.daf.butler import Registry 

63 

64_LOG = logging.getLogger(__name__) 

65 

66 

67class SimpleInstrument(Instrument): 

68 """Simple instrument class suitable for testing.""" 

69 

70 def __init__(self, *args: Any, **kwargs: Any): 

71 pass 

72 

73 @staticmethod 

74 def getName() -> str: 

75 return "INSTRU" 

76 

77 def getRawFormatter(self, dataId: DataId) -> type[Formatter]: 

78 return Formatter 

79 

80 def register(self, registry: Registry, *, update: bool = False) -> None: 

81 pass 

82 

83 

84class AddTaskConnections( 

85 PipelineTaskConnections, 

86 dimensions=("instrument", "detector"), 

87 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"}, 

88): 

89 """Connections for AddTask, has one input and two outputs, 

90 plus one init output. 

91 """ 

92 

93 input = cT.Input( 

94 name="add_dataset{in_tmpl}", 

95 dimensions=["instrument", "detector"], 

96 storageClass="NumpyArray", 

97 doc="Input dataset type for this task", 

98 ) 

99 output = cT.Output( 

100 name="add_dataset{out_tmpl}", 

101 dimensions=["instrument", "detector"], 

102 storageClass="NumpyArray", 

103 doc="Output dataset type for this task", 

104 ) 

105 output2 = cT.Output( 

106 name="add2_dataset{out_tmpl}", 

107 dimensions=["instrument", "detector"], 

108 storageClass="NumpyArray", 

109 doc="Output dataset type for this task", 

110 ) 

111 initout = cT.InitOutput( 

112 name="add_init_output{out_tmpl}", 

113 storageClass="NumpyArray", 

114 doc="Init Output dataset type for this task", 

115 ) 

116 

117 

118class AddTaskConfig(PipelineTaskConfig, pipelineConnections=AddTaskConnections): 

119 """Config for AddTask.""" 

120 

121 addend = pexConfig.Field[int](doc="amount to add", default=3) 

122 

123 

124class AddTask(PipelineTask): 

125 """Trivial PipelineTask for testing, has some extras useful for specific 

126 unit tests. 

127 """ 

128 

129 ConfigClass = AddTaskConfig 

130 _DefaultName = "add_task" 

131 

132 initout = numpy.array([999]) 

133 """InitOutputs for this task""" 

134 

135 taskFactory: AddTaskFactoryMock | None = None 

136 """Factory that makes instances""" 

137 

138 def run(self, input: int) -> Struct: 

139 if self.taskFactory: 

140 # do some bookkeeping 

141 if self.taskFactory.stopAt == self.taskFactory.countExec: 

142 raise RuntimeError("pretend something bad happened") 

143 self.taskFactory.countExec += 1 

144 

145 self.config = cast(AddTaskConfig, self.config) 

146 self.metadata.add("add", self.config.addend) 

147 output = input + self.config.addend 

148 output2 = output + self.config.addend 

149 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2) 

150 return Struct(output=output, output2=output2) 

151 

152 

153class AddTaskFactoryMock(TaskFactory): 

154 """Special task factory that instantiates AddTask. 

155 

156 It also defines some bookkeeping variables used by AddTask to report 

157 progress to unit tests. 

158 """ 

159 

160 def __init__(self, stopAt: int = -1): 

161 self.countExec = 0 # incremented by AddTask 

162 self.stopAt = stopAt # AddTask raises exception at this call to run() 

163 

164 def makeTask( 

165 self, taskDef: TaskDef, butler: LimitedButler, initInputRefs: Iterable[DatasetRef] | None 

166 ) -> PipelineTask: 

167 taskClass = taskDef.taskClass 

168 assert taskClass is not None 

169 task = taskClass(config=taskDef.config, initInputs=None, name=taskDef.label) 

170 task.taskFactory = self # type: ignore 

171 return task 

172 

173 

174def registerDatasetTypes(registry: Registry, pipeline: Pipeline | Iterable[TaskDef]) -> None: 

175 """Register all dataset types used by tasks in a registry. 

176 

177 Copied and modified from `PreExecInit.initializeDatasetTypes`. 

178 

179 Parameters 

180 ---------- 

181 registry : `~lsst.daf.butler.Registry` 

182 Registry instance. 

183 pipeline : `typing.Iterable` of `TaskDef` 

184 Iterable of TaskDef instances, likely the output of the method 

185 `Pipelines.toExpandedPipeline` on a `~lsst.pipe.base.Pipeline` object. 

186 """ 

187 for taskDef in pipeline: 

188 configDatasetType = DatasetType( 

189 taskDef.configDatasetName, {}, storageClass="Config", universe=registry.dimensions 

190 ) 

191 storageClass = "Packages" 

192 packagesDatasetType = DatasetType( 

193 "packages", {}, storageClass=storageClass, universe=registry.dimensions 

194 ) 

195 datasetTypes = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry) 

196 for datasetType in itertools.chain( 

197 datasetTypes.initInputs, 

198 datasetTypes.initOutputs, 

199 datasetTypes.inputs, 

200 datasetTypes.outputs, 

201 datasetTypes.prerequisites, 

202 [configDatasetType, packagesDatasetType], 

203 ): 

204 _LOG.info("Registering %s with registry", datasetType) 

205 # this is a no-op if it already exists and is consistent, 

206 # and it raises if it is inconsistent. But components must be 

207 # skipped 

208 if not datasetType.isComponent(): 

209 registry.registerDatasetType(datasetType) 

210 

211 

212def makeSimplePipeline(nQuanta: int, instrument: str | None = None) -> Pipeline: 

213 """Make a simple Pipeline for tests. 

214 

215 This is called by `makeSimpleQGraph()` if no pipeline is passed to that 

216 function. It can also be used to customize the pipeline used by 

217 `makeSimpleQGraph()` function by calling this first and passing the result 

218 to it. 

219 

220 Parameters 

221 ---------- 

222 nQuanta : `int` 

223 The number of quanta to add to the pipeline. 

224 instrument : `str` or `None`, optional 

225 The importable name of an instrument to be added to the pipeline or 

226 if no instrument should be added then an empty string or `None`, by 

227 default None 

228 

229 Returns 

230 ------- 

231 pipeline : `~lsst.pipe.base.Pipeline` 

232 The created pipeline object. 

233 """ 

234 pipeline = Pipeline("test pipeline") 

235 # make a bunch of tasks that execute in well defined order (via data 

236 # dependencies) 

237 for lvl in range(nQuanta): 

238 pipeline.addTask(AddTask, f"task{lvl}") 

239 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", lvl) 

240 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", lvl + 1) 

241 if instrument: 

242 pipeline.addInstrument(instrument) 

243 return pipeline 

244 

245 

246def makeSimpleButler( 

247 root: str, run: str = "test", inMemory: bool = True, config: Config | str | None = None 

248) -> Butler: 

249 """Create new data butler instance. 

250 

251 Parameters 

252 ---------- 

253 root : `str` 

254 Path or URI to the root location of the new repository. 

255 run : `str`, optional 

256 Run collection name. 

257 inMemory : `bool`, optional 

258 If true make in-memory repository. 

259 config : `~lsst.daf.butler.Config`, optional 

260 Configuration to use for new Butler, if `None` then default 

261 configuration is used. If ``inMemory`` is `True` then configuration 

262 is updated to use SQLite registry and file datastore in ``root``. 

263 

264 Returns 

265 ------- 

266 butler : `~lsst.daf.butler.Butler` 

267 Data butler instance. 

268 """ 

269 root_path = ResourcePath(root, forceDirectory=True) 

270 if not root_path.isLocal: 

271 raise ValueError(f"Only works with local root not {root_path}") 

272 butler_config = Config() 

273 if config: 

274 butler_config.update(Config(config)) 

275 if not inMemory: 

276 butler_config["registry", "db"] = f"sqlite:///{root_path.ospath}/gen3.sqlite" 

277 butler_config["datastore", "cls"] = "lsst.daf.butler.datastores.fileDatastore.FileDatastore" 

278 repo = butlerTests.makeTestRepo(str(root_path), {}, config=butler_config) 

279 butler = Butler(butler=repo, run=run) 

280 return butler 

281 

282 

283def populateButler( 

284 pipeline: Pipeline, butler: Butler, datasetTypes: dict[str | None, list[str]] | None = None 

285) -> None: 

286 """Populate data butler with data needed for test. 

287 

288 Initializes data butler with a bunch of items: 

289 - registers dataset types which are defined by pipeline 

290 - create dimensions data for (instrument, detector) 

291 - adds datasets based on ``datasetTypes`` dictionary, if dictionary is 

292 missing then a single dataset with type "add_dataset0" is added 

293 

294 All datasets added to butler have ``dataId={instrument=instrument, 

295 detector=0}`` where ``instrument`` is extracted from pipeline, "INSTR" is 

296 used if pipeline is missing instrument definition. Type of the dataset is 

297 guessed from dataset type name (assumes that pipeline is made of `AddTask` 

298 tasks). 

299 

300 Parameters 

301 ---------- 

302 pipeline : `~lsst.pipe.base.Pipeline` 

303 Pipeline instance. 

304 butler : `~lsst.daf.butler.Butler` 

305 Data butler instance. 

306 datasetTypes : `dict` [ `str`, `list` ], optional 

307 Dictionary whose keys are collection names and values are lists of 

308 dataset type names. By default a single dataset of type "add_dataset0" 

309 is added to a ``butler.run`` collection. 

310 """ 

311 # Add dataset types to registry 

312 taskDefs = list(pipeline.toExpandedPipeline()) 

313 registerDatasetTypes(butler.registry, taskDefs) 

314 

315 instrument = pipeline.getInstrument() 

316 if instrument is not None: 

317 instrument_class = doImportType(instrument) 

318 instrumentName = cast(Instrument, instrument_class).getName() 

319 instrumentClass = get_full_type_name(instrument_class) 

320 else: 

321 instrumentName = "INSTR" 

322 instrumentClass = None 

323 

324 # Add all needed dimensions to registry 

325 butler.registry.insertDimensionData("instrument", dict(name=instrumentName, class_name=instrumentClass)) 

326 butler.registry.insertDimensionData("detector", dict(instrument=instrumentName, id=0, full_name="det0")) 

327 

328 taskDefMap = {taskDef.label: taskDef for taskDef in taskDefs} 

329 # Add inputs to butler 

330 if not datasetTypes: 

331 datasetTypes = {None: ["add_dataset0"]} 

332 for run, dsTypes in datasetTypes.items(): 

333 if run is not None: 

334 butler.registry.registerRun(run) 

335 for dsType in dsTypes: 

336 if dsType == "packages": 

337 # Version is intentionally inconsistent. 

338 # Dict is convertible to Packages if Packages is installed. 

339 data: Any = {"python": "9.9.99"} 

340 butler.put(data, dsType, run=run) 

341 else: 

342 if dsType.endswith("_config"): 

343 # find a config from matching task name or make a new one 

344 taskLabel, _, _ = dsType.rpartition("_") 

345 taskDef = taskDefMap.get(taskLabel) 

346 if taskDef is not None: 

347 data = taskDef.config 

348 else: 

349 data = AddTaskConfig() 

350 elif dsType.endswith("_metadata"): 

351 data = _TASK_FULL_METADATA_TYPE() 

352 elif dsType.endswith("_log"): 

353 data = ButlerLogRecords.from_records([]) 

354 else: 

355 data = numpy.array([0.0, 1.0, 2.0, 5.0]) 

356 butler.put(data, dsType, run=run, instrument=instrumentName, detector=0) 

357 

358 

359def makeSimpleQGraph( 

360 nQuanta: int = 5, 

361 pipeline: Pipeline | None = None, 

362 butler: Butler | None = None, 

363 root: str | None = None, 

364 callPopulateButler: bool = True, 

365 run: str = "test", 

366 instrument: str | None = None, 

367 skipExistingIn: Any = None, 

368 inMemory: bool = True, 

369 userQuery: str = "", 

370 datasetTypes: dict[str | None, list[str]] | None = None, 

371 datasetQueryConstraint: DSQVariant = DSQVariant.ALL, 

372 makeDatastoreRecords: bool = False, 

373 bind: Mapping[str, Any] | None = None, 

374 metadata: MutableMapping[str, Any] | None = None, 

375) -> tuple[Butler, QuantumGraph]: 

376 """Make simple `QuantumGraph` for tests. 

377 

378 Makes simple one-task pipeline with AddTask, sets up in-memory registry 

379 and butler, fills them with minimal data, and generates QuantumGraph with 

380 all of that. 

381 

382 Parameters 

383 ---------- 

384 nQuanta : `int` 

385 Number of quanta in a graph, only used if ``pipeline`` is None. 

386 pipeline : `~lsst.pipe.base.Pipeline` 

387 If `None` then a pipeline is made with `AddTask` and default 

388 `AddTaskConfig`. 

389 butler : `~lsst.daf.butler.Butler`, optional 

390 Data butler instance, if None then new data butler is created by 

391 calling `makeSimpleButler`. 

392 callPopulateButler : `bool`, optional 

393 If True insert datasets into the butler prior to building a graph. 

394 If False butler argument must not be None, and must be pre-populated. 

395 Defaults to True. 

396 root : `str` 

397 Path or URI to the root location of the new repository. Only used if 

398 ``butler`` is None. 

399 run : `str`, optional 

400 Name of the RUN collection to add to butler, only used if ``butler`` 

401 is None. 

402 instrument : `str` or `None`, optional 

403 The importable name of an instrument to be added to the pipeline or 

404 if no instrument should be added then an empty string or `None`, by 

405 default `None`. Only used if ``pipeline`` is `None`. 

406 skipExistingIn 

407 Expressions representing the collections to search for existing 

408 output datasets that should be skipped. See 

409 :ref:`daf_butler_ordered_collection_searches`. 

410 inMemory : `bool`, optional 

411 If true make in-memory repository, only used if ``butler`` is `None`. 

412 userQuery : `str`, optional 

413 The user query to pass to ``makeGraph``, by default an empty string. 

414 datasetTypes : `dict` [ `str`, `list` ], optional 

415 Dictionary whose keys are collection names and values are lists of 

416 dataset type names. By default a single dataset of type "add_dataset0" 

417 is added to a ``butler.run`` collection. 

418 datasetQueryQConstraint : `DatasetQueryConstraintVariant` 

419 The query constraint variant that should be used to constrain the 

420 query based on dataset existence, defaults to 

421 `DatasetQueryConstraintVariant.ALL`. 

422 makeDatastoreRecords : `bool`, optional 

423 If `True` then add datastore records to generated quanta. 

424 bind : `~collections.abc.Mapping`, optional 

425 Mapping containing literal values that should be injected into the 

426 ``userQuery`` expression, keyed by the identifiers they replace. 

427 metadata : `~collections.abc.Mapping`, optional 

428 Optional graph metadata. 

429 

430 Returns 

431 ------- 

432 butler : `~lsst.daf.butler.Butler` 

433 Butler instance 

434 qgraph : `~lsst.pipe.base.QuantumGraph` 

435 Quantum graph instance 

436 """ 

437 if pipeline is None: 

438 pipeline = makeSimplePipeline(nQuanta=nQuanta, instrument=instrument) 

439 

440 if butler is None: 

441 if root is None: 

442 raise ValueError("Must provide `root` when `butler` is None") 

443 if callPopulateButler is False: 

444 raise ValueError("populateButler can only be False when butler is supplied as an argument") 

445 butler = makeSimpleButler(root, run=run, inMemory=inMemory) 

446 

447 if callPopulateButler: 

448 populateButler(pipeline, butler, datasetTypes=datasetTypes) 

449 

450 # Make the graph 

451 _LOG.debug("Instantiating GraphBuilder, skipExistingIn=%s", skipExistingIn) 

452 builder = GraphBuilder( 

453 registry=butler.registry, 

454 skipExistingIn=skipExistingIn, 

455 datastore=butler._datastore if makeDatastoreRecords else None, 

456 ) 

457 if not run: 

458 assert butler.run is not None, "Butler must have run defined" 

459 run = butler.run 

460 _LOG.debug( 

461 "Calling GraphBuilder.makeGraph, collections=%r, run=%r, userQuery=%r bind=%s", 

462 butler.collections, 

463 run, 

464 userQuery, 

465 bind, 

466 ) 

467 if not metadata: 

468 metadata = {} 

469 metadata["output_run"] = run 

470 

471 qgraph = builder.makeGraph( 

472 pipeline, 

473 collections=butler.collections, 

474 run=run, 

475 userQuery=userQuery, 

476 datasetQueryConstraint=datasetQueryConstraint, 

477 bind=bind, 

478 metadata=metadata, 

479 ) 

480 

481 return butler, qgraph