Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Unit tests for `lsst.pipe.base.tests`, a library for testing 

23PipelineTask subclasses. 

24""" 

25 

26import shutil 

27import tempfile 

28import unittest 

29 

30import lsst.utils.tests 

31import lsst.pex.config 

32import lsst.daf.butler 

33import lsst.daf.butler.tests as butlerTests 

34 

35from lsst.pipe.base import Struct, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, connectionTypes 

36from lsst.pipe.base.testUtils import runTestQuantum, makeQuantum, assertValidOutput, \ 

37 assertValidInitOutput, getInitInputs 

38 

39 

40class VisitConnections(PipelineTaskConnections, dimensions={"instrument", "visit"}): 

41 initIn = connectionTypes.InitInput( 

42 name="VisitInitIn", 

43 storageClass="StructuredData", 

44 multiple=False, 

45 ) 

46 a = connectionTypes.Input( 

47 name="VisitA", 

48 storageClass="StructuredData", 

49 multiple=False, 

50 dimensions={"instrument", "visit"}, 

51 ) 

52 b = connectionTypes.Input( 

53 name="VisitB", 

54 storageClass="StructuredData", 

55 multiple=False, 

56 dimensions={"instrument", "visit"}, 

57 ) 

58 initOut = connectionTypes.InitOutput( 

59 name="VisitInitOut", 

60 storageClass="StructuredData", 

61 multiple=True, 

62 ) 

63 outA = connectionTypes.Output( 

64 name="VisitOutA", 

65 storageClass="StructuredData", 

66 multiple=False, 

67 dimensions={"instrument", "visit"}, 

68 ) 

69 outB = connectionTypes.Output( 

70 name="VisitOutB", 

71 storageClass="StructuredData", 

72 multiple=False, 

73 dimensions={"instrument", "visit"}, 

74 ) 

75 

76 def __init__(self, *, config=None): 

77 super().__init__(config=config) 

78 

79 if not config.doUseInitIn: 

80 self.initInputs.remove("initIn") 

81 

82 

83class PatchConnections(PipelineTaskConnections, dimensions={"skymap", "tract"}): 

84 a = connectionTypes.Input( 

85 name="PatchA", 

86 storageClass="StructuredData", 

87 multiple=True, 

88 dimensions={"skymap", "tract", "patch"}, 

89 ) 

90 b = connectionTypes.PrerequisiteInput( 

91 name="PatchB", 

92 storageClass="StructuredData", 

93 multiple=False, 

94 dimensions={"skymap", "tract"}, 

95 ) 

96 initOutA = connectionTypes.InitOutput( 

97 name="PatchInitOutA", 

98 storageClass="StructuredData", 

99 multiple=False, 

100 ) 

101 initOutB = connectionTypes.InitOutput( 

102 name="PatchInitOutB", 

103 storageClass="StructuredData", 

104 multiple=False, 

105 ) 

106 out = connectionTypes.Output( 

107 name="PatchOut", 

108 storageClass="StructuredData", 

109 multiple=True, 

110 dimensions={"skymap", "tract", "patch"}, 

111 ) 

112 

113 def __init__(self, *, config=None): 

114 super().__init__(config=config) 

115 

116 if not config.doUseB: 

117 self.prerequisiteInputs.remove("b") 

118 

119 

120class SkyPixConnections(PipelineTaskConnections, dimensions={"skypix"}): 

121 a = connectionTypes.Input( 

122 name="PixA", 

123 storageClass="StructuredData", 

124 dimensions={"skypix"}, 

125 ) 

126 out = connectionTypes.Output( 

127 name="PixOut", 

128 storageClass="StructuredData", 

129 dimensions={"skypix"}, 

130 ) 

131 

132 

133class VisitConfig(PipelineTaskConfig, pipelineConnections=VisitConnections): 

134 doUseInitIn = lsst.pex.config.Field(default=False, dtype=bool, doc="") 

135 

136 

137class PatchConfig(PipelineTaskConfig, pipelineConnections=PatchConnections): 

138 doUseB = lsst.pex.config.Field(default=True, dtype=bool, doc="") 

139 

140 

141class SkyPixConfig(PipelineTaskConfig, pipelineConnections=SkyPixConnections): 

142 pass 

143 

144 

145class VisitTask(PipelineTask): 

146 ConfigClass = VisitConfig 

147 _DefaultName = "visit" 

148 

149 def __init__(self, initInputs=None, **kwargs): 

150 super().__init__(initInputs=initInputs, **kwargs) 

151 self.initOut = [ 

152 butlerTests.MetricsExample(data=[1, 2]), 

153 butlerTests.MetricsExample(data=[3, 4]), 

154 ] 

155 

156 def run(self, a, b): 

157 outA = butlerTests.MetricsExample(data=(a.data + b.data)) 

158 outB = butlerTests.MetricsExample(data=(a.data * max(b.data))) 

159 return Struct(outA=outA, outB=outB) 

160 

161 

162class PatchTask(PipelineTask): 

163 ConfigClass = PatchConfig 

164 _DefaultName = "patch" 

165 

166 def __init__(self, **kwargs): 

167 super().__init__(**kwargs) 

168 self.initOutA = butlerTests.MetricsExample(data=[1, 2, 4]) 

169 self.initOutB = butlerTests.MetricsExample(data=[1, 2, 3]) 

170 

171 def run(self, a, b=None): 

172 if self.config.doUseB: 

173 out = [butlerTests.MetricsExample(data=(oneA.data + b.data)) for oneA in a] 

174 else: 

175 out = a 

176 return Struct(out=out) 

177 

178 

179class SkyPixTask(PipelineTask): 

180 ConfigClass = SkyPixConfig 

181 _DefaultName = "skypix" 

182 

183 def run(self, a): 

184 return Struct(out=a) 

185 

186 

187class PipelineTaskTestSuite(lsst.utils.tests.TestCase): 

188 @classmethod 

189 def setUpClass(cls): 

190 super().setUpClass() 

191 # Repository should be re-created for each test case, but 

192 # this has a prohibitive run-time cost at present 

193 cls.root = tempfile.mkdtemp() 

194 

195 cls.repo = butlerTests.makeTestRepo(cls.root) 

196 butlerTests.addDataIdValue(cls.repo, "instrument", "notACam") 

197 butlerTests.addDataIdValue(cls.repo, "visit", 101) 

198 butlerTests.addDataIdValue(cls.repo, "visit", 102) 

199 butlerTests.addDataIdValue(cls.repo, "skymap", "sky") 

200 butlerTests.addDataIdValue(cls.repo, "tract", 42) 

201 butlerTests.addDataIdValue(cls.repo, "patch", 0) 

202 butlerTests.addDataIdValue(cls.repo, "patch", 1) 

203 butlerTests.registerMetricsExample(cls.repo) 

204 

205 for typeName in {"VisitA", "VisitB", "VisitOutA", "VisitOutB"}: 

206 butlerTests.addDatasetType(cls.repo, typeName, {"instrument", "visit"}, "StructuredData") 

207 for typeName in {"PatchA", "PatchOut"}: 

208 butlerTests.addDatasetType(cls.repo, typeName, {"skymap", "tract", "patch"}, "StructuredData") 

209 butlerTests.addDatasetType(cls.repo, "PatchB", {"skymap", "tract"}, "StructuredData") 

210 for typeName in {"PixA", "PixOut"}: 

211 butlerTests.addDatasetType(cls.repo, typeName, {"htm7"}, "StructuredData") 

212 butlerTests.addDatasetType(cls.repo, "VisitInitIn", set(), "StructuredData") 

213 

214 @classmethod 

215 def tearDownClass(cls): 

216 shutil.rmtree(cls.root, ignore_errors=True) 

217 super().tearDownClass() 

218 

219 def setUp(self): 

220 super().setUp() 

221 self.butler = butlerTests.makeTestCollection(self.repo) 

222 

223 def _makeVisitTestData(self, dataId): 

224 """Create dummy datasets suitable for VisitTask. 

