Coverage for tests/test_pipelineTask.py: 19%

181 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-06 02:51 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Simple unit test for PipelineTask. 

23""" 

24 

25import unittest 

26from types import SimpleNamespace 

27from typing import Any 

28 

29import lsst.pex.config as pexConfig 

30import lsst.pipe.base as pipeBase 

31import lsst.utils.logging 

32import lsst.utils.tests 

33from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum 

34 

35 

36class ButlerMock: 

37 """Mock version of butler, only usable for this test""" 

38 

39 def __init__(self) -> None: 

40 self.datasets: dict[str, dict[DataCoordinate, Any]] = {} 

41 self.registry = SimpleNamespace(dimensions=DimensionUniverse()) 

42 

43 def get(self, ref: DatasetRef) -> Any: 

44 dsdata = self.datasets.get(ref.datasetType.name) 

45 if dsdata: 

46 return dsdata.get(ref.dataId) 

47 return None 

48 

49 def put(self, inMemoryDataset: Any, dsRef: DatasetRef, producer: Any = None): 

50 key = dsRef.dataId 

51 name = dsRef.datasetType.name 

52 dsdata = self.datasets.setdefault(name, {}) 

53 dsdata[key] = inMemoryDataset 

54 

55 

56class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]): 

57 input = pipeBase.connectionTypes.Input( 

58 name="add_input", 

59 dimensions=["instrument", "visit", "detector", "physical_filter", "band"], 

60 storageClass="Catalog", 

61 doc="Input dataset type for this task", 

62 ) 

63 output = pipeBase.connectionTypes.Output( 

64 name="add_output", 

65 dimensions=["instrument", "visit", "detector", "physical_filter", "band"], 

66 storageClass="Catalog", 

67 doc="Output dataset type for this task", 

68 ) 

69 

70 

71class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections): 

72 addend = pexConfig.Field[int](doc="amount to add", default=3) 

73 

74 

75# example task which overrides run() method 

76class AddTask(pipeBase.PipelineTask): 

77 ConfigClass = AddConfig 

78 _DefaultName = "add_task" 

79 

80 def run(self, input: int) -> pipeBase.Struct: 

81 self.metadata.add("add", self.config.addend) 

82 output = input + self.config.addend 

83 return pipeBase.Struct(output=output) 

84 

85 

86# example task which overrides adaptArgsAndRun() method 

87class AddTask2(pipeBase.PipelineTask): 

88 ConfigClass = AddConfig 

89 _DefaultName = "add_task" 

90 

91 def runQuantum( 

92 self, butlerQC: pipeBase.ButlerQuantumContext, inputRefs: DatasetRef, outputRefs: DatasetRef 

93 ) -> None: 

94 self.metadata.add("add", self.config.addend) 

95 inputs = butlerQC.get(inputRefs) 

96 outputs = inputs["input"] + self.config.addend 

97 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs) 

98 

99 

100class PipelineTaskTestCase(unittest.TestCase): 

101 """A test case for PipelineTask""" 

102 

103 def _makeDSRefVisit(self, dstype: DatasetType, visitId: int, universe: DimensionUniverse) -> DatasetRef: 

104 dataId = DataCoordinate.standardize( 

105 detector="X", 

106 visit=visitId, 

107 physical_filter="a", 

108 band="b", 

109 instrument="TestInstrument", 

110 universe=universe, 

111 ) 

112 run = "test" 

113 ref = DatasetRef(datasetType=dstype, dataId=dataId, run=run) 

114 return ref 

115 

116 def _makeQuanta(self, config: pipeBase.PipelineTaskConfig, nquanta: int = 100) -> list[Quantum]: 

117 """Create set of Quanta""" 

118 universe = DimensionUniverse() 

119 connections = config.connections.ConnectionsClass(config=config) 

120 

121 dstype0 = connections.input.makeDatasetType(universe) 

122 dstype1 = connections.output.makeDatasetType(universe) 

123 

124 quanta = [] 

125 for visit in range(nquanta): 

126 inputRef = self._makeDSRefVisit(dstype0, visit, universe) 

127 outputRef = self._makeDSRefVisit(dstype1, visit, universe) 

128 quantum = Quantum( 

129 inputs={inputRef.datasetType: [inputRef]}, outputs={outputRef.datasetType: [outputRef]} 

130 ) 

131 quanta.append(quantum) 

132 

133 return quanta 

134 

135 def testRunQuantumFull(self): 

136 """Test for AddTask.runQuantum() implementation with full butler.""" 

137 self._testRunQuantum(full_butler=True) 

138 

139 def testRunQuantumLimited(self): 

140 """Test for AddTask.runQuantum() implementation with limited butler.""" 

141 self._testRunQuantum(full_butler=False) 

142 

143 def _testRunQuantum(self, full_butler: bool) -> None: 

144 """Test for AddTask.runQuantum() implementation.""" 

145 

146 butler = ButlerMock() 

147 task = AddTask(config=AddConfig()) 

148 connections = task.config.connections.ConnectionsClass(config=task.config) 

149 

150 # make all quanta 

151 quanta = self._makeQuanta(task.config) 

152 

153 # add input data to butler 

154 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions) 

155 for i, quantum in enumerate(quanta): 

156 ref = quantum.inputs[dstype0.name][0] 

157 butler.put(100 + i, ref) 

158 

159 # run task on each quanta 

160 checked_get = False 

161 for quantum in quanta: 

162 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum) 

163 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

164 task.runQuantum(butlerQC, inputRefs, outputRefs) 

165 

166 # Test getting of datasets in different ways. 

167 # (only need to do this one time) 

168 if not checked_get: 

169 # Force the periodic logger to issue messages. 

170 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 0.0 

171 

172 checked_get = True 

173 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm: 

174 input_data = butlerQC.get(inputRefs) 

175 self.assertIn("Completed", cm.output[-1]) 

176 self.assertEqual(len(input_data), len(inputRefs)) 

177 

178 # In this test there are no multiples returned. 

179 refs = [ref for _, ref in inputRefs] 

180 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm: 

181 list_get = butlerQC.get(refs) 

182 

183 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 600.0 

184 

185 self.assertIn("Completed", cm.output[-1]) 

186 self.assertEqual(len(list_get), len(input_data)) 

187 self.assertIsInstance(list_get[0], int) 

188 scalar_get = butlerQC.get(refs[0]) 

189 self.assertEqual(scalar_get, list_get[0]) 

190 

191 with self.assertRaises(TypeError): 

192 butlerQC.get({}) 

193 

194 # Output ref won't be known to this quantum. 

195 outputs = [ref for _, ref in outputRefs] 

196 with self.assertRaises(ValueError): 

197 butlerQC.get(outputs[0]) 

198 

199 # look at the output produced by the task 

200 outputName = connections.output.name 

201 dsdata = butler.datasets[outputName] 

202 self.assertEqual(len(dsdata), len(quanta)) 

203 for i, quantum in enumerate(quanta): 

204 ref = quantum.outputs[outputName][0] 

205 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

206 

207 def testChain2Full(self) -> None: 

208 """Test for two-task chain with full butler.""" 

209 self._testChain2(full_butler=True) 

210 

211 def testChain2Limited(self) -> None: 

212 """Test for two-task chain with limited butler.""" 

213 self._testChain2(full_butler=False) 

214 

215 def _testChain2(self, full_butler: bool) -> None: 

216 """Test for two-task chain.""" 

217 butler = ButlerMock() 

218 config1 = AddConfig() 

219 connections1 = config1.connections.ConnectionsClass(config=config1) 

220 task1 = AddTask(config=config1) 

221 config2 = AddConfig() 

222 config2.addend = 200 

223 config2.connections.input = task1.config.connections.output 

224 config2.connections.output = "add_output_2" 

225 task2 = AddTask2(config=config2) 

226 connections2 = config2.connections.ConnectionsClass(config=config2) 

227 

228 # make all quanta 

229 quanta1 = self._makeQuanta(task1.config) 

230 quanta2 = self._makeQuanta(task2.config) 

231 

232 # add input data to butler 

233 task1Connections = task1.config.connections.ConnectionsClass(config=task1.config) 

234 dstype0 = task1Connections.input.makeDatasetType(butler.registry.dimensions) 

235 for i, quantum in enumerate(quanta1): 

236 ref = quantum.inputs[dstype0.name][0] 

237 butler.put(100 + i, ref) 

238 

239 # run task on each quanta 

240 for quantum in quanta1: 

241 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum) 

242 inputRefs, outputRefs = connections1.buildDatasetRefs(quantum) 

243 task1.runQuantum(butlerQC, inputRefs, outputRefs) 

244 for quantum in quanta2: 

245 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum) 

246 inputRefs, outputRefs = connections2.buildDatasetRefs(quantum) 

247 task2.runQuantum(butlerQC, inputRefs, outputRefs) 

248 

249 # look at the output produced by the task 

250 outputName = task1.config.connections.output 

251 dsdata = butler.datasets[outputName] 

252 self.assertEqual(len(dsdata), len(quanta1)) 

253 for i, quantum in enumerate(quanta1): 

254 ref = quantum.outputs[outputName][0] 

255 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

256 

257 outputName = task2.config.connections.output 

258 dsdata = butler.datasets[outputName] 

259 self.assertEqual(len(dsdata), len(quanta2)) 

260 for i, quantum in enumerate(quanta2): 

261 ref = quantum.outputs[outputName][0] 

262 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200) 

263 

264 def testButlerQC(self): 

265 """Test for ButlerQuantumContext. Full and limited share get 

