Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Simple unit test for PipelineTask. 

23""" 

24 

25import unittest 

26from types import SimpleNamespace 

27 

28import lsst.utils.tests 

29from lsst.daf.butler import DatasetRef, Quantum, DimensionUniverse, DataCoordinate 

30import lsst.pex.config as pexConfig 

31import lsst.pipe.base as pipeBase 

32 

33 

34class ButlerMock(): 

35 """Mock version of butler, only usable for this test 

36 """ 

37 def __init__(self): 

38 self.datasets = {} 

39 self.registry = SimpleNamespace(dimensions=DimensionUniverse()) 

40 

41 def getDirect(self, ref): 

42 dsdata = self.datasets.get(ref.datasetType.name) 

43 if dsdata: 

44 return dsdata.get(ref.dataId) 

45 return None 

46 

47 def put(self, inMemoryDataset, dsRef, producer=None): 

48 key = dsRef.dataId 

49 if isinstance(dsRef.datasetType, str): 

50 name = dsRef.datasetType 

51 else: 

52 name = dsRef.datasetType.name 

53 dsdata = self.datasets.setdefault(name, {}) 

54 dsdata[key] = inMemoryDataset 

55 

56 

57class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]): 

58 input = pipeBase.connectionTypes.Input(name="add_input", 

59 dimensions=["instrument", "visit", "detector", 

60 "physical_filter", "band"], 

61 storageClass="Catalog", 

62 doc="Input dataset type for this task") 

63 output = pipeBase.connectionTypes.Output(name="add_output", 

64 dimensions=["instrument", "visit", "detector", 

65 "physical_filter", "band"], 

66 storageClass="Catalog", 

67 doc="Output dataset type for this task") 

68 

69 

70class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections): 

71 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3) 

72 

73 

74# example task which overrides run() method 

75class AddTask(pipeBase.PipelineTask): 

76 ConfigClass = AddConfig 

77 _DefaultName = "add_task" 

78 

79 def run(self, input): 

80 self.metadata.add("add", self.config.addend) 

81 output = input + self.config.addend 

82 return pipeBase.Struct(output=output) 

83 

84 

85# example task which overrides adaptArgsAndRun() method 

86class AddTask2(pipeBase.PipelineTask): 

87 ConfigClass = AddConfig 

88 _DefaultName = "add_task" 

89 

90 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

91 self.metadata.add("add", self.config.addend) 

92 inputs = butlerQC.get(inputRefs) 

93 outputs = inputs['input'] + self.config.addend 

94 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs) 

95 

96 

97class PipelineTaskTestCase(unittest.TestCase): 

98 """A test case for PipelineTask 

99 """ 

100 

101 def _makeDSRefVisit(self, dstype, visitId, universe): 

102 return DatasetRef( 

103 datasetType=dstype, 

104 dataId=DataCoordinate.standardize( 

105 detector="X", 

106 visit=visitId, 

107 physical_filter='a', 

108 band='b', 

109 instrument='TestInstrument', 

110 universe=universe 

111 ) 

112 ) 

113 

114 def _makeQuanta(self, config): 

115 """Create set of Quanta 

116 """ 

117 universe = DimensionUniverse() 

118 connections = config.connections.ConnectionsClass(config=config) 

119 

120 dstype0 = connections.input.makeDatasetType(universe) 

121 dstype1 = connections.output.makeDatasetType(universe) 

122 

123 quanta = [] 

124 for visit in range(100): 

125 inputRef = self._makeDSRefVisit(dstype0, visit, universe) 

126 outputRef = self._makeDSRefVisit(dstype1, visit, universe) 

127 quantum = Quantum(inputs={inputRef.datasetType: [inputRef]}, 

128 outputs={outputRef.datasetType: [outputRef]}) 

129 quanta.append(quantum) 

130 

131 return quanta 

132 

133 def testRunQuantum(self): 

134 """Test for AddTask.runQuantum() implementation. 

135 """ 

136 butler = ButlerMock() 

137 task = AddTask(config=AddConfig()) 

138 connections = task.config.connections.ConnectionsClass(config=task.config) 

139 

140 # make all quanta 

141 quanta = self._makeQuanta(task.config) 

142 

143 # add input data to butler 

144 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions) 

145 for i, quantum in enumerate(quanta): 

146 ref = quantum.inputs[dstype0.name][0] 

147 butler.put(100 + i, pipeBase.Struct(datasetType=dstype0.name, dataId=ref.dataId)) 

148 

149 # run task on each quanta 

150 for quantum in quanta: 

151 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum) 

152 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

153 task.runQuantum(butlerQC, inputRefs, outputRefs) 

154 

155 # look at the output produced by the task 

156 outputName = connections.output.name 

157 dsdata = butler.datasets[outputName] 

158 self.assertEqual(len(dsdata), len(quanta)) 

159 for i, quantum in enumerate(quanta): 

160 ref = quantum.outputs[outputName][0] 

161 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

162 

163 def testChain2(self): 

164 """Test for two-task chain. 

165 """ 

166 butler = ButlerMock() 

167 config1 = AddConfig() 

168 connections1 = config1.connections.ConnectionsClass(config=config1) 

169 task1 = AddTask(config=config1) 

170 config2 = AddConfig() 

171 config2.addend = 200 

172 config2.connections.input = task1.config.connections.output 

173 config2.connections.output = "add_output_2" 

174 task2 = AddTask2(config=config2) 

175 connections2 = config2.connections.ConnectionsClass(config=config2) 

176 

177 # make all quanta 

178 quanta1 = self._makeQuanta(task1.config) 

179 quanta2 = self._makeQuanta(task2.config) 

180 

181 # add input data to butler 

182 task1Connections = task1.config.connections.ConnectionsClass(config=task1.config) 

183 dstype0 = task1Connections.input.makeDatasetType(butler.registry.dimensions) 

184 for i, quantum in enumerate(quanta1): 

185 ref = quantum.inputs[dstype0.name][0] 

186 butler.put(100 + i, pipeBase.Struct(datasetType=dstype0.name, dataId=ref.dataId)) 

187 

188 # run task on each quanta 

189 for quantum in quanta1: 

190 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum) 

191 inputRefs, outputRefs = connections1.buildDatasetRefs(quantum) 

192 task1.runQuantum(butlerQC, inputRefs, outputRefs) 

193 for quantum in quanta2: 

194 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum) 

195 inputRefs, outputRefs = connections2.buildDatasetRefs(quantum) 

196 task2.runQuantum(butlerQC, inputRefs, outputRefs) 

197 

198 # look at the output produced by the task 

199 outputName = task1.config.connections.output 

200 dsdata = butler.datasets[outputName] 

201 self.assertEqual(len(dsdata), len(quanta1)) 

202 for i, quantum in enumerate(quanta1): 

203 ref = quantum.outputs[outputName][0] 

204 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

205 

206 outputName = task2.config.connections.output 

207 dsdata = butler.datasets[outputName] 

208 self.assertEqual(len(dsdata), len(quanta2)) 

209 for i, quantum in enumerate(quanta2): 

210 ref = quantum.outputs[outputName][0] 

211 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200) 

212 

213 

214class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

215 pass 

216 

217 

218def setup_module(module): 

219 lsst.utils.tests.init() 

220 

221 

222if __name__ == "__main__": 222 ↛ 223line 222 didn't jump to line 223, because the condition on line 222 was never true

223 lsst.utils.tests.init() 

224 unittest.main()