Coverage for python/lsst/pipe/base/tests/simpleQGraph.py: 30%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

152 statements  

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Bunch of common classes and methods for use in unit tests. 

23""" 

24from __future__ import annotations 

25 

26__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"] 

27 

28import itertools 

29import logging 

30from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Type, Union 

31 

32import lsst.daf.butler.tests as butlerTests 

33import lsst.pex.config as pexConfig 

34import numpy 

35 

36try: 

37 from lsst.base import Packages 

38except ImportError: 

39 Packages = None 

40from lsst.daf.butler import Butler, Config, DatasetType 

41from lsst.daf.butler.core.logging import ButlerLogRecords 

42from lsst.resources import ResourcePath 

43from lsst.utils import doImportType 

44 

45from .. import connectionTypes as cT 

46from ..config import PipelineTaskConfig 

47from ..connections import PipelineTaskConnections 

48from ..graph import QuantumGraph 

49from ..graphBuilder import DatasetQueryConstraintVariant as DSQVariant 

50from ..graphBuilder import GraphBuilder 

51from ..pipeline import Pipeline, TaskDatasetTypes, TaskDef 

52from ..pipelineTask import PipelineTask 

53from ..struct import Struct 

54from ..task import _TASK_FULL_METADATA_TYPE 

55from ..taskFactory import TaskFactory 

56 

57if TYPE_CHECKING: 57 ↛ 58line 57 didn't jump to line 58, because the condition on line 57 was never true

58 from lsst.daf.butler import Registry 

59 

60 from ..configOverrides import ConfigOverrides 

61 

62_LOG = logging.getLogger(__name__) 

63 

64 

65# SimpleInstrument has an instrument-like API as needed for unit testing, but 

66# can not explicitly depend on Instrument because pipe_base does not explicitly 

67# depend on obs_base. 

68class SimpleInstrument: 

69 def __init__(self, *args: Any, **kwargs: Any): 

70 pass 

71 

72 @staticmethod 

73 def getName() -> str: 

74 return "INSTRU" 

75 

76 def applyConfigOverrides(self, name: str, config: pexConfig.Config) -> None: 

77 pass 

78 

79 

80class AddTaskConnections( 

81 PipelineTaskConnections, 

82 dimensions=("instrument", "detector"), 

83 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"}, 

84): 

85 """Connections for AddTask, has one input and two outputs, 

86 plus one init output. 

87 """ 

88 

89 input = cT.Input( 

90 name="add_dataset{in_tmpl}", 

91 dimensions=["instrument", "detector"], 

92 storageClass="NumpyArray", 

93 doc="Input dataset type for this task", 

94 ) 

95 output = cT.Output( 

96 name="add_dataset{out_tmpl}", 

97 dimensions=["instrument", "detector"], 

98 storageClass="NumpyArray", 

99 doc="Output dataset type for this task", 

100 ) 

101 output2 = cT.Output( 

102 name="add2_dataset{out_tmpl}", 

103 dimensions=["instrument", "detector"], 

104 storageClass="NumpyArray", 

105 doc="Output dataset type for this task", 

106 ) 

107 initout = cT.InitOutput( 

108 name="add_init_output{out_tmpl}", 

109 storageClass="NumpyArray", 

110 doc="Init Output dataset type for this task", 

111 ) 

112 

113 

114class AddTaskConfig(PipelineTaskConfig, pipelineConnections=AddTaskConnections): 

115 """Config for AddTask.""" 

116 

117 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3) 

118 

119 

120class AddTask(PipelineTask): 

121 """Trivial PipelineTask for testing, has some extras useful for specific 

122 unit tests. 

123 """ 

124 

125 ConfigClass = AddTaskConfig 

126 _DefaultName = "add_task" 

127 

128 initout = numpy.array([999]) 

129 """InitOutputs for this task""" 

130 

131 taskFactory: Optional[AddTaskFactoryMock] = None 

132 """Factory that makes instances""" 

133 

134 def run(self, input: int) -> Struct: # type: ignore 

135 

136 if self.taskFactory: 

137 # do some bookkeeping 

138 if self.taskFactory.stopAt == self.taskFactory.countExec: 

139 raise RuntimeError("pretend something bad happened") 

140 self.taskFactory.countExec += 1 

141 

142 self.metadata.add("add", self.config.addend) 

143 output = input + self.config.addend 

144 output2 = output + self.config.addend 

145 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2) 

146 return Struct(output=output, output2=output2) 

147 

148 

149class AddTaskFactoryMock(TaskFactory): 

150 """Special task factory that instantiates AddTask. 

