Coverage for tests/test_pipelineTask.py: 19%

145 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-08 14:36 -0800

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Simple unit test for PipelineTask. 

23""" 

24 

25import unittest 

26from types import SimpleNamespace 

27 

28import lsst.pex.config as pexConfig 

29import lsst.pipe.base as pipeBase 

30import lsst.utils.logging 

31import lsst.utils.tests 

32from lsst.daf.butler import DataCoordinate, DatasetRef, DimensionUniverse, Quantum 

33 

34 

35class ButlerMock: 

36 """Mock version of butler, only usable for this test""" 

37 

38 def __init__(self): 

39 self.datasets = {} 

40 self.registry = SimpleNamespace(dimensions=DimensionUniverse()) 

41 

42 def getDirect(self, ref): 

43 dsdata = self.datasets.get(ref.datasetType.name) 

44 if dsdata: 

45 return dsdata.get(ref.dataId) 

46 return None 

47 

48 def put(self, inMemoryDataset, dsRef, producer=None): 

49 key = dsRef.dataId 

50 if isinstance(dsRef.datasetType, str): 

51 name = dsRef.datasetType 

52 else: 

53 name = dsRef.datasetType.name 

54 dsdata = self.datasets.setdefault(name, {}) 

55 dsdata[key] = inMemoryDataset 

56 

57 

58class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]): 

59 input = pipeBase.connectionTypes.Input( 

60 name="add_input", 

61 dimensions=["instrument", "visit", "detector", "physical_filter", "band"], 

62 storageClass="Catalog", 

63 doc="Input dataset type for this task", 

64 ) 

65 output = pipeBase.connectionTypes.Output( 

66 name="add_output", 

67 dimensions=["instrument", "visit", "detector", "physical_filter", "band"], 

68 storageClass="Catalog", 

69 doc="Output dataset type for this task", 

70 ) 

71 

72 

73class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections): 

74 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3) 

75 

76 

77# example task which overrides run() method 

78class AddTask(pipeBase.PipelineTask): 

79 ConfigClass = AddConfig 

80 _DefaultName = "add_task" 

81 

82 def run(self, input): 

83 self.metadata.add("add", self.config.addend) 

84 output = input + self.config.addend 

85 return pipeBase.Struct(output=output) 

86 

87 

88# example task which overrides adaptArgsAndRun() method 

89class AddTask2(pipeBase.PipelineTask): 

90 ConfigClass = AddConfig 

91 _DefaultName = "add_task" 

92 

93 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

94 self.metadata.add("add", self.config.addend) 

95 inputs = butlerQC.get(inputRefs) 

96 outputs = inputs["input"] + self.config.addend 

97 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs) 

98 

99 

100class PipelineTaskTestCase(unittest.TestCase): 

101 """A test case for PipelineTask""" 

102 

103 def _makeDSRefVisit(self, dstype, visitId, universe): 

104 return DatasetRef( 

105 datasetType=dstype, 

106 dataId=DataCoordinate.standardize( 

107 detector="X", 

108 visit=visitId, 

109 physical_filter="a", 

110 band="b", 

111 instrument="TestInstrument", 

112 universe=universe, 

113 ), 

114 ) 

115 

116 def _makeQuanta(self, config): 

117 """Create set of Quanta""" 

118 universe = DimensionUniverse() 

119 connections = config.connections.ConnectionsClass(config=config) 

120 

121 dstype0 = connections.input.makeDatasetType(universe) 

122 dstype1 = connections.output.makeDatasetType(universe) 

123 

124 quanta = [] 

125 for visit in range(100): 

126 inputRef = self._makeDSRefVisit(dstype0, visit, universe) 

127 outputRef = self._makeDSRefVisit(dstype1, visit, universe) 

128 quantum = Quantum( 

129 inputs={inputRef.datasetType: [inputRef]}, outputs={outputRef.datasetType: [outputRef]} 

130 ) 

131 quanta.append(quantum) 

132 

133 return quanta 

134 

135 def testRunQuantum(self): 

136 """Test for AddTask.runQuantum() implementation.""" 

137 butler = ButlerMock() 

138 task = AddTask(config=AddConfig()) 

139 connections = task.config.connections.ConnectionsClass(config=task.config) 

140 

141 # make all quanta 

142 quanta = self._makeQuanta(task.config) 

143 

144 # add input data to butler 

145 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions) 

146 for i, quantum in enumerate(quanta): 

147 ref = quantum.inputs[dstype0.name][0] 

148 butler.put(100 + i, pipeBase.Struct(datasetType=dstype0.name, dataId=ref.dataId)) 

149 

150 # run task on each quanta 

151 checked_get = False 

152 for quantum in quanta: 

153 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum) 

154 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

155 task.runQuantum(butlerQC, inputRefs, outputRefs) 

156 

157 # Test getting of datasets in different ways. 

158 # (only need to do this one time) 

159 if not checked_get: 

160 # Force the periodic logger to issue messages. 

161 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 0.0 

162 

163 checked_get = True 

164 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm: 

165 input_data = butlerQC.get(inputRefs) 

166 self.assertIn("Completed", cm.output[-1]) 

167 self.assertEqual(len(input_data), len(inputRefs)) 

168 

169 # In this test there are no multiples returned. 

170 refs = [ref for _, ref in inputRefs] 

171 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm: 

172 list_get = butlerQC.get(refs) 

173 

174 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 600.0 

175 

176 self.assertIn("Completed", cm.output[-1]) 

177 self.assertEqual(len(list_get), len(input_data)) 

178 self.assertIsInstance(list_get[0], int) 

179 scalar_get = butlerQC.get(refs[0]) 

180 self.assertEqual(scalar_get, list_get[0]) 

181 

182 with self.assertRaises(TypeError): 

183 butlerQC.get({}) 

184 

185 # Output ref won't be known to this quantum. 

186 outputs = [ref for _, ref in outputRefs] 

187 with self.assertRaises(ValueError): 

188 butlerQC.get(outputs[0]) 

189 

190 # look at the output produced by the task 

191 outputName = connections.output.name 

192 dsdata = butler.datasets[outputName] 

193 self.assertEqual(len(dsdata), len(quanta)) 

194 for i, quantum in enumerate(quanta): 

195 ref = quantum.outputs[outputName][0] 

196 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

197 

198 def testChain2(self): 

199 """Test for two-task chain.""" 

200 butler = ButlerMock() 

201 config1 = AddConfig() 

202 connections1 = config1.connections.ConnectionsClass(config=config1) 

203 task1 = AddTask(config=config1) 

204 config2 = AddConfig() 

205 config2.addend = 200 

206 config2.connections.input = task1.config.connections.output 

207 config2.connections.output = "add_output_2" 

208 task2 = AddTask2(config=config2) 

209 connections2 = config2.connections.ConnectionsClass(config=config2) 

210 

211 # make all quanta 

212 quanta1 = self._makeQuanta(task1.config) 

213 quanta2 = self._makeQuanta(task2.config) 

214 

215 # add input data to butler 

216 task1Connections = task1.config.connections.ConnectionsClass(config=task1.config) 

217 dstype0 = task1Connections.input.makeDatasetType(butler.registry.dimensions) 

218 for i, quantum in enumerate(quanta1): 

219 ref = quantum.inputs[dstype0.name][0] 

220 butler.put(100 + i, pipeBase.Struct(datasetType=dstype0.name, dataId=ref.dataId)) 

221 

222 # run task on each quanta 

223 for quantum in quanta1: 

224 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum) 

225 inputRefs, outputRefs = connections1.buildDatasetRefs(quantum) 

226 task1.runQuantum(butlerQC, inputRefs, outputRefs) 

227 for quantum in quanta2: 

228 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum) 

229 inputRefs, outputRefs = connections2.buildDatasetRefs(quantum) 

230 task2.runQuantum(butlerQC, inputRefs, outputRefs) 

231 

232 # look at the output produced by the task 

233 outputName = task1.config.connections.output 

234 dsdata = butler.datasets[outputName] 

235 self.assertEqual(len(dsdata), len(quanta1)) 

236 for i, quantum in enumerate(quanta1): 

237 ref = quantum.outputs[outputName][0] 

238 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

239 

240 outputName = task2.config.connections.output 

241 dsdata = butler.datasets[outputName] 

242 self.assertEqual(len(dsdata), len(quanta2)) 

243 for i, quantum in enumerate(quanta2): 

244 ref = quantum.outputs[outputName][0] 

245 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200) 

246 

247 

248class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

249 pass 

250 

251 

252def setup_module(module): 

253 lsst.utils.tests.init() 

254 

255 

256if __name__ == "__main__": 256 ↛ 257line 256 didn't jump to line 257, because the condition on line 256 was never true

257 lsst.utils.tests.init() 

258 unittest.main()