Coverage for tests/test_pipelineTask.py: 15%

221 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-18 10:50 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28"""Simple unit test for PipelineTask. 

29""" 

30 

31import copy 

32import pickle 

33import unittest 

34from typing import Any 

35 

36import astropy.units as u 

37import lsst.pex.config as pexConfig 

38import lsst.pipe.base as pipeBase 

39import lsst.utils.logging 

40import lsst.utils.tests 

41from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum 

42 

43 

44class ButlerMock: 

45 """Mock version of butler, only usable for this test""" 

46 

47 def __init__(self) -> None: 

48 self.datasets: dict[str, dict[DataCoordinate, Any]] = {} 

49 self.dimensions = DimensionUniverse() 

50 

51 def get(self, ref: DatasetRef) -> Any: 

52 dsdata = self.datasets.get(ref.datasetType.name) 

53 if dsdata: 

54 return dsdata.get(ref.dataId) 

55 return None 

56 

57 def put(self, inMemoryDataset: Any, dsRef: DatasetRef, producer: Any = None): 

58 key = dsRef.dataId 

59 name = dsRef.datasetType.name 

60 dsdata = self.datasets.setdefault(name, {}) 

61 dsdata[key] = inMemoryDataset 

62 

63 

64class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]): 

65 """Connections for the AddTask.""" 

66 

67 input = pipeBase.connectionTypes.Input( 

68 name="add_input", 

69 dimensions=["instrument", "visit", "detector", "physical_filter", "band"], 

70 storageClass="Catalog", 

71 doc="Input dataset type for this task", 

72 ) 

73 output = pipeBase.connectionTypes.Output( 

74 name="add_output", 

75 dimensions=["instrument", "visit", "detector", "physical_filter", "band"], 

76 storageClass="Catalog", 

77 doc="Output dataset type for this task", 

78 ) 

79 

80 

81class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections): 

82 """Config for the AddTask.""" 

83 

84 addend = pexConfig.Field[int](doc="amount to add", default=3) 

85 

86 

87class AddTask(pipeBase.PipelineTask): 

88 """Example task which overrides run() method.""" 

89 

90 ConfigClass = AddConfig 

91 _DefaultName = "add_task" 

92 

93 def run(self, input: int) -> pipeBase.Struct: 

94 self.metadata.add("add", self.config.addend) 

95 output = input + self.config.addend 

96 return pipeBase.Struct(output=output) 

97 

98 

99class AddTask2(pipeBase.PipelineTask): 

100 """Example task which overrides runQuantum() method.""" 

101 

102 ConfigClass = AddConfig 

103 _DefaultName = "add_task" 

104 

105 def runQuantum( 

106 self, butlerQC: pipeBase.QuantumContext, inputRefs: DatasetRef, outputRefs: DatasetRef 

107 ) -> None: 

108 self.metadata.add("add", self.config.addend) 

109 inputs = butlerQC.get(inputRefs) 

110 outputs = inputs["input"] + self.config.addend 

111 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs) 

112 

113 

114class PipelineTaskTestCase(unittest.TestCase): 

115 """A test case for PipelineTask""" 

116 

117 def _makeDSRefVisit(self, dstype: DatasetType, visitId: int, universe: DimensionUniverse) -> DatasetRef: 

118 dataId = DataCoordinate.standardize( 

119 detector="X", 

120 visit=visitId, 

121 physical_filter="a", 

122 band="b", 

123 instrument="TestInstrument", 

124 universe=universe, 

125 ) 

126 run = "test" 

127 ref = DatasetRef(datasetType=dstype, dataId=dataId, run=run) 

128 return ref 

129 

130 def _makeQuanta(self, config: pipeBase.PipelineTaskConfig, nquanta: int = 100) -> list[Quantum]: 

131 """Create set of Quanta""" 

132 universe = DimensionUniverse() 

133 connections = config.connections.ConnectionsClass(config=config) 

134 

135 dstype0 = connections.input.makeDatasetType(universe) 

136 dstype1 = connections.output.makeDatasetType(universe) 

137 

138 quanta = [] 

139 for visit in range(nquanta): 

140 inputRef = self._makeDSRefVisit(dstype0, visit, universe) 

141 outputRef = self._makeDSRefVisit(dstype1, visit, universe) 

142 quantum = Quantum( 

143 inputs={inputRef.datasetType: [inputRef]}, outputs={outputRef.datasetType: [outputRef]} 

144 ) 

145 quanta.append(quantum) 

146 

147 return quanta 

148 

149 def testRunQuantumFull(self): 

150 """Test for AddTask.runQuantum() implementation with full butler.""" 

151 self._testRunQuantum(full_butler=True) 

152 

153 def testRunQuantumLimited(self): 

154 """Test for AddTask.runQuantum() implementation with limited butler.""" 

155 self._testRunQuantum(full_butler=False) 

156 

157 def _testRunQuantum(self, full_butler: bool) -> None: 

158 """Test for AddTask.runQuantum() implementation.""" 

