Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Bunch of common classes and methods for use in unit tests. 

23""" 

24 

25__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"] 

26 

27import itertools 

28import logging 

29import numpy 

30 

31from lsst.daf.butler import Butler, Config, DatasetType 

32import lsst.daf.butler.tests as butlerTests 

33import lsst.pex.config as pexConfig 

34from ... import base as pipeBase 

35from .. import connectionTypes as cT 

36 

37_LOG = logging.getLogger(__name__) 

38 

39 

40# SimpleInstrument has an instrument-like API as needed for unit testing, but 

41# can not explicitly depend on Instrument because pipe_base does not explicitly 

42# depend on obs_base. 

43class SimpleInstrument: 

44 

45 @staticmethod 

46 def getName(): 

47 return "SimpleInstrument" 

48 

49 def applyConfigOverrides(self, name, config): 

50 pass 

51 

52 

53class AddTaskConnections(pipeBase.PipelineTaskConnections, 

54 dimensions=("instrument", "detector"), 

55 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"}): 

56 """Connections for AddTask, has one input and two outputs, 

57 plus one init output. 

58 """ 

59 input = cT.Input(name="add_dataset{in_tmpl}", 

60 dimensions=["instrument", "detector"], 

61 storageClass="NumpyArray", 

62 doc="Input dataset type for this task") 

63 output = cT.Output(name="add_dataset{out_tmpl}", 

64 dimensions=["instrument", "detector"], 

65 storageClass="NumpyArray", 

66 doc="Output dataset type for this task") 

67 output2 = cT.Output(name="add2_dataset{out_tmpl}", 

68 dimensions=["instrument", "detector"], 

69 storageClass="NumpyArray", 

70 doc="Output dataset type for this task") 

71 initout = cT.InitOutput(name="add_init_output{out_tmpl}", 

72 storageClass="NumpyArray", 

73 doc="Init Output dataset type for this task") 

74 

75 

76class AddTaskConfig(pipeBase.PipelineTaskConfig, 

77 pipelineConnections=AddTaskConnections): 

78 """Config for AddTask. 

79 """ 

80 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3) 

81 

82 

83class AddTask(pipeBase.PipelineTask): 

84 """Trivial PipelineTask for testing, has some extras useful for specific 

85 unit tests. 

86 """ 

87 

88 ConfigClass = AddTaskConfig 

89 _DefaultName = "add_task" 

90 

91 initout = numpy.array([999]) 

92 """InitOutputs for this task""" 

93 

94 taskFactory = None 

95 """Factory that makes instances""" 

96 

97 def run(self, input): 

98 

99 if self.taskFactory: 

100 # do some bookkeeping 

101 if self.taskFactory.stopAt == self.taskFactory.countExec: 

102 raise RuntimeError("pretend something bad happened") 

103 self.taskFactory.countExec += 1 

104 

105 self.metadata.add("add", self.config.addend) 

106 output = input + self.config.addend 

107 output2 = output + self.config.addend 

108 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2) 

109 return pipeBase.Struct(output=output, output2=output2) 

110 

111 

112class AddTaskFactoryMock(pipeBase.TaskFactory): 

113 """Special task factory that instantiates AddTask. 

114 

115 It also defines some bookkeeping variables used by AddTask to report 

116 progress to unit tests. 

117 """ 

118 def __init__(self, stopAt=-1): 

119 self.countExec = 0 # incremented by AddTask 

120 self.stopAt = stopAt # AddTask raises exception at this call to run() 

121 

122 def loadTaskClass(self, taskName): 

123 if taskName == "AddTask": 

124 return AddTask, "AddTask" 

125 

126 def makeTask(self, taskClass, config, overrides, butler): 

127 if config is None: 

128 config = taskClass.ConfigClass() 

129 if overrides: 

130 overrides.applyTo(config) 

131 task = taskClass(config=config, initInputs=None) 

132 task.taskFactory = self 

133 return task 

134 

135 

136def registerDatasetTypes(registry, pipeline): 

137 """Register all dataset types used by tasks in a registry. 

138 

139 Copied and modified from `PreExecInit.initializeDatasetTypes`. 

140 

141 Parameters 

142 ---------- 

143 registry : `~lsst.daf.butler.Registry` 

144 Registry instance. 

145 pipeline : `typing.Iterable` of `TaskDef` 

146 Iterable of TaskDef instances, likely the output of the method 

147 toExpandedPipeline on a `~lsst.pipe.base.Pipeline` object 

