Coverage for python/lsst/pipe/base/tests/simpleQGraph.py: 26%

161 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-12 11:14 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Bunch of common classes and methods for use in unit tests. 

23""" 

24from __future__ import annotations 

25 

26__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"] 

27 

28import itertools 

29import logging 

30from collections.abc import Iterable, Mapping, MutableMapping 

31from typing import TYPE_CHECKING, Any, cast 

32 

33import lsst.daf.butler.tests as butlerTests 

34import lsst.pex.config as pexConfig 

35import numpy 

36from lsst.daf.butler import Butler, Config, DataId, DatasetRef, DatasetType, Formatter, LimitedButler 

37from lsst.daf.butler.core.logging import ButlerLogRecords 

38from lsst.resources import ResourcePath 

39from lsst.utils import doImportType 

40from lsst.utils.introspection import get_full_type_name 

41 

42from .. import connectionTypes as cT 

43from .._instrument import Instrument 

44from ..config import PipelineTaskConfig 

45from ..connections import PipelineTaskConnections 

46from ..graph import QuantumGraph 

47from ..graphBuilder import DatasetQueryConstraintVariant as DSQVariant 

48from ..graphBuilder import GraphBuilder 

49from ..pipeline import Pipeline, TaskDatasetTypes, TaskDef 

50from ..pipelineTask import PipelineTask 

51from ..struct import Struct 

52from ..task import _TASK_FULL_METADATA_TYPE 

53from ..taskFactory import TaskFactory 

54 

55if TYPE_CHECKING: 

56 from lsst.daf.butler import Registry 

57 

58_LOG = logging.getLogger(__name__) 

59 

60 

61class SimpleInstrument(Instrument): 

62 """Simple instrument class suitable for testing.""" 

63 

64 def __init__(self, *args: Any, **kwargs: Any): 

65 pass 

66 

67 @staticmethod 

68 def getName() -> str: 

69 return "INSTRU" 

70 

71 def getRawFormatter(self, dataId: DataId) -> type[Formatter]: 

72 return Formatter 

73 

74 def register(self, registry: Registry, *, update: bool = False) -> None: 

75 pass 

76 

77 

78class AddTaskConnections( 

79 PipelineTaskConnections, 

80 dimensions=("instrument", "detector"), 

81 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"}, 

82): 

83 """Connections for AddTask, has one input and two outputs, 

84 plus one init output. 

85 """ 

86 

87 input = cT.Input( 

88 name="add_dataset{in_tmpl}", 

89 dimensions=["instrument", "detector"], 

90 storageClass="NumpyArray", 

91 doc="Input dataset type for this task", 

92 ) 

93 output = cT.Output( 

94 name="add_dataset{out_tmpl}", 

95 dimensions=["instrument", "detector"], 

96 storageClass="NumpyArray", 

97 doc="Output dataset type for this task", 

98 ) 

99 output2 = cT.Output( 

100 name="add2_dataset{out_tmpl}", 

101 dimensions=["instrument", "detector"], 

102 storageClass="NumpyArray", 

103 doc="Output dataset type for this task", 

104 ) 

105 initout = cT.InitOutput( 

106 name="add_init_output{out_tmpl}", 

107 storageClass="NumpyArray", 

108 doc="Init Output dataset type for this task", 

109 ) 

110 

111 

112class AddTaskConfig(PipelineTaskConfig, pipelineConnections=AddTaskConnections): 

113 """Config for AddTask.""" 

114 

115 addend = pexConfig.Field[int](doc="amount to add", default=3) 

116 

117 

118class AddTask(PipelineTask): 

119 """Trivial PipelineTask for testing, has some extras useful for specific 

120 unit tests. 

121 """ 

122 

123 ConfigClass = AddTaskConfig 

124 _DefaultName = "add_task" 

125 

126 initout = numpy.array([999]) 

127 """InitOutputs for this task""" 

128 

129 taskFactory: AddTaskFactoryMock | None = None 

130 """Factory that makes instances""" 

131 

132 def run(self, input: int) -> Struct: # type: ignore 

133 if self.taskFactory: 

134 # do some bookkeeping 

135 if self.taskFactory.stopAt == self.taskFactory.countExec: 

136 raise RuntimeError("pretend something bad happened") 

137 self.taskFactory.countExec += 1 

138 

139 self.config = cast(AddTaskConfig, self.config) 

140 self.metadata.add("add", self.config.addend) 

141 output = input + self.config.addend 

142 output2 = output + self.config.addend 

143 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2) 

144 return Struct(output=output, output2=output2) 

145 

146 

147class AddTaskFactoryMock(TaskFactory): 

148 """Special task factory that instantiates AddTask. 

