Coverage for tests/test_pipelineTask.py: 18%

177 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-11 02:00 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Simple unit test for PipelineTask. 

23""" 

24 

25import unittest 

26from typing import Any 

27 

28import lsst.pex.config as pexConfig 

29import lsst.pipe.base as pipeBase 

30import lsst.utils.logging 

31import lsst.utils.tests 

32from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum 

33 

34 

35class ButlerMock: 

36 """Mock version of butler, only usable for this test""" 

37 

38 def __init__(self) -> None: 

39 self.datasets: dict[str, dict[DataCoordinate, Any]] = {} 

40 self.dimensions = DimensionUniverse() 

41 

42 def get(self, ref: DatasetRef) -> Any: 

43 dsdata = self.datasets.get(ref.datasetType.name) 

44 if dsdata: 

45 return dsdata.get(ref.dataId) 

46 return None 

47 

48 def put(self, inMemoryDataset: Any, dsRef: DatasetRef, producer: Any = None): 

49 key = dsRef.dataId 

50 name = dsRef.datasetType.name 

51 dsdata = self.datasets.setdefault(name, {}) 

52 dsdata[key] = inMemoryDataset 

53 

54 

55class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]): 

56 input = pipeBase.connectionTypes.Input( 

57 name="add_input", 

58 dimensions=["instrument", "visit", "detector", "physical_filter", "band"], 

59 storageClass="Catalog", 

60 doc="Input dataset type for this task", 

61 ) 

62 output = pipeBase.connectionTypes.Output( 

63 name="add_output", 

64 dimensions=["instrument", "visit", "detector", "physical_filter", "band"], 

65 storageClass="Catalog", 

66 doc="Output dataset type for this task", 

67 ) 

68 

69 

70class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections): 

71 addend = pexConfig.Field[int](doc="amount to add", default=3) 

72 

73 

74# example task which overrides run() method 

75class AddTask(pipeBase.PipelineTask): 

76 ConfigClass = AddConfig 

77 _DefaultName = "add_task" 

78 

79 def run(self, input: int) -> pipeBase.Struct: 

80 self.metadata.add("add", self.config.addend) 

81 output = input + self.config.addend 

82 return pipeBase.Struct(output=output) 

83 

84 

85# example task which overrides adaptArgsAndRun() method 

86class AddTask2(pipeBase.PipelineTask): 

87 ConfigClass = AddConfig 

88 _DefaultName = "add_task" 

89 

90 def runQuantum( 

91 self, butlerQC: pipeBase.ButlerQuantumContext, inputRefs: DatasetRef, outputRefs: DatasetRef 

92 ) -> None: 

93 self.metadata.add("add", self.config.addend) 

94 inputs = butlerQC.get(inputRefs) 

95 outputs = inputs["input"] + self.config.addend 

96 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs) 

97 

98 

99class PipelineTaskTestCase(unittest.TestCase): 

100 """A test case for PipelineTask""" 

101 

102 def _makeDSRefVisit(self, dstype: DatasetType, visitId: int, universe: DimensionUniverse) -> DatasetRef: 

103 dataId = DataCoordinate.standardize( 

104 detector="X", 

105 visit=visitId, 

106 physical_filter="a", 

107 band="b", 

108 instrument="TestInstrument", 

109 universe=universe, 

110 ) 

111 run = "test" 

112 ref = DatasetRef(datasetType=dstype, dataId=dataId, run=run) 

113 return ref 

114 

115 def _makeQuanta(self, config: pipeBase.PipelineTaskConfig, nquanta: int = 100) -> list[Quantum]: 

116 """Create set of Quanta""" 

117 universe = DimensionUniverse() 

118 connections = config.connections.ConnectionsClass(config=config) 

119 

120 dstype0 = connections.input.makeDatasetType(universe) 

121 dstype1 = connections.output.makeDatasetType(universe) 

122 

123 quanta = [] 

124 for visit in range(nquanta): 

125 inputRef = self._makeDSRefVisit(dstype0, visit, universe) 

126 outputRef = self._makeDSRefVisit(dstype1, visit, universe) 

127 quantum = Quantum( 

128 inputs={inputRef.datasetType: [inputRef]}, outputs={outputRef.datasetType: [outputRef]} 

129 ) 

130 quanta.append(quantum) 

131 

132 return quanta 

133 

134 def testRunQuantumFull(self): 

135 """Test for AddTask.runQuantum() implementation with full butler.""" 

136 self._testRunQuantum(full_butler=True) 

137 

138 def testRunQuantumLimited(self): 

139 """Test for AddTask.runQuantum() implementation with limited butler.""" 

140 self._testRunQuantum(full_butler=False) 

141 

142 def _testRunQuantum(self, full_butler: bool) -> None: 

143 """Test for AddTask.runQuantum() implementation.""" 

144 

145 butler = ButlerMock() 

146 task = AddTask(config=AddConfig()) 

147 connections = task.config.connections.ConnectionsClass(config=task.config) 

148 

149 # make all quanta 

150 quanta = self._makeQuanta(task.config) 

151 

152 # add input data to butler 

153 dstype0 = connections.input.makeDatasetType(butler.dimensions) 

154 for i, quantum in enumerate(quanta): 

155 ref = quantum.inputs[dstype0.name][0] 

156 butler.put(100 + i, ref) 

157 

158 # run task on each quanta 

159 checked_get = False 

160 for quantum in quanta: 

161 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum) 

162 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

163 task.runQuantum(butlerQC, inputRefs, outputRefs) 

164 

165 # Test getting of datasets in different ways. 

166 # (only need to do this one time) 

167 if not checked_get: 

168 # Force the periodic logger to issue messages. 

169 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 0.0 

170 

171 checked_get = True 

172 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm: 

173 input_data = butlerQC.get(inputRefs) 

174 self.assertIn("Completed", cm.output[-1]) 

175 self.assertEqual(len(input_data), len(inputRefs)) 

176 

177 # In this test there are no multiples returned. 

178 refs = [ref for _, ref in inputRefs] 

179 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm: 

180 list_get = butlerQC.get(refs) 

181 

182 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 600.0 

183 

184 self.assertIn("Completed", cm.output[-1]) 

185 self.assertEqual(len(list_get), len(input_data)) 

186 self.assertIsInstance(list_get[0], int) 

187 scalar_get = butlerQC.get(refs[0]) 

188 self.assertEqual(scalar_get, list_get[0]) 

189 

190 with self.assertRaises(TypeError): 

191 butlerQC.get({}) 

192 

193 # Output ref won't be known to this quantum. 

194 outputs = [ref for _, ref in outputRefs] 

195 with self.assertRaises(ValueError): 

196 butlerQC.get(outputs[0]) 

197 

198 # look at the output produced by the task 

199 outputName = connections.output.name 

200 dsdata = butler.datasets[outputName] 

201 self.assertEqual(len(dsdata), len(quanta)) 

202 for i, quantum in enumerate(quanta): 

203 ref = quantum.outputs[outputName][0] 

204 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

205 

206 def testChain2Full(self) -> None: 

207 """Test for two-task chain with full butler.""" 

208 self._testChain2(full_butler=True) 

209 

210 def testChain2Limited(self) -> None: 

211 """Test for two-task chain with limited butler.""" 

212 self._testChain2(full_butler=False) 

213 

214 def _testChain2(self, full_butler: bool) -> None: 

215 """Test for two-task chain.""" 

216 butler = ButlerMock() 

217 config1 = AddConfig() 

218 connections1 = config1.connections.ConnectionsClass(config=config1) 

219 task1 = AddTask(config=config1) 

220 config2 = AddConfig() 

221 config2.addend = 200 

222 config2.connections.input = task1.config.connections.output 

223 config2.connections.output = "add_output_2" 

224 task2 = AddTask2(config=config2) 

225 connections2 = config2.connections.ConnectionsClass(config=config2) 

226 

227 # make all quanta 

228 quanta1 = self._makeQuanta(task1.config) 

229 quanta2 = self._makeQuanta(task2.config) 

230 

231 # add input data to butler 

232 task1Connections = task1.config.connections.ConnectionsClass(config=task1.config) 

233 dstype0 = task1Connections.input.makeDatasetType(butler.dimensions) 

234 for i, quantum in enumerate(quanta1): 

235 ref = quantum.inputs[dstype0.name][0] 

236 butler.put(100 + i, ref) 

237 

238 # run task on each quanta 

239 for quantum in quanta1: 

240 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum) 

241 inputRefs, outputRefs = connections1.buildDatasetRefs(quantum) 

242 task1.runQuantum(butlerQC, inputRefs, outputRefs) 

243 for quantum in quanta2: 

244 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum) 

245 inputRefs, outputRefs = connections2.buildDatasetRefs(quantum) 

246 task2.runQuantum(butlerQC, inputRefs, outputRefs) 

247 

248 # look at the output produced by the task 

249 outputName = task1.config.connections.output 

250 dsdata = butler.datasets[outputName] 

251 self.assertEqual(len(dsdata), len(quanta1)) 

252 for i, quantum in enumerate(quanta1): 

253 ref = quantum.outputs[outputName][0] 

254 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

255 

256 outputName = task2.config.connections.output 

257 dsdata = butler.datasets[outputName] 

258 self.assertEqual(len(dsdata), len(quanta2)) 

259 for i, quantum in enumerate(quanta2): 

260 ref = quantum.outputs[outputName][0] 

261 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200) 

262 

263 def testButlerQC(self): 

264 """Test for ButlerQuantumContext. Full and limited share get 