148 """ 

149 for taskDef in pipeline: 

150 configDatasetType = DatasetType(taskDef.configDatasetName, {}, 

151 storageClass="Config", 

152 universe=registry.dimensions) 

153 packagesDatasetType = DatasetType("packages", {}, 

154 storageClass="Packages", 

155 universe=registry.dimensions) 

156 datasetTypes = pipeBase.TaskDatasetTypes.fromTaskDef(taskDef, registry=registry) 

157 for datasetType in itertools.chain(datasetTypes.initInputs, datasetTypes.initOutputs, 

158 datasetTypes.inputs, datasetTypes.outputs, 

159 datasetTypes.prerequisites, 

160 [configDatasetType, packagesDatasetType]): 

161 _LOG.info("Registering %s with registry", datasetType) 

162 # this is a no-op if it already exists and is consistent, 

163 # and it raises if it is inconsistent. But components must be 

164 # skipped 

165 if not datasetType.isComponent(): 

166 registry.registerDatasetType(datasetType) 

167 

168 

169def makeSimplePipeline(nQuanta, instrument=None): 

170 """Make a simple Pipeline for tests. 

171 

172 This is called by ``makeSimpleQGraph`` if no pipeline is passed to that 

173 function. It can also be used to customize the pipeline used by 

174 ``makeSimpleQGraph`` function by calling this first and passing the result 

175 to it. 

176 

177 Parameters 

178 ---------- 

179 nQuanta : `int` 

180 The number of quanta to add to the pipeline. 

181 instrument : `str` or `None`, optional 

182 The importable name of an instrument to be added to the pipeline or 

183 if no instrument should be added then an empty string or `None`, by 

184 default None 

185 

186 Returns 

187 ------- 

188 pipeline : `~lsst.pipe.base.Pipeline` 

189 The created pipeline object. 

190 """ 

191 pipeline = pipeBase.Pipeline("test pipeline") 

192 # make a bunch of tasks that execute in well defined order (via data 

193 # dependencies) 

194 for lvl in range(nQuanta): 

195 pipeline.addTask(AddTask, f"task{lvl}") 

196 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", f"{lvl}") 

197 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", f"{lvl+1}") 

198 if instrument: 

199 pipeline.addInstrument(instrument) 

200 return pipeline 

201 

202 

203def makeSimpleQGraph(nQuanta=5, pipeline=None, butler=None, root=None, skipExisting=False, inMemory=True, 

204 userQuery=""): 

205 """Make simple QuantumGraph for tests. 

206 

207 Makes simple one-task pipeline with AddTask, sets up in-memory 

208 registry and butler, fills them with minimal data, and generates 

209 QuantumGraph with all of that. 

210 

211 Parameters 

212 ---------- 

213 nQuanta : `int` 

214 Number of quanta in a graph. 

215 pipeline : `~lsst.pipe.base.Pipeline` 

216 If `None` then one-task pipeline is made with `AddTask` and 

217 default `AddTaskConfig`. 

218 butler : `~lsst.daf.butler.Butler`, optional 

219 Data butler instance, this should be an instance returned from a 

220 previous call to this method. 

221 root : `str` 

222 Path or URI to the root location of the new repository. Only used if 

223 ``butler`` is None. 

224 skipExisting : `bool`, optional 

225 If `True` (default), a Quantum is not created if all its outputs 

226 already exist. 

227 inMemory : `bool`, optional 

228 If true make in-memory repository. 

229 userQuery : `str`, optional 

230 The user query to pass to ``makeGraph``, by default an empty string. 

231 

232 Returns 

233 ------- 

234 butler : `~lsst.daf.butler.Butler` 

235 Butler instance 

236 qgraph : `~lsst.pipe.base.QuantumGraph` 

237 Quantum graph instance 

238 """ 

239 

240 if pipeline is None: 

241 pipeline = makeSimplePipeline(nQuanta=nQuanta) 

242 

243 if butler is None: 

244 

245 if root is None: 

246 raise ValueError("Must provide `root` when `butler` is None") 

247 

248 config = Config() 

249 if not inMemory: 

250 config["registry", "db"] = f"sqlite:///{root}/gen3.sqlite" 

251 config["datastore", "cls"] = "lsst.daf.butler.datastores.posixDatastore.PosixDatastore" 

252 repo = butlerTests.makeTestRepo(root, {}, config=config) 

253 collection = "test" 

254 butler = Butler(butler=repo, run=collection) 

255 

256 # Add dataset types to registry 

257 registerDatasetTypes(butler.registry, pipeline.toExpandedPipeline()) 

258 

259 # Add all needed dimensions to registry 

260 butler.registry.insertDimensionData("instrument", dict(name="INSTR")) 

261 butler.registry.insertDimensionData("detector", dict(instrument="INSTR", id=0, full_name="det0")) 

262 

263 # Add inputs to butler 

264 data = numpy.array([0., 1., 2., 5.]) 

265 butler.put(data, "add_dataset0", instrument="INSTR", detector=0) 

266 

267 # Make the graph 

268 builder = pipeBase.GraphBuilder(registry=butler.registry, skipExisting=skipExisting) 

269 qgraph = builder.makeGraph( 

270 pipeline, 

271 collections=[butler.run], 

272 run=butler.run, 

273 userQuery=userQuery 

274 ) 

275 

276 return butler, qgraph