266 implementation so only full is tested. 

267 """ 

268 

269 butler = ButlerMock() 

270 task = AddTask(config=AddConfig()) 

271 connections = task.config.connections.ConnectionsClass(config=task.config) 

272 

273 # make one quantum 

274 (quantum,) = self._makeQuanta(task.config, 1) 

275 

276 # add input data to butler 

277 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions) 

278 ref = quantum.inputs[dstype0.name][0] 

279 butler.put(100, ref) 

280 

281 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum) 

282 

283 # Pass ref as single argument or a list. 

284 obj = butlerQC.get(ref) 

285 self.assertEqual(obj, 100) 

286 obj = butlerQC.get([ref]) 

287 self.assertEqual(obj, [100]) 

288 

289 # Pass None instead of a ref. 

290 obj = butlerQC.get(None) 

291 self.assertIsNone(obj) 

292 obj = butlerQC.get([None]) 

293 self.assertEqual(obj, [None]) 

294 

295 # COmbine a ref and None. 

296 obj = butlerQC.get([ref, None]) 

297 self.assertEqual(obj, [100, None]) 

298 

299 # Use refs from a QuantizedConnection. 

300 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

301 obj = butlerQC.get(inputRefs) 

302 self.assertEqual(obj, {"input": 100}) 

303 

304 # Add few None values to a QuantizedConnection. 

305 inputRefs.input = [None, ref] 

306 inputRefs.input2 = None 

307 obj = butlerQC.get(inputRefs) 

308 self.assertEqual(obj, {"input": [None, 100], "input2": None}) 

309 

310 

311class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

312 pass 

313 

314 

315def setup_module(module): 

316 lsst.utils.tests.init() 

317 

318 

319if __name__ == "__main__": 319 ↛ 320line 319 didn't jump to line 320, because the condition on line 319 was never true

320 lsst.utils.tests.init() 

321 unittest.main()