Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Unit tests for `lsst.pipe.base.tests`, a library for testing 

23PipelineTask subclasses. 

24""" 

25 

26import shutil 

27import tempfile 

28import unittest 

29 

30import lsst.utils.tests 

31import lsst.pex.config 

32import lsst.daf.butler 

33import lsst.daf.butler.tests as butlerTests 

34 

35from lsst.pipe.base import Struct, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, connectionTypes 

36from lsst.pipe.base.testUtils import runTestQuantum, makeQuantum, assertValidOutput 

37 

38 

39class VisitConnections(PipelineTaskConnections, dimensions={"instrument", "visit"}): 

40 a = connectionTypes.Input( 

41 name="VisitA", 

42 storageClass="StructuredData", 

43 multiple=False, 

44 dimensions={"instrument", "visit"}, 

45 ) 

46 b = connectionTypes.Input( 

47 name="VisitB", 

48 storageClass="StructuredData", 

49 multiple=False, 

50 dimensions={"instrument", "visit"}, 

51 ) 

52 outA = connectionTypes.Output( 

53 name="VisitOutA", 

54 storageClass="StructuredData", 

55 multiple=False, 

56 dimensions={"instrument", "visit"}, 

57 ) 

58 outB = connectionTypes.Output( 

59 name="VisitOutB", 

60 storageClass="StructuredData", 

61 multiple=False, 

62 dimensions={"instrument", "visit"}, 

63 ) 

64 

65 

66class PatchConnections(PipelineTaskConnections, dimensions={"skymap", "tract"}): 

67 a = connectionTypes.Input( 

68 name="PatchA", 

69 storageClass="StructuredData", 

70 multiple=True, 

71 dimensions={"skymap", "tract", "patch"}, 

72 ) 

73 b = connectionTypes.PrerequisiteInput( 

74 name="PatchB", 

75 storageClass="StructuredData", 

76 multiple=False, 

77 dimensions={"skymap", "tract"}, 

78 ) 

79 out = connectionTypes.Output( 

80 name="PatchOut", 

81 storageClass="StructuredData", 

82 multiple=True, 

83 dimensions={"skymap", "tract", "patch"}, 

84 ) 

85 

86 def __init__(self, *, config=None): 

87 super().__init__(config=config) 

88 

89 if not config.doUseB: 

90 self.prerequisiteInputs.remove("b") 

91 

92 

93class SkyPixConnections(PipelineTaskConnections, dimensions={"skypix"}): 

94 a = connectionTypes.Input( 

95 name="PixA", 

96 storageClass="StructuredData", 

97 dimensions={"skypix"}, 

98 ) 

99 out = connectionTypes.Output( 

100 name="PixOut", 

101 storageClass="StructuredData", 

102 dimensions={"skypix"}, 

103 ) 

104 

105 

106class VisitConfig(PipelineTaskConfig, pipelineConnections=VisitConnections): 

107 pass 

108 

109 

110class PatchConfig(PipelineTaskConfig, pipelineConnections=PatchConnections): 

111 doUseB = lsst.pex.config.Field(default=True, dtype=bool, doc="") 

112 

113 

114class SkyPixConfig(PipelineTaskConfig, pipelineConnections=SkyPixConnections): 

115 pass 

116 

117 

118class VisitTask(PipelineTask): 

119 ConfigClass = VisitConfig 

120 _DefaultName = "visit" 

121 

122 def run(self, a, b): 

123 outA = butlerTests.MetricsExample(data=(a.data + b.data)) 

124 outB = butlerTests.MetricsExample(data=(a.data * max(b.data))) 

125 return Struct(outA=outA, outB=outB) 

126 

127 

128class PatchTask(PipelineTask): 

129 ConfigClass = PatchConfig 

130 _DefaultName = "patch" 

131 

132 def run(self, a, b=None): 

133 if self.config.doUseB: 

134 out = [butlerTests.MetricsExample(data=(oneA.data + b.data)) for oneA in a] 

135 else: 

136 out = a 

137 return Struct(out=out) 

138 

139 

140class SkyPixTask(PipelineTask): 

141 ConfigClass = SkyPixConfig 

142 _DefaultName = "skypix" 

143 

144 def run(self, a): 

145 return Struct(out=a) 

146 

147 

148class PipelineTaskTestSuite(lsst.utils.tests.TestCase): 

149 @classmethod 

150 def setUpClass(cls): 

151 super().setUpClass() 

152 # Repository should be re-created for each test case, but 

153 # this has a prohibitive run-time cost at present 

154 cls.root = tempfile.mkdtemp() 

155 

156 dataIds = { 

157 "instrument": ["notACam"], 

158 "physical_filter": ["k2020"], # needed for expandUniqueId(visit) 

159 "visit": [101, 102], 

160 "skymap": ["sky"], 

161 "tract": [42], 

162 "patch": [0, 1], 

163 } 

164 cls.repo = butlerTests.makeTestRepo(cls.root, dataIds) 

165 butlerTests.registerMetricsExample(cls.repo) 

166 

167 for typeName in {"VisitA", "VisitB", "VisitOutA", "VisitOutB"}: 

168 butlerTests.addDatasetType(cls.repo, typeName, {"instrument", "visit"}, "StructuredData") 

169 for typeName in {"PatchA", "PatchOut"}: 

170 butlerTests.addDatasetType(cls.repo, typeName, {"skymap", "tract", "patch"}, "StructuredData") 

171 butlerTests.addDatasetType(cls.repo, "PatchB", {"skymap", "tract"}, "StructuredData") 

172 for typeName in {"PixA", "PixOut"}: 

173 butlerTests.addDatasetType(cls.repo, typeName, {"htm7"}, "StructuredData") 

174 

175 @classmethod 

176 def tearDownClass(cls): 

177 shutil.rmtree(cls.root, ignore_errors=True) 

178 super().tearDownClass() 

179 

180 def setUp(self): 

181 super().setUp() 

182 self.butler = butlerTests.makeTestCollection(self.repo) 

183 

184 def _makeVisitTestData(self, dataId): 

185 """Create dummy datasets suitable for VisitTask. 