149 

150 It also defines some bookkeeping variables used by AddTask to report 

151 progress to unit tests. 

152 """ 

153 

154 def __init__(self, stopAt: int = -1): 

155 self.countExec = 0 # incremented by AddTask 

156 self.stopAt = stopAt # AddTask raises exception at this call to run() 

157 

158 def makeTask( 

159 self, taskDef: TaskDef, butler: LimitedButler, initInputRefs: Iterable[DatasetRef] | None 

160 ) -> PipelineTask: 

161 taskClass = taskDef.taskClass 

162 assert taskClass is not None 

163 task = taskClass(config=taskDef.config, initInputs=None, name=taskDef.label) 

164 task.taskFactory = self # type: ignore 

165 return task 

166 

167 

168def registerDatasetTypes(registry: Registry, pipeline: Pipeline | Iterable[TaskDef]) -> None: 

169 """Register all dataset types used by tasks in a registry. 

170 

171 Copied and modified from `PreExecInit.initializeDatasetTypes`. 

172 

173 Parameters 

174 ---------- 

175 registry : `~lsst.daf.butler.Registry` 

176 Registry instance. 

177 pipeline : `typing.Iterable` of `TaskDef` 

178 Iterable of TaskDef instances, likely the output of the method 

179 `Pipelines.toExpandedPipeline` on a `~lsst.pipe.base.Pipeline` object. 

180 """ 

181 for taskDef in pipeline: 

182 configDatasetType = DatasetType( 

183 taskDef.configDatasetName, {}, storageClass="Config", universe=registry.dimensions 

184 ) 

185 storageClass = "Packages" 

186 packagesDatasetType = DatasetType( 

187 "packages", {}, storageClass=storageClass, universe=registry.dimensions 

188 ) 

189 datasetTypes = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry) 

190 for datasetType in itertools.chain( 

191 datasetTypes.initInputs, 

192 datasetTypes.initOutputs, 

193 datasetTypes.inputs, 

194 datasetTypes.outputs, 

195 datasetTypes.prerequisites, 

196 [configDatasetType, packagesDatasetType], 

197 ): 

198 _LOG.info("Registering %s with registry", datasetType) 

199 # this is a no-op if it already exists and is consistent, 

200 # and it raises if it is inconsistent. But components must be 

201 # skipped 

202 if not datasetType.isComponent(): 

203 registry.registerDatasetType(datasetType) 

204 

205 

206def makeSimplePipeline(nQuanta: int, instrument: str | None = None) -> Pipeline: 

207 """Make a simple Pipeline for tests. 

208 

209 This is called by `makeSimpleQGraph()` if no pipeline is passed to that 

210 function. It can also be used to customize the pipeline used by 

211 `makeSimpleQGraph()` function by calling this first and passing the result 

212 to it. 

213 

214 Parameters 

215 ---------- 

216 nQuanta : `int` 

217 The number of quanta to add to the pipeline. 

218 instrument : `str` or `None`, optional 

219 The importable name of an instrument to be added to the pipeline or 

220 if no instrument should be added then an empty string or `None`, by 

221 default None 

222 

223 Returns 

224 ------- 

225 pipeline : `~lsst.pipe.base.Pipeline` 

226 The created pipeline object. 

227 """ 

228 pipeline = Pipeline("test pipeline") 

229 # make a bunch of tasks that execute in well defined order (via data 

230 # dependencies) 

231 for lvl in range(nQuanta): 

232 pipeline.addTask(AddTask, f"task{lvl}") 

233 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", lvl) 

234 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", lvl + 1) 

235 if instrument: 

236 pipeline.addInstrument(instrument) 

237 return pipeline 

238 

239 

240def makeSimpleButler( 

241 root: str, run: str = "test", inMemory: bool = True, config: Config | str | None = None 

242) -> Butler: 

243 """Create new data butler instance. 

244 

245 Parameters 

246 ---------- 

247 root : `str` 

248 Path or URI to the root location of the new repository. 

249 run : `str`, optional 

250 Run collection name. 

251 inMemory : `bool`, optional 

252 If true make in-memory repository. 

253 config : `~lsst.daf.butler.Config`, optional 

254 Configuration to use for new Butler, if `None` then default 

255 configuration is used. If ``inMemory`` is `True` then configuration 

256 is updated to use SQLite registry and file datastore in ``root``. 

257 

258 Returns 

259 ------- 

260 butler : `~lsst.daf.butler.Butler` 

261 Data butler instance. 

