Coverage for tests/test_pipelineTask.py: 18%

190 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-27 02:47 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Simple unit test for PipelineTask. 

23""" 

24 

25import unittest 

26import warnings 

27from types import SimpleNamespace 

28from typing import Any 

29 

30import lsst.pex.config as pexConfig 

31import lsst.pipe.base as pipeBase 

32import lsst.utils.logging 

33import lsst.utils.tests 

34from lsst.daf.butler import ( 

35 DataCoordinate, 

36 DatasetRef, 

37 DatasetType, 

38 DimensionUniverse, 

39 Quantum, 

40 UnresolvedRefWarning, 

41) 

42 

43 

44class ButlerMock: 

45 """Mock version of butler, only usable for this test""" 

46 

47 def __init__(self) -> None: 

48 self.datasets: dict[str, dict[DataCoordinate, Any]] = {} 

49 self.registry = SimpleNamespace(dimensions=DimensionUniverse()) 

50 

51 def get(self, ref: DatasetRef) -> Any: 

52 # Requires resolved ref. 

53 assert ref.id is not None 

54 dsdata = self.datasets.get(ref.datasetType.name) 

55 if dsdata: 

56 return dsdata.get(ref.dataId) 

57 return None 

58 

59 def put(self, inMemoryDataset: Any, dsRef: DatasetRef, producer: Any = None): 

60 key = dsRef.dataId 

61 name = dsRef.datasetType.name 

62 dsdata = self.datasets.setdefault(name, {}) 

63 dsdata[key] = inMemoryDataset 

64 

65 

66class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]): 

67 input = pipeBase.connectionTypes.Input( 

68 name="add_input", 

69 dimensions=["instrument", "visit", "detector", "physical_filter", "band"], 

70 storageClass="Catalog", 

71 doc="Input dataset type for this task", 

72 ) 

73 output = pipeBase.connectionTypes.Output( 

74 name="add_output", 

75 dimensions=["instrument", "visit", "detector", "physical_filter", "band"], 

76 storageClass="Catalog", 

77 doc="Output dataset type for this task", 

78 ) 

79 

80 

81class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections): 

82 addend = pexConfig.Field[int](doc="amount to add", default=3) 

83 

84 

85# example task which overrides run() method 

86class AddTask(pipeBase.PipelineTask): 

87 ConfigClass = AddConfig 

88 _DefaultName = "add_task" 

89 

90 def run(self, input: int) -> pipeBase.Struct: 

91 self.metadata.add("add", self.config.addend) 

92 output = input + self.config.addend 

93 return pipeBase.Struct(output=output) 

94 

95 

96# example task which overrides adaptArgsAndRun() method 

97class AddTask2(pipeBase.PipelineTask): 

98 ConfigClass = AddConfig 

99 _DefaultName = "add_task" 

100 

101 def runQuantum( 

102 self, butlerQC: pipeBase.ButlerQuantumContext, inputRefs: DatasetRef, outputRefs: DatasetRef 

103 ) -> None: 

104 self.metadata.add("add", self.config.addend) 

105 inputs = butlerQC.get(inputRefs) 

106 outputs = inputs["input"] + self.config.addend 

107 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs) 

108 

109 

110class PipelineTaskTestCase(unittest.TestCase): 

111 """A test case for PipelineTask""" 

112 

113 def _makeDSRefVisit( 

114 self, dstype: DatasetType, visitId: int, universe: DimensionUniverse, resolve: bool = False 

115 ) -> DatasetRef: 

116 dataId = DataCoordinate.standardize( 

117 detector="X", 

118 visit=visitId, 

119 physical_filter="a", 

120 band="b", 

121 instrument="TestInstrument", 

122 universe=universe, 

123 ) 

124 if resolve: 

125 run = "test" 

126 ref = DatasetRef(datasetType=dstype, dataId=dataId, run=run) 

127 else: 

128 with warnings.catch_warnings(): 

129 warnings.simplefilter("ignore", category=UnresolvedRefWarning) 

130 ref = DatasetRef(datasetType=dstype, dataId=dataId) 

131 return ref 

132 

133 def _makeQuanta( 

134 self, config: pipeBase.PipelineTaskConfig, nquanta: int = 100, resolve_outputs: bool = False 

135 ) -> list[Quantum]: 

136 """Create set of Quanta""" 

137 universe = DimensionUniverse() 

138 connections = config.connections.ConnectionsClass(config=config) 

139 

140 dstype0 = connections.input.makeDatasetType(universe) 

141 dstype1 = connections.output.makeDatasetType(universe) 

142 

143 quanta = [] 

144 for visit in range(nquanta): 

145 inputRef = self._makeDSRefVisit(dstype0, visit, universe, resolve=True) 

146 outputRef = self._makeDSRefVisit(dstype1, visit, universe, resolve=resolve_outputs) 

147 quantum = Quantum( 

148 inputs={inputRef.datasetType: [inputRef]}, outputs={outputRef.datasetType: [outputRef]} 

149 ) 

150 quanta.append(quantum) 

151 

152 return quanta 

153 

154 def testRunQuantumFull(self): 

155 """Test for AddTask.runQuantum() implementation with full butler.""" 

156 self._testRunQuantum(full_butler=True) 

157 

158 def testRunQuantumLimited(self): 

159 """Test for AddTask.runQuantum() implementation with limited butler.""" 

160 self._testRunQuantum(full_butler=False) 

161 

162 def _testRunQuantum(self, full_butler: bool) -> None: 

163 """Test for AddTask.runQuantum() implementation.""" 

164 

165 butler = ButlerMock() 

166 task = AddTask(config=AddConfig()) 

167 connections = task.config.connections.ConnectionsClass(config=task.config) 

168 

169 # make all quanta 

170 quanta = self._makeQuanta(task.config, resolve_outputs=not full_butler) 

171 

172 # add input data to butler 

173 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions) 

174 for i, quantum in enumerate(quanta): 

175 ref = quantum.inputs[dstype0.name][0] 

176 butler.put(100 + i, ref) 

177 

178 # run task on each quanta 

179 checked_get = False 

180 for quantum in quanta: 

181 if full_butler: 

182 butlerQC = pipeBase.ButlerQuantumContext.from_full(butler, quantum) 

183 else: 

184 butlerQC = pipeBase.ButlerQuantumContext.from_limited(butler, quantum) 

185 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

186 task.runQuantum(butlerQC, inputRefs, outputRefs) 

187 

188 # Test getting of datasets in different ways. 

189 # (only need to do this one time) 

190 if not checked_get: 

191 # Force the periodic logger to issue messages. 

192 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 0.0 

193 

194 checked_get = True 

195 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm: 

196 input_data = butlerQC.get(inputRefs) 

197 self.assertIn("Completed", cm.output[-1]) 

198 self.assertEqual(len(input_data), len(inputRefs)) 

199 

200 # In this test there are no multiples returned. 

201 refs = [ref for _, ref in inputRefs] 

202 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm: 

203 list_get = butlerQC.get(refs) 

204 

205 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 600.0 

206 

207 self.assertIn("Completed", cm.output[-1]) 

208 self.assertEqual(len(list_get), len(input_data)) 

209 self.assertIsInstance(list_get[0], int) 

210 scalar_get = butlerQC.get(refs[0]) 

211 self.assertEqual(scalar_get, list_get[0]) 

212 

213 with self.assertRaises(TypeError): 

214 butlerQC.get({}) 

215 

216 # Output ref won't be known to this quantum. 

217 outputs = [ref for _, ref in outputRefs] 

218 with self.assertRaises(ValueError): 

219 butlerQC.get(outputs[0]) 

220 

221 # look at the output produced by the task 

222 outputName = connections.output.name 

223 dsdata = butler.datasets[outputName] 

224 self.assertEqual(len(dsdata), len(quanta)) 

225 for i, quantum in enumerate(quanta): 

226 ref = quantum.outputs[outputName][0] 

227 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

228 

229 def testChain2Full(self) -> None: 

230 """Test for two-task chain with full butler.""" 

231 self._testChain2(full_butler=True) 

232 

233 def testChain2Limited(self) -> None: 

234 """Test for two-task chain with limited butler.""" 

235 self._testChain2(full_butler=True) 

236 

237 def _testChain2(self, full_butler: bool) -> None: 

238 """Test for two-task chain.""" 

239 butler = ButlerMock() 

240 config1 = AddConfig() 

241 connections1 = config1.connections.ConnectionsClass(config=config1) 

242 task1 = AddTask(config=config1) 

243 config2 = AddConfig() 

244 config2.addend = 200 

245 config2.connections.input = task1.config.connections.output 

246 config2.connections.output = "add_output_2" 

247 task2 = AddTask2(config=config2) 

248 connections2 = config2.connections.ConnectionsClass(config=config2) 

249 

250 # make all quanta 

251 quanta1 = self._makeQuanta(task1.config, resolve_outputs=not full_butler) 

252 quanta2 = self._makeQuanta(task2.config, resolve_outputs=not full_butler) 

253 

254 # add input data to butler 

255 task1Connections = task1.config.connections.ConnectionsClass(config=task1.config) 

256 dstype0 = task1Connections.input.makeDatasetType(butler.registry.dimensions) 

257 for i, quantum in enumerate(quanta1): 

258 ref = quantum.inputs[dstype0.name][0] 

259 butler.put(100 + i, ref) 

260 

261 butler_qc_factory = ( 

262 pipeBase.ButlerQuantumContext.from_full 

263 if full_butler 

264 else pipeBase.ButlerQuantumContext.from_limited 

265 ) 

266 

267 # run task on each quanta 

268 for quantum in quanta1: 

269 butlerQC = butler_qc_factory(butler, quantum) 

270 inputRefs, outputRefs = connections1.buildDatasetRefs(quantum) 

271 task1.runQuantum(butlerQC, inputRefs, outputRefs) 

272 for quantum in quanta2: 

273 butlerQC = butler_qc_factory(butler, quantum) 

274 inputRefs, outputRefs = connections2.buildDatasetRefs(quantum) 

275 task2.runQuantum(butlerQC, inputRefs, outputRefs) 

276 

277 # look at the output produced by the task 

278 outputName = task1.config.connections.output 

279 dsdata = butler.datasets[outputName] 

280 self.assertEqual(len(dsdata), len(quanta1)) 

281 for i, quantum in enumerate(quanta1): 

282 ref = quantum.outputs[outputName][0] 

283 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

284 

285 outputName = task2.config.connections.output 

286 dsdata = butler.datasets[outputName] 

287 self.assertEqual(len(dsdata), len(quanta2)) 

288 for i, quantum in enumerate(quanta2): 

289 ref = quantum.outputs[outputName][0] 

290 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200) 

291 

292 def testButlerQC(self): 

293 """Test for ButlerQuantumContext. Full and limited share get 