151 

152 It also defines some bookkeeping variables used by AddTask to report 

153 progress to unit tests. 

154 """ 

155 

156 def __init__(self, stopAt: int = -1): 

157 self.countExec = 0 # incremented by AddTask 

158 self.stopAt = stopAt # AddTask raises exception at this call to run() 

159 

160 def makeTask( 

161 self, 

162 taskClass: Type[PipelineTask], 

163 name: Optional[str], 

164 config: Optional[PipelineTaskConfig], 

165 overrides: Optional[ConfigOverrides], 

166 butler: Optional[Butler], 

167 ) -> PipelineTask: 

168 if config is None: 

169 config = taskClass.ConfigClass() 

170 if overrides: 

171 overrides.applyTo(config) 

172 task = taskClass(config=config, initInputs=None, name=name) 

173 task.taskFactory = self # type: ignore 

174 return task 

175 

176 

177def registerDatasetTypes(registry: Registry, pipeline: Union[Pipeline, Iterable[TaskDef]]) -> None: 

178 """Register all dataset types used by tasks in a registry. 

179 

180 Copied and modified from `PreExecInit.initializeDatasetTypes`. 

181 

182 Parameters 

183 ---------- 

184 registry : `~lsst.daf.butler.Registry` 

185 Registry instance. 

186 pipeline : `typing.Iterable` of `TaskDef` 

187 Iterable of TaskDef instances, likely the output of the method 

188 toExpandedPipeline on a `~lsst.pipe.base.Pipeline` object 

189 """ 

190 for taskDef in pipeline: 

191 configDatasetType = DatasetType( 

192 taskDef.configDatasetName, {}, storageClass="Config", universe=registry.dimensions 

193 ) 

194 storageClass = "Packages" if Packages is not None else "StructuredDataDict" 

195 packagesDatasetType = DatasetType( 

196 "packages", {}, storageClass=storageClass, universe=registry.dimensions 

197 ) 

198 datasetTypes = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry) 

199 for datasetType in itertools.chain( 

200 datasetTypes.initInputs, 

201 datasetTypes.initOutputs, 

202 datasetTypes.inputs, 

203 datasetTypes.outputs, 

204 datasetTypes.prerequisites, 

205 [configDatasetType, packagesDatasetType], 

206 ): 

207 _LOG.info("Registering %s with registry", datasetType) 

208 # this is a no-op if it already exists and is consistent, 

209 # and it raises if it is inconsistent. But components must be 

210 # skipped 

211 if not datasetType.isComponent(): 

212 registry.registerDatasetType(datasetType) 

213 

214 

215def makeSimplePipeline(nQuanta: int, instrument: Optional[str] = None) -> Pipeline: 

216 """Make a simple Pipeline for tests. 

217 

218 This is called by ``makeSimpleQGraph`` if no pipeline is passed to that 

219 function. It can also be used to customize the pipeline used by 

220 ``makeSimpleQGraph`` function by calling this first and passing the result 

221 to it. 

222 

223 Parameters 

224 ---------- 

225 nQuanta : `int` 

226 The number of quanta to add to the pipeline. 

227 instrument : `str` or `None`, optional 

228 The importable name of an instrument to be added to the pipeline or 

229 if no instrument should be added then an empty string or `None`, by 

230 default None 

231 

232 Returns 

233 ------- 

234 pipeline : `~lsst.pipe.base.Pipeline` 

235 The created pipeline object. 

236 """ 

237 pipeline = Pipeline("test pipeline") 

238 # make a bunch of tasks that execute in well defined order (via data 

239 # dependencies) 

240 for lvl in range(nQuanta): 

241 pipeline.addTask(AddTask, f"task{lvl}") 

242 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", lvl) 

243 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", lvl + 1) 

244 if instrument: 

245 pipeline.addInstrument(instrument) 

246 return pipeline 

247 

248 

249def makeSimpleButler(root: str, run: str = "test", inMemory: bool = True) -> Butler: 

250 """Create new data butler instance. 

251 

252 Parameters 