262 """ 

263 root_path = ResourcePath(root, forceDirectory=True) 

264 if not root_path.isLocal: 

265 raise ValueError(f"Only works with local root not {root_path}") 

266 butler_config = Config() 

267 if config: 

268 butler_config.update(Config(config)) 

269 if not inMemory: 

270 butler_config["registry", "db"] = f"sqlite:///{root_path.ospath}/gen3.sqlite" 

271 butler_config["datastore", "cls"] = "lsst.daf.butler.datastores.fileDatastore.FileDatastore" 

272 repo = butlerTests.makeTestRepo(str(root_path), {}, config=butler_config) 

273 butler = Butler(butler=repo, run=run) 

274 return butler 

275 

276 

277def populateButler( 

278 pipeline: Pipeline, butler: Butler, datasetTypes: dict[str | None, list[str]] | None = None 

279) -> None: 

280 """Populate data butler with data needed for test. 

281 

282 Initializes data butler with a bunch of items: 

283 - registers dataset types which are defined by pipeline 

284 - create dimensions data for (instrument, detector) 

285 - adds datasets based on ``datasetTypes`` dictionary, if dictionary is 

286 missing then a single dataset with type "add_dataset0" is added 

287 

288 All datasets added to butler have ``dataId={instrument=instrument, 

289 detector=0}`` where ``instrument`` is extracted from pipeline, "INSTR" is 

290 used if pipeline is missing instrument definition. Type of the dataset is 

291 guessed from dataset type name (assumes that pipeline is made of `AddTask` 

292 tasks). 

293 

294 Parameters 

295 ---------- 

296 pipeline : `~lsst.pipe.base.Pipeline` 

297 Pipeline instance. 

298 butler : `~lsst.daf.butler.Butler` 

299 Data butler instance. 

300 datasetTypes : `dict` [ `str`, `list` ], optional 

301 Dictionary whose keys are collection names and values are lists of 

302 dataset type names. By default a single dataset of type "add_dataset0" 

303 is added to a ``butler.run`` collection. 

304 """ 

305 # Add dataset types to registry 

306 taskDefs = list(pipeline.toExpandedPipeline()) 

307 registerDatasetTypes(butler.registry, taskDefs) 

308 

309 instrument = pipeline.getInstrument() 

310 if instrument is not None: 

311 instrument_class = doImportType(instrument) 

312 instrumentName = cast(Instrument, instrument_class).getName() 

313 instrumentClass = get_full_type_name(instrument_class) 

314 else: 

315 instrumentName = "INSTR" 

316 instrumentClass = None 

317 

318 # Add all needed dimensions to registry 

319 butler.registry.insertDimensionData("instrument", dict(name=instrumentName, class_name=instrumentClass)) 

320 butler.registry.insertDimensionData("detector", dict(instrument=instrumentName, id=0, full_name="det0")) 

321 

322 taskDefMap = dict((taskDef.label, taskDef) for taskDef in taskDefs) 

323 # Add inputs to butler 

324 if not datasetTypes: 

325 datasetTypes = {None: ["add_dataset0"]} 

326 for run, dsTypes in datasetTypes.items(): 

327 if run is not None: 

328 butler.registry.registerRun(run) 

329 for dsType in dsTypes: 

330 if dsType == "packages": 

331 # Version is intentionally inconsistent. 

332 # Dict is convertible to Packages if Packages is installed. 

333 data: Any = {"python": "9.9.99"} 

334 butler.put(data, dsType, run=run) 

335 else: 

336 if dsType.endswith("_config"): 

337 # find a config from matching task name or make a new one 

338 taskLabel, _, _ = dsType.rpartition("_") 

339 taskDef = taskDefMap.get(taskLabel) 

340 if taskDef is not None: 

341 data = taskDef.config 

342 else: 

343 data = AddTaskConfig() 

344 elif dsType.endswith("_metadata"): 

345 data = _TASK_FULL_METADATA_TYPE() 

346 elif dsType.endswith("_log"): 

347 data = ButlerLogRecords.from_records([]) 

348 else: 

349 data = numpy.array([0.0, 1.0, 2.0, 5.0]) 

350 butler.put(data, dsType, run=run, instrument=instrumentName, detector=0) 

351 

352 

353def makeSimpleQGraph( 

354 nQuanta: int = 5, 

355 pipeline: Pipeline | None = None, 

356 butler: Butler | None = None, 

357 root: str | None = None, 

358 callPopulateButler: bool = True, 

359 run: str = "test", 

360 instrument: str | None = None, 

361 skipExistingIn: Any = None, 

362 inMemory: bool = True, 

363 userQuery: str = "", 

364 datasetTypes: dict[str | None, list[str]] | None = None, 

365 datasetQueryConstraint: DSQVariant = DSQVariant.ALL, 

366 makeDatastoreRecords: bool = False, 

367 bind: Mapping[str, Any] | None = None, 

368 metadata: MutableMapping[str, Any] | None = None, 

369) -> tuple[Butler, QuantumGraph]: 

370 """Make simple `QuantumGraph` for tests. 

