Coverage for tests/test_pipelineTask.py: 18%

185 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-05 03:30 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Simple unit test for PipelineTask. 

23""" 

24 

25import unittest 

26from types import SimpleNamespace 

27from typing import Any 

28 

29import lsst.pex.config as pexConfig 

30import lsst.pipe.base as pipeBase 

31import lsst.utils.logging 

32import lsst.utils.tests 

33from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum 

34 

35 

36class ButlerMock: 

37 """Mock version of butler, only usable for this test""" 

38 

39 def __init__(self) -> None: 

40 self.datasets: dict[str, dict[DataCoordinate, Any]] = {} 

41 self.registry = SimpleNamespace(dimensions=DimensionUniverse()) 

42 

43 def get(self, ref: DatasetRef) -> Any: 

44 # Requires resolved ref. 

45 assert ref.id is not None 

46 dsdata = self.datasets.get(ref.datasetType.name) 

47 if dsdata: 

48 return dsdata.get(ref.dataId) 

49 return None 

50 

51 def put(self, inMemoryDataset: Any, dsRef: DatasetRef, producer: Any = None): 

52 key = dsRef.dataId 

53 name = dsRef.datasetType.name 

54 dsdata = self.datasets.setdefault(name, {}) 

55 dsdata[key] = inMemoryDataset 

56 

57 

58class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]): 

59 input = pipeBase.connectionTypes.Input( 

60 name="add_input", 

61 dimensions=["instrument", "visit", "detector", "physical_filter", "band"], 

62 storageClass="Catalog", 

63 doc="Input dataset type for this task", 

64 ) 

65 output = pipeBase.connectionTypes.Output( 

66 name="add_output", 

67 dimensions=["instrument", "visit", "detector", "physical_filter", "band"], 

68 storageClass="Catalog", 

69 doc="Output dataset type for this task", 

70 ) 

71 

72 

73class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections): 

74 addend = pexConfig.Field[int](doc="amount to add", default=3) 

75 

76 

77# example task which overrides run() method 

78class AddTask(pipeBase.PipelineTask): 

79 ConfigClass = AddConfig 

80 _DefaultName = "add_task" 

81 

82 def run(self, input: int) -> pipeBase.Struct: 

83 self.metadata.add("add", self.config.addend) 

84 output = input + self.config.addend 

85 return pipeBase.Struct(output=output) 

86 

87 

88# example task which overrides adaptArgsAndRun() method 

89class AddTask2(pipeBase.PipelineTask): 

90 ConfigClass = AddConfig 

91 _DefaultName = "add_task" 

92 

93 def runQuantum( 

94 self, butlerQC: pipeBase.ButlerQuantumContext, inputRefs: DatasetRef, outputRefs: DatasetRef 

95 ) -> None: 

96 self.metadata.add("add", self.config.addend) 

97 inputs = butlerQC.get(inputRefs) 

98 outputs = inputs["input"] + self.config.addend 

99 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs) 

100 

101 

102class PipelineTaskTestCase(unittest.TestCase): 

103 """A test case for PipelineTask""" 

104 

105 def _makeDSRefVisit(self, dstype: DatasetType, visitId: int, universe: DimensionUniverse) -> DatasetRef: 

106 dataId = DataCoordinate.standardize( 

107 detector="X", 

108 visit=visitId, 

109 physical_filter="a", 

110 band="b", 

111 instrument="TestInstrument", 

112 universe=universe, 

113 ) 

114 run = "test" 

115 ref = DatasetRef(datasetType=dstype, dataId=dataId, run=run) 

116 return ref 

117 

118 def _makeQuanta(self, config: pipeBase.PipelineTaskConfig, nquanta: int = 100) -> list[Quantum]: 

119 """Create set of Quanta""" 

120 universe = DimensionUniverse() 

121 connections = config.connections.ConnectionsClass(config=config) 

122 

123 dstype0 = connections.input.makeDatasetType(universe) 

124 dstype1 = connections.output.makeDatasetType(universe) 

125 

126 quanta = [] 

127 for visit in range(nquanta): 

128 inputRef = self._makeDSRefVisit(dstype0, visit, universe) 

129 outputRef = self._makeDSRefVisit(dstype1, visit, universe) 

130 quantum = Quantum( 

131 inputs={inputRef.datasetType: [inputRef]}, outputs={outputRef.datasetType: [outputRef]} 

132 ) 

133 quanta.append(quantum) 

134 

135 return quanta 

136 

137 def testRunQuantumFull(self): 

138 """Test for AddTask.runQuantum() implementation with full butler.""" 

139 self._testRunQuantum(full_butler=True) 

140 

141 def testRunQuantumLimited(self): 

142 """Test for AddTask.runQuantum() implementation with limited butler.""" 

143 self._testRunQuantum(full_butler=False) 

144 

145 def _testRunQuantum(self, full_butler: bool) -> None: 

146 """Test for AddTask.runQuantum() implementation.""" 

147 

148 butler = ButlerMock() 

149 task = AddTask(config=AddConfig()) 

150 connections = task.config.connections.ConnectionsClass(config=task.config) 

151 

152 # make all quanta 

153 quanta = self._makeQuanta(task.config) 

154 

155 # add input data to butler 

156 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions) 

157 for i, quantum in enumerate(quanta): 

158 ref = quantum.inputs[dstype0.name][0] 

159 butler.put(100 + i, ref) 

160 

161 # run task on each quanta 

162 checked_get = False 

163 for quantum in quanta: 

164 if full_butler: 

165 butlerQC = pipeBase.ButlerQuantumContext.from_full(butler, quantum) 

166 else: 

167 butlerQC = pipeBase.ButlerQuantumContext.from_limited(butler, quantum) 

168 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

169 task.runQuantum(butlerQC, inputRefs, outputRefs) 

170 

171 # Test getting of datasets in different ways. 

172 # (only need to do this one time) 

173 if not checked_get: 

174 # Force the periodic logger to issue messages. 

175 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 0.0 

176 

177 checked_get = True 

178 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm: 

179 input_data = butlerQC.get(inputRefs) 

180 self.assertIn("Completed", cm.output[-1]) 

181 self.assertEqual(len(input_data), len(inputRefs)) 

182 

183 # In this test there are no multiples returned. 

184 refs = [ref for _, ref in inputRefs] 

185 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm: 

186 list_get = butlerQC.get(refs) 

187 

188 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 600.0 

189 

190 self.assertIn("Completed", cm.output[-1]) 

191 self.assertEqual(len(list_get), len(input_data)) 

192 self.assertIsInstance(list_get[0], int) 

193 scalar_get = butlerQC.get(refs[0]) 

194 self.assertEqual(scalar_get, list_get[0]) 

195 

196 with self.assertRaises(TypeError): 

197 butlerQC.get({}) 

198 

199 # Output ref won't be known to this quantum. 

200 outputs = [ref for _, ref in outputRefs] 

201 with self.assertRaises(ValueError): 

202 butlerQC.get(outputs[0]) 

203 

204 # look at the output produced by the task 

205 outputName = connections.output.name 

206 dsdata = butler.datasets[outputName] 

207 self.assertEqual(len(dsdata), len(quanta)) 

208 for i, quantum in enumerate(quanta): 

209 ref = quantum.outputs[outputName][0] 

210 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

211 

212 def testChain2Full(self) -> None: 

213 """Test for two-task chain with full butler.""" 

214 self._testChain2(full_butler=True) 

215 

216 def testChain2Limited(self) -> None: 

217 """Test for two-task chain with limited butler.""" 

218 self._testChain2(full_butler=False) 

219 

220 def _testChain2(self, full_butler: bool) -> None: 

221 """Test for two-task chain.""" 

222 butler = ButlerMock() 

223 config1 = AddConfig() 

224 connections1 = config1.connections.ConnectionsClass(config=config1) 

225 task1 = AddTask(config=config1) 

226 config2 = AddConfig() 

227 config2.addend = 200 

228 config2.connections.input = task1.config.connections.output 

229 config2.connections.output = "add_output_2" 

230 task2 = AddTask2(config=config2) 

231 connections2 = config2.connections.ConnectionsClass(config=config2) 

232 

233 # make all quanta 

234 quanta1 = self._makeQuanta(task1.config) 

235 quanta2 = self._makeQuanta(task2.config) 

236 

237 # add input data to butler 

238 task1Connections = task1.config.connections.ConnectionsClass(config=task1.config) 

239 dstype0 = task1Connections.input.makeDatasetType(butler.registry.dimensions) 

240 for i, quantum in enumerate(quanta1): 

241 ref = quantum.inputs[dstype0.name][0] 

242 butler.put(100 + i, ref) 

243 

244 butler_qc_factory = ( 

245 pipeBase.ButlerQuantumContext.from_full 

246 if full_butler 

247 else pipeBase.ButlerQuantumContext.from_limited 

248 ) 

249 

250 # run task on each quanta 

251 for quantum in quanta1: 

252 butlerQC = butler_qc_factory(butler, quantum) 

253 inputRefs, outputRefs = connections1.buildDatasetRefs(quantum) 

254 task1.runQuantum(butlerQC, inputRefs, outputRefs) 

255 for quantum in quanta2: 

256 butlerQC = butler_qc_factory(butler, quantum) 

257 inputRefs, outputRefs = connections2.buildDatasetRefs(quantum) 

258 task2.runQuantum(butlerQC, inputRefs, outputRefs) 

259 

260 # look at the output produced by the task 

261 outputName = task1.config.connections.output 

262 dsdata = butler.datasets[outputName] 

263 self.assertEqual(len(dsdata), len(quanta1)) 

264 for i, quantum in enumerate(quanta1): 

265 ref = quantum.outputs[outputName][0] 

266 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

267 

268 outputName = task2.config.connections.output 

269 dsdata = butler.datasets[outputName] 

270 self.assertEqual(len(dsdata), len(quanta2)) 

271 for i, quantum in enumerate(quanta2): 

272 ref = quantum.outputs[outputName][0] 

273 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200) 

274 

275 def testButlerQC(self): 

276 """Test for ButlerQuantumContext. Full and limited share get 

