Coverage for tests / test_pipelineTask.py: 16%

222 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-30 08:49 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28"""Simple unit test for PipelineTask.""" 

29 

30import copy 

31import pickle 

32import unittest 

33from typing import Any 

34 

35import astropy.units as u 

36 

37import lsst.pex.config as pexConfig 

38import lsst.pipe.base as pipeBase 

39import lsst.utils.logging 

40import lsst.utils.tests 

41from lsst.daf.butler import ( 

42 DataCoordinate, 

43 DatasetProvenance, 

44 DatasetRef, 

45 DatasetType, 

46 DimensionUniverse, 

47 Quantum, 

48) 

49 

50 

51class ButlerMock: 

52 """Mock version of butler, only usable for this test.""" 

53 

54 def __init__(self) -> None: 

55 self.datasets: dict[str, dict[DataCoordinate, Any]] = {} 

56 self.dimensions = DimensionUniverse() 

57 

58 def get(self, ref: DatasetRef) -> Any: 

59 dsdata = self.datasets.get(ref.datasetType.name) 

60 if dsdata: 

61 return dsdata.get(ref.dataId) 

62 return None 

63 

64 def put( 

65 self, 

66 inMemoryDataset: Any, 

67 dsRef: DatasetRef, 

68 provenance: DatasetProvenance | None = None, 

69 producer: Any = None, 

70 ): 

71 key = dsRef.dataId 

72 name = dsRef.datasetType.name 

73 dsdata = self.datasets.setdefault(name, {}) 

74 dsdata[key] = inMemoryDataset 

75 

76 

77class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]): 

78 """Connections for the AddTask.""" 

79 

80 input = pipeBase.connectionTypes.Input( 

81 name="add_input", 

82 dimensions=["instrument", "visit", "detector", "physical_filter", "band"], 

83 storageClass="Catalog", 

84 doc="Input dataset type for this task", 

85 ) 

86 output = pipeBase.connectionTypes.Output( 

87 name="add_output", 

88 dimensions=["instrument", "visit", "detector", "physical_filter", "band"], 

89 storageClass="Catalog", 

90 doc="Output dataset type for this task", 

91 ) 

92 

93 

94class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections): 

95 """Config for the AddTask.""" 

96 

97 addend = pexConfig.Field[int](doc="amount to add", default=3) 

98 

99 

100class AddTask(pipeBase.PipelineTask): 

101 """Example task which overrides run() method.""" 

102 

103 ConfigClass = AddConfig 

104 _DefaultName = "add_task" 

105 

106 def run(self, input: int) -> pipeBase.Struct: 

107 self.metadata.add("add", self.config.addend) 

108 output = input + self.config.addend 

109 return pipeBase.Struct(output=output) 

110 

111 

112class AddTask2(pipeBase.PipelineTask): 

113 """Example task which overrides runQuantum() method.""" 

114 

115 ConfigClass = AddConfig 

116 _DefaultName = "add_task_2" 

117 

118 def runQuantum( 

119 self, butlerQC: pipeBase.QuantumContext, inputRefs: DatasetRef, outputRefs: DatasetRef 

120 ) -> None: 

121 self.metadata.add("add", self.config.addend) 

122 inputs = butlerQC.get(inputRefs) 

123 outputs = inputs["input"] + self.config.addend 

124 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs) 

125 

126 

127class PipelineTaskTestCase(unittest.TestCase): 

128 """A test case for PipelineTask.""" 

129 

130 def _makeDSRefVisit(self, dstype: DatasetType, visitId: int) -> DatasetRef: 

131 dataId = DataCoordinate.standardize( 

132 detector="X", 

133 visit=visitId, 

134 physical_filter="a", 

135 band="b", 

136 instrument="TestInstrument", 

137 dimensions=dstype.dimensions, 

138 ) 

139 run = "test" 

140 ref = DatasetRef(datasetType=dstype, dataId=dataId, run=run) 

141 return ref 

142 

143 def _makeQuanta( 

144 self, 

145 task_node: pipeBase.pipeline_graph.TaskNode, 

146 pipeline_graph: pipeBase.PipelineGraph, 

147 nquanta: int = 100, 

148 ) -> list[Quantum]: 

