Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Bunch of common classes and methods for use in unit tests. 

23""" 

24 

25__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"] 

26 

27import itertools 

28import logging 

29import numpy 

30 

31from lsst.daf.butler import Butler, Config, DatasetType 

32import lsst.daf.butler.tests as butlerTests 

33import lsst.pex.config as pexConfig 

34from lsst.utils import doImport 

35from ... import base as pipeBase 

36from .. import connectionTypes as cT 

37 

38_LOG = logging.getLogger(__name__) 

39 

40 

41# SimpleInstrument has an instrument-like API as needed for unit testing, but 

42# can not explicitly depend on Instrument because pipe_base does not explicitly 

43# depend on obs_base. 

44class SimpleInstrument: 

45 

46 @staticmethod 

47 def getName(): 

48 return "INSTRU" 

49 

50 def applyConfigOverrides(self, name, config): 

51 pass 

52 

53 

54class AddTaskConnections(pipeBase.PipelineTaskConnections, 

55 dimensions=("instrument", "detector"), 

56 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"}): 

57 """Connections for AddTask, has one input and two outputs, 

58 plus one init output. 

59 """ 

60 input = cT.Input(name="add_dataset{in_tmpl}", 

61 dimensions=["instrument", "detector"], 

62 storageClass="NumpyArray", 

63 doc="Input dataset type for this task") 

64 output = cT.Output(name="add_dataset{out_tmpl}", 

65 dimensions=["instrument", "detector"], 

66 storageClass="NumpyArray", 

67 doc="Output dataset type for this task") 

68 output2 = cT.Output(name="add2_dataset{out_tmpl}", 

69 dimensions=["instrument", "detector"], 

70 storageClass="NumpyArray", 

71 doc="Output dataset type for this task") 

72 initout = cT.InitOutput(name="add_init_output{out_tmpl}", 

73 storageClass="NumpyArray", 

74 doc="Init Output dataset type for this task") 

75 

76 

77class AddTaskConfig(pipeBase.PipelineTaskConfig, 

78 pipelineConnections=AddTaskConnections): 

79 """Config for AddTask. 

80 """ 

81 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3) 

82 

83 

84class AddTask(pipeBase.PipelineTask): 

85 """Trivial PipelineTask for testing, has some extras useful for specific 

86 unit tests. 

87 """ 

88 

89 ConfigClass = AddTaskConfig 

90 _DefaultName = "add_task" 

91 

92 initout = numpy.array([999]) 

93 """InitOutputs for this task""" 

94 

95 taskFactory = None 

96 """Factory that makes instances""" 

97 

98 def run(self, input): 

99 

100 if self.taskFactory: 

101 # do some bookkeeping 

102 if self.taskFactory.stopAt == self.taskFactory.countExec: 

103 raise RuntimeError("pretend something bad happened") 

104 self.taskFactory.countExec += 1 

105 

106 self.metadata.add("add", self.config.addend) 

107 output = input + self.config.addend 

108 output2 = output + self.config.addend 

109 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2) 

110 return pipeBase.Struct(output=output, output2=output2) 

111 

112 

113class AddTaskFactoryMock(pipeBase.TaskFactory): 

114 """Special task factory that instantiates AddTask. 

115 

116 It also defines some bookkeeping variables used by AddTask to report 

117 progress to unit tests. 

118 """ 

119 def __init__(self, stopAt=-1): 

120 self.countExec = 0 # incremented by AddTask 

121 self.stopAt = stopAt # AddTask raises exception at this call to run() 

122 

123 def loadTaskClass(self, taskName): 

124 if taskName == "AddTask": 

125 return AddTask, "AddTask" 

126 

127 def makeTask(self, taskClass, config, overrides, butler): 

128 if config is None: 

129 config = taskClass.ConfigClass() 

130 if overrides: 

131 overrides.applyTo(config) 

132 task = taskClass(config=config, initInputs=None) 

133 task.taskFactory = self 

134 return task 

135 

136 

137def registerDatasetTypes(registry, pipeline): 

138 """Register all dataset types used by tasks in a registry. 

139 

140 Copied and modified from `PreExecInit.initializeDatasetTypes`. 

141 

142 Parameters 

143 ---------- 

144 registry : `~lsst.daf.butler.Registry` 

145 Registry instance. 

146 pipeline : `typing.Iterable` of `TaskDef` 

147 Iterable of TaskDef instances, likely the output of the method 

148 toExpandedPipeline on a `~lsst.pipe.base.Pipeline` object 

149 """ 

150 for taskDef in pipeline: 

151 configDatasetType = DatasetType(taskDef.configDatasetName, {}, 

152 storageClass="Config", 

153 universe=registry.dimensions) 

154 packagesDatasetType = DatasetType("packages", {}, 

155 storageClass="Packages", 

156 universe=registry.dimensions) 

157 datasetTypes = pipeBase.TaskDatasetTypes.fromTaskDef(taskDef, registry=registry) 

158 for datasetType in itertools.chain(datasetTypes.initInputs, datasetTypes.initOutputs, 

159 datasetTypes.inputs, datasetTypes.outputs, 

160 datasetTypes.prerequisites, 

161 [configDatasetType, packagesDatasetType]): 

162 _LOG.info("Registering %s with registry", datasetType) 

163 # this is a no-op if it already exists and is consistent, 

164 # and it raises if it is inconsistent. But components must be 

165 # skipped 

166 if not datasetType.isComponent(): 

167 registry.registerDatasetType(datasetType) 

168 

169 

170def makeSimplePipeline(nQuanta, instrument=None): 

171 """Make a simple Pipeline for tests. 