186 

187 This method updates ``self.butler`` directly. 

188 

189 Parameters 

190 ---------- 

191 dataId : any data ID type 

192 The (shared) ID for the datasets to create. 

193 

194 Returns 

195 ------- 

196 datasets : `dict` [`str`, `list`] 

197 A dictionary keyed by dataset type. Its values are the list of 

198 integers used to create each dataset. The datasets stored in the 

199 butler are `lsst.daf.butler.tests.MetricsExample` objects with 

200 these lists as their ``data`` argument, but the lists are easier 

201 to manipulate in test code. 

202 """ 

203 inA = [1, 2, 3] 

204 inB = [4, 0, 1] 

205 self.butler.put(butlerTests.MetricsExample(data=inA), "VisitA", dataId) 

206 self.butler.put(butlerTests.MetricsExample(data=inB), "VisitB", dataId) 

207 return {"VisitA": inA, "VisitB": inB, } 

208 

209 def _makePatchTestData(self, dataId): 

210 """Create dummy datasets suitable for PatchTask. 

211 

212 This method updates ``self.butler`` directly. 

213 

214 Parameters 

215 ---------- 

216 dataId : any data ID type 

217 The (shared) ID for the datasets to create. Any patch ID is 

218 overridden to create multiple datasets. 

219 

220 Returns 

221 ------- 

222 datasets : `dict` [`str`, `list` [`tuple` [data ID, `list`]]] 

223 A dictionary keyed by dataset type. Its values are the data ID 

224 of each dataset and the list of integers used to create each. The 

225 datasets stored in the butler are 

226 `lsst.daf.butler.tests.MetricsExample` objects with these lists as 

