Coverage for tests/test_pipelineTask.py: 16%

222 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-25 09:14 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Simple unit test for PipelineTask. 

23""" 

24 

25import copy 

26import pickle 

27import unittest 

28from typing import Any 

29 

30import astropy.units as u 

31import lsst.pex.config as pexConfig 

32import lsst.pipe.base as pipeBase 

33import lsst.utils.logging 

34import lsst.utils.tests 

35from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum 

36 

37 

38class ButlerMock: 

39 """Mock version of butler, only usable for this test""" 

40 

41 def __init__(self) -> None: 

42 self.datasets: dict[str, dict[DataCoordinate, Any]] = {} 

43 self.dimensions = DimensionUniverse() 

44 

45 def get(self, ref: DatasetRef) -> Any: 

46 dsdata = self.datasets.get(ref.datasetType.name) 

47 if dsdata: 

48 return dsdata.get(ref.dataId) 

49 return None 

50 

51 def put(self, inMemoryDataset: Any, dsRef: DatasetRef, producer: Any = None): 

52 key = dsRef.dataId 

53 name = dsRef.datasetType.name 

54 dsdata = self.datasets.setdefault(name, {}) 

55 dsdata[key] = inMemoryDataset 

56 

57 

58class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]): 

59 input = pipeBase.connectionTypes.Input( 

60 name="add_input", 

61 dimensions=["instrument", "visit", "detector", "physical_filter", "band"], 

62 storageClass="Catalog", 

63 doc="Input dataset type for this task", 

64 ) 

65 output = pipeBase.connectionTypes.Output( 

66 name="add_output", 

67 dimensions=["instrument", "visit", "detector", "physical_filter", "band"], 

68 storageClass="Catalog", 

69 doc="Output dataset type for this task", 

70 ) 

71 

72 

73class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections): 

74 addend = pexConfig.Field[int](doc="amount to add", default=3) 

75 

76 

77# example task which overrides run() method 

78class AddTask(pipeBase.PipelineTask): 

79 ConfigClass = AddConfig 

80 _DefaultName = "add_task" 

81 

82 def run(self, input: int) -> pipeBase.Struct: 

83 self.metadata.add("add", self.config.addend) 

84 output = input + self.config.addend 

85 return pipeBase.Struct(output=output) 

86 

87 

88# example task which overrides adaptArgsAndRun() method 

89class AddTask2(pipeBase.PipelineTask): 

90 ConfigClass = AddConfig 

91 _DefaultName = "add_task" 

92 

93 def runQuantum( 

94 self, butlerQC: pipeBase.QuantumContext, inputRefs: DatasetRef, outputRefs: DatasetRef 

95 ) -> None: 

96 self.metadata.add("add", self.config.addend) 

97 inputs = butlerQC.get(inputRefs) 

98 outputs = inputs["input"] + self.config.addend 

99 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs) 

100 

101 

102class PipelineTaskTestCase(unittest.TestCase): 

103 """A test case for PipelineTask""" 

104 

105 def _makeDSRefVisit(self, dstype: DatasetType, visitId: int, universe: DimensionUniverse) -> DatasetRef: 

106 dataId = DataCoordinate.standardize( 

107 detector="X", 

108 visit=visitId, 

109 physical_filter="a", 

110 band="b", 

111 instrument="TestInstrument", 

112 universe=universe, 

113 ) 

114 run = "test" 

115 ref = DatasetRef(datasetType=dstype, dataId=dataId, run=run) 

116 return ref 

117 

118 def _makeQuanta(self, config: pipeBase.PipelineTaskConfig, nquanta: int = 100) -> list[Quantum]: 

119 """Create set of Quanta""" 

120 universe = DimensionUniverse() 

121 connections = config.connections.ConnectionsClass(config=config) 

122 

123 dstype0 = connections.input.makeDatasetType(universe) 

124 dstype1 = connections.output.makeDatasetType(universe) 

125 

126 quanta = [] 

127 for visit in range(nquanta): 

128 inputRef = self._makeDSRefVisit(dstype0, visit, universe) 

129 outputRef = self._makeDSRefVisit(dstype1, visit, universe) 

130 quantum = Quantum( 

131 inputs={inputRef.datasetType: [inputRef]}, outputs={outputRef.datasetType: [outputRef]} 

132 ) 

133 quanta.append(quantum) 

134 

135 return quanta 

136 

137 def testRunQuantumFull(self): 

138 """Test for AddTask.runQuantum() implementation with full butler.""" 

139 self._testRunQuantum(full_butler=True) 

140 

141 def testRunQuantumLimited(self): 

142 """Test for AddTask.runQuantum() implementation with limited butler.""" 

143 self._testRunQuantum(full_butler=False) 

144 

145 def _testRunQuantum(self, full_butler: bool) -> None: 

146 """Test for AddTask.runQuantum() implementation.""" 

147 

148 butler = ButlerMock() 

149 task = AddTask(config=AddConfig()) 

150 connections = task.config.connections.ConnectionsClass(config=task.config) 

151 

152 # make all quanta 

153 quanta = self._makeQuanta(task.config) 

154 

155 # add input data to butler 

156 dstype0 = connections.input.makeDatasetType(butler.dimensions) 

157 for i, quantum in enumerate(quanta): 

158 ref = quantum.inputs[dstype0.name][0] 

159 butler.put(100 + i, ref) 

160 

161 # run task on each quanta 

162 checked_get = False 

163 for quantum in quanta: 

164 butlerQC = pipeBase.QuantumContext(butler, quantum) 

165 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

166 task.runQuantum(butlerQC, inputRefs, outputRefs) 

167 

168 # Test getting of datasets in different ways. 

169 # (only need to do this one time) 

170 if not checked_get: 

171 # Force the periodic logger to issue messages. 

172 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 0.0 

173 

174 checked_get = True 

175 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm: 

176 input_data = butlerQC.get(inputRefs) 

177 self.assertIn("Completed", cm.output[-1]) 

178 self.assertEqual(len(input_data), len(inputRefs)) 

179 

180 # In this test there are no multiples returned. 

181 refs = [ref for _, ref in inputRefs] 

182 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm: 

183 list_get = butlerQC.get(refs) 

184 

185 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 600.0 

186 

187 self.assertIn("Completed", cm.output[-1]) 

188 self.assertEqual(len(list_get), len(input_data)) 

189 self.assertIsInstance(list_get[0], int) 

190 scalar_get = butlerQC.get(refs[0]) 

191 self.assertEqual(scalar_get, list_get[0]) 

192 

193 with self.assertRaises(TypeError): 

194 butlerQC.get({}) 

195 

196 # Output ref won't be known to this quantum. 

197 outputs = [ref for _, ref in outputRefs] 

198 with self.assertRaises(ValueError): 

199 butlerQC.get(outputs[0]) 

200 

201 # look at the output produced by the task 

202 outputName = connections.output.name 

203 dsdata = butler.datasets[outputName] 

204 self.assertEqual(len(dsdata), len(quanta)) 

205 for i, quantum in enumerate(quanta): 

206 ref = quantum.outputs[outputName][0] 

207 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

208 

209 def testChain2Full(self) -> None: 

210 """Test for two-task chain with full butler.""" 

211 self._testChain2(full_butler=True) 

212 

213 def testChain2Limited(self) -> None: 

214 """Test for two-task chain with limited butler.""" 

215 self._testChain2(full_butler=False) 

216 

217 def _testChain2(self, full_butler: bool) -> None: 

218 """Test for two-task chain.""" 

219 butler = ButlerMock() 

220 config1 = AddConfig() 

221 connections1 = config1.connections.ConnectionsClass(config=config1) 

222 task1 = AddTask(config=config1) 

223 config2 = AddConfig() 

224 config2.addend = 200 

225 config2.connections.input = task1.config.connections.output 

226 config2.connections.output = "add_output_2" 

227 task2 = AddTask2(config=config2) 

228 connections2 = config2.connections.ConnectionsClass(config=config2) 

229 

230 # make all quanta 

231 quanta1 = self._makeQuanta(task1.config) 

232 quanta2 = self._makeQuanta(task2.config) 

233 

234 # add input data to butler 

235 task1Connections = task1.config.connections.ConnectionsClass(config=task1.config) 

236 dstype0 = task1Connections.input.makeDatasetType(butler.dimensions) 

237 for i, quantum in enumerate(quanta1): 

238 ref = quantum.inputs[dstype0.name][0] 

239 butler.put(100 + i, ref) 

240 

241 # run task on each quanta 

242 for quantum in quanta1: 

243 butlerQC = pipeBase.QuantumContext(butler, quantum) 

244 inputRefs, outputRefs = connections1.buildDatasetRefs(quantum) 

245 task1.runQuantum(butlerQC, inputRefs, outputRefs) 

246 for quantum in quanta2: 

247 butlerQC = pipeBase.QuantumContext(butler, quantum) 

248 inputRefs, outputRefs = connections2.buildDatasetRefs(quantum) 

249 task2.runQuantum(butlerQC, inputRefs, outputRefs) 

250 

251 # look at the output produced by the task 

252 outputName = task1.config.connections.output 

253 dsdata = butler.datasets[outputName] 

254 self.assertEqual(len(dsdata), len(quanta1)) 

255 for i, quantum in enumerate(quanta1): 

256 ref = quantum.outputs[outputName][0] 

257 self.assertEqual(dsdata[ref.dataId], 100 + i + 3) 

258 

259 outputName = task2.config.connections.output 

260 dsdata = butler.datasets[outputName] 

261 self.assertEqual(len(dsdata), len(quanta2)) 

262 for i, quantum in enumerate(quanta2): 

263 ref = quantum.outputs[outputName][0] 

264 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200) 

265 

266 def testButlerQC(self): 

267 """Test for QuantumContext. Full and limited share get 

