Coverage for python/lsst/pipe/base/tests/simpleQGraph.py: 26%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

132 statements  

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Bunch of common classes and methods for use in unit tests. 

23""" 

24 

25__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"] 

26 

27import itertools 

28import logging 

29import numpy 

30 

31from lsst.base import Packages 

32from lsst.daf.base import PropertySet 

33from lsst.daf.butler import Butler, ButlerURI, Config, DatasetType 

34import lsst.daf.butler.tests as butlerTests 

35from lsst.daf.butler.core.logging import ButlerLogRecords 

36import lsst.pex.config as pexConfig 

37from lsst.utils import doImport 

38from ... import base as pipeBase 

39from .. import connectionTypes as cT 

40 

41_LOG = logging.getLogger(__name__) 

42 

43 

44# SimpleInstrument has an instrument-like API as needed for unit testing, but 

45# can not explicitly depend on Instrument because pipe_base does not explicitly 

46# depend on obs_base. 

47class SimpleInstrument: 

48 

49 def __init__(self, *args, **kwargs): 

50 pass 

51 

52 @staticmethod 

53 def getName(): 

54 return "INSTRU" 

55 

56 def applyConfigOverrides(self, name, config): 

57 pass 

58 

59 

60class AddTaskConnections(pipeBase.PipelineTaskConnections, 

61 dimensions=("instrument", "detector"), 

62 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"}): 

63 """Connections for AddTask, has one input and two outputs, 

64 plus one init output. 

65 """ 

66 input = cT.Input(name="add_dataset{in_tmpl}", 

67 dimensions=["instrument", "detector"], 

68 storageClass="NumpyArray", 

69 doc="Input dataset type for this task") 

70 output = cT.Output(name="add_dataset{out_tmpl}", 

71 dimensions=["instrument", "detector"], 

72 storageClass="NumpyArray", 

73 doc="Output dataset type for this task") 

74 output2 = cT.Output(name="add2_dataset{out_tmpl}", 

75 dimensions=["instrument", "detector"], 

76 storageClass="NumpyArray", 

77 doc="Output dataset type for this task") 

78 initout = cT.InitOutput(name="add_init_output{out_tmpl}", 

79 storageClass="NumpyArray", 

80 doc="Init Output dataset type for this task") 

81 

82 

83class AddTaskConfig(pipeBase.PipelineTaskConfig, 

84 pipelineConnections=AddTaskConnections): 

85 """Config for AddTask. 

86 """ 

87 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3) 

88 

89 

90class AddTask(pipeBase.PipelineTask): 

91 """Trivial PipelineTask for testing, has some extras useful for specific 

92 unit tests. 

93 """ 

94 

95 ConfigClass = AddTaskConfig 

96 _DefaultName = "add_task" 

97 

98 initout = numpy.array([999]) 

99 """InitOutputs for this task""" 

100 

101 taskFactory = None 

102 """Factory that makes instances""" 

103 

104 def run(self, input): 

105 

106 if self.taskFactory: 

107 # do some bookkeeping 

108 if self.taskFactory.stopAt == self.taskFactory.countExec: 

109 raise RuntimeError("pretend something bad happened") 

110 self.taskFactory.countExec += 1 

111 

112 self.metadata.add("add", self.config.addend) 

113 output = input + self.config.addend 

114 output2 = output + self.config.addend 

115 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2) 

116 return pipeBase.Struct(output=output, output2=output2) 

117 

118 

119class AddTaskFactoryMock(pipeBase.TaskFactory): 

120 """Special task factory that instantiates AddTask. 

121 

122 It also defines some bookkeeping variables used by AddTask to report 

123 progress to unit tests. 

124 """ 

125 def __init__(self, stopAt=-1): 

126 self.countExec = 0 # incremented by AddTask 

127 self.stopAt = stopAt # AddTask raises exception at this call to run() 

128 

129 def makeTask(self, taskClass, name, config, overrides, butler): 

130 if config is None: 

131 config = taskClass.ConfigClass() 

132 if overrides: 

133 overrides.applyTo(config) 

134 task = taskClass(config=config, initInputs=None, name=name) 

135 task.taskFactory = self 

136 return task 

137 

138 

139def registerDatasetTypes(registry, pipeline): 

140 """Register all dataset types used by tasks in a registry. 

141 

142 Copied and modified from `PreExecInit.initializeDatasetTypes`. 

143 

144 Parameters 

145 ---------- 

146 registry : `~lsst.daf.butler.Registry` 

147 Registry instance. 

148 pipeline : `typing.Iterable` of `TaskDef` 

149 Iterable of TaskDef instances, likely the output of the method 

150 toExpandedPipeline on a `~lsst.pipe.base.Pipeline` object 

