Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Simple unit test for PipelineTask. 

23""" 

24 

25import unittest 

26from types import SimpleNamespace 

27 

28import lsst.utils.tests 

29from lsst.daf.butler import DatasetRef, Quantum, DimensionUniverse, DataCoordinate 

30import lsst.pex.config as pexConfig 

31import lsst.pipe.base as pipeBase 

32 

33 

34class ButlerMock(): 

35 """Mock version of butler, only usable for this test 

36 """ 

37 def __init__(self): 

38 self.datasets = {} 

39 self.registry = SimpleNamespace(dimensions=DimensionUniverse()) 

40 

41 def get(self, datasetRefOrType, dataId=None): 

42 if isinstance(datasetRefOrType, DatasetRef): 

43 dataId = datasetRefOrType.dataId 

44 dsTypeName = datasetRefOrType.datasetType.name 

45 else: 

46 dsTypeName = datasetRefOrType 

47 key = dataId 

48 dsdata = self.datasets.get(dsTypeName) 

49 if dsdata: 

50 return dsdata.get(key) 

51 return None 

52 

53 def put(self, inMemoryDataset, dsRef, producer=None): 

54 key = dsRef.dataId 

55 if isinstance(dsRef.datasetType, str): 

56 name = dsRef.datasetType 

57 else: 

58 name = dsRef.datasetType.name 

59 dsdata = self.datasets.setdefault(name, {}) 

60 dsdata[key] = inMemoryDataset 

61 

62 

63class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]): 

64 input = pipeBase.connectionTypes.Input(name="add_input", 

65 dimensions=["instrument", "visit", "detector", 

66 "physical_filter", "band"], 

67 storageClass="Catalog", 

68 doc="Input dataset type for this task") 

69 output = pipeBase.connectionTypes.Output(name="add_output", 

70 dimensions=["instrument", "visit", "detector", 

71 "physical_filter", "band"], 

72 storageClass="Catalog", 

73 doc="Output dataset type for this task") 

74 

75 

76class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections): 

77 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3) 

78 

79 

80# example task which overrides run() method 

81class AddTask(pipeBase.PipelineTask): 

82 ConfigClass = AddConfig 

83 _DefaultName = "add_task" 

84 

85 def run(self, input): 

86 self.metadata.add("add", self.config.addend) 

87 output = input + self.config.addend 

88 return pipeBase.Struct(output=output) 

89 

90 

91# example task which overrides adaptArgsAndRun() method 

92class AddTask2(pipeBase.PipelineTask): 

93 ConfigClass = AddConfig 

94 _DefaultName = "add_task" 

95 

96 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

97 self.metadata.add("add", self.config.addend) 

98 inputs = butlerQC.get(inputRefs) 

99 outputs = inputs['input'] + self.config.addend 

100 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs) 

101 

102 

103class PipelineTaskTestCase(unittest.TestCase): 

104 """A test case for PipelineTask 

105 """ 

106 

107 def _makeDSRefVisit(self, dstype, visitId, universe): 

108 return DatasetRef( 

109 datasetType=dstype, 

110 dataId=DataCoordinate.standardize( 

111 detector="X", 

112 visit=visitId, 

113 physical_filter='a', 

114 band='b', 

115 instrument='TestInstrument', 

116 universe=universe 

117 ) 

118 ) 

119 

120 def _makeQuanta(self, config): 

121 """Create set of Quanta 

122 """ 

123 universe = DimensionUniverse() 

124 run = "run1" 

125 connections = config.connections.ConnectionsClass(config=config) 

126 

127 dstype0 = connections.input.makeDatasetType(universe) 

128 dstype1 = connections.output.makeDatasetType(universe) 

129 

130 quanta = [] 

131 for visit in range(100): 

132 quantum = Quantum(run=run) 

133 quantum.addPredictedInput(self._makeDSRefVisit(dstype0, visit, universe)) 

134 quantum.addOutput(self._makeDSRefVisit(dstype1, visit, universe)) 

135 quanta.append(quantum) 

136 

137 return quanta 

138 

139 def testRunQuantum(self): 

140 """Test for AddTask.runQuantum() implementation. 