149 """Create set of Quanta""" 

150 quanta = [] 

151 for visit in range(nquanta): 

152 quantum = Quantum( 

153 inputs={ 

154 pipeline_graph.dataset_types[edge.dataset_type_name].dataset_type: [ 

155 self._makeDSRefVisit( 

156 pipeline_graph.dataset_types[edge.dataset_type_name].dataset_type, visit 

157 ) 

158 ] 

159 for edge in task_node.inputs.values() 

160 }, 

161 outputs={ 

162 pipeline_graph.dataset_types[edge.dataset_type_name].dataset_type: [ 

163 self._makeDSRefVisit( 

164 pipeline_graph.dataset_types[edge.dataset_type_name].dataset_type, visit 

165 ) 

166 ] 

167 for edge in task_node.outputs.values() 

168 }, 

169 ) 

170 quanta.append(quantum) 

171 

172 return quanta 

173 

174 def testRunQuantumFull(self): 

175 """Test for AddTask.runQuantum() implementation with full butler.""" 

176 self._testRunQuantum(full_butler=True) 

177 

178 def testRunQuantumLimited(self): 

179 """Test for AddTask.runQuantum() implementation with limited butler.""" 

180 self._testRunQuantum(full_butler=False) 

181 

182 def _testRunQuantum(self, full_butler: bool) -> None: 

183 """Test for AddTask.runQuantum() implementation.""" 

184 butler = ButlerMock() 

185 task = AddTask(config=AddConfig()) 

186 input_name = task.config.connections.input 

187 output_name = task.config.connections.output 

188 

189 pipeline_graph = pipeBase.PipelineGraph() 

190 task_node = pipeline_graph.add_task(None, type(task), task.config) 

191 pipeline_graph.resolve(dimensions=butler.dimensions, dataset_types={}) 

192 

193 # make all quanta 

194 quanta = self._makeQuanta(task_node, pipeline_graph) 

195 

196 # add input data to butler 

197 dstype0 = pipeline_graph.dataset_types[input_name].dataset_type 

198 for i, quantum in enumerate(quanta): 

199 ref = quantum.inputs[dstype0.name][0] 

200 butler.put(100 + i, ref) 

201 

202 # run task on each quanta 

203 checked_get = False 

204 for quantum in quanta: 

205 butlerQC = pipeBase.QuantumContext(butler, quantum) 

206 inputRefs, outputRefs = task_node.get_connections().buildDatasetRefs(quantum) 

207 task.runQuantum(butlerQC, inputRefs, outputRefs) 

208 

209 # Test getting of datasets in different ways. 

210 # (only need to do this one time) 

211 if not checked_get: 

212 # Force the periodic logger to issue messages. 

213 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 0.0 

214 

215 checked_get = True 

216 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm: 

217 input_data = butlerQC.get(inputRefs) 

218 self.assertIn("Completed", cm.output[-1]) 

219 self.assertEqual(len(input_data), len(inputRefs)) 

220 

221 # In this test there are no multiples returned. 

222 refs = [ref for _, ref in inputRefs] 

223 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm: 

224 list_get = butlerQC.get(refs) 

225 

226 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 600.0 

227 

228 self.assertIn("Completed", cm.output[-1]) 

229 self.assertEqual(len(list_get), len(input_data)) 

230 self.assertIsInstance(list_get[0], int) 

231 scalar_get = butlerQC.get(refs[0]) 

232 self.assertEqual(scalar_get, list_get[0]) 

233 

234 with self.assertRaises(TypeError): 

235 butlerQC.get({}) 

236 

237 # Output ref won't be known to this quantum. 

238 outputs = [ref for _, ref in outputRefs] 

239 with self.assertRaises(ValueError): 

240 butlerQC.get(outputs[0]) 

241 

242 # look at the output produced by the task 

243 dsdata = butler.datasets[output_name] 

244 self.assertEqual(len(dsdata), len(quanta)) 

245 for i, quantum in enumerate(quanta): 

246 ref = quantum.outputs[output_name][0] 

247 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

248 

