Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Unit tests for `lsst.pipe.base.tests`, a library for testing 

23PipelineTask subclasses. 

24""" 

25 

26import shutil 

27import tempfile 

28import unittest 

29 

30import lsst.utils.tests 

31import lsst.pex.config 

32import lsst.daf.butler 

33import lsst.daf.butler.tests as butlerTests 

34 

35from lsst.pipe.base import Struct, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, connectionTypes 

36from lsst.pipe.base.testUtils import runTestQuantum, makeQuantum, assertValidOutput 

37 

38 

39class VisitConnections(PipelineTaskConnections, dimensions={"instrument", "visit"}): 

40 a = connectionTypes.Input( 

41 name="VisitA", 

42 storageClass="StructuredData", 

43 multiple=False, 

44 dimensions={"instrument", "visit"}, 

45 ) 

46 b = connectionTypes.Input( 

47 name="VisitB", 

48 storageClass="StructuredData", 

49 multiple=False, 

50 dimensions={"instrument", "visit"}, 

51 ) 

52 outA = connectionTypes.Output( 

53 name="VisitOutA", 

54 storageClass="StructuredData", 

55 multiple=False, 

56 dimensions={"instrument", "visit"}, 

57 ) 

58 outB = connectionTypes.Output( 

59 name="VisitOutB", 

60 storageClass="StructuredData", 

61 multiple=False, 

62 dimensions={"instrument", "visit"}, 

63 ) 

64 

65 

66class PatchConnections(PipelineTaskConnections, dimensions={"skymap", "tract"}): 

67 a = connectionTypes.Input( 

68 name="PatchA", 

69 storageClass="StructuredData", 

70 multiple=True, 

71 dimensions={"skymap", "tract", "patch"}, 

72 ) 

73 b = connectionTypes.PrerequisiteInput( 

74 name="PatchB", 

75 storageClass="StructuredData", 

76 multiple=False, 

77 dimensions={"skymap", "tract"}, 

78 ) 

79 out = connectionTypes.Output( 

80 name="PatchOut", 

81 storageClass="StructuredData", 

82 multiple=True, 

83 dimensions={"skymap", "tract", "patch"}, 

84 ) 

85 

86 def __init__(self, *, config=None): 

87 super().__init__(config=config) 

88 

89 if not config.doUseB: 

90 self.prerequisiteInputs.remove("b") 

91 

92 

93class SkyPixConnections(PipelineTaskConnections, dimensions={"skypix"}): 

94 a = connectionTypes.Input( 

95 name="PixA", 

96 storageClass="StructuredData", 

97 dimensions={"skypix"}, 

98 ) 

99 out = connectionTypes.Output( 

100 name="PixOut", 

101 storageClass="StructuredData", 

102 dimensions={"skypix"}, 

103 ) 

104 

105 

106class VisitConfig(PipelineTaskConfig, pipelineConnections=VisitConnections): 

107 pass 

108 

109 

110class PatchConfig(PipelineTaskConfig, pipelineConnections=PatchConnections): 

111 doUseB = lsst.pex.config.Field(default=True, dtype=bool, doc="") 

112 

113 

114class SkyPixConfig(PipelineTaskConfig, pipelineConnections=SkyPixConnections): 

115 pass 

116 

117 

118class VisitTask(PipelineTask): 

119 ConfigClass = VisitConfig 

120 _DefaultName = "visit" 

121 

122 def run(self, a, b): 

123 outA = butlerTests.MetricsExample(data=(a.data + b.data)) 

124 outB = butlerTests.MetricsExample(data=(a.data * max(b.data))) 

125 return Struct(outA=outA, outB=outB) 

126 

127 

128class PatchTask(PipelineTask): 

129 ConfigClass = PatchConfig 

130 _DefaultName = "patch" 

131 

132 def run(self, a, b=None): 

133 if self.config.doUseB: 

134 out = [butlerTests.MetricsExample(data=(oneA.data + b.data)) for oneA in a] 

135 else: 

136 out = a 

137 return Struct(out=out) 

138 

139 

140class SkyPixTask(PipelineTask): 

141 ConfigClass = SkyPixConfig 

142 _DefaultName = "skypix" 

143 

144 def run(self, a): 

145 return Struct(out=a) 

146 

147 

148class PipelineTaskTestSuite(lsst.utils.tests.TestCase): 

149 @classmethod 

150 def setUpClass(cls): 

151 super().setUpClass() 

152 # Repository should be re-created for each test case, but 

153 # this has a prohibitive run-time cost at present 

154 cls.root = tempfile.mkdtemp() 

155 

156 dataIds = { 

157 "instrument": ["notACam"], 

158 "physical_filter": ["k2020"], # needed for expandUniqueId(visit) 

159 "visit": [101, 102], 

160 "skymap": ["sky"], 

161 "tract": [42], 

162 "patch": [0, 1], 

163 } 

164 cls.repo = butlerTests.makeTestRepo(cls.root, dataIds) 

165 butlerTests.registerMetricsExample(cls.repo) 

166 

167 for typeName in {"VisitA", "VisitB", "VisitOutA", "VisitOutB"}: 

168 butlerTests.addDatasetType(cls.repo, typeName, {"instrument", "visit"}, "StructuredData") 

169 for typeName in {"PatchA", "PatchOut"}: 

170 butlerTests.addDatasetType(cls.repo, typeName, {"skymap", "tract", "patch"}, "StructuredData") 

171 butlerTests.addDatasetType(cls.repo, "PatchB", {"skymap", "tract"}, "StructuredData") 

172 for typeName in {"PixA", "PixOut"}: 

173 butlerTests.addDatasetType(cls.repo, typeName, {"htm7"}, "StructuredData") 

174 

175 @classmethod 

176 def tearDownClass(cls): 

177 shutil.rmtree(cls.root, ignore_errors=True) 

178 super().tearDownClass() 

179 

180 def setUp(self): 

181 super().setUp() 

182 self.butler = butlerTests.makeTestCollection(self.repo) 

183 

184 def _makeVisitTestData(self, dataId): 

185 """Create dummy datasets suitable for VisitTask. 