253 ---------- 

254 root : `str` 

255 Path or URI to the root location of the new repository. 

256 run : `str`, optional 

257 Run collection name. 

258 inMemory : `bool`, optional 

259 If true make in-memory repository. 

260 

261 Returns 

262 ------- 

263 butler : `~lsst.daf.butler.Butler` 

264 Data butler instance. 

265 """ 

266 root_path = ResourcePath(root, forceDirectory=True) 

267 if not root_path.isLocal: 

268 raise ValueError(f"Only works with local root not {root_path}") 

269 config = Config() 

270 if not inMemory: 

271 config["registry", "db"] = f"sqlite:///{root_path.ospath}/gen3.sqlite" 

272 config["datastore", "cls"] = "lsst.daf.butler.datastores.fileDatastore.FileDatastore" 

273 repo = butlerTests.makeTestRepo(str(root_path), {}, config=config) 

274 butler = Butler(butler=repo, run=run) 

275 return butler 

276 

277 

278def populateButler( 

279 pipeline: Pipeline, butler: Butler, datasetTypes: Dict[Optional[str], List[str]] = None 

280) -> None: 

281 """Populate data butler with data needed for test. 

282 

283 Initializes data butler with a bunch of items: 

284 - registers dataset types which are defined by pipeline 

285 - create dimensions data for (instrument, detector) 

286 - adds datasets based on ``datasetTypes`` dictionary, if dictionary is 

287 missing then a single dataset with type "add_dataset0" is added 

288 

289 All datasets added to butler have ``dataId={instrument=instrument, 

290 detector=0}`` where ``instrument`` is extracted from pipeline, "INSTR" is 

291 used if pipeline is missing instrument definition. Type of the dataset is 

292 guessed from dataset type name (assumes that pipeline is made of `AddTask` 

293 tasks). 

294 

295 Parameters 

296 ---------- 

297 pipeline : `~lsst.pipe.base.Pipeline` 

298 Pipeline instance. 

299 butler : `~lsst.daf.butler.Butler` 

300 Data butler instance. 

301 datasetTypes : `dict` [ `str`, `list` ], optional 

302 Dictionary whose keys are collection names and values are lists of 

303 dataset type names. By default a single dataset of type "add_dataset0" 

304 is added to a ``butler.run`` collection. 

305 """ 

306 

307 # Add dataset types to registry 

308 taskDefs = list(pipeline.toExpandedPipeline()) 

309 registerDatasetTypes(butler.registry, taskDefs) 

310 

311 instrument = pipeline.getInstrument() 

312 if instrument is not None: 

313 instrument_class = doImportType(instrument) 

314 instrumentName = instrument_class.getName() 

315 else: 

316 instrumentName = "INSTR" 

317 

318 # Add all needed dimensions to registry 

319 butler.registry.insertDimensionData("instrument", dict(name=instrumentName)) 

320 butler.registry.insertDimensionData("detector", dict(instrument=instrumentName, id=0, full_name="det0")) 

321 

322 taskDefMap = dict((taskDef.label, taskDef) for taskDef in taskDefs) 

323 # Add inputs to butler 

324 if not datasetTypes: 

325 datasetTypes = {None: ["add_dataset0"]} 

326 for run, dsTypes in datasetTypes.items(): 

327 if run is not None: 

328 butler.registry.registerRun(run) 

329 for dsType in dsTypes: 

330 if dsType == "packages": 

331 # Version is intentionally inconsistent. 

332 # Dict is convertible to Packages if Packages is installed. 

333 data: Any = {"python": "9.9.99"} 

334 butler.put(data, dsType, run=run) 

335 else: 

336 if dsType.endswith("_config"): 

337 # find a confing from matching task name or make a new one 

338 taskLabel, _, _ = dsType.rpartition("_") 

339 taskDef = taskDefMap.get(taskLabel) 

340 if taskDef is not None: 

341 data = taskDef.config 

342 else: 

343 data = AddTaskConfig() 

344 elif dsType.endswith("_metadata"): 

345 data = _TASK_FULL_METADATA_TYPE() 

346 elif dsType.endswith("_log"): 

347 data = ButlerLogRecords.from_records([]) 

348 else: 

349 data = numpy.array([0.0, 1.0, 2.0, 5.0]) 

350 butler.put(data, dsType, run=run, instrument=instrumentName, detector=0) 

351 

352 

353def makeSimpleQGraph( 

354 nQuanta: int = 5, 

355 pipeline: Optional[Pipeline] = None, 

356 butler: Optional[Butler] = None, 

357 root: Optional[str] = None, 

358 callPopulateButler: bool = True, 

359 run: str = "test", 

360 skipExistingIn: Any = None, 

361 inMemory: bool = True, 

362 userQuery: str = "", 

363 datasetTypes: Optional[Dict[Optional[str], List[str]]] = None, 

364 datasetQueryConstraint: DSQVariant = DSQVariant.ALL, 

365) -> Tuple[Butler, QuantumGraph]: 

366 """Make simple QuantumGraph for tests. 

