Coverage for tests/test_pipelineTask.py: 19%

192 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-01-18 01:59 -0800

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Simple unit test for PipelineTask. 

23""" 

24 

25import unittest 

26from types import SimpleNamespace 

27from typing import Any 

28 

29import lsst.pex.config as pexConfig 

30import lsst.pipe.base as pipeBase 

31import lsst.utils.logging 

32import lsst.utils.tests 

33from lsst.daf.butler import ( 

34 DataCoordinate, 

35 DatasetIdFactory, 

36 DatasetRef, 

37 DatasetType, 

38 DimensionUniverse, 

39 Quantum, 

40) 

41 

42 

43class ButlerMock: 

44 """Mock version of butler, only usable for this test""" 

45 

46 def __init__(self) -> None: 

47 self.datasets: dict[str, dict[DataCoordinate, Any]] = {} 

48 self.registry = SimpleNamespace(dimensions=DimensionUniverse()) 

49 

50 def getDirect(self, ref: DatasetRef) -> Any: 

51 # getDirect requires resolved ref 

52 assert ref.id is not None 

53 dsdata = self.datasets.get(ref.datasetType.name) 

54 if dsdata: 

55 return dsdata.get(ref.dataId) 

56 return None 

57 

58 def put(self, inMemoryDataset: Any, dsRef: DatasetRef, producer: Any = None): 

59 # put requires unresolved ref 

60 assert dsRef.id is None 

61 key = dsRef.dataId 

62 name = dsRef.datasetType.name 

63 dsdata = self.datasets.setdefault(name, {}) 

64 dsdata[key] = inMemoryDataset 

65 

66 def putDirect(self, obj: Any, ref: DatasetRef): 

67 # putDirect requires resolved ref 

68 assert ref.id is not None 

69 self.put(obj, ref.unresolved()) 

70 

71 

72class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]): 

73 input = pipeBase.connectionTypes.Input( 

74 name="add_input", 

75 dimensions=["instrument", "visit", "detector", "physical_filter", "band"], 

76 storageClass="Catalog", 

77 doc="Input dataset type for this task", 

78 ) 

79 output = pipeBase.connectionTypes.Output( 

80 name="add_output", 

81 dimensions=["instrument", "visit", "detector", "physical_filter", "band"], 

82 storageClass="Catalog", 

83 doc="Output dataset type for this task", 

84 ) 

85 

86 

87class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections): 

88 addend = pexConfig.Field[int](doc="amount to add", default=3) 

89 

90 

91# example task which overrides run() method 

92class AddTask(pipeBase.PipelineTask): 

93 ConfigClass = AddConfig 

94 _DefaultName = "add_task" 

95 

96 def run(self, input: int) -> pipeBase.Struct: 

97 self.metadata.add("add", self.config.addend) 

98 output = input + self.config.addend 

99 return pipeBase.Struct(output=output) 

100 

101 

102# example task which overrides adaptArgsAndRun() method 

103class AddTask2(pipeBase.PipelineTask): 

104 ConfigClass = AddConfig 

105 _DefaultName = "add_task" 

106 

107 def runQuantum( 

108 self, butlerQC: pipeBase.ButlerQuantumContext, inputRefs: DatasetRef, outputRefs: DatasetRef 

109 ) -> None: 

110 self.metadata.add("add", self.config.addend) 

111 inputs = butlerQC.get(inputRefs) 

112 outputs = inputs["input"] + self.config.addend 

113 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs) 

114 

115 

116class PipelineTaskTestCase(unittest.TestCase): 

117 """A test case for PipelineTask""" 

118 

119 datasetIdFactory = DatasetIdFactory() 

120 

121 def _resolve_ref(self, ref: DatasetRef) -> DatasetRef: 

122 return self.datasetIdFactory.resolveRef(ref, "test") 

123 

124 def _makeDSRefVisit( 

125 self, dstype: DatasetType, visitId: int, universe: DimensionUniverse, resolve: bool = False 

126 ) -> DatasetRef: 

127 ref = DatasetRef( 

128 datasetType=dstype, 

129 dataId=DataCoordinate.standardize( 

130 detector="X", 

131 visit=visitId, 

132 physical_filter="a", 

133 band="b", 

134 instrument="TestInstrument", 

135 universe=universe, 

136 ), 

137 ) 

138 if resolve: 

139 ref = self._resolve_ref(ref) 

140 return ref 

141 

142 def _makeQuanta( 

143 self, config: pipeBase.PipelineTaskConfig, nquanta: int = 100, resolve_outputs: bool = False 

144 ) -> list[Quantum]: 

145 """Create set of Quanta""" 

146 universe = DimensionUniverse() 

147 connections = config.connections.ConnectionsClass(config=config) 

148 

149 dstype0 = connections.input.makeDatasetType(universe) 

150 dstype1 = connections.output.makeDatasetType(universe) 

151 

152 quanta = [] 

153 for visit in range(nquanta): 

154 inputRef = self._makeDSRefVisit(dstype0, visit, universe, resolve=True) 

155 outputRef = self._makeDSRefVisit(dstype1, visit, universe, resolve=resolve_outputs) 

156 quantum = Quantum( 

157 inputs={inputRef.datasetType: [inputRef]}, outputs={outputRef.datasetType: [outputRef]} 

158 ) 

159 quanta.append(quantum) 

160 

161 return quanta 

162 

163 def testRunQuantumFull(self): 

164 """Test for AddTask.runQuantum() implementation with full butler.""" 

165 self._testRunQuantum(full_butler=True) 

166 

167 def testRunQuantumLimited(self): 

168 """Test for AddTask.runQuantum() implementation with limited butler.""" 

169 self._testRunQuantum(full_butler=False) 

170 

171 def _testRunQuantum(self, full_butler: bool) -> None: 

172 """Test for AddTask.runQuantum() implementation.""" 

173 

174 butler = ButlerMock() 

175 task = AddTask(config=AddConfig()) 

176 connections = task.config.connections.ConnectionsClass(config=task.config) 

177 

178 # make all quanta 

179 quanta = self._makeQuanta(task.config, resolve_outputs=not full_butler) 

180 

181 # add input data to butler 

182 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions) 

183 for i, quantum in enumerate(quanta): 

184 ref = quantum.inputs[dstype0.name][0] 

185 butler.putDirect(100 + i, ref) 

186 

187 # run task on each quanta 

188 checked_get = False 

189 for quantum in quanta: 

190 if full_butler: 

191 butlerQC = pipeBase.ButlerQuantumContext.from_full(butler, quantum) 

192 else: 

193 butlerQC = pipeBase.ButlerQuantumContext.from_limited(butler, quantum) 

194 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

195 task.runQuantum(butlerQC, inputRefs, outputRefs) 

196 

197 # Test getting of datasets in different ways. 

198 # (only need to do this one time) 

199 if not checked_get: 

200 # Force the periodic logger to issue messages. 

201 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 0.0 

202 

203 checked_get = True 

204 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm: 

205 input_data = butlerQC.get(inputRefs) 

206 self.assertIn("Completed", cm.output[-1]) 

207 self.assertEqual(len(input_data), len(inputRefs)) 

208 

209 # In this test there are no multiples returned. 

210 refs = [ref for _, ref in inputRefs] 

211 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm: 

212 list_get = butlerQC.get(refs) 

213 

214 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 600.0 

215 

216 self.assertIn("Completed", cm.output[-1]) 

217 self.assertEqual(len(list_get), len(input_data)) 

218 self.assertIsInstance(list_get[0], int) 

219 scalar_get = butlerQC.get(refs[0]) 

220 self.assertEqual(scalar_get, list_get[0]) 

221 

222 with self.assertRaises(TypeError): 

223 butlerQC.get({}) 

224 

225 # Output ref won't be known to this quantum. 

226 outputs = [ref for _, ref in outputRefs] 

227 with self.assertRaises(ValueError): 

228 butlerQC.get(outputs[0]) 

229 

230 # look at the output produced by the task 

231 outputName = connections.output.name 

232 dsdata = butler.datasets[outputName] 

233 self.assertEqual(len(dsdata), len(quanta)) 

234 for i, quantum in enumerate(quanta): 

235 ref = quantum.outputs[outputName][0] 

236 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

237 

238 def testChain2Full(self) -> None: 

239 """Test for two-task chain with full butler.""" 

240 self._testChain2(full_butler=True) 

241 

242 def testChain2Limited(self) -> None: 

243 """Test for two-task chain with limited butler.""" 

244 self._testChain2(full_butler=True) 

245 

246 def _testChain2(self, full_butler: bool) -> None: 

247 """Test for two-task chain.""" 

248 butler = ButlerMock() 

249 config1 = AddConfig() 

250 connections1 = config1.connections.ConnectionsClass(config=config1) 

251 task1 = AddTask(config=config1) 

252 config2 = AddConfig() 

253 config2.addend = 200 

254 config2.connections.input = task1.config.connections.output 

255 config2.connections.output = "add_output_2" 

256 task2 = AddTask2(config=config2) 

257 connections2 = config2.connections.ConnectionsClass(config=config2) 

258 

259 # make all quanta 

260 quanta1 = self._makeQuanta(task1.config, resolve_outputs=not full_butler) 

261 quanta2 = self._makeQuanta(task2.config, resolve_outputs=not full_butler) 

262 

263 # add input data to butler 

264 task1Connections = task1.config.connections.ConnectionsClass(config=task1.config) 

265 dstype0 = task1Connections.input.makeDatasetType(butler.registry.dimensions) 

266 for i, quantum in enumerate(quanta1): 

267 ref = quantum.inputs[dstype0.name][0] 

268 butler.putDirect(100 + i, ref) 

269 

270 butler_qc_factory = ( 

271 pipeBase.ButlerQuantumContext.from_full 

272 if full_butler 

273 else pipeBase.ButlerQuantumContext.from_limited 

274 ) 

275 

276 # run task on each quanta 

277 for quantum in quanta1: 

278 butlerQC = butler_qc_factory(butler, quantum) 

279 inputRefs, outputRefs = connections1.buildDatasetRefs(quantum) 

280 task1.runQuantum(butlerQC, inputRefs, outputRefs) 

281 for quantum in quanta2: 

282 butlerQC = butler_qc_factory(butler, quantum) 

283 inputRefs, outputRefs = connections2.buildDatasetRefs(quantum) 

284 task2.runQuantum(butlerQC, inputRefs, outputRefs) 

285 

286 # look at the output produced by the task 

287 outputName = task1.config.connections.output 

288 dsdata = butler.datasets[outputName] 

289 self.assertEqual(len(dsdata), len(quanta1)) 

290 for i, quantum in enumerate(quanta1): 

291 ref = quantum.outputs[outputName][0] 

292 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

293 

294 outputName = task2.config.connections.output 

295 dsdata = butler.datasets[outputName] 

296 self.assertEqual(len(dsdata), len(quanta2)) 

297 for i, quantum in enumerate(quanta2): 

298 ref = quantum.outputs[outputName][0] 

299 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200) 

300 

301 def testButlerQC(self): 

302 """Test for ButlerQuantumContext. Full and limited share get 

