Coverage for tests/test_pipelineTask.py: 15%

223 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-11 03:31 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28"""Simple unit test for PipelineTask. 

29""" 

30 

31import copy 

32import pickle 

33import unittest 

34from typing import Any 

35 

36import astropy.units as u 

37import lsst.pex.config as pexConfig 

38import lsst.pipe.base as pipeBase 

39import lsst.utils.logging 

40import lsst.utils.tests 

41from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum 

42 

43 

44class ButlerMock: 

45 """Mock version of butler, only usable for this test.""" 

46 

47 def __init__(self) -> None: 

48 self.datasets: dict[str, dict[DataCoordinate, Any]] = {} 

49 self.dimensions = DimensionUniverse() 

50 

51 def get(self, ref: DatasetRef) -> Any: 

52 dsdata = self.datasets.get(ref.datasetType.name) 

53 if dsdata: 

54 return dsdata.get(ref.dataId) 

55 return None 

56 

57 def put(self, inMemoryDataset: Any, dsRef: DatasetRef, producer: Any = None): 

58 key = dsRef.dataId 

59 name = dsRef.datasetType.name 

60 dsdata = self.datasets.setdefault(name, {}) 

61 dsdata[key] = inMemoryDataset 

62 

63 

64class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]): 

65 """Connections for the AddTask.""" 

66 

67 input = pipeBase.connectionTypes.Input( 

68 name="add_input", 

69 dimensions=["instrument", "visit", "detector", "physical_filter", "band"], 

70 storageClass="Catalog", 

71 doc="Input dataset type for this task", 

72 ) 

73 output = pipeBase.connectionTypes.Output( 

74 name="add_output", 

75 dimensions=["instrument", "visit", "detector", "physical_filter", "band"], 

76 storageClass="Catalog", 

77 doc="Output dataset type for this task", 

78 ) 

79 

80 

81class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections): 

82 """Config for the AddTask.""" 

83 

84 addend = pexConfig.Field[int](doc="amount to add", default=3) 

85 

86 

87class AddTask(pipeBase.PipelineTask): 

88 """Example task which overrides run() method.""" 

89 

90 ConfigClass = AddConfig 

91 _DefaultName = "add_task" 

92 

93 def run(self, input: int) -> pipeBase.Struct: 

94 self.metadata.add("add", self.config.addend) 

95 output = input + self.config.addend 

96 return pipeBase.Struct(output=output) 

97 

98 

99class AddTask2(pipeBase.PipelineTask): 

100 """Example task which overrides runQuantum() method.""" 

101 

102 ConfigClass = AddConfig 

103 _DefaultName = "add_task_2" 

104 

105 def runQuantum( 

106 self, butlerQC: pipeBase.QuantumContext, inputRefs: DatasetRef, outputRefs: DatasetRef 

107 ) -> None: 

108 self.metadata.add("add", self.config.addend) 

109 inputs = butlerQC.get(inputRefs) 

110 outputs = inputs["input"] + self.config.addend 

111 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs) 

112 

113 

114class PipelineTaskTestCase(unittest.TestCase): 

115 """A test case for PipelineTask.""" 

116 

117 def _makeDSRefVisit(self, dstype: DatasetType, visitId: int) -> DatasetRef: 

118 dataId = DataCoordinate.standardize( 

119 detector="X", 

120 visit=visitId, 

121 physical_filter="a", 

122 band="b", 

123 instrument="TestInstrument", 

124 dimensions=dstype.dimensions, 

125 ) 

126 run = "test" 

127 ref = DatasetRef(datasetType=dstype, dataId=dataId, run=run) 

128 return ref 

129 

130 def _makeQuanta( 

131 self, 

132 task_node: pipeBase.pipeline_graph.TaskNode, 

133 pipeline_graph: pipeBase.PipelineGraph, 

134 nquanta: int = 100, 

135 ) -> list[Quantum]: 

136 """Create set of Quanta""" 

137 quanta = [] 

138 for visit in range(nquanta): 