172 

173 This is called by ``makeSimpleQGraph`` if no pipeline is passed to that 

174 function. It can also be used to customize the pipeline used by 

175 ``makeSimpleQGraph`` function by calling this first and passing the result 

176 to it. 

177 

178 Parameters 

179 ---------- 

180 nQuanta : `int` 

181 The number of quanta to add to the pipeline. 

182 instrument : `str` or `None`, optional 

183 The importable name of an instrument to be added to the pipeline or 

184 if no instrument should be added then an empty string or `None`, by 

185 default None 

186 

187 Returns 

188 ------- 

189 pipeline : `~lsst.pipe.base.Pipeline` 

190 The created pipeline object. 

191 """ 

192 pipeline = pipeBase.Pipeline("test pipeline") 

193 # make a bunch of tasks that execute in well defined order (via data 

194 # dependencies) 

195 for lvl in range(nQuanta): 

196 pipeline.addTask(AddTask, f"task{lvl}") 

197 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", f"{lvl}") 

198 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", f"{lvl+1}") 

199 if instrument: 

200 pipeline.addInstrument(instrument) 

201 return pipeline 

202 

203 

204def makeSimpleQGraph(nQuanta=5, pipeline=None, butler=None, root=None, skipExisting=False, inMemory=True, 

205 userQuery=""): 

206 """Make simple QuantumGraph for tests. 

207 

208 Makes simple one-task pipeline with AddTask, sets up in-memory 

209 registry and butler, fills them with minimal data, and generates 

210 QuantumGraph with all of that. 

211 

212 Parameters 

213 ---------- 

214 nQuanta : `int` 

215 Number of quanta in a graph. 

216 pipeline : `~lsst.pipe.base.Pipeline` 

217 If `None` then one-task pipeline is made with `AddTask` and 

218 default `AddTaskConfig`. 

219 butler : `~lsst.daf.butler.Butler`, optional 

220 Data butler instance, this should be an instance returned from a 

221 previous call to this method. 

222 root : `str` 

223 Path or URI to the root location of the new repository. Only used if 

224 ``butler`` is None. 

225 skipExisting : `bool`, optional 

226 If `True` (default), a Quantum is not created if all its outputs 

227 already exist. 

228 inMemory : `bool`, optional 

229 If true make in-memory repository. 

230 userQuery : `str`, optional 

231 The user query to pass to ``makeGraph``, by default an empty string. 

232 

233 Returns 

234 ------- 

235 butler : `~lsst.daf.butler.Butler` 

236 Butler instance 

237 qgraph : `~lsst.pipe.base.QuantumGraph` 

238 Quantum graph instance 

239 """ 

240 

241 if pipeline is None: 

242 pipeline = makeSimplePipeline(nQuanta=nQuanta) 

243 

244 if butler is None: 

245 

246 if root is None: 

247 raise ValueError("Must provide `root` when `butler` is None") 

248 

249 config = Config() 

250 if not inMemory: 

251 config["registry", "db"] = f"sqlite:///{root}/gen3.sqlite" 

252 config["datastore", "cls"] = "lsst.daf.butler.datastores.posixDatastore.PosixDatastore" 

253 repo = butlerTests.makeTestRepo(root, {}, config=config) 

254 collection = "test" 

255 butler = Butler(butler=repo, run=collection) 

256 

257 # Add dataset types to registry 

258 registerDatasetTypes(butler.registry, pipeline.toExpandedPipeline()) 

259 

260 instrument = pipeline.getInstrument() 

261 if instrument is not None: 

262 if isinstance(instrument, str): 

263 instrument = doImport(instrument) 

264 instrumentName = instrument.getName() 

265 else: 

266 instrumentName = "INSTR" 

267 

268 # Add all needed dimensions to registry 

269 butler.registry.insertDimensionData("instrument", dict(name=instrumentName)) 

270 butler.registry.insertDimensionData("detector", dict(instrument=instrumentName, id=0, 

271 full_name="det0")) 

272 

273 # Add inputs to butler 

274 data = numpy.array([0., 1., 2., 5.]) 

275 butler.put(data, "add_dataset0", instrument=instrumentName, detector=0) 

276 

277 # Make the graph 

278 builder = pipeBase.GraphBuilder(registry=butler.registry, skipExisting=skipExisting) 

279 qgraph = builder.makeGraph( 

280 pipeline, 

281 collections=[butler.run], 

282 run=butler.run, 

283 userQuery=userQuery 

284 ) 

285 

286 return butler, qgraph