186 

187 This method updates ``self.butler`` directly. 

188 

189 Parameters 

190 ---------- 

191 dataId : any data ID type 

192 The (shared) ID for the datasets to create. 

193 

194 Returns 

195 ------- 

196 datasets : `dict` [`str`, `list`] 

197 A dictionary keyed by dataset type. Its values are the list of 

198 integers used to create each dataset. The datasets stored in the 

199 butler are `lsst.daf.butler.tests.MetricsExample` objects with 

200 these lists as their ``data`` argument, but the lists are easier 

201 to manipulate in test code. 

202 """ 

203 inA = [1, 2, 3] 

204 inB = [4, 0, 1] 

205 self.butler.put(butlerTests.MetricsExample(data=inA), "VisitA", dataId) 

206 self.butler.put(butlerTests.MetricsExample(data=inB), "VisitB", dataId) 

207 return {"VisitA": inA, "VisitB": inB, } 

208 

209 def _makePatchTestData(self, dataId): 

210 """Create dummy datasets suitable for PatchTask. 

211 

212 This method updates ``self.butler`` directly. 

213 

214 Parameters 

215 ---------- 

216 dataId : any data ID type 

217 The (shared) ID for the datasets to create. Any patch ID is 

218 overridden to create multiple datasets. 

219 

220 Returns 

221 ------- 

222 datasets : `dict` [`str`, `list` [`tuple` [data ID, `list`]]] 

223 A dictionary keyed by dataset type. Its values are the data ID 

224 of each dataset and the list of integers used to create each. The 

225 datasets stored in the butler are 

226 `lsst.daf.butler.tests.MetricsExample` objects with these lists as 