139 quantum = Quantum( 

140 inputs={ 

141 pipeline_graph.dataset_types[edge.dataset_type_name].dataset_type: [ 

142 self._makeDSRefVisit( 

143 pipeline_graph.dataset_types[edge.dataset_type_name].dataset_type, visit 

144 ) 

145 ] 

146 for edge in task_node.inputs.values() 

147 }, 

148 outputs={ 

149 pipeline_graph.dataset_types[edge.dataset_type_name].dataset_type: [ 

150 self._makeDSRefVisit( 

151 pipeline_graph.dataset_types[edge.dataset_type_name].dataset_type, visit 

152 ) 

153 ] 

154 for edge in task_node.outputs.values() 

155 }, 

156 ) 

157 quanta.append(quantum) 

158 

159 return quanta 

160 

161 def testRunQuantumFull(self): 

162 """Test for AddTask.runQuantum() implementation with full butler.""" 

163 self._testRunQuantum(full_butler=True) 

164 

165 def testRunQuantumLimited(self): 

166 """Test for AddTask.runQuantum() implementation with limited butler.""" 

167 self._testRunQuantum(full_butler=False) 

168 

169 def _testRunQuantum(self, full_butler: bool) -> None: 

170 """Test for AddTask.runQuantum() implementation.""" 

171 butler = ButlerMock() 

172 task = AddTask(config=AddConfig()) 

173 input_name = task.config.connections.input 

174 output_name = task.config.connections.output 

175 

176 pipeline_graph = pipeBase.PipelineGraph() 

177 task_node = pipeline_graph.add_task(None, type(task), task.config) 

178 pipeline_graph.resolve(dimensions=butler.dimensions, dataset_types={}) 

179 

180 # make all quanta 

181 quanta = self._makeQuanta(task_node, pipeline_graph) 

182 

183 # add input data to butler 

184 dstype0 = pipeline_graph.dataset_types[input_name].dataset_type 

185 for i, quantum in enumerate(quanta): 

186 ref = quantum.inputs[dstype0.name][0] 

187 butler.put(100 + i, ref) 

188 

189 # run task on each quanta 

190 checked_get = False 

191 for quantum in quanta: 

192 butlerQC = pipeBase.QuantumContext(butler, quantum) 

193 inputRefs, outputRefs = task_node.get_connections().buildDatasetRefs(quantum) 

194 task.runQuantum(butlerQC, inputRefs, outputRefs) 

195 

196 # Test getting of datasets in different ways. 

197 # (only need to do this one time) 

198 if not checked_get: 

199 # Force the periodic logger to issue messages. 

200 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 0.0 

201 

202 checked_get = True 

203 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm: 

204 input_data = butlerQC.get(inputRefs) 

205 self.assertIn("Completed", cm.output[-1]) 

206 self.assertEqual(len(input_data), len(inputRefs)) 

207 

208 # In this test there are no multiples returned. 

209 refs = [ref for _, ref in inputRefs] 

210 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm: 

211 list_get = butlerQC.get(refs) 

212 

213 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 600.0 

214 

215 self.assertIn("Completed", cm.output[-1]) 

216 self.assertEqual(len(list_get), len(input_data)) 

217 self.assertIsInstance(list_get[0], int) 

218 scalar_get = butlerQC.get(refs[0]) 

219 self.assertEqual(scalar_get, list_get[0]) 

220 

221 with self.assertRaises(TypeError): 

222 butlerQC.get({}) 

223 

224 # Output ref won't be known to this quantum. 

225 outputs = [ref for _, ref in outputRefs] 

226 with self.assertRaises(ValueError): 

227 butlerQC.get(outputs[0]) 

228 

229 # look at the output produced by the task 

230 dsdata = butler.datasets[output_name] 

231 self.assertEqual(len(dsdata), len(quanta)) 

232 for i, quantum in enumerate(quanta): 

233 ref = quantum.outputs[output_name][0] 

234 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

235 

236 def testChain2Full(self) -> None: 

237 """Test for two-task chain with full butler.""" 

238 self._testChain2(full_butler=True) 

239 

