Coverage for tests/test_pipelineTask.py: 30%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

122 statements  

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Simple unit test for PipelineTask. 

23""" 

24 

25import unittest 

26from types import SimpleNamespace 

27 

28import lsst.pex.config as pexConfig 

29import lsst.pipe.base as pipeBase 

30import lsst.utils.tests 

31from lsst.daf.butler import DataCoordinate, DatasetRef, DimensionUniverse, Quantum 

32 

33 

34class ButlerMock: 

35 """Mock version of butler, only usable for this test""" 

36 

37 def __init__(self): 

38 self.datasets = {} 

39 self.registry = SimpleNamespace(dimensions=DimensionUniverse()) 

40 

41 def getDirect(self, ref): 

42 dsdata = self.datasets.get(ref.datasetType.name) 

43 if dsdata: 

44 return dsdata.get(ref.dataId) 

45 return None 

46 

47 def put(self, inMemoryDataset, dsRef, producer=None): 

48 key = dsRef.dataId 

49 if isinstance(dsRef.datasetType, str): 

50 name = dsRef.datasetType 

51 else: 

52 name = dsRef.datasetType.name 

53 dsdata = self.datasets.setdefault(name, {}) 

54 dsdata[key] = inMemoryDataset 

55 

56 

57class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]): 

58 input = pipeBase.connectionTypes.Input( 

59 name="add_input", 

60 dimensions=["instrument", "visit", "detector", "physical_filter", "band"], 

61 storageClass="Catalog", 

62 doc="Input dataset type for this task", 

63 ) 

64 output = pipeBase.connectionTypes.Output( 

65 name="add_output", 

66 dimensions=["instrument", "visit", "detector", "physical_filter", "band"], 

67 storageClass="Catalog", 

68 doc="Output dataset type for this task", 

69 ) 

70 

71 

72class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections): 

73 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3) 

74 

75 

76# example task which overrides run() method 

77class AddTask(pipeBase.PipelineTask): 

78 ConfigClass = AddConfig 

79 _DefaultName = "add_task" 

80 

81 def run(self, input): 

82 self.metadata.add("add", self.config.addend) 

83 output = input + self.config.addend 

84 return pipeBase.Struct(output=output) 

85 

86 

87# example task which overrides adaptArgsAndRun() method 

88class AddTask2(pipeBase.PipelineTask): 

89 ConfigClass = AddConfig 

90 _DefaultName = "add_task" 

91 

92 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

93 self.metadata.add("add", self.config.addend) 

94 inputs = butlerQC.get(inputRefs) 

95 outputs = inputs["input"] + self.config.addend 

96 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs) 

97 

98 

99class PipelineTaskTestCase(unittest.TestCase): 

100 """A test case for PipelineTask""" 

101 

102 def _makeDSRefVisit(self, dstype, visitId, universe): 

103 return DatasetRef( 

104 datasetType=dstype, 

105 dataId=DataCoordinate.standardize( 

106 detector="X", 

107 visit=visitId, 

108 physical_filter="a", 

109 band="b", 

110 instrument="TestInstrument", 

111 universe=universe, 

112 ), 

113 ) 

114 

115 def _makeQuanta(self, config): 

116 """Create set of Quanta""" 

117 universe = DimensionUniverse() 

118 connections = config.connections.ConnectionsClass(config=config) 

119 

120 dstype0 = connections.input.makeDatasetType(universe) 

121 dstype1 = connections.output.makeDatasetType(universe) 

122 

123 quanta = [] 

124 for visit in range(100): 

125 inputRef = self._makeDSRefVisit(dstype0, visit, universe) 

126 outputRef = self._makeDSRefVisit(dstype1, visit, universe) 

127 quantum = Quantum( 

128 inputs={inputRef.datasetType: [inputRef]}, outputs={outputRef.datasetType: [outputRef]} 

129 ) 

130 quanta.append(quantum) 

131 

132 return quanta 

133 

134 def testRunQuantum(self): 

135 """Test for AddTask.runQuantum() implementation.""" 

136 butler = ButlerMock() 

137 task = AddTask(config=AddConfig()) 

138 connections = task.config.connections.ConnectionsClass(config=task.config) 

139 

140 # make all quanta 

141 quanta = self._makeQuanta(task.config) 

142 

143 # add input data to butler 

144 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions) 

145 for i, quantum in enumerate(quanta): 

146 ref = quantum.inputs[dstype0.name][0] 

147 butler.put(100 + i, pipeBase.Struct(datasetType=dstype0.name, dataId=ref.dataId)) 

148 

149 # run task on each quanta 

150 for quantum in quanta: 

151 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum) 

152 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

153 task.runQuantum(butlerQC, inputRefs, outputRefs) 

154 

155 # look at the output produced by the task 

156 outputName = connections.output.name 

157 dsdata = butler.datasets[outputName] 

158 self.assertEqual(len(dsdata), len(quanta)) 

159 for i, quantum in enumerate(quanta): 

160 ref = quantum.outputs[outputName][0] 

161 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

162 

163 def testChain2(self): 

164 """Test for two-task chain.""" 

165 butler = ButlerMock() 

166 config1 = AddConfig() 

167 connections1 = config1.connections.ConnectionsClass(config=config1) 

168 task1 = AddTask(config=config1) 

169 config2 = AddConfig() 

170 config2.addend = 200 

171 config2.connections.input = task1.config.connections.output 

172 config2.connections.output = "add_output_2" 

173 task2 = AddTask2(config=config2) 

174 connections2 = config2.connections.ConnectionsClass(config=config2) 

175 

176 # make all quanta 

177 quanta1 = self._makeQuanta(task1.config) 

178 quanta2 = self._makeQuanta(task2.config) 

179 

180 # add input data to butler 

181 task1Connections = task1.config.connections.ConnectionsClass(config=task1.config) 

182 dstype0 = task1Connections.input.makeDatasetType(butler.registry.dimensions) 

183 for i, quantum in enumerate(quanta1): 

184 ref = quantum.inputs[dstype0.name][0] 

185 butler.put(100 + i, pipeBase.Struct(datasetType=dstype0.name, dataId=ref.dataId)) 

186 

187 # run task on each quanta 

188 for quantum in quanta1: 

189 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum) 

190 inputRefs, outputRefs = connections1.buildDatasetRefs(quantum) 

191 task1.runQuantum(butlerQC, inputRefs, outputRefs) 

192 for quantum in quanta2: 

193 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum) 

194 inputRefs, outputRefs = connections2.buildDatasetRefs(quantum) 

195 task2.runQuantum(butlerQC, inputRefs, outputRefs) 

196 

197 # look at the output produced by the task 

198 outputName = task1.config.connections.output 

199 dsdata = butler.datasets[outputName] 

200 self.assertEqual(len(dsdata), len(quanta1)) 

201 for i, quantum in enumerate(quanta1): 

202 ref = quantum.outputs[outputName][0] 

203 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

204 

205 outputName = task2.config.connections.output 

206 dsdata = butler.datasets[outputName] 

207 self.assertEqual(len(dsdata), len(quanta2)) 

208 for i, quantum in enumerate(quanta2): 

209 ref = quantum.outputs[outputName][0] 

210 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200) 

211 

212 

213class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

214 pass 

215 

216 

217def setup_module(module): 

218 lsst.utils.tests.init() 

219 

220 

221if __name__ == "__main__": 221 ↛ 222line 221 didn't jump to line 222, because the condition on line 221 was never true

222 lsst.utils.tests.init() 

223 unittest.main()