Coverage for tests/test_pipelineTask.py: 19%

188 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-19 04:01 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Simple unit test for PipelineTask. 

23""" 

24 

25import unittest 

26from types import SimpleNamespace 

27from typing import Any 

28 

29import lsst.pex.config as pexConfig 

30import lsst.pipe.base as pipeBase 

31import lsst.utils.logging 

32import lsst.utils.tests 

33from lsst.daf.butler import ( 

34 DataCoordinate, 

35 DatasetIdFactory, 

36 DatasetRef, 

37 DatasetType, 

38 DimensionUniverse, 

39 Quantum, 

40) 

41 

42 

43class ButlerMock: 

44 """Mock version of butler, only usable for this test""" 

45 

46 def __init__(self) -> None: 

47 self.datasets: dict[str, dict[DataCoordinate, Any]] = {} 

48 self.registry = SimpleNamespace(dimensions=DimensionUniverse()) 

49 

50 def get(self, ref: DatasetRef) -> Any: 

51 # Requires resolved ref. 

52 assert ref.id is not None 

53 dsdata = self.datasets.get(ref.datasetType.name) 

54 if dsdata: 

55 return dsdata.get(ref.dataId) 

56 return None 

57 

58 def put(self, inMemoryDataset: Any, dsRef: DatasetRef, producer: Any = None): 

59 key = dsRef.dataId 

60 name = dsRef.datasetType.name 

61 dsdata = self.datasets.setdefault(name, {}) 

62 dsdata[key] = inMemoryDataset 

63 

64 

65class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]): 

66 input = pipeBase.connectionTypes.Input( 

67 name="add_input", 

68 dimensions=["instrument", "visit", "detector", "physical_filter", "band"], 

69 storageClass="Catalog", 

70 doc="Input dataset type for this task", 

71 ) 

72 output = pipeBase.connectionTypes.Output( 

73 name="add_output", 

74 dimensions=["instrument", "visit", "detector", "physical_filter", "band"], 

75 storageClass="Catalog", 

76 doc="Output dataset type for this task", 

77 ) 

78 

79 

80class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections): 

81 addend = pexConfig.Field[int](doc="amount to add", default=3) 

82 

83 

84# example task which overrides run() method 

85class AddTask(pipeBase.PipelineTask): 

86 ConfigClass = AddConfig 

87 _DefaultName = "add_task" 

88 

89 def run(self, input: int) -> pipeBase.Struct: 

90 self.metadata.add("add", self.config.addend) 

91 output = input + self.config.addend 

92 return pipeBase.Struct(output=output) 

93 

94 

95# example task which overrides adaptArgsAndRun() method 

96class AddTask2(pipeBase.PipelineTask): 

97 ConfigClass = AddConfig 

98 _DefaultName = "add_task" 

99 

100 def runQuantum( 

101 self, butlerQC: pipeBase.ButlerQuantumContext, inputRefs: DatasetRef, outputRefs: DatasetRef 

102 ) -> None: 

103 self.metadata.add("add", self.config.addend) 

104 inputs = butlerQC.get(inputRefs) 

105 outputs = inputs["input"] + self.config.addend 

106 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs) 

107 

108 

109class PipelineTaskTestCase(unittest.TestCase): 

110 """A test case for PipelineTask""" 

111 

112 datasetIdFactory = DatasetIdFactory() 

113 

114 def _resolve_ref(self, ref: DatasetRef) -> DatasetRef: 

115 return self.datasetIdFactory.resolveRef(ref, "test") 

116 

117 def _makeDSRefVisit( 

118 self, dstype: DatasetType, visitId: int, universe: DimensionUniverse, resolve: bool = False 

119 ) -> DatasetRef: 

120 ref = DatasetRef( 

121 datasetType=dstype, 

122 dataId=DataCoordinate.standardize( 

123 detector="X", 

124 visit=visitId, 

125 physical_filter="a", 

126 band="b", 

127 instrument="TestInstrument", 

128 universe=universe, 

129 ), 

130 ) 

131 if resolve: 

132 ref = self._resolve_ref(ref) 

133 return ref 

134 

135 def _makeQuanta( 

136 self, config: pipeBase.PipelineTaskConfig, nquanta: int = 100, resolve_outputs: bool = False 

137 ) -> list[Quantum]: 

138 """Create set of Quanta""" 

139 universe = DimensionUniverse() 

140 connections = config.connections.ConnectionsClass(config=config) 

141 

142 dstype0 = connections.input.makeDatasetType(universe) 

143 dstype1 = connections.output.makeDatasetType(universe) 

144 

145 quanta = [] 

146 for visit in range(nquanta): 

147 inputRef = self._makeDSRefVisit(dstype0, visit, universe, resolve=True) 

148 outputRef = self._makeDSRefVisit(dstype1, visit, universe, resolve=resolve_outputs) 

149 quantum = Quantum( 

150 inputs={inputRef.datasetType: [inputRef]}, outputs={outputRef.datasetType: [outputRef]} 

151 ) 

152 quanta.append(quantum) 

153 

154 return quanta 

155 

156 def testRunQuantumFull(self): 

157 """Test for AddTask.runQuantum() implementation with full butler.""" 

158 self._testRunQuantum(full_butler=True) 

159 

160 def testRunQuantumLimited(self): 

161 """Test for AddTask.runQuantum() implementation with limited butler.""" 

162 self._testRunQuantum(full_butler=False) 

163 

164 def _testRunQuantum(self, full_butler: bool) -> None: 

165 """Test for AddTask.runQuantum() implementation.""" 

166 

167 butler = ButlerMock() 

168 task = AddTask(config=AddConfig()) 

169 connections = task.config.connections.ConnectionsClass(config=task.config) 

170 

171 # make all quanta 

172 quanta = self._makeQuanta(task.config, resolve_outputs=not full_butler) 

173 

174 # add input data to butler 

175 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions) 

176 for i, quantum in enumerate(quanta): 

177 ref = quantum.inputs[dstype0.name][0] 

178 butler.put(100 + i, ref) 

179 

180 # run task on each quanta 

181 checked_get = False 

182 for quantum in quanta: 

183 if full_butler: 

184 butlerQC = pipeBase.ButlerQuantumContext.from_full(butler, quantum) 

185 else: 

186 butlerQC = pipeBase.ButlerQuantumContext.from_limited(butler, quantum) 

187 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

188 task.runQuantum(butlerQC, inputRefs, outputRefs) 

189 

190 # Test getting of datasets in different ways. 

191 # (only need to do this one time) 

192 if not checked_get: 

193 # Force the periodic logger to issue messages. 

194 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 0.0 

195 

196 checked_get = True 

197 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm: 

198 input_data = butlerQC.get(inputRefs) 

199 self.assertIn("Completed", cm.output[-1]) 

200 self.assertEqual(len(input_data), len(inputRefs)) 

201 

202 # In this test there are no multiples returned. 

203 refs = [ref for _, ref in inputRefs] 

204 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm: 

205 list_get = butlerQC.get(refs) 

206 

207 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 600.0 

208 

209 self.assertIn("Completed", cm.output[-1]) 

210 self.assertEqual(len(list_get), len(input_data)) 

211 self.assertIsInstance(list_get[0], int) 

212 scalar_get = butlerQC.get(refs[0]) 

213 self.assertEqual(scalar_get, list_get[0]) 

214 

215 with self.assertRaises(TypeError): 

216 butlerQC.get({}) 

217 

218 # Output ref won't be known to this quantum. 

219 outputs = [ref for _, ref in outputRefs] 

220 with self.assertRaises(ValueError): 

221 butlerQC.get(outputs[0]) 

222 

223 # look at the output produced by the task 

224 outputName = connections.output.name 

225 dsdata = butler.datasets[outputName] 

226 self.assertEqual(len(dsdata), len(quanta)) 

227 for i, quantum in enumerate(quanta): 

228 ref = quantum.outputs[outputName][0] 

229 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

230 

231 def testChain2Full(self) -> None: 

232 """Test for two-task chain with full butler.""" 

233 self._testChain2(full_butler=True) 

234 

235 def testChain2Limited(self) -> None: 

236 """Test for two-task chain with limited butler.""" 

237 self._testChain2(full_butler=True) 

238 

239 def _testChain2(self, full_butler: bool) -> None: 

240 """Test for two-task chain.""" 

241 butler = ButlerMock() 

242 config1 = AddConfig() 

243 connections1 = config1.connections.ConnectionsClass(config=config1) 

244 task1 = AddTask(config=config1) 

245 config2 = AddConfig() 

246 config2.addend = 200 

247 config2.connections.input = task1.config.connections.output 

248 config2.connections.output = "add_output_2" 

249 task2 = AddTask2(config=config2) 

250 connections2 = config2.connections.ConnectionsClass(config=config2) 

251 

252 # make all quanta 

253 quanta1 = self._makeQuanta(task1.config, resolve_outputs=not full_butler) 

254 quanta2 = self._makeQuanta(task2.config, resolve_outputs=not full_butler) 

255 

256 # add input data to butler 

257 task1Connections = task1.config.connections.ConnectionsClass(config=task1.config) 

258 dstype0 = task1Connections.input.makeDatasetType(butler.registry.dimensions) 

259 for i, quantum in enumerate(quanta1): 

260 ref = quantum.inputs[dstype0.name][0] 

261 butler.put(100 + i, ref) 

262 

263 butler_qc_factory = ( 

264 pipeBase.ButlerQuantumContext.from_full 

265 if full_butler 

266 else pipeBase.ButlerQuantumContext.from_limited 

267 ) 

268 

269 # run task on each quanta 

270 for quantum in quanta1: 

271 butlerQC = butler_qc_factory(butler, quantum) 

272 inputRefs, outputRefs = connections1.buildDatasetRefs(quantum) 

273 task1.runQuantum(butlerQC, inputRefs, outputRefs) 

274 for quantum in quanta2: 

275 butlerQC = butler_qc_factory(butler, quantum) 

276 inputRefs, outputRefs = connections2.buildDatasetRefs(quantum) 

277 task2.runQuantum(butlerQC, inputRefs, outputRefs) 

278 

279 # look at the output produced by the task 

280 outputName = task1.config.connections.output 

281 dsdata = butler.datasets[outputName] 

282 self.assertEqual(len(dsdata), len(quanta1)) 

283 for i, quantum in enumerate(quanta1): 

284 ref = quantum.outputs[outputName][0] 

285 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

286 

287 outputName = task2.config.connections.output 

288 dsdata = butler.datasets[outputName] 

289 self.assertEqual(len(dsdata), len(quanta2)) 

290 for i, quantum in enumerate(quanta2): 

291 ref = quantum.outputs[outputName][0] 

292 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200) 

293 

294 def testButlerQC(self): 

295 """Test for ButlerQuantumContext. Full and limited share get 