240 def testChain2Limited(self) -> None: 

241 """Test for two-task chain with limited butler.""" 

242 self._testChain2(full_butler=False) 

243 

244 def _testChain2(self, full_butler: bool) -> None: 

245 """Test for two-task chain.""" 

246 butler = ButlerMock() 

247 config1 = AddConfig() 

248 overall_input_name = config1.connections.input 

249 intermediate_name = config1.connections.output 

250 task1 = AddTask(config=config1) 

251 config2 = AddConfig() 

252 config2.addend = 200 

253 config2.connections.input = intermediate_name 

254 overall_output_name = "add_output_2" 

255 config2.connections.output = overall_output_name 

256 task2 = AddTask2(config=config2) 

257 

258 pipeline_graph = pipeBase.PipelineGraph() 

259 task1_node = pipeline_graph.add_task(None, type(task1), task1.config) 

260 task2_node = pipeline_graph.add_task(None, type(task2), task2.config) 

261 pipeline_graph.resolve(dimensions=butler.dimensions, dataset_types={}) 

262 

263 # make all quanta 

264 quanta1 = self._makeQuanta(task1_node, pipeline_graph) 

265 quanta2 = self._makeQuanta(task2_node, pipeline_graph) 

266 

267 # add input data to butler 

268 dstype0 = pipeline_graph.dataset_types[overall_input_name].dataset_type 

269 for i, quantum in enumerate(quanta1): 

270 ref = quantum.inputs[dstype0.name][0] 

271 butler.put(100 + i, ref) 

272 

273 # run task on each quanta 

274 for quantum in quanta1: 

275 butlerQC = pipeBase.QuantumContext(butler, quantum) 

276 inputRefs, outputRefs = task1_node.get_connections().buildDatasetRefs(quantum) 

277 task1.runQuantum(butlerQC, inputRefs, outputRefs) 

278 for quantum in quanta2: 

279 butlerQC = pipeBase.QuantumContext(butler, quantum) 

280 inputRefs, outputRefs = task2_node.get_connections().buildDatasetRefs(quantum) 

281 task2.runQuantum(butlerQC, inputRefs, outputRefs) 

282 

283 # look at the output produced by the task 

284 dsdata = butler.datasets[intermediate_name] 

285 self.assertEqual(len(dsdata), len(quanta1)) 

286 for i, quantum in enumerate(quanta1): 

287 ref = quantum.outputs[intermediate_name][0] 

288 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

289 

290 dsdata = butler.datasets[overall_output_name] 

291 self.assertEqual(len(dsdata), len(quanta2)) 

292 for i, quantum in enumerate(quanta2): 

293 ref = quantum.outputs[overall_output_name][0] 

294 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200) 

295 

296 def testButlerQC(self): 

297 """Test for QuantumContext. Full and limited share get 

298 implementation so only full is tested. 

299 """ 

300 butler = ButlerMock() 

301 task = AddTask(config=AddConfig()) 

302 input_name = task.config.connections.input 

303 

304 pipeline_graph = pipeBase.PipelineGraph() 

305 task_node = pipeline_graph.add_task(None, type(task), config=task.config) 

306 pipeline_graph.resolve(dimensions=butler.dimensions, dataset_types={}) 

307 

308 # make one quantum 

309 (quantum,) = self._makeQuanta(task_node, pipeline_graph, 1) 

310 

311 # add input data to butler 

312 dstype0 = pipeline_graph.dataset_types[input_name].dataset_type 

313 ref = quantum.inputs[dstype0.name][0] 

314 butler.put(100, ref) 

315 

316 butlerQC = pipeBase.QuantumContext(butler, quantum) 

317 self.assertEqual(butlerQC.resources.num_cores, 1) 

318 self.assertIsNone(butlerQC.resources.max_mem) 

319 

320 # Pass ref as single argument or a list. 

321 obj = butlerQC.get(ref) 

322 self.assertEqual(obj, 100) 

323 obj = butlerQC.get([ref]) 

324 self.assertEqual(obj, [100]) 

325 