277 implementation so only full is tested. 

278 """ 

279 

280 butler = ButlerMock() 

281 task = AddTask(config=AddConfig()) 

282 connections = task.config.connections.ConnectionsClass(config=task.config) 

283 

284 # make one quantum 

285 (quantum,) = self._makeQuanta(task.config, 1) 

286 

287 # add input data to butler 

288 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions) 

289 ref = quantum.inputs[dstype0.name][0] 

290 butler.put(100, ref) 

291 

292 butlerQC = pipeBase.ButlerQuantumContext.from_full(butler, quantum) 

293 

294 # Pass ref as single argument or a list. 

295 obj = butlerQC.get(ref) 

296 self.assertEqual(obj, 100) 

297 obj = butlerQC.get([ref]) 

298 self.assertEqual(obj, [100]) 

299 

300 # Pass None instead of a ref. 

301 obj = butlerQC.get(None) 

302 self.assertIsNone(obj) 

303 obj = butlerQC.get([None]) 

304 self.assertEqual(obj, [None]) 

305 

306 # COmbine a ref and None. 

307 obj = butlerQC.get([ref, None]) 

308 self.assertEqual(obj, [100, None]) 

309 

310 # Use refs from a QuantizedConnection. 

311 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

312 obj = butlerQC.get(inputRefs) 

313 self.assertEqual(obj, {"input": 100}) 

314 

315 # Add few None values to a QuantizedConnection. 

316 inputRefs.input = [None, ref] 

317 inputRefs.input2 = None 

318 obj = butlerQC.get(inputRefs) 

319 self.assertEqual(obj, {"input": [None, 100], "input2": None}) 

320 

321 

322class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

323 pass 

324 

325 

326def setup_module(module): 

327 lsst.utils.tests.init() 

328 

329 

330if __name__ == "__main__": 330 ↛ 331line 330 didn't jump to line 331, because the condition on line 330 was never true

331 lsst.utils.tests.init() 

332 unittest.main()