249 def testChain2Full(self) -> None: 

250 """Test for two-task chain with full butler.""" 

251 self._testChain2(full_butler=True) 

252 

253 def testChain2Limited(self) -> None: 

254 """Test for two-task chain with limited butler.""" 

255 self._testChain2(full_butler=False) 

256 

257 def _testChain2(self, full_butler: bool) -> None: 

258 """Test for two-task chain.""" 

259 butler = ButlerMock() 

260 config1 = AddConfig() 

261 overall_input_name = config1.connections.input 

262 intermediate_name = config1.connections.output 

263 task1 = AddTask(config=config1) 

264 config2 = AddConfig() 

265 config2.addend = 200 

266 config2.connections.input = intermediate_name 

267 overall_output_name = "add_output_2" 

268 config2.connections.output = overall_output_name 

269 task2 = AddTask2(config=config2) 

270 

271 pipeline_graph = pipeBase.PipelineGraph() 

272 task1_node = pipeline_graph.add_task(None, type(task1), task1.config) 

273 task2_node = pipeline_graph.add_task(None, type(task2), task2.config) 

274 pipeline_graph.resolve(dimensions=butler.dimensions, dataset_types={}) 

275 

276 # make all quanta 

277 quanta1 = self._makeQuanta(task1_node, pipeline_graph) 

278 quanta2 = self._makeQuanta(task2_node, pipeline_graph) 

279 

280 # add input data to butler 

281 dstype0 = pipeline_graph.dataset_types[overall_input_name].dataset_type 

282 for i, quantum in enumerate(quanta1): 

283 ref = quantum.inputs[dstype0.name][0] 

284 butler.put(100 + i, ref) 

285 

286 # run task on each quanta 

287 for quantum in quanta1: 

288 butlerQC = pipeBase.QuantumContext(butler, quantum) 

289 inputRefs, outputRefs = task1_node.get_connections().buildDatasetRefs(quantum) 

290 task1.runQuantum(butlerQC, inputRefs, outputRefs) 

291 for quantum in quanta2: 

292 butlerQC = pipeBase.QuantumContext(butler, quantum) 

293 inputRefs, outputRefs = task2_node.get_connections().buildDatasetRefs(quantum) 

294 task2.runQuantum(butlerQC, inputRefs, outputRefs) 

295 

296 # look at the output produced by the task 

297 dsdata = butler.datasets[intermediate_name] 

298 self.assertEqual(len(dsdata), len(quanta1)) 

299 for i, quantum in enumerate(quanta1): 

300 ref = quantum.outputs[intermediate_name][0] 

301 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

302 

303 dsdata = butler.datasets[overall_output_name] 

304 self.assertEqual(len(dsdata), len(quanta2)) 

305 for i, quantum in enumerate(quanta2): 

306 ref = quantum.outputs[overall_output_name][0] 

307 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200) 

308 

309 def testButlerQC(self): 

310 """Test for QuantumContext. Full and limited share get 

311 implementation so only full is tested. 

312 """ 

313 butler = ButlerMock() 

314 task = AddTask(config=AddConfig()) 

315 input_name = task.config.connections.input 

316 

317 pipeline_graph = pipeBase.PipelineGraph() 

318 task_node = pipeline_graph.add_task(None, type(task), config=task.config) 

319 pipeline_graph.resolve(dimensions=butler.dimensions, dataset_types={}) 

320 

321 # make one quantum 

322 (quantum,) = self._makeQuanta(task_node, pipeline_graph, 1) 

323 

324 # add input data to butler 

325 dstype0 = pipeline_graph.dataset_types[input_name].dataset_type 

326 ref = quantum.inputs[dstype0.name][0] 

327 butler.put(100, ref) 

328 

329 butlerQC = pipeBase.QuantumContext(butler, quantum) 

330 self.assertEqual(butlerQC.resources.num_cores, 1) 

331 self.assertIsNone(butlerQC.resources.max_mem) 

332 

333 # Pass ref as single argument or a list. 

334 obj = butlerQC.get(ref) 

335 self.assertEqual(obj, 100) 

336 obj = butlerQC.get([ref]) 