294 implementation so only full is tested. 

295 """ 

296 

297 butler = ButlerMock() 

298 task = AddTask(config=AddConfig()) 

299 connections = task.config.connections.ConnectionsClass(config=task.config) 

300 

301 # make one quantum 

302 (quantum,) = self._makeQuanta(task.config, 1) 

303 

304 # add input data to butler 

305 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions) 

306 ref = quantum.inputs[dstype0.name][0] 

307 butler.put(100, ref) 

308 

309 butlerQC = pipeBase.ButlerQuantumContext.from_full(butler, quantum) 

310 

311 # Pass ref as single argument or a list. 

312 obj = butlerQC.get(ref) 

313 self.assertEqual(obj, 100) 

314 obj = butlerQC.get([ref]) 

315 self.assertEqual(obj, [100]) 

316 

317 # Pass None instead of a ref. 

318 obj = butlerQC.get(None) 

319 self.assertIsNone(obj) 

320 obj = butlerQC.get([None]) 

321 self.assertEqual(obj, [None]) 

322 

323 # COmbine a ref and None. 

324 obj = butlerQC.get([ref, None]) 

325 self.assertEqual(obj, [100, None]) 

326 

327 # Use refs from a QuantizedConnection. 

328 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

329 obj = butlerQC.get(inputRefs) 

330 self.assertEqual(obj, {"input": 100}) 

331 

332 # Add few None values to a QuantizedConnection. 

333 inputRefs.input = [None, ref] 

334 inputRefs.input2 = None 

335 obj = butlerQC.get(inputRefs) 

336 self.assertEqual(obj, {"input": [None, 100], "input2": None}) 

337 

338 

339class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

340 pass 

341 

342 

343def setup_module(module): 

344 lsst.utils.tests.init() 

345 

346 

347if __name__ == "__main__": 347 ↛ 348line 347 didn't jump to line 348, because the condition on line 347 was never true

348 lsst.utils.tests.init() 

349 unittest.main()