159 butler = ButlerMock() 

160 task = AddTask(config=AddConfig()) 

161 connections = task.config.connections.ConnectionsClass(config=task.config) 

162 

163 # make all quanta 

164 quanta = self._makeQuanta(task.config) 

165 

166 # add input data to butler 

167 dstype0 = connections.input.makeDatasetType(butler.dimensions) 

168 for i, quantum in enumerate(quanta): 

169 ref = quantum.inputs[dstype0.name][0] 

170 butler.put(100 + i, ref) 

171 

172 # run task on each quanta 

173 checked_get = False 

174 for quantum in quanta: 

175 butlerQC = pipeBase.QuantumContext(butler, quantum) 

176 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

177 task.runQuantum(butlerQC, inputRefs, outputRefs) 

178 

179 # Test getting of datasets in different ways. 

180 # (only need to do this one time) 

181 if not checked_get: 

182 # Force the periodic logger to issue messages. 

183 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 0.0 

184 

185 checked_get = True 

186 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm: 

187 input_data = butlerQC.get(inputRefs) 

188 self.assertIn("Completed", cm.output[-1]) 

189 self.assertEqual(len(input_data), len(inputRefs)) 

190 

191 # In this test there are no multiples returned. 

192 refs = [ref for _, ref in inputRefs] 

193 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm: 

194 list_get = butlerQC.get(refs) 

195 

196 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 600.0 

197 

198 self.assertIn("Completed", cm.output[-1]) 

199 self.assertEqual(len(list_get), len(input_data)) 

200 self.assertIsInstance(list_get[0], int) 

201 scalar_get = butlerQC.get(refs[0]) 

202 self.assertEqual(scalar_get, list_get[0]) 

203 

204 with self.assertRaises(TypeError): 

205 butlerQC.get({}) 

206 

207 # Output ref won't be known to this quantum. 

208 outputs = [ref for _, ref in outputRefs] 

209 with self.assertRaises(ValueError): 

210 butlerQC.get(outputs[0]) 

211 

212 # look at the output produced by the task 

213 outputName = connections.output.name 

214 dsdata = butler.datasets[outputName] 

215 self.assertEqual(len(dsdata), len(quanta)) 

216 for i, quantum in enumerate(quanta): 

217 ref = quantum.outputs[outputName][0] 

218 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

219 

220 def testChain2Full(self) -> None: 

221 """Test for two-task chain with full butler.""" 

222 self._testChain2(full_butler=True) 

223 

224 def testChain2Limited(self) -> None: 

225 """Test for two-task chain with limited butler.""" 

226 self._testChain2(full_butler=False) 

227 

228 def _testChain2(self, full_butler: bool) -> None: 

229 """Test for two-task chain.""" 

230 butler = ButlerMock() 

231 config1 = AddConfig() 

232 connections1 = config1.connections.ConnectionsClass(config=config1) 

233 task1 = AddTask(config=config1) 

234 config2 = AddConfig() 

235 config2.addend = 200 

236 config2.connections.input = task1.config.connections.output 

237 config2.connections.output = "add_output_2" 

238 task2 = AddTask2(config=config2) 

239 connections2 = config2.connections.ConnectionsClass(config=config2) 

240 

241 # make all quanta 

242 quanta1 = self._makeQuanta(task1.config) 

243 quanta2 = self._makeQuanta(task2.config) 

244 

245 # add input data to butler 

246 task1Connections = task1.config.connections.ConnectionsClass(config=task1.config) 

247 dstype0 = task1Connections.input.makeDatasetType(butler.dimensions) 

248 for i, quantum in enumerate(quanta1): 

249 ref = quantum.inputs[dstype0.name][0] 

250 butler.put(100 + i, ref) 

251 

252 # run task on each quanta 

253 for quantum in quanta1: 

254 butlerQC = pipeBase.QuantumContext(butler, quantum) 

255 inputRefs, outputRefs = connections1.buildDatasetRefs(quantum) 

256 task1.runQuantum(butlerQC, inputRefs, outputRefs) 

257 for quantum in quanta2: 

258 butlerQC = pipeBase.QuantumContext(butler, quantum) 

259 inputRefs, outputRefs = connections2.buildDatasetRefs(quantum) 

260 task2.runQuantum(butlerQC, inputRefs, outputRefs) 

261 

262 # look at the output produced by the task 

263 outputName = task1.config.connections.output 

264 dsdata = butler.datasets[outputName] 

265 self.assertEqual(len(dsdata), len(quanta1)) 

266 for i, quantum in enumerate(quanta1): 

267 ref = quantum.outputs[outputName][0] 

268 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

269 

270 outputName = task2.config.connections.output 

271 dsdata = butler.datasets[outputName] 

272 self.assertEqual(len(dsdata), len(quanta2)) 

273 for i, quantum in enumerate(quanta2): 

274 ref = quantum.outputs[outputName][0] 

275 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200) 

276 

277 def testButlerQC(self): 

