Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Bunch of common classes and methods for use in unit tests. 

23""" 

24 

25__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"] 

26 

27import itertools 

28import logging 

29import numpy 

30 

31from lsst.daf.butler import (Butler, Config, DatasetType, CollectionSearch) 

32import lsst.daf.butler.tests as butlerTests 

33import lsst.pex.config as pexConfig 

34import lsst.pipe.base as pipeBase 

35from lsst.pipe.base import connectionTypes as cT 

36 

37_LOG = logging.getLogger(__name__) 

38 

39 

40class AddTaskConnections(pipeBase.PipelineTaskConnections, 

41 dimensions=("instrument", "detector"), 

42 defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"}): 

43 """Connections for AddTask, has one input and two outputs, 

44 plus one init output. 

45 """ 

46 input = cT.Input(name="add_dataset{in_tmpl}", 

47 dimensions=["instrument", "detector"], 

48 storageClass="NumpyArray", 

49 doc="Input dataset type for this task") 

50 output = cT.Output(name="add_dataset{out_tmpl}", 

51 dimensions=["instrument", "detector"], 

52 storageClass="NumpyArray", 

53 doc="Output dataset type for this task") 

54 output2 = cT.Output(name="add2_dataset{out_tmpl}", 

55 dimensions=["instrument", "detector"], 

56 storageClass="NumpyArray", 

57 doc="Output dataset type for this task") 

58 initout = cT.InitOutput(name="add_init_output{out_tmpl}", 

59 storageClass="NumpyArray", 

60 doc="Init Output dataset type for this task") 

61 

62 

63class AddTaskConfig(pipeBase.PipelineTaskConfig, 

64 pipelineConnections=AddTaskConnections): 

65 """Config for AddTask. 

66 """ 

67 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3) 

68 

69 

70class AddTask(pipeBase.PipelineTask): 

71 """Trivial PipelineTask for testing, has some extras useful for specific 

72 unit tests. 

73 """ 

74 

75 ConfigClass = AddTaskConfig 

76 _DefaultName = "add_task" 

77 

78 initout = numpy.array([999]) 

79 """InitOutputs for this task""" 

80 

81 taskFactory = None 

82 """Factory that makes instances""" 

83 

84 def run(self, input): 

85 

86 if self.taskFactory: 

87 # do some bookkeeping 

88 if self.taskFactory.stopAt == self.taskFactory.countExec: 

89 raise RuntimeError("pretend something bad happened") 

90 self.taskFactory.countExec += 1 

91 

92 self.metadata.add("add", self.config.addend) 

93 output = input + self.config.addend 

94 output2 = output + self.config.addend 

95 _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2) 

96 return pipeBase.Struct(output=output, output2=output2) 

97 

98 

99class AddTaskFactoryMock(pipeBase.TaskFactory): 

100 """Special task factory that instantiates AddTask. 

101 

102 It also defines some bookkeeping variables used by AddTask to report 

103 progress to unit tests. 

104 """ 

105 def __init__(self, stopAt=-1): 

106 self.countExec = 0 # incremented by AddTask 

107 self.stopAt = stopAt # AddTask raises exception at this call to run() 

108 

109 def loadTaskClass(self, taskName): 

110 if taskName == "AddTask": 

111 return AddTask, "AddTask" 

112 

113 def makeTask(self, taskClass, config, overrides, butler): 

114 if config is None: 

115 config = taskClass.ConfigClass() 

116 if overrides: 

117 overrides.applyTo(config) 

118 task = taskClass(config=config, initInputs=None) 

119 task.taskFactory = self 

120 return task 

121 

122 

123def registerDatasetTypes(registry, pipeline): 

124 """Register all dataset types used by tasks in a registry. 

125 

126 Copied and modified from `PreExecInit.initializeDatasetTypes`. 

127 

128 Parameters 

129 ---------- 

130 registry : `~lsst.daf.butler.Registry` 

131 Registry instance. 

132 pipeline : `typing.Iterable` of `TaskDef` 

133 Iterable of TaskDef instances, likely the output of the method 

134 toExpandedPipeline on a `~lsst.pipe.base.Pipeline` object 