367 

368 Makes simple one-task pipeline with AddTask, sets up in-memory registry 

369 and butler, fills them with minimal data, and generates QuantumGraph with 

370 all of that. 

371 

372 Parameters 

373 ---------- 

374 nQuanta : `int` 

375 Number of quanta in a graph, only used if ``pipeline`` is None. 

376 pipeline : `~lsst.pipe.base.Pipeline` 

377 If `None` then a pipeline is made with `AddTask` and default 

378 `AddTaskConfig`. 

379 butler : `~lsst.daf.butler.Butler`, optional 

380 Data butler instance, if None then new data butler is created by 

381 calling `makeSimpleButler`. 

382 callPopulateButler : `bool`, optional 

383 If True insert datasets into the butler prior to building a graph. 

384 If False butler argument must not be None, and must be pre-populated. 

385 Defaults to True. 

386 root : `str` 

387 Path or URI to the root location of the new repository. Only used if 

388 ``butler`` is None. 

389 run : `str`, optional 

390 Name of the RUN collection to add to butler, only used if ``butler`` 

391 is None. 

392 skipExistingIn 

393 Expressions representing the collections to search for existing 

394 output datasets that should be skipped. May be any of the types 

395 accepted by `lsst.daf.butler.CollectionSearch.fromExpression`. 

396 inMemory : `bool`, optional 

397 If true make in-memory repository, only used if ``butler`` is `None`. 

398 userQuery : `str`, optional 

399 The user query to pass to ``makeGraph``, by default an empty string. 

400 datasetTypes : `dict` [ `str`, `list` ], optional 

401 Dictionary whose keys are collection names and values are lists of 

402 dataset type names. By default a single dataset of type "add_dataset0" 

403 is added to a ``butler.run`` collection. 

404 datasetQueryQConstraint : `DatasetQueryConstraintVariant` 

405 The query constraint variant that should be used to constrain the 

406 query based on dataset existence, defaults to 

407 `DatasetQueryConstraintVariant.ALL`. 

408 

409 Returns 

410 ------- 

411 butler : `~lsst.daf.butler.Butler` 

412 Butler instance 

413 qgraph : `~lsst.pipe.base.QuantumGraph` 

414 Quantum graph instance 

415 """ 

416 

417 if pipeline is None: 

418 pipeline = makeSimplePipeline(nQuanta=nQuanta) 

419 

420 if butler is None: 

421 if root is None: 

422 raise ValueError("Must provide `root` when `butler` is None") 

423 if callPopulateButler is False: 

424 raise ValueError("populateButler can only be False when butler is supplied as an argument") 

425 butler = makeSimpleButler(root, run=run, inMemory=inMemory) 

426 

427 if callPopulateButler: 

428 populateButler(pipeline, butler, datasetTypes=datasetTypes) 

429 

430 # Make the graph 

431 _LOG.debug("Instantiating GraphBuilder, skipExistingIn=%s", skipExistingIn) 

432 builder = GraphBuilder(registry=butler.registry, skipExistingIn=skipExistingIn) 

433 _LOG.debug( 

434 "Calling GraphBuilder.makeGraph, collections=%r, run=%r, userQuery=%r", 

435 butler.collections, 

436 run or butler.run, 

437 userQuery, 

438 ) 

439 qgraph = builder.makeGraph( 

440 pipeline, 

441 collections=butler.collections, 

442 run=run or butler.run, 

443 userQuery=userQuery, 

444 datasetQueryConstraint=datasetQueryConstraint, 

445 ) 

446 

447 return butler, qgraph