268 implementation so only full is tested. 

269 """ 

270 

271 butler = ButlerMock() 

272 task = AddTask(config=AddConfig()) 

273 connections = task.config.connections.ConnectionsClass(config=task.config) 

274 

275 # make one quantum 

276 (quantum,) = self._makeQuanta(task.config, 1) 

277 

278 # add input data to butler 

279 dstype0 = connections.input.makeDatasetType(butler.dimensions) 

280 ref = quantum.inputs[dstype0.name][0] 

281 butler.put(100, ref) 

282 

283 butlerQC = pipeBase.QuantumContext(butler, quantum) 

284 self.assertEqual(butlerQC.resources.num_cores, 1) 

285 self.assertIsNone(butlerQC.resources.max_mem) 

286 

287 # Pass ref as single argument or a list. 

288 obj = butlerQC.get(ref) 

289 self.assertEqual(obj, 100) 

290 obj = butlerQC.get([ref]) 

291 self.assertEqual(obj, [100]) 

292 

293 # Pass None instead of a ref. 

294 obj = butlerQC.get(None) 

295 self.assertIsNone(obj) 

296 obj = butlerQC.get([None]) 

297 self.assertEqual(obj, [None]) 

298 

299 # COmbine a ref and None. 

300 obj = butlerQC.get([ref, None]) 

301 self.assertEqual(obj, [100, None]) 

302 

303 # Use refs from a QuantizedConnection. 

304 inputRefs, outputRefs = connections.buildDatasetRefs(quantum) 

305 obj = butlerQC.get(inputRefs) 

306 self.assertEqual(obj, {"input": 100}) 

307 

308 # Add few None values to a QuantizedConnection. 

309 inputRefs.input = [None, ref] 

310 inputRefs.input2 = None 

311 obj = butlerQC.get(inputRefs) 

312 self.assertEqual(obj, {"input": [None, 100], "input2": None}) 

313 

314 # Set additional context. 

315 resources = pipeBase.ExecutionResources(num_cores=4, max_mem=5 * u.MB) 

316 butlerQC = pipeBase.QuantumContext(butler, quantum, resources=resources) 

317 self.assertEqual(butlerQC.resources.num_cores, 4) 

318 self.assertEqual(butlerQC.resources.max_mem, 5_000_000 * u.B) 

319 

320 resources = pipeBase.ExecutionResources(max_mem=5) 

321 butlerQC = pipeBase.QuantumContext(butler, quantum, resources=resources) 

322 self.assertEqual(butlerQC.resources.num_cores, 1) 

323 self.assertEqual(butlerQC.resources.max_mem, 5 * u.B) 

324 

325 def test_ExecutionResources(self): 

326 res = pipeBase.ExecutionResources() 

327 self.assertEqual(res.num_cores, 1) 

328 self.assertIsNone(res.max_mem) 

329 self.assertEqual(pickle.loads(pickle.dumps(res)), res) 

330 

331 res = pipeBase.ExecutionResources(num_cores=4, max_mem=1 * u.MiB) 

332 self.assertEqual(res.num_cores, 4) 

333 self.assertEqual(res.max_mem.value, 1024 * 1024) 

334 self.assertEqual(pickle.loads(pickle.dumps(res)), res) 

335 

336 res = pipeBase.ExecutionResources(max_mem=512) 

337 self.assertEqual(res.num_cores, 1) 

338 self.assertEqual(res.max_mem.value, 512) 

339 self.assertEqual(pickle.loads(pickle.dumps(res)), res) 

340 

341 res = pipeBase.ExecutionResources(max_mem="") 

342 self.assertIsNone(res.max_mem) 

343 

344 res = pipeBase.ExecutionResources(max_mem="32 KiB") 

345 self.assertEqual(res.num_cores, 1) 

346 self.assertEqual(res.max_mem.value, 32 * 1024) 

347 self.assertEqual(pickle.loads(pickle.dumps(res)), res) 

348 

349 self.assertIs(res, copy.deepcopy(res)) 

350 

351 with self.assertRaises(AttributeError): 

352 res.num_cores = 4 

353 

354 with self.assertRaises(u.UnitConversionError): 

355 pipeBase.ExecutionResources(max_mem=1 * u.m) 

356 with self.assertRaises(ValueError): 

357 pipeBase.ExecutionResources(num_cores=-32) 

358 with self.assertRaises(u.UnitConversionError): 

359 pipeBase.ExecutionResources(max_mem=1 * u.m) 

360 with self.assertRaises(u.UnitConversionError): 

361 pipeBase.ExecutionResources(max_mem=1, default_mem_units=u.m) 

362 with self.assertRaises(u.UnitConversionError): 

363 pipeBase.ExecutionResources(max_mem="32 Pa") 

364 

365 

366class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

367 pass 

368 

369 

370def setup_module(module): 

371 lsst.utils.tests.init() 

372 

373 

374if __name__ == "__main__": 

375 lsst.utils.tests.init() 

376 unittest.main()