Coverage for tests/test_pipelineTask.py: 15%

221 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-08-23 10:31 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Simple unit test for PipelineTask. 

23""" 

24 

25import copy 

26import pickle 

27import unittest 

28from typing import Any 

29 

30import astropy.units as u 

31import lsst.pex.config as pexConfig 

32import lsst.pipe.base as pipeBase 

33import lsst.utils.logging 

34import lsst.utils.tests 

35from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum 

36 

37 

38class ButlerMock: 

39 """Mock version of butler, only usable for this test""" 

40 

41 def __init__(self) -> None: 

42 self.datasets: dict[str, dict[DataCoordinate, Any]] = {} 

43 self.dimensions = DimensionUniverse() 

44 

45 def get(self, ref: DatasetRef) -> Any: 

46 dsdata = self.datasets.get(ref.datasetType.name) 

47 if dsdata: 

48 return dsdata.get(ref.dataId) 

49 return None 

50 

51 def put(self, inMemoryDataset: Any, dsRef: DatasetRef, producer: Any = None): 

52 key = dsRef.dataId 

53 name = dsRef.datasetType.name 

54 dsdata = self.datasets.setdefault(name, {}) 

55 dsdata[key] = inMemoryDataset 

56 

57 

58class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]): 

59 """Connections for the AddTask.""" 

60 

61 input = pipeBase.connectionTypes.Input( 

62 name="add_input", 

63 dimensions=["instrument", "visit", "detector", "physical_filter", "band"], 

64 storageClass="Catalog", 

65 doc="Input dataset type for this task", 

66 ) 

67 output = pipeBase.connectionTypes.Output( 

68 name="add_output", 

69 dimensions=["instrument", "visit", "detector", "physical_filter", "band"], 

70 storageClass="Catalog", 

71 doc="Output dataset type for this task", 

72 ) 

73 

74 

75class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections): 

76 """Config for the AddTask.""" 

77 

78 addend = pexConfig.Field[int](doc="amount to add", default=3) 

79 

80 

81class AddTask(pipeBase.PipelineTask): 

82 """Example task which overrides run() method.""" 

83 

84 ConfigClass = AddConfig 

85 _DefaultName = "add_task" 

86 

87 def run(self, input: int) -> pipeBase.Struct: 

88 self.metadata.add("add", self.config.addend) 

89 output = input + self.config.addend 

90 return pipeBase.Struct(output=output) 

91 

92 

93class AddTask2(pipeBase.PipelineTask): 

94 """Example task which overrides runQuantum() method.""" 

95 

96 ConfigClass = AddConfig 

97 _DefaultName = "add_task" 

98 

99 def runQuantum( 

100 self, butlerQC: pipeBase.QuantumContext, inputRefs: DatasetRef, outputRefs: DatasetRef 

101 ) -> None: 

102 self.metadata.add("add", self.config.addend) 

103 inputs = butlerQC.get(inputRefs) 

104 outputs = inputs["input"] + self.config.addend 

105 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs) 

106 

107 

108class PipelineTaskTestCase(unittest.TestCase): 

109 """A test case for PipelineTask""" 

110 

111 def _makeDSRefVisit(self, dstype: DatasetType, visitId: int, universe: DimensionUniverse) -> DatasetRef: 

112 dataId = DataCoordinate.standardize( 

113 detector="X", 

114 visit=visitId, 

115 physical_filter="a", 

116 band="b", 

117 instrument="TestInstrument", 

118 universe=universe, 

119 ) 

120 run = "test" 

121 ref = DatasetRef(datasetType=dstype, dataId=dataId, run=run) 

122 return ref 

123 

124 def _makeQuanta(self, config: pipeBase.PipelineTaskConfig, nquanta: int = 100) -> list[Quantum]: 

125 """Create set of Quanta""" 

126 universe = DimensionUniverse() 

127 connections = config.connections.ConnectionsClass(config=config) 

128 

129 dstype0 = connections.input.makeDatasetType(universe) 

130 dstype1 = connections.output.makeDatasetType(universe) 

131 

132 quanta = [] 

133 for visit in range(nquanta): 

134 inputRef = self._makeDSRefVisit(dstype0, visit, universe) 

135 outputRef = self._makeDSRefVisit(dstype1, visit, universe) 

136 quantum = Quantum( 

137 inputs={inputRef.datasetType: [inputRef]}, outputs={outputRef.datasetType: [outputRef]} 

138 ) 

139 quanta.append(quantum) 

140 

141 return quanta 

142 

143 def testRunQuantumFull(self): 

144 """Test for AddTask.runQuantum() implementation with full butler.""" 

145 self._testRunQuantum(full_butler=True) 

146 

147 def testRunQuantumLimited(self): 

148 """Test for AddTask.runQuantum() implementation with limited butler.""" 

149 self._testRunQuantum(full_butler=False) 

150 

151 def _testRunQuantum(self, full_butler: bool) -> None: 

152 """Test for AddTask.runQuantum() implementation.""" 

153 butler = ButlerMock() 

154 task = AddTask(config=AddConfig()) 

155 connections = task.config.connections.ConnectionsClass(config=task.config) 

156 

157 # make all quanta 

158 quanta = self._makeQuanta(task.config) 

159 

160 # add input data to butler 

161 dstype0 = connections.input.makeDatasetType(butler.dimensions) 

162 for i, quantum in enumerate(quanta): 

163 ref = quantum.inputs[dstype0.name][0] 

164 butler.put(100 + i, ref) 

165 

166 # run task on each quanta 

167 checked_get = False 

168 for quantum in quanta: 

169 butlerQC = pipeBase.QuantumContext(butler, quantum) 

170 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

171 task.runQuantum(butlerQC, inputRefs, outputRefs) 

172 

173 # Test getting of datasets in different ways. 

174 # (only need to do this one time) 

175 if not checked_get: 

176 # Force the periodic logger to issue messages. 

177 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 0.0 

178 

179 checked_get = True 

180 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm: 

181 input_data = butlerQC.get(inputRefs) 

182 self.assertIn("Completed", cm.output[-1]) 

183 self.assertEqual(len(input_data), len(inputRefs)) 

184 

185 # In this test there are no multiples returned. 

186 refs = [ref for _, ref in inputRefs] 

187 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm: 

188 list_get = butlerQC.get(refs) 

189 

190 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 600.0 

191 

192 self.assertIn("Completed", cm.output[-1]) 

193 self.assertEqual(len(list_get), len(input_data)) 

194 self.assertIsInstance(list_get[0], int) 

195 scalar_get = butlerQC.get(refs[0]) 

196 self.assertEqual(scalar_get, list_get[0]) 

197 

198 with self.assertRaises(TypeError): 

199 butlerQC.get({}) 

200 

201 # Output ref won't be known to this quantum. 

202 outputs = [ref for _, ref in outputRefs] 

203 with self.assertRaises(ValueError): 

204 butlerQC.get(outputs[0]) 

205 

206 # look at the output produced by the task 

207 outputName = connections.output.name 

208 dsdata = butler.datasets[outputName] 

209 self.assertEqual(len(dsdata), len(quanta)) 

210 for i, quantum in enumerate(quanta): 

211 ref = quantum.outputs[outputName][0] 

212 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

213 

214 def testChain2Full(self) -> None: 

215 """Test for two-task chain with full butler.""" 

216 self._testChain2(full_butler=True) 

217 

218 def testChain2Limited(self) -> None: 

219 """Test for two-task chain with limited butler.""" 

220 self._testChain2(full_butler=False) 

221 

222 def _testChain2(self, full_butler: bool) -> None: 

223 """Test for two-task chain.""" 

224 butler = ButlerMock() 

225 config1 = AddConfig() 

226 connections1 = config1.connections.ConnectionsClass(config=config1) 

227 task1 = AddTask(config=config1) 

228 config2 = AddConfig() 

229 config2.addend = 200 

230 config2.connections.input = task1.config.connections.output 

231 config2.connections.output = "add_output_2" 

232 task2 = AddTask2(config=config2) 

233 connections2 = config2.connections.ConnectionsClass(config=config2) 

234 

235 # make all quanta 

236 quanta1 = self._makeQuanta(task1.config) 

237 quanta2 = self._makeQuanta(task2.config) 

238 

239 # add input data to butler 

240 task1Connections = task1.config.connections.ConnectionsClass(config=task1.config) 

241 dstype0 = task1Connections.input.makeDatasetType(butler.dimensions) 

242 for i, quantum in enumerate(quanta1): 

243 ref = quantum.inputs[dstype0.name][0] 

244 butler.put(100 + i, ref) 

245 

246 # run task on each quanta 

247 for quantum in quanta1: 

248 butlerQC = pipeBase.QuantumContext(butler, quantum) 

249 inputRefs, outputRefs = connections1.buildDatasetRefs(quantum) 

250 task1.runQuantum(butlerQC, inputRefs, outputRefs) 

251 for quantum in quanta2: 

252 butlerQC = pipeBase.QuantumContext(butler, quantum) 

253 inputRefs, outputRefs = connections2.buildDatasetRefs(quantum) 

254 task2.runQuantum(butlerQC, inputRefs, outputRefs) 

255 

256 # look at the output produced by the task 

257 outputName = task1.config.connections.output 

258 dsdata = butler.datasets[outputName] 

259 self.assertEqual(len(dsdata), len(quanta1)) 

260 for i, quantum in enumerate(quanta1): 

261 ref = quantum.outputs[outputName][0] 

262 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

263 

264 outputName = task2.config.connections.output 

265 dsdata = butler.datasets[outputName] 

266 self.assertEqual(len(dsdata), len(quanta2)) 

267 for i, quantum in enumerate(quanta2): 

268 ref = quantum.outputs[outputName][0] 

269 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200) 

270 

271 def testButlerQC(self): 

272 """Test for QuantumContext. Full and limited share get 