265 implementation so only full is tested. 

266 """ 

267 

268 butler = ButlerMock() 

269 task = AddTask(config=AddConfig()) 

270 connections = task.config.connections.ConnectionsClass(config=task.config) 

271 

272 # make one quantum 

273 (quantum,) = self._makeQuanta(task.config, 1) 

274 

275 # add input data to butler 

276 dstype0 = connections.input.makeDatasetType(butler.dimensions) 

277 ref = quantum.inputs[dstype0.name][0] 

278 butler.put(100, ref) 

279 

280 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum) 

281 

282 # Pass ref as single argument or a list. 

283 obj = butlerQC.get(ref) 

284 self.assertEqual(obj, 100) 

285 obj = butlerQC.get([ref]) 

286 self.assertEqual(obj, [100]) 

287 

288 # Pass None instead of a ref. 

289 obj = butlerQC.get(None) 

290 self.assertIsNone(obj) 

291 obj = butlerQC.get([None]) 

292 self.assertEqual(obj, [None]) 

293 

294 # COmbine a ref and None. 

295 obj = butlerQC.get([ref, None]) 

296 self.assertEqual(obj, [100, None]) 

297 

298 # Use refs from a QuantizedConnection. 

299 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

300 obj = butlerQC.get(inputRefs) 

301 self.assertEqual(obj, {"input": 100}) 

302 

303 # Add few None values to a QuantizedConnection. 

304 inputRefs.input = [None, ref] 

305 inputRefs.input2 = None 

306 obj = butlerQC.get(inputRefs) 

307 self.assertEqual(obj, {"input": [None, 100], "input2": None}) 

308 

309 

310class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

311 pass 

312 

313 

314def setup_module(module): 

315 lsst.utils.tests.init() 

316 

317 

318if __name__ == "__main__": 

319 lsst.utils.tests.init() 

320 unittest.main()