303 implementation so only full is tested. 

304 """ 

305 

306 butler = ButlerMock() 

307 task = AddTask(config=AddConfig()) 

308 connections = task.config.connections.ConnectionsClass(config=task.config) 

309 

310 # make one quantum 

311 (quantum,) = self._makeQuanta(task.config, 1) 

312 

313 # add input data to butler 

314 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions) 

315 ref = quantum.inputs[dstype0.name][0] 

316 butler.putDirect(100, ref) 

317 

318 butlerQC = pipeBase.ButlerQuantumContext.from_full(butler, quantum) 

319 

320 # Pass ref as single argument or a list. 

321 obj = butlerQC.get(ref) 

322 self.assertEqual(obj, 100) 

323 obj = butlerQC.get([ref]) 

324 self.assertEqual(obj, [100]) 

325 

326 # Pass None instead of a ref. 

327 obj = butlerQC.get(None) 

328 self.assertIsNone(obj) 

329 obj = butlerQC.get([None]) 

330 self.assertEqual(obj, [None]) 

331 

332 # COmbine a ref and None. 

333 obj = butlerQC.get([ref, None]) 

334 self.assertEqual(obj, [100, None]) 

335 

336 # Use refs from a QuantizedConnection. 

337 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

338 obj = butlerQC.get(inputRefs) 

339 self.assertEqual(obj, {"input": 100}) 

340 

341 # Add few None values to a QuantizedConnection. 

342 inputRefs.input = [None, ref] 

343 inputRefs.input2 = None 

344 obj = butlerQC.get(inputRefs) 

345 self.assertEqual(obj, {"input": [None, 100], "input2": None}) 

346 

347 

348class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

349 pass 

350 

351 

352def setup_module(module): 

353 lsst.utils.tests.init() 

354 

355 

356if __name__ == "__main__": 356 ↛ 357line 356 didn't jump to line 357, because the condition on line 356 was never true

357 lsst.utils.tests.init() 

358 unittest.main()