296 implementation so only full is tested. 

297 """ 

298 

299 butler = ButlerMock() 

300 task = AddTask(config=AddConfig()) 

301 connections = task.config.connections.ConnectionsClass(config=task.config) 

302 

303 # make one quantum 

304 (quantum,) = self._makeQuanta(task.config, 1) 

305 

306 # add input data to butler 

307 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions) 

308 ref = quantum.inputs[dstype0.name][0] 

309 butler.put(100, ref) 

310 

311 butlerQC = pipeBase.ButlerQuantumContext.from_full(butler, quantum) 

312 

313 # Pass ref as single argument or a list. 

314 obj = butlerQC.get(ref) 

315 self.assertEqual(obj, 100) 

316 obj = butlerQC.get([ref]) 

317 self.assertEqual(obj, [100]) 

318 

319 # Pass None instead of a ref. 

320 obj = butlerQC.get(None) 

321 self.assertIsNone(obj) 

322 obj = butlerQC.get([None]) 

323 self.assertEqual(obj, [None]) 

324 

325 # COmbine a ref and None. 

326 obj = butlerQC.get([ref, None]) 

327 self.assertEqual(obj, [100, None]) 

328 

329 # Use refs from a QuantizedConnection. 

330 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

331 obj = butlerQC.get(inputRefs) 

332 self.assertEqual(obj, {"input": 100}) 

333 

334 # Add few None values to a QuantizedConnection. 

335 inputRefs.input = [None, ref] 

336 inputRefs.input2 = None 

337 obj = butlerQC.get(inputRefs) 

338 self.assertEqual(obj, {"input": [None, 100], "input2": None}) 

339 

340 

341class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

342 pass 

343 

344 

345def setup_module(module): 

346 lsst.utils.tests.init() 

347 

348 

349if __name__ == "__main__": 349 ↛ 350line 349 didn't jump to line 350, because the condition on line 349 was never true

350 lsst.utils.tests.init() 

351 unittest.main()