371 

372 Makes simple one-task pipeline with AddTask, sets up in-memory registry 

373 and butler, fills them with minimal data, and generates QuantumGraph with 

374 all of that. 

375 

376 Parameters 

377 ---------- 

378 nQuanta : `int` 

379 Number of quanta in a graph, only used if ``pipeline`` is None. 

380 pipeline : `~lsst.pipe.base.Pipeline` 

381 If `None` then a pipeline is made with `AddTask` and default 

382 `AddTaskConfig`. 

383 butler : `~lsst.daf.butler.Butler`, optional 

384 Data butler instance, if None then new data butler is created by 

385 calling `makeSimpleButler`. 

386 callPopulateButler : `bool`, optional 

387 If True insert datasets into the butler prior to building a graph. 

388 If False butler argument must not be None, and must be pre-populated. 

389 Defaults to True. 

390 root : `str` 

391 Path or URI to the root location of the new repository. Only used if 

392 ``butler`` is None. 

393 run : `str`, optional 

394 Name of the RUN collection to add to butler, only used if ``butler`` 

395 is None. 

396 instrument : `str` or `None`, optional 

397 The importable name of an instrument to be added to the pipeline or 

398 if no instrument should be added then an empty string or `None`, by 

399 default `None`. Only used if ``pipeline`` is `None`. 

400 skipExistingIn 

401 Expressions representing the collections to search for existing 

402 output datasets that should be skipped. See 

403 :ref:`daf_butler_ordered_collection_searches`. 

404 inMemory : `bool`, optional 

405 If true make in-memory repository, only used if ``butler`` is `None`. 

406 userQuery : `str`, optional 

407 The user query to pass to ``makeGraph``, by default an empty string. 

408 datasetTypes : `dict` [ `str`, `list` ], optional 

409 Dictionary whose keys are collection names and values are lists of 

410 dataset type names. By default a single dataset of type "add_dataset0" 

411 is added to a ``butler.run`` collection. 

412 datasetQueryQConstraint : `DatasetQueryConstraintVariant` 

413 The query constraint variant that should be used to constrain the 

414 query based on dataset existence, defaults to 

415 `DatasetQueryConstraintVariant.ALL`. 

416 makeDatastoreRecords : `bool`, optional 

417 If `True` then add datstore records to generated quanta. 

418 bind : `~collections.abc.Mapping`, optional 

419 Mapping containing literal values that should be injected into the 

420 ``userQuery`` expression, keyed by the identifiers they replace. 

421 metadata : `~collections.abc.Mapping`, optional 

422 Optional graph metadata. 

423 

424 Returns 

425 ------- 

426 butler : `~lsst.daf.butler.Butler` 

427 Butler instance 

428 qgraph : `~lsst.pipe.base.QuantumGraph` 

429 Quantum graph instance 

430 """ 

431 if pipeline is None: 

432 pipeline = makeSimplePipeline(nQuanta=nQuanta, instrument=instrument) 

433 

434 if butler is None: 

435 if root is None: 

436 raise ValueError("Must provide `root` when `butler` is None") 

437 if callPopulateButler is False: 

438 raise ValueError("populateButler can only be False when butler is supplied as an argument") 

439 butler = makeSimpleButler(root, run=run, inMemory=inMemory) 

440 

441 if callPopulateButler: 

442 populateButler(pipeline, butler, datasetTypes=datasetTypes) 

443 

444 # Make the graph 

445 _LOG.debug("Instantiating GraphBuilder, skipExistingIn=%s", skipExistingIn) 

446 builder = GraphBuilder( 

447 registry=butler.registry, 

448 skipExistingIn=skipExistingIn, 

449 datastore=butler.datastore if makeDatastoreRecords else None, 

450 ) 

451 if not run: 

452 assert butler.run is not None, "Butler must have run defined" 

453 run = butler.run 

454 _LOG.debug( 

455 "Calling GraphBuilder.makeGraph, collections=%r, run=%r, userQuery=%r bind=%s", 

456 butler.collections, 

457 run, 

458 userQuery, 

459 bind, 

460 ) 

461 if not metadata: 

462 metadata = {} 

463 metadata["output_run"] = run 

464 

465 qgraph = builder.makeGraph( 

466 pipeline, 

467 collections=butler.collections, 

468 run=run, 

469 userQuery=userQuery, 

470 datasetQueryConstraint=datasetQueryConstraint, 

471 bind=bind, 

472 metadata=metadata, 

473 ) 

474 

475 return butler, qgraph