151 """ 

152 for taskDef in pipeline: 

153 configDatasetType = DatasetType(taskDef.configDatasetName, {}, 

154 storageClass="Config", 

155 universe=registry.dimensions) 

156 packagesDatasetType = DatasetType("packages", {}, 

157 storageClass="Packages", 

158 universe=registry.dimensions) 

159 datasetTypes = pipeBase.TaskDatasetTypes.fromTaskDef(taskDef, registry=registry) 

160 for datasetType in itertools.chain(datasetTypes.initInputs, datasetTypes.initOutputs, 

161 datasetTypes.inputs, datasetTypes.outputs, 

162 datasetTypes.prerequisites, 

163 [configDatasetType, packagesDatasetType]): 

164 _LOG.info("Registering %s with registry", datasetType) 

165 # this is a no-op if it already exists and is consistent, 

166 # and it raises if it is inconsistent. But components must be 

167 # skipped 

168 if not datasetType.isComponent(): 

169 registry.registerDatasetType(datasetType) 

170 

171 

172def makeSimplePipeline(nQuanta, instrument=None): 

173 """Make a simple Pipeline for tests. 

174 

175 This is called by ``makeSimpleQGraph`` if no pipeline is passed to that 

176 function. It can also be used to customize the pipeline used by 

177 ``makeSimpleQGraph`` function by calling this first and passing the result 

178 to it. 

179 

180 Parameters 

181 ---------- 

182 nQuanta : `int` 

183 The number of quanta to add to the pipeline. 

184 instrument : `str` or `None`, optional 

185 The importable name of an instrument to be added to the pipeline or 

186 if no instrument should be added then an empty string or `None`, by 

187 default None 

188 

189 Returns 

190 ------- 

191 pipeline : `~lsst.pipe.base.Pipeline` 

192 The created pipeline object. 

193 """ 

194 pipeline = pipeBase.Pipeline("test pipeline") 

195 # make a bunch of tasks that execute in well defined order (via data 

196 # dependencies) 

197 for lvl in range(nQuanta): 

198 pipeline.addTask(AddTask, f"task{lvl}") 

199 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", lvl) 

200 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", lvl+1) 

201 if instrument: 

202 pipeline.addInstrument(instrument) 

203 return pipeline 

204 

205 

206def makeSimpleButler(root: str, run: str = "test", inMemory: bool = True) -> Butler: 

207 """Create new data butler instance. 

208 

209 Parameters 

210 ---------- 

211 root : `str` 

212 Path or URI to the root location of the new repository. 

213 run : `str`, optional 

214 Run collection name. 

215 inMemory : `bool`, optional 

216 If true make in-memory repository. 

217 

218 Returns 

219 ------- 

220 butler : `~lsst.daf.butler.Butler` 

221 Data butler instance. 

222 """ 

223 root = ButlerURI(root, forceDirectory=True) 

224 if not root.isLocal: 

225 raise ValueError(f"Only works with local root not {root}") 

226 config = Config() 

227 if not inMemory: 

228 config["registry", "db"] = f"sqlite:///{root.ospath}/gen3.sqlite" 

229 config["datastore", "cls"] = "lsst.daf.butler.datastores.fileDatastore.FileDatastore" 

230 repo = butlerTests.makeTestRepo(root, {}, config=config) 

231 butler = Butler(butler=repo, run=run) 

232 return butler 

233 

234 

235def populateButler(pipeline, butler, datasetTypes=None): 

236 """Populate data butler with data needed for test. 

237 

238 Initializes data butler with a bunch of items: 

239 - registers dataset types which are defined by pipeline 

240 - create dimensions data for (instrument, detector) 

241 - adds datasets based on ``datasetTypes`` dictionary, if dictionary is 

242 missing then a single dataset with type "add_dataset0" is added 

243 

244 All datasets added to butler have ``dataId={instrument=instrument, 

245 detector=0}`` where ``instrument`` is extracted from pipeline, "INSTR" is 

246 used if pipeline is missing instrument definition. Type of the dataset is 

247 guessed from dataset type name (assumes that pipeline is made of `AddTask` 

248 tasks). 

249 

250 Parameters 

251 ---------- 

252 pipeline : `~lsst.pipe.base.Pipeline` 

253 Pipeline instance. 

254 butler : `~lsst.daf.butler.Butler` 

255 Data butler instance. 

256 datasetTypes : `dict` [ `str`, `list` ], optional 

257 Dictionary whose keys are collection names and values are lists of 

258 dataset type names. By default a single dataset of type "add_dataset0" 

259 is added to a ``butler.run`` collection. 