225 

226 This method updates ``self.butler`` directly. 

227 

228 Parameters 

229 ---------- 

230 dataId : any data ID type 

231 The (shared) ID for the datasets to create. 

232 

233 Returns 

234 ------- 

235 datasets : `dict` [`str`, `list`] 

236 A dictionary keyed by dataset type. Its values are the list of 

237 integers used to create each dataset. The datasets stored in the 

238 butler are `lsst.daf.butler.tests.MetricsExample` objects with 

239 these lists as their ``data`` argument, but the lists are easier 

240 to manipulate in test code. 

241 """ 

242 inInit = [4, 2] 

243 inA = [1, 2, 3] 

244 inB = [4, 0, 1] 

245 self.butler.put(butlerTests.MetricsExample(data=inA), "VisitA", dataId) 

246 self.butler.put(butlerTests.MetricsExample(data=inB), "VisitB", dataId) 

247 self.butler.put(butlerTests.MetricsExample(data=inInit), "VisitInitIn", set()) 

248 return {"VisitA": inA, "VisitB": inB, "VisitInitIn": inInit, } 

249 

250 def _makePatchTestData(self, dataId): 

251 """Create dummy datasets suitable for PatchTask. 

252 

253 This method updates ``self.butler`` directly. 

254 

255 Parameters 