135 """ 

136 for taskDef in pipeline: 

137 configDatasetType = DatasetType(taskDef.configDatasetName, {}, 

138 storageClass="Config", 

139 universe=registry.dimensions) 

140 packagesDatasetType = DatasetType("packages", {}, 

141 storageClass="Packages", 

142 universe=registry.dimensions) 

143 datasetTypes = pipeBase.TaskDatasetTypes.fromTaskDef(taskDef, registry=registry) 

144 for datasetType in itertools.chain(datasetTypes.initInputs, datasetTypes.initOutputs, 

145 datasetTypes.inputs, datasetTypes.outputs, 

146 datasetTypes.prerequisites, 

147 [configDatasetType, packagesDatasetType]): 

148 _LOG.info("Registering %s with registry", datasetType) 

149 # this is a no-op if it already exists and is consistent, 

150 # and it raises if it is inconsistent. But components must be 

151 # skipped 

152 if not datasetType.isComponent(): 

153 registry.registerDatasetType(datasetType) 

154 

155 

156def makeSimpleQGraph(nQuanta=5, pipeline=None, butler=None, root=None, skipExisting=False, inMemory=True): 

157 """Make simple QuantumGraph for tests. 

158 

159 Makes simple one-task pipeline with AddTask, sets up in-memory 

160 registry and butler, fills them with minimal data, and generates 

161 QuantumGraph with all of that. 

162 

163 Parameters 

164 ---------- 

165 nQuanta : `int` 

166 Number of quanta in a graph. 

167 pipeline : `~lsst.pipe.base.Pipeline` 

168 If `None` then one-task pipeline is made with `AddTask` and 

169 default `AddTaskConfig`. 

170 butler : `~lsst.daf.butler.Butler`, optional 

171 Data butler instance, this should be an instance returned from a 

172 previous call to this method. 

173 root : `str` 

174 Path or URI to the root location of the new repository. Only used if 

175 ``butler`` is None. 

176 skipExisting : `bool`, optional 

177 If `True` (default), a Quantum is not created if all its outputs 

178 already exist. 

179 inMemory : `bool`, optional 

180 If true make in-memory repository. 

181 

182 Returns 

183 ------- 

184 butler : `~lsst.daf.butler.Butler` 

185 Butler instance 

186 qgraph : `~lsst.pipe.base.QuantumGraph` 

187 Quantum graph instance 

188 """ 

189 

190 if pipeline is None: 

191 pipeline = pipeBase.Pipeline("test pipeline") 

192 # make a bunch of tasks that execute in well defined order (via data 

193 # dependencies) 

194 for lvl in range(nQuanta): 

195 pipeline.addTask(AddTask, f"task{lvl}") 

196 pipeline.addConfigOverride(f"task{lvl}", "connections.in_tmpl", f"{lvl}") 

197 pipeline.addConfigOverride(f"task{lvl}", "connections.out_tmpl", f"{lvl+1}") 

198 

199 if butler is None: 

200 

201 if root is None: 

202 raise ValueError("Must provide `root` when `butler` is None") 

203 

204 config = Config() 

205 if not inMemory: 

206 config["registry", "db"] = f"sqlite:///{root}/gen3.sqlite" 

207 config["datastore", "cls"] = "lsst.daf.butler.datastores.posixDatastore.PosixDatastore" 

208 repo = butlerTests.makeTestRepo(root, {}, config=config) 

209 collection = "test" 

210 butler = Butler(butler=repo, run=collection) 

211 

212 # Add dataset types to registry 

213 registerDatasetTypes(butler.registry, pipeline.toExpandedPipeline()) 

214 

215 # Add all needed dimensions to registry 

216 butler.registry.insertDimensionData("instrument", dict(name="INSTR")) 

217 butler.registry.insertDimensionData("detector", dict(instrument="INSTR", id=0, full_name="det0")) 

218 

219 # Add inputs to butler 

220 data = numpy.array([0., 1., 2., 5.]) 

221 butler.put(data, "add_dataset0", instrument="INSTR", detector=0) 

222 

223 # Make the graph 

224 builder = pipeBase.GraphBuilder(registry=butler.registry, skipExisting=skipExisting) 

225 qgraph = builder.makeGraph( 

226 pipeline, 

227 collections=CollectionSearch.fromExpression(butler.run), 

228 run=butler.run, 

229 userQuery="" 

230 ) 

231 

232 return butler, qgraph