260 """ 

261 

262 # Add dataset types to registry 

263 taskDefs = list(pipeline.toExpandedPipeline()) 

264 registerDatasetTypes(butler.registry, taskDefs) 

265 

266 instrument = pipeline.getInstrument() 

267 if instrument is not None: 

268 if isinstance(instrument, str): 

269 instrument = doImport(instrument) 

270 instrumentName = instrument.getName() 

271 else: 

272 instrumentName = "INSTR" 

273 

274 # Add all needed dimensions to registry 

275 butler.registry.insertDimensionData("instrument", dict(name=instrumentName)) 

276 butler.registry.insertDimensionData("detector", dict(instrument=instrumentName, 

277 id=0, full_name="det0")) 

278 

279 taskDefMap = dict((taskDef.label, taskDef) for taskDef in taskDefs) 

280 # Add inputs to butler 

281 if not datasetTypes: 

282 datasetTypes = {None: ["add_dataset0"]} 

283 for run, dsTypes in datasetTypes.items(): 

284 if run is not None: 

285 butler.registry.registerRun(run) 

286 for dsType in dsTypes: 

287 if dsType == "packages": 

288 # Version is intentionally inconsistent 

289 data = Packages({"python": "9.9.99"}) 

290 butler.put(data, dsType, run=run) 

291 else: 

292 if dsType.endswith("_config"): 

293 # find a confing from matching task name or make a new one 

294 taskLabel, _, _ = dsType.rpartition("_") 

295 taskDef = taskDefMap.get(taskLabel) 

296 if taskDef is not None: 

297 data = taskDef.config 

298 else: 

299 data = AddTaskConfig() 

300 elif dsType.endswith("_metadata"): 

301 data = PropertySet() 

302 elif dsType.endswith("_log"): 

303 data = ButlerLogRecords.from_records([]) 

304 else: 

305 data = numpy.array([0., 1., 2., 5.]) 

306 butler.put(data, dsType, run=run, instrument=instrumentName, detector=0) 

307 

308 

309def makeSimpleQGraph(nQuanta=5, pipeline=None, butler=None, root=None, run="test", 

310 skipExistingIn=None, inMemory=True, userQuery="", 

311 datasetTypes=None): 

312 """Make simple QuantumGraph for tests. 

313 

314 Makes simple one-task pipeline with AddTask, sets up in-memory registry 

315 and butler, fills them with minimal data, and generates QuantumGraph with 

316 all of that. 

317 

318 Parameters 

319 ---------- 

320 nQuanta : `int` 

321 Number of quanta in a graph, only used if ``pipeline`` is None. 

322 pipeline : `~lsst.pipe.base.Pipeline` 

323 If `None` then a pipeline is made with `AddTask` and default 

324 `AddTaskConfig`. 

325 butler : `~lsst.daf.butler.Butler`, optional 

326 Data butler instance, if None then new data butler is created by 

327 calling `makeSimpleButler`. 

328 root : `str` 

329 Path or URI to the root location of the new repository. Only used if 

330 ``butler`` is None. 

331 run : `str`, optional 

332 Name of the RUN collection to add to butler, only used if ``butler`` 

333 is None. 

334 skipExistingIn 

335 Expressions representing the collections to search for existing 

336 output datasets that should be skipped. May be any of the types 

337 accepted by `lsst.daf.butler.CollectionSearch.fromExpression`. 

338 inMemory : `bool`, optional 

339 If true make in-memory repository, only used if ``butler`` is `None`. 

340 userQuery : `str`, optional 

341 The user query to pass to ``makeGraph``, by default an empty string. 

342 datasetTypes : `dict` [ `str`, `list` ], optional 

343 Dictionary whose keys are collection names and values are lists of 

344 dataset type names. By default a single dataset of type "add_dataset0" 

345 is added to a ``butler.run`` collection. 

346 

347 Returns 

348 ------- 

349 butler : `~lsst.daf.butler.Butler` 

350 Butler instance 

351 qgraph : `~lsst.pipe.base.QuantumGraph` 

352 Quantum graph instance 

353 """ 

354 

355 if pipeline is None: 

356 pipeline = makeSimplePipeline(nQuanta=nQuanta) 

357 

358 if butler is None: 

359 if root is None: 

360 raise ValueError("Must provide `root` when `butler` is None") 

361 butler = makeSimpleButler(root, run=run, inMemory=inMemory) 

362 

363 populateButler(pipeline, butler, datasetTypes=datasetTypes) 

364 

365 # Make the graph 

366 _LOG.debug("Instantiating GraphBuilder, skipExistingIn=%s", skipExistingIn) 

367 builder = pipeBase.GraphBuilder(registry=butler.registry, skipExistingIn=skipExistingIn) 

368 _LOG.debug("Calling GraphBuilder.makeGraph, collections=%r, run=%r, userQuery=%r", 

369 butler.collections, run or butler.run, userQuery) 

370 qgraph = builder.makeGraph( 

371 pipeline, 

372 collections=butler.collections, 

373 run=run or butler.run, 

374 userQuery=userQuery 

375 ) 

376 

377 return butler, qgraph