227 their ``data`` argument, but the lists are easier to manipulate 

228 in test code. 

229 """ 

230 inA = [1, 2, 3] 

231 inB = [4, 0, 1] 

232 datasets = {"PatchA": [], "PatchB": []} 

233 for patch in {0, 1}: 

234 self.butler.put(butlerTests.MetricsExample(data=(inA + [patch])), "PatchA", dataId, patch=patch) 

235 datasets["PatchA"].append((dict(dataId, patch=patch), inA + [patch])) 

236 self.butler.put(butlerTests.MetricsExample(data=inB), "PatchB", dataId) 

237 datasets["PatchB"].append((dataId, inB)) 

238 return datasets 

239 

240 def testMakeQuantumNoSuchDatatype(self): 

241 config = VisitConfig() 

242 config.connections.a = "Visit" 

243 task = VisitTask(config=config) 

244 

245 dataId = butlerTests.expandUniqueId(self.butler, {"visit": 102}) 

246 self._makeVisitTestData(dataId) 

247 

248 with self.assertRaises(ValueError): 

249 makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "b", "outA", "outB"}}) 

250 

251 def testMakeQuantumInvalidDimension(self): 

252 config = VisitConfig() 

253 config.connections.a = "PatchA" 

254 task = VisitTask(config=config) 

255 dataIdV = butlerTests.expandUniqueId(self.butler, {"visit": 102}) 

256 dataIdP = butlerTests.expandUniqueId(self.butler, {"patch": 0}) 

257 

258 inA = [1, 2, 3] 

259 inB = [4, 0, 1] 

260 self.butler.put(butlerTests.MetricsExample(data=inA), "VisitA", dataIdV) 

261 self.butler.put(butlerTests.MetricsExample(data=inA), "PatchA", dataIdP) 

262 self.butler.put(butlerTests.MetricsExample(data=inB), "VisitB", dataIdV) 

263 

264 with self.assertRaises(ValueError): 

265 makeQuantum(task, self.butler, dataIdV, { 

266 "a": dataIdP, 

267 "b": dataIdV, 

268 "outA": dataIdV, 

269 "outB": dataIdV, 

270 }) 

271 

272 def testMakeQuantumMissingMultiple(self): 

273 task = PatchTask() 

274 

275 dataId = butlerTests.expandUniqueId(self.butler, {"tract": 42}) 

276 self._makePatchTestData(dataId) 

277 

278 with self.assertRaises(ValueError): 

279 makeQuantum(task, self.butler, dataId, { 

280 "a": dict(dataId, patch=0), 

281 "b": dataId, 

282 "out": [dict(dataId, patch=patch) for patch in {0, 1}], 

283 }) 

284 

285 def testMakeQuantumExtraMultiple(self): 

286 task = PatchTask() 

287 

288 dataId = butlerTests.expandUniqueId(self.butler, {"tract": 42}) 

289 self._makePatchTestData(dataId) 

290 

291 with self.assertRaises(ValueError): 

292 makeQuantum(task, self.butler, dataId, { 

293 "a": [dict(dataId, patch=patch) for patch in {0, 1}], 

294 "b": [dataId], 

295 "out": [dict(dataId, patch=patch) for patch in {0, 1}], 

296 }) 

297 

298 def testMakeQuantumMissingDataId(self): 

299 task = VisitTask() 

300 

301 dataId = butlerTests.expandUniqueId(self.butler, {"visit": 102}) 

302 self._makeVisitTestData(dataId) 

303 

304 with self.assertRaises(ValueError): 

305 makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "outA", "outB"}}) 

306 with self.assertRaises(ValueError): 

307 makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "b", "outB"}}) 

308 

309 def testMakeQuantumCorruptedDataId(self): 

310 task = VisitTask() 

311 

312 dataId = butlerTests.expandUniqueId(self.butler, {"visit": 102}) 

313 self._makeVisitTestData(dataId) 

314 

315 with self.assertRaises(ValueError): 

316 # fourth argument should be a mapping keyed by component name 

317 makeQuantum(task, self.butler, dataId, dataId) 

318 

319 def testRunTestQuantumVisitWithRun(self): 

320 task = VisitTask() 

321 

322 dataId = butlerTests.expandUniqueId(self.butler, {"visit": 102}) 

323 data = self._makeVisitTestData(dataId) 

324 

325 quantum = makeQuantum(task, self.butler, dataId, 

326 {key: dataId for key in {"a", "b", "outA", "outB"}}) 

327 runTestQuantum(task, self.butler, quantum, mockRun=False) 

328 

329 # Can we use runTestQuantum to verify that task.run got called with 

330 # correct inputs/outputs? 

331 self.assertTrue(self.butler.datasetExists("VisitOutA", dataId)) 

332 self.assertEqual(self.butler.get("VisitOutA", dataId), 

333 butlerTests.MetricsExample(data=(data["VisitA"] + data["VisitB"]))) 

334 self.assertTrue(self.butler.datasetExists("VisitOutB", dataId)) 

335 self.assertEqual(self.butler.get("VisitOutB", dataId), 

336 butlerTests.MetricsExample(data=(data["VisitA"] * max(data["VisitB"])))) 

337 

338 def testRunTestQuantumPatchWithRun(self): 

339 task = PatchTask() 

340 

341 dataId = butlerTests.expandUniqueId(self.butler, {"tract": 42}) 

342 data = self._makePatchTestData(dataId) 

343 

344 quantum = makeQuantum(task, self.butler, dataId, { 

345 "a": [dataset[0] for dataset in data["PatchA"]], 

346 "b": dataId, 

347 "out": [dataset[0] for dataset in data["PatchA"]], 

348 }) 

349 runTestQuantum(task, self.butler, quantum, mockRun=False) 

350 

351 # Can we use runTestQuantum to verify that task.run got called with 

352 # correct inputs/outputs? 

353 inB = data["PatchB"][0][1] 

354 for dataset in data["PatchA"]: 

355 patchId = dataset[0] 

356 self.assertTrue(self.butler.datasetExists("PatchOut", patchId)) 

357 self.assertEqual(self.butler.get("PatchOut", patchId), 

358 butlerTests.MetricsExample(data=(dataset[1] + inB))) 

359 

360 def testRunTestQuantumVisitMockRun(self): 

361 task = VisitTask() 

362 

363 dataId = butlerTests.expandUniqueId(self.butler, {"visit": 102}) 

364 data = self._makeVisitTestData(dataId) 

365 

366 quantum = makeQuantum(task, self.butler, dataId, 

367 {key: dataId for key in {"a", "b", "outA", "outB"}}) 

368 run = runTestQuantum(task, self.butler, quantum, mockRun=True) 

369 

370 # Can we use the mock to verify that task.run got called with the 

371 # correct inputs? 

372 run.assert_called_once_with(a=butlerTests.MetricsExample(data=data["VisitA"]), 

373 b=butlerTests.MetricsExample(data=data["VisitB"])) 

374 

375 def testRunTestQuantumPatchMockRun(self): 

376 task = PatchTask() 

377 

378 dataId = butlerTests.expandUniqueId(self.butler, {"tract": 42}) 

379 data = self._makePatchTestData(dataId) 

380 

381 quantum = makeQuantum(task, self.butler, dataId, { 

382 # Use lists, not sets, to ensure order agrees with test assertion 

383 "a": [dataset[0] for dataset in data["PatchA"]], 

384 "b": dataId, 

385 "out": [dataset[0] for dataset in data["PatchA"]], 

386 }) 

387 run = runTestQuantum(task, self.butler, quantum, mockRun=True) 

388 

389 # Can we use the mock to verify that task.run got called with the 

390 # correct inputs? 

391 run.assert_called_once_with( 

392 a=[butlerTests.MetricsExample(data=dataset[1]) for dataset in data["PatchA"]], 

393 b=butlerTests.MetricsExample(data=data["PatchB"][0][1]) 

394 ) 

395 

396 def testRunTestQuantumPatchOptionalInput(self): 

397 config = PatchConfig() 

398 config.doUseB = False 

399 task = PatchTask(config=config) 

400 

401 dataId = butlerTests.expandUniqueId(self.butler, {"tract": 42}) 

402 data = self._makePatchTestData(dataId) 

403 

404 quantum = makeQuantum(task, self.butler, dataId, { 

405 # Use lists, not sets, to ensure order agrees with test assertion 

406 "a": [dataset[0] for dataset in data["PatchA"]], 

407 "out": [dataset[0] for dataset in data["PatchA"]], 

408 }) 

409 run = runTestQuantum(task, self.butler, quantum, mockRun=True) 

410 

411 # Can we use the mock to verify that task.run got called with the 

412 # correct inputs? 

413 run.assert_called_once_with( 

414 a=[butlerTests.MetricsExample(data=dataset[1]) for dataset in data["PatchA"]] 

415 ) 

416 

417 def testAssertValidOutputPass(self): 

418 task = VisitTask() 

419 

420 inA = butlerTests.MetricsExample(data=[1, 2, 3]) 

421 inB = butlerTests.MetricsExample(data=[4, 0, 1]) 

422 result = task.run(inA, inB) 

423 

424 # should not throw 

425 assertValidOutput(task, result) 

426 

427 def testAssertValidOutputMissing(self): 

428 task = VisitTask() 

429 

430 def run(a, b): 

431 return Struct(outA=a) 

432 task.run = run 

433 

434 inA = butlerTests.MetricsExample(data=[1, 2, 3]) 

435 inB = butlerTests.MetricsExample(data=[4, 0, 1]) 

436 result = task.run(inA, inB) 

437 

438 with self.assertRaises(AssertionError): 

439 assertValidOutput(task, result) 

440 

441 def testAssertValidOutputSingle(self): 

442 task = PatchTask() 

443 

444 def run(a, b): 

445 return Struct(out=b) 

446 task.run = run 

447 

448 inA = butlerTests.MetricsExample(data=[1, 2, 3]) 

449 inB = butlerTests.MetricsExample(data=[4, 0, 1]) 

450 result = task.run([inA], inB) 

451 

452 with self.assertRaises(AssertionError): 

453 assertValidOutput(task, result) 

454 

455 def testAssertValidOutputMultiple(self): 

456 task = VisitTask() 

457 

458 def run(a, b): 

459 return Struct(outA=[a], outB=b) 

460 task.run = run 

461 

462 inA = butlerTests.MetricsExample(data=[1, 2, 3]) 

463 inB = butlerTests.MetricsExample(data=[4, 0, 1]) 

464 result = task.run(inA, inB) 

465 

466 with self.assertRaises(AssertionError): 

467 assertValidOutput(task, result) 

468 

469 def testSkypixHandling(self): 

470 task = SkyPixTask() 

471 

472 dataId = {"htm7": 157227} # connection declares skypix, but Butler uses htm7 

473 data = butlerTests.MetricsExample(data=[1, 2, 3]) 

474 self.butler.put(data, "PixA", dataId) 

475 

476 quantum = makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "out"}}) 

477 run = runTestQuantum(task, self.butler, quantum, mockRun=True) 

478 

479 # PixA dataset should have been retrieved by runTestQuantum 

480 run.assert_called_once_with(a=data) 

481 

482 

483class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

484 pass 

485 

486 

487def setup_module(module): 

488 lsst.utils.tests.init() 

489 

490 

491if __name__ == "__main__": 491 ↛ 492line 491 didn't jump to line 492, because the condition on line 491 was never true

492 lsst.utils.tests.init() 

493 unittest.main()