256 ---------- 

257 dataId : any data ID type 

258 The (shared) ID for the datasets to create. Any patch ID is 

259 overridden to create multiple datasets. 

260 

261 Returns 

262 ------- 

263 datasets : `dict` [`str`, `list` [`tuple` [data ID, `list`]]] 

264 A dictionary keyed by dataset type. Its values are the data ID 

265 of each dataset and the list of integers used to create each. The 

266 datasets stored in the butler are 

267 `lsst.daf.butler.tests.MetricsExample` objects with these lists as 

268 their ``data`` argument, but the lists are easier to manipulate 

269 in test code. 

270 """ 

271 inA = [1, 2, 3] 

272 inB = [4, 0, 1] 

273 datasets = {"PatchA": [], "PatchB": []} 

274 for patch in {0, 1}: 

275 self.butler.put(butlerTests.MetricsExample(data=(inA + [patch])), "PatchA", dataId, patch=patch) 

276 datasets["PatchA"].append((dict(dataId, patch=patch), inA + [patch])) 

277 self.butler.put(butlerTests.MetricsExample(data=inB), "PatchB", dataId) 

278 datasets["PatchB"].append((dataId, inB)) 

279 return datasets 

280 

281 def testMakeQuantumNoSuchDatatype(self): 

282 config = VisitConfig() 

283 config.connections.a = "Visit" 

284 task = VisitTask(config=config) 

285 

286 dataId = {"instrument": "notACam", "visit": 102} 

287 self._makeVisitTestData(dataId) 

288 

289 with self.assertRaises(ValueError): 

290 makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "b", "outA", "outB"}}) 

291 

292 def testMakeQuantumInvalidDimension(self): 

293 config = VisitConfig() 

294 config.connections.a = "PatchA" 

295 task = VisitTask(config=config) 

296 dataIdV = {"instrument": "notACam", "visit": 102} 

297 dataIdP = {"skymap": "sky", "tract": 42, "patch": 0} 

298 

299 inA = [1, 2, 3] 

300 inB = [4, 0, 1] 

301 self.butler.put(butlerTests.MetricsExample(data=inA), "VisitA", dataIdV) 

302 self.butler.put(butlerTests.MetricsExample(data=inA), "PatchA", dataIdP) 

303 self.butler.put(butlerTests.MetricsExample(data=inB), "VisitB", dataIdV) 

304 

305 with self.assertRaises(ValueError): 

306 makeQuantum(task, self.butler, dataIdV, { 

307 "a": dataIdP, 

308 "b": dataIdV, 

309 "outA": dataIdV, 

310 "outB": dataIdV, 

311 }) 

312 

313 def testMakeQuantumMissingMultiple(self): 

314 task = PatchTask() 

315 

316 dataId = {"skymap": "sky", "tract": 42} 

317 self._makePatchTestData(dataId) 

318 

319 with self.assertRaises(ValueError): 

320 makeQuantum(task, self.butler, dataId, { 

321 "a": dict(dataId, patch=0), 

322 "b": dataId, 

323 "out": [dict(dataId, patch=patch) for patch in {0, 1}], 

324 }) 

325 

326 def testMakeQuantumExtraMultiple(self): 

327 task = PatchTask() 

328 

329 dataId = {"skymap": "sky", "tract": 42} 

330 self._makePatchTestData(dataId) 

331 

332 with self.assertRaises(ValueError): 

333 makeQuantum(task, self.butler, dataId, { 

334 "a": [dict(dataId, patch=patch) for patch in {0, 1}], 

335 "b": [dataId], 

336 "out": [dict(dataId, patch=patch) for patch in {0, 1}], 

337 }) 

338 

339 def testMakeQuantumMissingDataId(self): 

340 task = VisitTask() 

341 

342 dataId = {"instrument": "notACam", "visit": 102} 

343 self._makeVisitTestData(dataId) 

344 

345 with self.assertRaises(ValueError): 

346 makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "outA", "outB"}}) 

347 with self.assertRaises(ValueError): 

348 makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "b", "outB"}}) 

349 

350 def testMakeQuantumCorruptedDataId(self): 

351 task = VisitTask() 

352 

353 dataId = {"instrument": "notACam", "visit": 102} 

354 self._makeVisitTestData(dataId) 

355 

356 with self.assertRaises(ValueError): 

357 # fourth argument should be a mapping keyed by component name 

358 makeQuantum(task, self.butler, dataId, dataId) 

359 

360 def testRunTestQuantumVisitWithRun(self): 

361 task = VisitTask() 

362 

363 dataId = {"instrument": "notACam", "visit": 102} 

364 data = self._makeVisitTestData(dataId) 

365 

366 quantum = makeQuantum(task, self.butler, dataId, 

367 {key: dataId for key in {"a", "b", "outA", "outB"}}) 

368 runTestQuantum(task, self.butler, quantum, mockRun=False) 

369 

370 # Can we use runTestQuantum to verify that task.run got called with 

371 # correct inputs/outputs? 

372 self.assertTrue(self.butler.datasetExists("VisitOutA", dataId)) 

373 self.assertEqual(self.butler.get("VisitOutA", dataId), 

374 butlerTests.MetricsExample(data=(data["VisitA"] + data["VisitB"]))) 

375 self.assertTrue(self.butler.datasetExists("VisitOutB", dataId)) 

376 self.assertEqual(self.butler.get("VisitOutB", dataId), 

377 butlerTests.MetricsExample(data=(data["VisitA"] * max(data["VisitB"])))) 

378 

379 def testRunTestQuantumPatchWithRun(self): 

380 task = PatchTask() 

381 

382 dataId = {"skymap": "sky", "tract": 42} 

383 data = self._makePatchTestData(dataId) 

384 

385 quantum = makeQuantum(task, self.butler, dataId, { 

386 "a": [dataset[0] for dataset in data["PatchA"]], 

387 "b": dataId, 

388 "out": [dataset[0] for dataset in data["PatchA"]], 

389 }) 

390 runTestQuantum(task, self.butler, quantum, mockRun=False) 

391 

392 # Can we use runTestQuantum to verify that task.run got called with 

393 # correct inputs/outputs? 

394 inB = data["PatchB"][0][1] 

395 for dataset in data["PatchA"]: 

396 patchId = dataset[0] 

397 self.assertTrue(self.butler.datasetExists("PatchOut", patchId)) 

398 self.assertEqual(self.butler.get("PatchOut", patchId), 

399 butlerTests.MetricsExample(data=(dataset[1] + inB))) 

400 

401 def testRunTestQuantumVisitMockRun(self): 

402 task = VisitTask() 

403 

404 dataId = {"instrument": "notACam", "visit": 102} 

405 data = self._makeVisitTestData(dataId) 

406 

407 quantum = makeQuantum(task, self.butler, dataId, 

408 {key: dataId for key in {"a", "b", "outA", "outB"}}) 

409 run = runTestQuantum(task, self.butler, quantum, mockRun=True) 

410 

411 # Can we use the mock to verify that task.run got called with the 

412 # correct inputs? 

413 run.assert_called_once_with(a=butlerTests.MetricsExample(data=data["VisitA"]), 

414 b=butlerTests.MetricsExample(data=data["VisitB"])) 

415 

416 def testRunTestQuantumPatchMockRun(self): 

417 task = PatchTask() 

418 

419 dataId = {"skymap": "sky", "tract": 42} 

420 data = self._makePatchTestData(dataId) 

421 

422 quantum = makeQuantum(task, self.butler, dataId, { 

423 # Use lists, not sets, to ensure order agrees with test assertion 

424 "a": [dataset[0] for dataset in data["PatchA"]], 

425 "b": dataId, 

426 "out": [dataset[0] for dataset in data["PatchA"]], 

427 }) 

428 run = runTestQuantum(task, self.butler, quantum, mockRun=True) 

429 

430 # Can we use the mock to verify that task.run got called with the 

431 # correct inputs? 

432 run.assert_called_once_with( 

433 a=[butlerTests.MetricsExample(data=dataset[1]) for dataset in data["PatchA"]], 

434 b=butlerTests.MetricsExample(data=data["PatchB"][0][1]) 

435 ) 

436 

437 def testRunTestQuantumPatchOptionalInput(self): 

438 config = PatchConfig() 

439 config.doUseB = False 

440 task = PatchTask(config=config) 

441 

442 dataId = {"skymap": "sky", "tract": 42} 

443 data = self._makePatchTestData(dataId) 

444 

445 quantum = makeQuantum(task, self.butler, dataId, { 

446 # Use lists, not sets, to ensure order agrees with test assertion 

447 "a": [dataset[0] for dataset in data["PatchA"]], 

448 "out": [dataset[0] for dataset in data["PatchA"]], 

449 }) 

450 run = runTestQuantum(task, self.butler, quantum, mockRun=True) 

451 

452 # Can we use the mock to verify that task.run got called with the 

453 # correct inputs? 

454 run.assert_called_once_with( 

455 a=[butlerTests.MetricsExample(data=dataset[1]) for dataset in data["PatchA"]] 

456 ) 

457 

458 def testAssertValidOutputPass(self): 

459 task = VisitTask() 

460 

461 inA = butlerTests.MetricsExample(data=[1, 2, 3]) 

462 inB = butlerTests.MetricsExample(data=[4, 0, 1]) 

463 result = task.run(inA, inB) 

464 

465 # should not throw 

466 assertValidOutput(task, result) 

467 

468 def testAssertValidOutputMissing(self): 

469 task = VisitTask() 

470 

471 def run(a, b): 

472 return Struct(outA=a) 

473 task.run = run 

474 

475 inA = butlerTests.MetricsExample(data=[1, 2, 3]) 

476 inB = butlerTests.MetricsExample(data=[4, 0, 1]) 

477 result = task.run(inA, inB) 

478 

479 with self.assertRaises(AssertionError): 

480 assertValidOutput(task, result) 

481 

482 def testAssertValidOutputSingle(self): 

483 task = PatchTask() 

484 

485 def run(a, b): 

486 return Struct(out=b) 

487 task.run = run 

488 

489 inA = butlerTests.MetricsExample(data=[1, 2, 3]) 

490 inB = butlerTests.MetricsExample(data=[4, 0, 1]) 

491 result = task.run([inA], inB) 

492 

493 with self.assertRaises(AssertionError): 

494 assertValidOutput(task, result) 

495 

496 def testAssertValidOutputMultiple(self): 

497 task = VisitTask() 

498 

499 def run(a, b): 

500 return Struct(outA=[a], outB=b) 

501 task.run = run 

502 

503 inA = butlerTests.MetricsExample(data=[1, 2, 3]) 

504 inB = butlerTests.MetricsExample(data=[4, 0, 1]) 

505 result = task.run(inA, inB) 

506 

507 with self.assertRaises(AssertionError): 

508 assertValidOutput(task, result) 

509 

510 def testAssertValidInitOutputPass(self): 

511 task = VisitTask() 

512 # should not throw 

513 assertValidInitOutput(task) 

514 

515 task = PatchTask() 

516 # should not throw 

517 assertValidInitOutput(task) 

518 

519 def testAssertValidInitOutputMissing(self): 

520 class BadVisitTask(VisitTask): 

521 def __init__(self, **kwargs): 

522 PipelineTask.__init__(self, **kwargs) # Bypass VisitTask constructor 

523 pass # do not set fields 

524 

525 task = BadVisitTask() 

526 

527 with self.assertRaises(AssertionError): 

528 assertValidInitOutput(task) 

529 

530 def testAssertValidInitOutputSingle(self): 

531 class BadVisitTask(VisitTask): 

532 def __init__(self, **kwargs): 

533 PipelineTask.__init__(self, **kwargs) # Bypass VisitTask constructor 

534 self.initOut = butlerTests.MetricsExample(data=[1, 2]) 

535 

536 task = BadVisitTask() 

537 

538 with self.assertRaises(AssertionError): 

539 assertValidInitOutput(task) 

540 

541 def testAssertValidInitOutputMultiple(self): 

542 class BadPatchTask(PatchTask): 

543 def __init__(self, **kwargs): 

544 PipelineTask.__init__(self, **kwargs) # Bypass PatchTask constructor 

545 self.initOutA = [butlerTests.MetricsExample(data=[1, 2, 4])] 

546 self.initOutB = butlerTests.MetricsExample(data=[1, 2, 3]) 

547 

548 task = BadPatchTask() 

549 

550 with self.assertRaises(AssertionError): 

551 assertValidInitOutput(task) 

552 

553 def testGetInitInputs(self): 

554 dataId = {"instrument": "notACam", "visit": 102} 

555 data = self._makeVisitTestData(dataId) 

556 

557 self.assertEqual(getInitInputs(self.butler, VisitConfig()), {}) 

558 

559 config = VisitConfig() 

560 config.doUseInitIn = True 

561 self.assertEqual(getInitInputs(self.butler, config), 

562 {"initIn": butlerTests.MetricsExample(data=data["VisitInitIn"])}) 

563 

564 def testSkypixHandling(self): 

565 task = SkyPixTask() 

566 

567 dataId = {"htm7": 157227} # connection declares skypix, but Butler uses htm7 

568 data = butlerTests.MetricsExample(data=[1, 2, 3]) 

569 self.butler.put(data, "PixA", dataId) 

570 

571 quantum = makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "out"}}) 

572 run = runTestQuantum(task, self.butler, quantum, mockRun=True) 

573 

574 # PixA dataset should have been retrieved by runTestQuantum 

575 run.assert_called_once_with(a=data) 

576 

577 

578class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

579 pass 

580 

581 

582def setup_module(module): 

583 lsst.utils.tests.init() 

584 

585 

586if __name__ == "__main__": 586 ↛ 587line 586 didn't jump to line 587, because the condition on line 586 was never true

587 lsst.utils.tests.init() 

588 unittest.main()