278 """Test for QuantumContext. Full and limited share get 

279 implementation so only full is tested. 

280 """ 

281 butler = ButlerMock() 

282 task = AddTask(config=AddConfig()) 

283 connections = task.config.connections.ConnectionsClass(config=task.config) 

284 

285 # make one quantum 

286 (quantum,) = self._makeQuanta(task.config, 1) 

287 

288 # add input data to butler 

289 dstype0 = connections.input.makeDatasetType(butler.dimensions) 

290 ref = quantum.inputs[dstype0.name][0] 

291 butler.put(100, ref) 

292 

293 butlerQC = pipeBase.QuantumContext(butler, quantum) 

294 self.assertEqual(butlerQC.resources.num_cores, 1) 

295 self.assertIsNone(butlerQC.resources.max_mem) 

296 

297 # Pass ref as single argument or a list. 

298 obj = butlerQC.get(ref) 

299 self.assertEqual(obj, 100) 

300 obj = butlerQC.get([ref]) 

301 self.assertEqual(obj, [100]) 

302 

303 # Pass None instead of a ref. 

304 obj = butlerQC.get(None) 

305 self.assertIsNone(obj) 

306 obj = butlerQC.get([None]) 

307 self.assertEqual(obj, [None]) 

308 

309 # COmbine a ref and None. 

310 obj = butlerQC.get([ref, None]) 

311 self.assertEqual(obj, [100, None]) 

312 

313 # Use refs from a QuantizedConnection. 

314 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

315 obj = butlerQC.get(inputRefs) 

316 self.assertEqual(obj, {"input": 100}) 

317 

318 # Add few None values to a QuantizedConnection. 

319 inputRefs.input = [None, ref] 

320 inputRefs.input2 = None 

321 obj = butlerQC.get(inputRefs) 

322 self.assertEqual(obj, {"input": [None, 100], "input2": None}) 

323 

324 # Set additional context. 

325 resources = pipeBase.ExecutionResources(num_cores=4, max_mem=5 * u.MB) 

326 butlerQC = pipeBase.QuantumContext(butler, quantum, resources=resources) 

327 self.assertEqual(butlerQC.resources.num_cores, 4) 

328 self.assertEqual(butlerQC.resources.max_mem, 5_000_000 * u.B) 

329 

330 resources = pipeBase.ExecutionResources(max_mem=5) 

331 butlerQC = pipeBase.QuantumContext(butler, quantum, resources=resources) 

332 self.assertEqual(butlerQC.resources.num_cores, 1) 

333 self.assertEqual(butlerQC.resources.max_mem, 5 * u.B) 

334 

335 def test_ExecutionResources(self): 

336 res = pipeBase.ExecutionResources() 

337 self.assertEqual(res.num_cores, 1) 

338 self.assertIsNone(res.max_mem) 

339 self.assertEqual(pickle.loads(pickle.dumps(res)), res) 

340 

341 res = pipeBase.ExecutionResources(num_cores=4, max_mem=1 * u.MiB) 

342 self.assertEqual(res.num_cores, 4) 

343 self.assertEqual(res.max_mem.value, 1024 * 1024) 

344 self.assertEqual(pickle.loads(pickle.dumps(res)), res) 

345 

346 res = pipeBase.ExecutionResources(max_mem=512) 

347 self.assertEqual(res.num_cores, 1) 

348 self.assertEqual(res.max_mem.value, 512) 

349 self.assertEqual(pickle.loads(pickle.dumps(res)), res) 

350 

351 res = pipeBase.ExecutionResources(max_mem="") 

352 self.assertIsNone(res.max_mem) 

353 

354 res = pipeBase.ExecutionResources(max_mem="32 KiB") 

355 self.assertEqual(res.num_cores, 1) 

356 self.assertEqual(res.max_mem.value, 32 * 1024) 

357 self.assertEqual(pickle.loads(pickle.dumps(res)), res) 

358 

359 self.assertIs(res, copy.deepcopy(res)) 

360 

361 with self.assertRaises(AttributeError): 

362 res.num_cores = 4 

363 

364 with self.assertRaises(u.UnitConversionError): 

365 pipeBase.ExecutionResources(max_mem=1 * u.m) 

366 with self.assertRaises(ValueError): 

367 pipeBase.ExecutionResources(num_cores=-32) 

368 with self.assertRaises(u.UnitConversionError): 

369 pipeBase.ExecutionResources(max_mem=1 * u.m) 

370 with self.assertRaises(u.UnitConversionError): 

371 pipeBase.ExecutionResources(max_mem=1, default_mem_units=u.m) 

372 with self.assertRaises(u.UnitConversionError): 

373 pipeBase.ExecutionResources(max_mem="32 Pa") 

374 

375 

376class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

377 """Run file leak tests.""" 

378 

379 

380def setup_module(module): 

381 """Configure pytest.""" 

382 lsst.utils.tests.init() 

383 

384 

385if __name__ == "__main__": 

386 lsst.utils.tests.init() 

387 unittest.main()