141 """ 

142 butler = ButlerMock() 

143 task = AddTask(config=AddConfig()) 

144 connections = task.config.connections.ConnectionsClass(config=task.config) 

145 

146 # make all quanta 

147 quanta = self._makeQuanta(task.config) 

148 

149 # add input data to butler 

150 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions) 

151 for i, quantum in enumerate(quanta): 

152 ref = quantum.predictedInputs[dstype0.name][0] 

153 butler.put(100 + i, pipeBase.Struct(datasetType=dstype0.name, dataId=ref.dataId)) 

154 

155 # run task on each quanta 

156 for quantum in quanta: 

157 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum) 

158 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

159 task.runQuantum(butlerQC, inputRefs, outputRefs) 

160 

161 # look at the output produced by the task 

162 outputName = connections.output.name 

163 dsdata = butler.datasets[outputName] 

164 self.assertEqual(len(dsdata), len(quanta)) 

165 for i, quantum in enumerate(quanta): 

166 ref = quantum.outputs[outputName][0] 

167 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

168 

169 def testChain2(self): 

170 """Test for two-task chain. 

171 """ 

172 butler = ButlerMock() 

173 config1 = AddConfig() 

174 connections1 = config1.connections.ConnectionsClass(config=config1) 

175 task1 = AddTask(config=config1) 

176 config2 = AddConfig() 

177 config2.addend = 200 

178 config2.connections.input = task1.config.connections.output 

179 config2.connections.output = "add_output_2" 

180 task2 = AddTask2(config=config2) 

181 connections2 = config2.connections.ConnectionsClass(config=config2) 

182 

183 # make all quanta 

184 quanta1 = self._makeQuanta(task1.config) 

185 quanta2 = self._makeQuanta(task2.config) 

186 

187 # add input data to butler 

188 task1Connections = task1.config.connections.ConnectionsClass(config=task1.config) 

189 dstype0 = task1Connections.input.makeDatasetType(butler.registry.dimensions) 

190 for i, quantum in enumerate(quanta1): 

191 ref = quantum.predictedInputs[dstype0.name][0] 

192 butler.put(100 + i, pipeBase.Struct(datasetType=dstype0.name, dataId=ref.dataId)) 

193 

194 # run task on each quanta 

195 for quantum in quanta1: 

196 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum) 

197 inputRefs, outputRefs = connections1.buildDatasetRefs(quantum) 

198 task1.runQuantum(butlerQC, inputRefs, outputRefs) 

199 for quantum in quanta2: 

200 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum) 

201 inputRefs, outputRefs = connections2.buildDatasetRefs(quantum) 

202 task2.runQuantum(butlerQC, inputRefs, outputRefs) 

203 

204 # look at the output produced by the task 

205 outputName = task1.config.connections.output 

206 dsdata = butler.datasets[outputName] 

207 self.assertEqual(len(dsdata), len(quanta1)) 

208 for i, quantum in enumerate(quanta1): 

209 ref = quantum.outputs[outputName][0] 

210 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

211 

212 outputName = task2.config.connections.output 

213 dsdata = butler.datasets[outputName] 

214 self.assertEqual(len(dsdata), len(quanta2)) 

215 for i, quantum in enumerate(quanta2): 

216 ref = quantum.outputs[outputName][0] 

217 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200) 

218 

219 

220class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

221 pass 

222 

223 

224def setup_module(module): 

225 lsst.utils.tests.init() 

226 

227 

228if __name__ == "__main__": 228 ↛ 229line 228 didn't jump to line 229, because the condition on line 228 was never true

229 lsst.utils.tests.init() 

230 unittest.main()