273 implementation so only full is tested. 

274 """ 

275 butler = ButlerMock() 

276 task = AddTask(config=AddConfig()) 

277 connections = task.config.connections.ConnectionsClass(config=task.config) 

278 

279 # make one quantum 

280 (quantum,) = self._makeQuanta(task.config, 1) 

281 

282 # add input data to butler 

283 dstype0 = connections.input.makeDatasetType(butler.dimensions) 

284 ref = quantum.inputs[dstype0.name][0] 

285 butler.put(100, ref) 

286 

287 butlerQC = pipeBase.QuantumContext(butler, quantum) 

288 self.assertEqual(butlerQC.resources.num_cores, 1) 

289 self.assertIsNone(butlerQC.resources.max_mem) 

290 

291 # Pass ref as single argument or a list. 

292 obj = butlerQC.get(ref) 

293 self.assertEqual(obj, 100) 

294 obj = butlerQC.get([ref]) 

295 self.assertEqual(obj, [100]) 

296 

297 # Pass None instead of a ref. 

298 obj = butlerQC.get(None) 

299 self.assertIsNone(obj) 

300 obj = butlerQC.get([None]) 

301 self.assertEqual(obj, [None]) 

302 

303 # COmbine a ref and None. 

304 obj = butlerQC.get([ref, None]) 

305 self.assertEqual(obj, [100, None]) 

306 

307 # Use refs from a QuantizedConnection. 

308 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

309 obj = butlerQC.get(inputRefs) 

310 self.assertEqual(obj, {"input": 100}) 

311 

312 # Add few None values to a QuantizedConnection. 

313 inputRefs.input = [None, ref] 

314 inputRefs.input2 = None 

315 obj = butlerQC.get(inputRefs) 

316 self.assertEqual(obj, {"input": [None, 100], "input2": None}) 

317 

318 # Set additional context. 

319 resources = pipeBase.ExecutionResources(num_cores=4, max_mem=5 * u.MB) 

320 butlerQC = pipeBase.QuantumContext(butler, quantum, resources=resources) 

321 self.assertEqual(butlerQC.resources.num_cores, 4) 

322 self.assertEqual(butlerQC.resources.max_mem, 5_000_000 * u.B) 

323 

324 resources = pipeBase.ExecutionResources(max_mem=5) 

325 butlerQC = pipeBase.QuantumContext(butler, quantum, resources=resources) 

326 self.assertEqual(butlerQC.resources.num_cores, 1) 

327 self.assertEqual(butlerQC.resources.max_mem, 5 * u.B) 

328 

329 def test_ExecutionResources(self): 

330 res = pipeBase.ExecutionResources() 

331 self.assertEqual(res.num_cores, 1) 

332 self.assertIsNone(res.max_mem) 

333 self.assertEqual(pickle.loads(pickle.dumps(res)), res) 

334 

335 res = pipeBase.ExecutionResources(num_cores=4, max_mem=1 * u.MiB) 

336 self.assertEqual(res.num_cores, 4) 

337 self.assertEqual(res.max_mem.value, 1024 * 1024) 

338 self.assertEqual(pickle.loads(pickle.dumps(res)), res) 

339 

340 res = pipeBase.ExecutionResources(max_mem=512) 

341 self.assertEqual(res.num_cores, 1) 

342 self.assertEqual(res.max_mem.value, 512) 

343 self.assertEqual(pickle.loads(pickle.dumps(res)), res) 

344 

345 res = pipeBase.ExecutionResources(max_mem="") 

346 self.assertIsNone(res.max_mem) 

347 

348 res = pipeBase.ExecutionResources(max_mem="32 KiB") 

349 self.assertEqual(res.num_cores, 1) 

350 self.assertEqual(res.max_mem.value, 32 * 1024) 

351 self.assertEqual(pickle.loads(pickle.dumps(res)), res) 

352 

353 self.assertIs(res, copy.deepcopy(res)) 

354 

355 with self.assertRaises(AttributeError): 

356 res.num_cores = 4 

357 

358 with self.assertRaises(u.UnitConversionError): 

359 pipeBase.ExecutionResources(max_mem=1 * u.m) 

360 with self.assertRaises(ValueError): 

361 pipeBase.ExecutionResources(num_cores=-32) 

362 with self.assertRaises(u.UnitConversionError): 

363 pipeBase.ExecutionResources(max_mem=1 * u.m) 

364 with self.assertRaises(u.UnitConversionError): 

365 pipeBase.ExecutionResources(max_mem=1, default_mem_units=u.m) 

366 with self.assertRaises(u.UnitConversionError): 

367 pipeBase.ExecutionResources(max_mem="32 Pa") 

368 

369 

370class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

371 """Run file leak tests.""" 

372 

373 

374def setup_module(module): 

375 """Configure pytest.""" 

376 lsst.utils.tests.init() 

377 

378 

379if __name__ == "__main__": 

380 lsst.utils.tests.init() 

381 unittest.main()