337 self.assertEqual(obj, [100]) 

338 

339 # Pass None instead of a ref. 

340 obj = butlerQC.get(None) 

341 self.assertIsNone(obj) 

342 obj = butlerQC.get([None]) 

343 self.assertEqual(obj, [None]) 

344 

345 # COmbine a ref and None. 

346 obj = butlerQC.get([ref, None]) 

347 self.assertEqual(obj, [100, None]) 

348 

349 # Use refs from a QuantizedConnection. 

350 inputRefs, outputRefs = task_node.get_connections().buildDatasetRefs(quantum) 

351 obj = butlerQC.get(inputRefs) 

352 self.assertEqual(obj, {"input": 100}) 

353 

354 # Add few None values to a QuantizedConnection. 

355 inputRefs.input = [None, ref] 

356 inputRefs.input2 = None 

357 obj = butlerQC.get(inputRefs) 

358 self.assertEqual(obj, {"input": [None, 100], "input2": None}) 

359 

360 # Set additional context. 

361 resources = pipeBase.ExecutionResources(num_cores=4, max_mem=5 * u.MB) 

362 butlerQC = pipeBase.QuantumContext(butler, quantum, resources=resources) 

363 self.assertEqual(butlerQC.resources.num_cores, 4) 

364 self.assertEqual(butlerQC.resources.max_mem, 5_000_000 * u.B) 

365 

366 resources = pipeBase.ExecutionResources(max_mem=5) 

367 butlerQC = pipeBase.QuantumContext(butler, quantum, resources=resources) 

368 self.assertEqual(butlerQC.resources.num_cores, 1) 

369 self.assertEqual(butlerQC.resources.max_mem, 5 * u.B) 

370 

371 def test_ExecutionResources(self): 

372 res = pipeBase.ExecutionResources() 

373 self.assertEqual(res.num_cores, 1) 

374 self.assertIsNone(res.max_mem) 

375 self.assertEqual(pickle.loads(pickle.dumps(res)), res) 

376 

377 res = pipeBase.ExecutionResources(num_cores=4, max_mem=1 * u.MiB) 

378 self.assertEqual(res.num_cores, 4) 

379 self.assertEqual(res.max_mem.value, 1024 * 1024) 

380 self.assertEqual(pickle.loads(pickle.dumps(res)), res) 

381 

382 res = pipeBase.ExecutionResources(max_mem=512) 

383 self.assertEqual(res.num_cores, 1) 

384 self.assertEqual(res.max_mem.value, 512) 

385 self.assertEqual(pickle.loads(pickle.dumps(res)), res) 

386 

387 res = pipeBase.ExecutionResources(max_mem="") 

388 self.assertIsNone(res.max_mem) 

389 

390 res = pipeBase.ExecutionResources(max_mem="32 KiB") 

391 self.assertEqual(res.num_cores, 1) 

392 self.assertEqual(res.max_mem.value, 32 * 1024) 

393 self.assertEqual(pickle.loads(pickle.dumps(res)), res) 

394 

395 self.assertIs(res, copy.deepcopy(res)) 

396 

397 with self.assertRaises(AttributeError): 

398 res.num_cores = 4 

399 

400 with self.assertRaises(u.UnitConversionError): 

401 pipeBase.ExecutionResources(max_mem=1 * u.m) 

402 with self.assertRaises(ValueError): 

403 pipeBase.ExecutionResources(num_cores=-32) 

404 with self.assertRaises(u.UnitConversionError): 

405 pipeBase.ExecutionResources(max_mem=1 * u.m) 

406 with self.assertRaises(u.UnitConversionError): 

407 pipeBase.ExecutionResources(max_mem=1, default_mem_units=u.m) 

408 with self.assertRaises(u.UnitConversionError): 

409 pipeBase.ExecutionResources(max_mem="32 Pa") 

410 

411 

412class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

413 """Run file leak tests.""" 

414 

415 

416def setup_module(module): 

417 """Configure pytest.""" 

418 lsst.utils.tests.init() 

419 

420 

421if __name__ == "__main__": 

422 lsst.utils.tests.init() 

423 unittest.main()