227 their ``data`` argument, but the lists are easier to manipulate 

228 in test code. 

229 """ 

230 inA = [1, 2, 3] 

231 inB = [4, 0, 1] 

232 datasets = {"PatchA": [], "PatchB": []} 

233 for patch in {0, 1}: 

234 self.butler.put(butlerTests.MetricsExample(data=(inA + [patch])), "PatchA", dataId, patch=patch) 

235 datasets["PatchA"].append((dict(dataId, patch=patch), inA + [patch])) 

236 self.butler.put(butlerTests.MetricsExample(data=inB), "PatchB", dataId) 

237 datasets["PatchB"].append((dataId, inB)) 

238 return datasets 

239 

240 def testMakeQuantumNoSuchDatatype(self): 

241 config = VisitConfig() 

242 config.connections.a = "Visit" 

243 task = VisitTask(config=config) 

244 

245 dataId = butlerTests.expandUniqueId(self.butler, {"visit": 102}) 

246 self._makeVisitTestData(dataId) 

247 

248 with self.assertRaises(ValueError): 

249 makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "b", "outA", "outB"}}) 

250 

251 def testMakeQuantumInvalidDimension(self): 

252 config = VisitConfig() 

253 config.connections.a = "PatchA" 

254 task = VisitTask(config=config) 

255 dataIdV = butlerTests.expandUniqueId(self.butler, {"visit": 102}) 

256 dataIdP = butlerTests.expandUniqueId(self.butler, {"patch": 0}) 

257 

258 inA = [1, 2, 3] 

259 inB = [4, 0, 1] 

260 self.butler.put(butlerTests.MetricsExample(data=inA), "VisitA", dataIdV) 

261 self.butler.put(butlerTests.MetricsExample(data=inA), "PatchA", dataIdP) 

262 self.butler.put(butlerTests.MetricsExample(data=inB), "VisitB", dataIdV) 

263 

264 with self.assertRaises(ValueError): 

265 makeQuantum(task, self.butler, dataIdV, { 

266 "a": dataIdP, 

267 "b": dataIdV, 

268 "outA": dataIdV, 

269 "outB": dataIdV, 

270 }) 

271 

272 def testMakeQuantumMissingMultiple(self): 

273 task = PatchTask() 

274 

275 dataId = butlerTests.expandUniqueId(self.butler, {"tract": 42}) 

276 self._makePatchTestData(dataId) 

277 

278 with self.assertRaises(ValueError): 

279 makeQuantum(task, self.butler, dataId, { 

280 "a": dict(dataId, patch=0), 

281 "b": dataId, 

282 "out": [dict(dataId, patch=patch) for patch in {0, 1}], 

283 }) 

284 

285 def testMakeQuantumExtraMultiple(self): 

286 task = PatchTask() 

287 

288 dataId = butlerTests.expandUniqueId(self.butler, {"tract": 42}) 

289 self._makePatchTestData(dataId) 

290 

291 with self.assertRaises(ValueError): 

292 makeQuantum(task, self.butler, dataId, { 

293 "a": [dict(dataId, patch=patch) for patch in {0, 1}], 

294 "b": [dataId], 

295 "out": [dict(dataId, patch=patch) for patch in {0, 1}], 

296 }) 

297 

298 def testMakeQuantumMissingDataId(self): 

299 task = VisitTask() 

300 

301 dataId = butlerTests.expandUniqueId(self.butler, {"visit": 102}) 

302 self._makeVisitTestData(dataId) 

303 

304 with self.assertRaises(ValueError): 

305 makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "outA", "outB"}}) 

306 with self.assertRaises(ValueError): 

307 makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "b", "outB"}}) 

308 

309 def testMakeQuantumCorruptedDataId(self): 

310 task = VisitTask() 

311 

312 dataId = butlerTests.expandUniqueId(self.butler, {"visit": 102}) 

313 self._makeVisitTestData(dataId) 

314 

315 with self.assertRaises(ValueError): 

316 # fourth argument should be a mapping keyed by component name 

317 makeQuantum(task, self.butler, dataId, dataId) 

318 

319 def testRunTestQuantumVisitWithRun(self): 

320 task = VisitTask() 

321 

322 dataId = butlerTests.expandUniqueId(self.butler, {"visit": 102}) 

323 data = self._makeVisitTestData(dataId) 

324 

325 quantum = makeQuantum(task, self.butler, dataId, 

326 {key: dataId for key in {"a", "b", "outA", "outB"}}) 

327 runTestQuantum(task, self.butler, quantum, mockRun=False) 

328 

329 # Can we use runTestQuantum to verify that task.run got called with correct inputs/outputs? 

330 self.assertTrue(self.butler.datasetExists("VisitOutA", dataId)) 

331 self.assertEqual(self.butler.get("VisitOutA", dataId), 

332 butlerTests.MetricsExample(data=(data["VisitA"] + data["VisitB"]))) 

333 self.assertTrue(self.butler.datasetExists("VisitOutB", dataId)) 

334 self.assertEqual(self.butler.get("VisitOutB", dataId), 

335 butlerTests.MetricsExample(data=(data["VisitA"] * max(data["VisitB"])))) 

336 

337 def testRunTestQuantumPatchWithRun(self): 

338 task = PatchTask() 

339 

340 dataId = butlerTests.expandUniqueId(self.butler, {"tract": 42}) 

341 data = self._makePatchTestData(dataId) 

342 

343 quantum = makeQuantum(task, self.butler, dataId, { 

344 "a": [dataset[0] for dataset in data["PatchA"]], 

345 "b": dataId, 

346 "out": [dataset[0] for dataset in data["PatchA"]], 

347 }) 

348 runTestQuantum(task, self.butler, quantum, mockRun=False) 

349 

350 # Can we use runTestQuantum to verify that task.run got called with correct inputs/outputs? 

351 inB = data["PatchB"][0][1] 

352 for dataset in data["PatchA"]: 

353 patchId = dataset[0] 

354 self.assertTrue(self.butler.datasetExists("PatchOut", patchId)) 

355 self.assertEqual(self.butler.get("PatchOut", patchId), 

356 butlerTests.MetricsExample(data=(dataset[1] + inB))) 

357 

358 def testRunTestQuantumVisitMockRun(self): 

359 task = VisitTask() 

360 

361 dataId = butlerTests.expandUniqueId(self.butler, {"visit": 102}) 

362 data = self._makeVisitTestData(dataId) 

363 

364 quantum = makeQuantum(task, self.butler, dataId, 

365 {key: dataId for key in {"a", "b", "outA", "outB"}}) 

366 run = runTestQuantum(task, self.butler, quantum, mockRun=True) 

367 

368 # Can we use the mock to verify that task.run got called with the correct inputs? 

369 run.assert_called_once_with(a=butlerTests.MetricsExample(data=data["VisitA"]), 

370 b=butlerTests.MetricsExample(data=data["VisitB"])) 

371 

372 def testRunTestQuantumPatchMockRun(self): 

373 task = PatchTask() 

374 

375 dataId = butlerTests.expandUniqueId(self.butler, {"tract": 42}) 

376 data = self._makePatchTestData(dataId) 

377 

378 quantum = makeQuantum(task, self.butler, dataId, { 

379 # Use lists, not sets, to ensure order agrees with test assertion 

380 "a": [dataset[0] for dataset in data["PatchA"]], 

381 "b": dataId, 

382 "out": [dataset[0] for dataset in data["PatchA"]], 

383 }) 

384 run = runTestQuantum(task, self.butler, quantum, mockRun=True) 

385 

386 # Can we use the mock to verify that task.run got called with the correct inputs? 

387 run.assert_called_once_with( 

388 a=[butlerTests.MetricsExample(data=dataset[1]) for dataset in data["PatchA"]], 

389 b=butlerTests.MetricsExample(data=data["PatchB"][0][1]) 

390 ) 

391 

392 def testRunTestQuantumPatchOptionalInput(self): 

393 config = PatchConfig() 

394 config.doUseB = False 

395 task = PatchTask(config=config) 

396 

397 dataId = butlerTests.expandUniqueId(self.butler, {"tract": 42}) 

398 data = self._makePatchTestData(dataId) 

399 

400 quantum = makeQuantum(task, self.butler, dataId, { 

401 # Use lists, not sets, to ensure order agrees with test assertion 

402 "a": [dataset[0] for dataset in data["PatchA"]], 

403 "out": [dataset[0] for dataset in data["PatchA"]], 

404 }) 

405 run = runTestQuantum(task, self.butler, quantum, mockRun=True) 

406 

407 # Can we use the mock to verify that task.run got called with the correct inputs? 

408 run.assert_called_once_with( 

409 a=[butlerTests.MetricsExample(data=dataset[1]) for dataset in data["PatchA"]] 

410 ) 

411 

412 def testAssertValidOutputPass(self): 

413 task = VisitTask() 

414 

415 inA = butlerTests.MetricsExample(data=[1, 2, 3]) 

416 inB = butlerTests.MetricsExample(data=[4, 0, 1]) 

417 result = task.run(inA, inB) 

418 

419 # should not throw 

420 assertValidOutput(task, result) 

421 

422 def testAssertValidOutputMissing(self): 

423 task = VisitTask() 

424 

425 def run(a, b): 

426 return Struct(outA=a) 

427 task.run = run 

428 

429 inA = butlerTests.MetricsExample(data=[1, 2, 3]) 

430 inB = butlerTests.MetricsExample(data=[4, 0, 1]) 

431 result = task.run(inA, inB) 

432 

433 with self.assertRaises(AssertionError): 

434 assertValidOutput(task, result) 

435 

436 def testAssertValidOutputSingle(self): 

437 task = PatchTask() 

438 

439 def run(a, b): 

440 return Struct(out=b) 

441 task.run = run 

442 

443 inA = butlerTests.MetricsExample(data=[1, 2, 3]) 

444 inB = butlerTests.MetricsExample(data=[4, 0, 1]) 

445 result = task.run([inA], inB) 

446 

447 with self.assertRaises(AssertionError): 

448 assertValidOutput(task, result) 

449 

450 def testAssertValidOutputMultiple(self): 

451 task = VisitTask() 

452 

453 def run(a, b): 

454 return Struct(outA=[a], outB=b) 

455 task.run = run 

456 

457 inA = butlerTests.MetricsExample(data=[1, 2, 3]) 

458 inB = butlerTests.MetricsExample(data=[4, 0, 1]) 

459 result = task.run(inA, inB) 

460 

461 with self.assertRaises(AssertionError): 

462 assertValidOutput(task, result) 

463 

464 def testSkypixHandling(self): 

465 task = SkyPixTask() 

466 

467 dataId = {"htm7": 157227} # connection declares skypix, but Butler uses htm7 

468 data = butlerTests.MetricsExample(data=[1, 2, 3]) 

469 self.butler.put(data, "PixA", dataId) 

470 

471 quantum = makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "out"}}) 

472 run = runTestQuantum(task, self.butler, quantum, mockRun=True) 

473 

474 # PixA dataset should have been retrieved by runTestQuantum 

475 run.assert_called_once_with(a=data) 

476 

477 

478class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

479 pass 

480 

481 

482def setup_module(module): 

483 lsst.utils.tests.init() 

484 

485 

486if __name__ == "__main__": 486 ↛ 487line 486 didn't jump to line 487, because the condition on line 486 was never true

487 lsst.utils.tests.init() 

488 unittest.main()