326 # Pass None instead of a ref. 

327 obj = butlerQC.get(None) 

328 self.assertIsNone(obj) 

329 obj = butlerQC.get([None]) 

330 self.assertEqual(obj, [None]) 

331 

332 # COmbine a ref and None. 

333 obj = butlerQC.get([ref, None]) 

334 self.assertEqual(obj, [100, None]) 

335 

336 # Use refs from a QuantizedConnection. 

337 inputRefs, outputRefs = task_node.get_connections().buildDatasetRefs(quantum) 

338 obj = butlerQC.get(inputRefs) 

339 self.assertEqual(obj, {"input": 100}) 

340 

341 # Add few None values to a QuantizedConnection. 

342 inputRefs.input = [None, ref] 

343 inputRefs.input2 = None 

344 obj = butlerQC.get(inputRefs) 

345 self.assertEqual(obj, {"input": [None, 100], "input2": None}) 

346 

347 # Set additional context. 

348 resources = pipeBase.ExecutionResources(num_cores=4, max_mem=5 * u.MB) 

349 butlerQC = pipeBase.QuantumContext(butler, quantum, resources=resources) 

350 self.assertEqual(butlerQC.resources.num_cores, 4) 

351 self.assertEqual(butlerQC.resources.max_mem, 5_000_000 * u.B) 

352 

353 resources = pipeBase.ExecutionResources(max_mem=5) 

354 butlerQC = pipeBase.QuantumContext(butler, quantum, resources=resources) 

355 self.assertEqual(butlerQC.resources.num_cores, 1) 

356 self.assertEqual(butlerQC.resources.max_mem, 5 * u.B) 

357 

358 def test_ExecutionResources(self): 

359 res = pipeBase.ExecutionResources() 

360 self.assertEqual(res.num_cores, 1) 

361 self.assertIsNone(res.max_mem) 

362 self.assertEqual(pickle.loads(pickle.dumps(res)), res) 

363 

364 res = pipeBase.ExecutionResources(num_cores=4, max_mem=1 * u.MiB) 

365 self.assertEqual(res.num_cores, 4) 

366 self.assertEqual(res.max_mem.value, 1024 * 1024) 

367 self.assertEqual(pickle.loads(pickle.dumps(res)), res) 

368 

369 res = pipeBase.ExecutionResources(max_mem=512) 

370 self.assertEqual(res.num_cores, 1) 

371 self.assertEqual(res.max_mem.value, 512) 

372 self.assertEqual(pickle.loads(pickle.dumps(res)), res) 

373 

374 res = pipeBase.ExecutionResources(max_mem="") 

375 self.assertIsNone(res.max_mem) 

376 

377 res = pipeBase.ExecutionResources(max_mem="32 KiB") 

378 self.assertEqual(res.num_cores, 1) 

379 self.assertEqual(res.max_mem.value, 32 * 1024) 

380 self.assertEqual(pickle.loads(pickle.dumps(res)), res) 

381 

382 self.assertIs(res, copy.deepcopy(res)) 

383 

384 with self.assertRaises(AttributeError): 

385 res.num_cores = 4 

386 

387 with self.assertRaises(u.UnitConversionError): 

388 pipeBase.ExecutionResources(max_mem=1 * u.m) 

389 with self.assertRaises(ValueError): 

390 pipeBase.ExecutionResources(num_cores=-32) 

391 with self.assertRaises(u.UnitConversionError): 

392 pipeBase.ExecutionResources(max_mem=1 * u.m) 

393 with self.assertRaises(u.UnitConversionError): 

394 pipeBase.ExecutionResources(max_mem=1, default_mem_units=u.m) 

395 with self.assertRaises(u.UnitConversionError): 

396 pipeBase.ExecutionResources(max_mem="32 Pa") 

397 

398 

399class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

400 """Run file leak tests.""" 

401 

402 

403def setup_module(module): 

404 """Configure pytest.""" 

405 lsst.utils.tests.init() 

406 

407 

408if __name__ == "__main__": 

409 lsst.utils.tests.init() 

410 unittest.main()