Coverage for tests/test_testUtils.py: 25%
316 statements
« prev ^ index » next coverage.py v6.4.4, created at 2022-08-26 02:35 -0700
« prev ^ index » next coverage.py v6.4.4, created at 2022-08-26 02:35 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""Unit tests for `lsst.pipe.base.tests`, a library for testing
23PipelineTask subclasses.
24"""
26import shutil
27import tempfile
28import unittest
30import lsst.daf.butler
31import lsst.daf.butler.tests as butlerTests
32import lsst.pex.config
33import lsst.utils.tests
34from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct, connectionTypes
35from lsst.pipe.base.testUtils import (
36 assertValidInitOutput,
37 assertValidOutput,
38 getInitInputs,
39 lintConnections,
40 makeQuantum,
41 runTestQuantum,
42)
45class VisitConnections(PipelineTaskConnections, dimensions={"instrument", "visit"}):
46 initIn = connectionTypes.InitInput(
47 name="VisitInitIn",
48 storageClass="StructuredData",
49 multiple=False,
50 )
51 a = connectionTypes.Input(
52 name="VisitA",
53 storageClass="StructuredData",
54 multiple=False,
55 dimensions={"instrument", "visit"},
56 )
57 b = connectionTypes.Input(
58 name="VisitB",
59 storageClass="StructuredData",
60 multiple=False,
61 dimensions={"instrument", "visit"},
62 )
63 initOut = connectionTypes.InitOutput(
64 name="VisitInitOut",
65 storageClass="StructuredData",
66 multiple=True,
67 )
68 outA = connectionTypes.Output(
69 name="VisitOutA",
70 storageClass="StructuredData",
71 multiple=False,
72 dimensions={"instrument", "visit"},
73 )
74 outB = connectionTypes.Output(
75 name="VisitOutB",
76 storageClass="StructuredData",
77 multiple=False,
78 dimensions={"instrument", "visit"},
79 )
81 def __init__(self, *, config=None):
82 super().__init__(config=config)
84 if not config.doUseInitIn:
85 self.initInputs.remove("initIn")
88class PatchConnections(PipelineTaskConnections, dimensions={"skymap", "tract"}):
89 a = connectionTypes.Input(
90 name="PatchA",
91 storageClass="StructuredData",
92 multiple=True,
93 dimensions={"skymap", "tract", "patch"},
94 )
95 b = connectionTypes.PrerequisiteInput(
96 name="PatchB",
97 storageClass="StructuredData",
98 multiple=False,
99 dimensions={"skymap", "tract"},
100 )
101 initOutA = connectionTypes.InitOutput(
102 name="PatchInitOutA",
103 storageClass="StructuredData",
104 multiple=False,
105 )
106 initOutB = connectionTypes.InitOutput(
107 name="PatchInitOutB",
108 storageClass="StructuredData",
109 multiple=False,
110 )
111 out = connectionTypes.Output(
112 name="PatchOut",
113 storageClass="StructuredData",
114 multiple=True,
115 dimensions={"skymap", "tract", "patch"},
116 )
118 def __init__(self, *, config=None):
119 super().__init__(config=config)
121 if not config.doUseB:
122 self.prerequisiteInputs.remove("b")
125class SkyPixConnections(PipelineTaskConnections, dimensions={"skypix"}):
126 a = connectionTypes.Input(
127 name="PixA",
128 storageClass="StructuredData",
129 dimensions={"skypix"},
130 )
131 out = connectionTypes.Output(
132 name="PixOut",
133 storageClass="StructuredData",
134 dimensions={"skypix"},
135 )
138class VisitConfig(PipelineTaskConfig, pipelineConnections=VisitConnections):
139 doUseInitIn = lsst.pex.config.Field(default=False, dtype=bool, doc="")
142class PatchConfig(PipelineTaskConfig, pipelineConnections=PatchConnections):
143 doUseB = lsst.pex.config.Field(default=True, dtype=bool, doc="")
146class SkyPixConfig(PipelineTaskConfig, pipelineConnections=SkyPixConnections):
147 pass
150class VisitTask(PipelineTask):
151 ConfigClass = VisitConfig
152 _DefaultName = "visit"
154 def __init__(self, initInputs=None, **kwargs):
155 super().__init__(initInputs=initInputs, **kwargs)
156 self.initOut = [
157 butlerTests.MetricsExample(data=[1, 2]),
158 butlerTests.MetricsExample(data=[3, 4]),
159 ]
161 def run(self, a, b):
162 outA = butlerTests.MetricsExample(data=(a.data + b.data))
163 outB = butlerTests.MetricsExample(data=(a.data * max(b.data)))
164 return Struct(outA=outA, outB=outB)
167class PatchTask(PipelineTask):
168 ConfigClass = PatchConfig
169 _DefaultName = "patch"
171 def __init__(self, **kwargs):
172 super().__init__(**kwargs)
173 self.initOutA = butlerTests.MetricsExample(data=[1, 2, 4])
174 self.initOutB = butlerTests.MetricsExample(data=[1, 2, 3])
176 def run(self, a, b=None):
177 if self.config.doUseB:
178 out = [butlerTests.MetricsExample(data=(oneA.data + b.data)) for oneA in a]
179 else:
180 out = a
181 return Struct(out=out)
184class SkyPixTask(PipelineTask):
185 ConfigClass = SkyPixConfig
186 _DefaultName = "skypix"
188 def run(self, a):
189 return Struct(out=a)
192class PipelineTaskTestSuite(lsst.utils.tests.TestCase):
193 @classmethod
194 def setUpClass(cls):
195 super().setUpClass()
196 # Repository should be re-created for each test case, but
197 # this has a prohibitive run-time cost at present
198 cls.root = tempfile.mkdtemp()
200 cls.repo = butlerTests.makeTestRepo(cls.root)
201 butlerTests.addDataIdValue(cls.repo, "instrument", "notACam")
202 butlerTests.addDataIdValue(cls.repo, "visit", 101)
203 butlerTests.addDataIdValue(cls.repo, "visit", 102)
204 butlerTests.addDataIdValue(cls.repo, "skymap", "sky")
205 butlerTests.addDataIdValue(cls.repo, "tract", 42)
206 butlerTests.addDataIdValue(cls.repo, "patch", 0)
207 butlerTests.addDataIdValue(cls.repo, "patch", 1)
208 butlerTests.registerMetricsExample(cls.repo)
210 for typeName in {"VisitA", "VisitB", "VisitOutA", "VisitOutB"}:
211 butlerTests.addDatasetType(cls.repo, typeName, {"instrument", "visit"}, "StructuredData")
212 for typeName in {"PatchA", "PatchOut"}:
213 butlerTests.addDatasetType(cls.repo, typeName, {"skymap", "tract", "patch"}, "StructuredData")
214 butlerTests.addDatasetType(cls.repo, "PatchB", {"skymap", "tract"}, "StructuredData")
215 for typeName in {"PixA", "PixOut"}:
216 butlerTests.addDatasetType(cls.repo, typeName, {"htm7"}, "StructuredData")
217 butlerTests.addDatasetType(cls.repo, "VisitInitIn", set(), "StructuredData")
219 @classmethod
220 def tearDownClass(cls):
221 shutil.rmtree(cls.root, ignore_errors=True)
222 super().tearDownClass()
224 def setUp(self):
225 super().setUp()
226 self.butler = butlerTests.makeTestCollection(self.repo)
228 def _makeVisitTestData(self, dataId):
229 """Create dummy datasets suitable for VisitTask.
231 This method updates ``self.butler`` directly.
233 Parameters
234 ----------
235 dataId : any data ID type
236 The (shared) ID for the datasets to create.
238 Returns
239 -------
240 datasets : `dict` [`str`, `list`]
241 A dictionary keyed by dataset type. Its values are the list of
242 integers used to create each dataset. The datasets stored in the
243 butler are `lsst.daf.butler.tests.MetricsExample` objects with
244 these lists as their ``data`` argument, but the lists are easier
245 to manipulate in test code.
246 """
247 inInit = [4, 2]
248 inA = [1, 2, 3]
249 inB = [4, 0, 1]
250 self.butler.put(butlerTests.MetricsExample(data=inA), "VisitA", dataId)
251 self.butler.put(butlerTests.MetricsExample(data=inB), "VisitB", dataId)
252 self.butler.put(butlerTests.MetricsExample(data=inInit), "VisitInitIn", set())
253 return {
254 "VisitA": inA,
255 "VisitB": inB,
256 "VisitInitIn": inInit,
257 }
259 def _makePatchTestData(self, dataId):
260 """Create dummy datasets suitable for PatchTask.
262 This method updates ``self.butler`` directly.
264 Parameters
265 ----------
266 dataId : any data ID type
267 The (shared) ID for the datasets to create. Any patch ID is
268 overridden to create multiple datasets.
270 Returns
271 -------
272 datasets : `dict` [`str`, `list` [`tuple` [data ID, `list`]]]
273 A dictionary keyed by dataset type. Its values are the data ID
274 of each dataset and the list of integers used to create each. The
275 datasets stored in the butler are
276 `lsst.daf.butler.tests.MetricsExample` objects with these lists as
277 their ``data`` argument, but the lists are easier to manipulate
278 in test code.
279 """
280 inA = [1, 2, 3]
281 inB = [4, 0, 1]
282 datasets = {"PatchA": [], "PatchB": []}
283 for patch in {0, 1}:
284 self.butler.put(butlerTests.MetricsExample(data=(inA + [patch])), "PatchA", dataId, patch=patch)
285 datasets["PatchA"].append((dict(dataId, patch=patch), inA + [patch]))
286 self.butler.put(butlerTests.MetricsExample(data=inB), "PatchB", dataId)
287 datasets["PatchB"].append((dataId, inB))
288 return datasets
290 def testMakeQuantumNoSuchDatatype(self):
291 config = VisitConfig()
292 config.connections.a = "Visit"
293 task = VisitTask(config=config)
295 dataId = {"instrument": "notACam", "visit": 102}
296 self._makeVisitTestData(dataId)
298 with self.assertRaises(ValueError):
299 makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "b", "outA", "outB"}})
301 def testMakeQuantumInvalidDimension(self):
302 config = VisitConfig()
303 config.connections.a = "PatchA"
304 task = VisitTask(config=config)
305 dataIdV = {"instrument": "notACam", "visit": 102}
306 dataIdVExtra = {"instrument": "notACam", "visit": 102, "detector": 42}
307 dataIdP = {"skymap": "sky", "tract": 42, "patch": 0}
309 inA = [1, 2, 3]
310 inB = [4, 0, 1]
311 self.butler.put(butlerTests.MetricsExample(data=inA), "VisitA", dataIdV)
312 self.butler.put(butlerTests.MetricsExample(data=inA), "PatchA", dataIdP)
313 self.butler.put(butlerTests.MetricsExample(data=inB), "VisitB", dataIdV)
315 # dataIdV is correct everywhere, dataIdP should error
316 with self.assertRaises(ValueError):
317 makeQuantum(
318 task,
319 self.butler,
320 dataIdV,
321 {
322 "a": dataIdP,
323 "b": dataIdV,
324 "outA": dataIdV,
325 "outB": dataIdV,
326 },
327 )
328 with self.assertRaises(ValueError):
329 makeQuantum(
330 task,
331 self.butler,
332 dataIdP,
333 {
334 "a": dataIdV,
335 "b": dataIdV,
336 "outA": dataIdV,
337 "outB": dataIdV,
338 },
339 )
340 # should not accept small changes, either
341 with self.assertRaises(ValueError):
342 makeQuantum(
343 task,
344 self.butler,
345 dataIdV,
346 {
347 "a": dataIdV,
348 "b": dataIdV,
349 "outA": dataIdVExtra,
350 "outB": dataIdV,
351 },
352 )
353 with self.assertRaises(ValueError):
354 makeQuantum(
355 task,
356 self.butler,
357 dataIdVExtra,
358 {
359 "a": dataIdV,
360 "b": dataIdV,
361 "outA": dataIdV,
362 "outB": dataIdV,
363 },
364 )
366 def testMakeQuantumMissingMultiple(self):
367 task = PatchTask()
369 dataId = {"skymap": "sky", "tract": 42}
370 self._makePatchTestData(dataId)
372 with self.assertRaises(ValueError):
373 makeQuantum(
374 task,
375 self.butler,
376 dataId,
377 {
378 "a": dict(dataId, patch=0),
379 "b": dataId,
380 "out": [dict(dataId, patch=patch) for patch in {0, 1}],
381 },
382 )
384 def testMakeQuantumExtraMultiple(self):
385 task = PatchTask()
387 dataId = {"skymap": "sky", "tract": 42}
388 self._makePatchTestData(dataId)
390 with self.assertRaises(ValueError):
391 makeQuantum(
392 task,
393 self.butler,
394 dataId,
395 {
396 "a": [dict(dataId, patch=patch) for patch in {0, 1}],
397 "b": [dataId],
398 "out": [dict(dataId, patch=patch) for patch in {0, 1}],
399 },
400 )
402 def testMakeQuantumMissingDataId(self):
403 task = VisitTask()
405 dataId = {"instrument": "notACam", "visit": 102}
406 self._makeVisitTestData(dataId)
408 with self.assertRaises(ValueError):
409 makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "outA", "outB"}})
410 with self.assertRaises(ValueError):
411 makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "b", "outB"}})
413 def testMakeQuantumCorruptedDataId(self):
414 task = VisitTask()
416 dataId = {"instrument": "notACam", "visit": 102}
417 self._makeVisitTestData(dataId)
419 with self.assertRaises(ValueError):
420 # fourth argument should be a mapping keyed by component name
421 makeQuantum(task, self.butler, dataId, dataId)
423 def testRunTestQuantumVisitWithRun(self):
424 task = VisitTask()
426 dataId = {"instrument": "notACam", "visit": 102}
427 data = self._makeVisitTestData(dataId)
429 quantum = makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "b", "outA", "outB"}})
430 runTestQuantum(task, self.butler, quantum, mockRun=False)
432 # Can we use runTestQuantum to verify that task.run got called with
433 # correct inputs/outputs?
434 self.assertTrue(self.butler.datasetExists("VisitOutA", dataId))
435 self.assertEqual(
436 self.butler.get("VisitOutA", dataId),
437 butlerTests.MetricsExample(data=(data["VisitA"] + data["VisitB"])),
438 )
439 self.assertTrue(self.butler.datasetExists("VisitOutB", dataId))
440 self.assertEqual(
441 self.butler.get("VisitOutB", dataId),
442 butlerTests.MetricsExample(data=(data["VisitA"] * max(data["VisitB"]))),
443 )
445 def testRunTestQuantumPatchWithRun(self):
446 task = PatchTask()
448 dataId = {"skymap": "sky", "tract": 42}
449 data = self._makePatchTestData(dataId)
451 quantum = makeQuantum(
452 task,
453 self.butler,
454 dataId,
455 {
456 "a": [dataset[0] for dataset in data["PatchA"]],
457 "b": dataId,
458 "out": [dataset[0] for dataset in data["PatchA"]],
459 },
460 )
461 runTestQuantum(task, self.butler, quantum, mockRun=False)
463 # Can we use runTestQuantum to verify that task.run got called with
464 # correct inputs/outputs?
465 inB = data["PatchB"][0][1]
466 for dataset in data["PatchA"]:
467 patchId = dataset[0]
468 self.assertTrue(self.butler.datasetExists("PatchOut", patchId))
469 self.assertEqual(
470 self.butler.get("PatchOut", patchId), butlerTests.MetricsExample(data=(dataset[1] + inB))
471 )
473 def testRunTestQuantumVisitMockRun(self):
474 task = VisitTask()
476 dataId = {"instrument": "notACam", "visit": 102}
477 data = self._makeVisitTestData(dataId)
479 quantum = makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "b", "outA", "outB"}})
480 run = runTestQuantum(task, self.butler, quantum, mockRun=True)
482 # Can we use the mock to verify that task.run got called with the
483 # correct inputs?
484 run.assert_called_once_with(
485 a=butlerTests.MetricsExample(data=data["VisitA"]),
486 b=butlerTests.MetricsExample(data=data["VisitB"]),
487 )
489 def testRunTestQuantumPatchMockRun(self):
490 task = PatchTask()
492 dataId = {"skymap": "sky", "tract": 42}
493 data = self._makePatchTestData(dataId)
495 quantum = makeQuantum(
496 task,
497 self.butler,
498 dataId,
499 {
500 # Use lists, not sets, to ensure order agrees with test
501 # assertion.
502 "a": [dataset[0] for dataset in data["PatchA"]],
503 "b": dataId,
504 "out": [dataset[0] for dataset in data["PatchA"]],
505 },
506 )
507 run = runTestQuantum(task, self.butler, quantum, mockRun=True)
509 # Can we use the mock to verify that task.run got called with the
510 # correct inputs?
511 run.assert_called_once_with(
512 a=[butlerTests.MetricsExample(data=dataset[1]) for dataset in data["PatchA"]],
513 b=butlerTests.MetricsExample(data=data["PatchB"][0][1]),
514 )
516 def testRunTestQuantumPatchOptionalInput(self):
517 config = PatchConfig()
518 config.doUseB = False
519 task = PatchTask(config=config)
521 dataId = {"skymap": "sky", "tract": 42}
522 data = self._makePatchTestData(dataId)
524 quantum = makeQuantum(
525 task,
526 self.butler,
527 dataId,
528 {
529 # Use lists, not sets, to ensure order agrees with test
530 # assertion.
531 "a": [dataset[0] for dataset in data["PatchA"]],
532 "out": [dataset[0] for dataset in data["PatchA"]],
533 },
534 )
535 run = runTestQuantum(task, self.butler, quantum, mockRun=True)
537 # Can we use the mock to verify that task.run got called with the
538 # correct inputs?
539 run.assert_called_once_with(
540 a=[butlerTests.MetricsExample(data=dataset[1]) for dataset in data["PatchA"]]
541 )
543 def testAssertValidOutputPass(self):
544 task = VisitTask()
546 inA = butlerTests.MetricsExample(data=[1, 2, 3])
547 inB = butlerTests.MetricsExample(data=[4, 0, 1])
548 result = task.run(inA, inB)
550 # should not throw
551 assertValidOutput(task, result)
553 def testAssertValidOutputMissing(self):
554 task = VisitTask()
556 def run(a, b):
557 return Struct(outA=a)
559 task.run = run
561 inA = butlerTests.MetricsExample(data=[1, 2, 3])
562 inB = butlerTests.MetricsExample(data=[4, 0, 1])
563 result = task.run(inA, inB)
565 with self.assertRaises(AssertionError):
566 assertValidOutput(task, result)
568 def testAssertValidOutputSingle(self):
569 task = PatchTask()
571 def run(a, b):
572 return Struct(out=b)
574 task.run = run
576 inA = butlerTests.MetricsExample(data=[1, 2, 3])
577 inB = butlerTests.MetricsExample(data=[4, 0, 1])
578 result = task.run([inA], inB)
580 with self.assertRaises(AssertionError):
581 assertValidOutput(task, result)
583 def testAssertValidOutputMultiple(self):
584 task = VisitTask()
586 def run(a, b):
587 return Struct(outA=[a], outB=b)
589 task.run = run
591 inA = butlerTests.MetricsExample(data=[1, 2, 3])
592 inB = butlerTests.MetricsExample(data=[4, 0, 1])
593 result = task.run(inA, inB)
595 with self.assertRaises(AssertionError):
596 assertValidOutput(task, result)
598 def testAssertValidInitOutputPass(self):
599 task = VisitTask()
600 # should not throw
601 assertValidInitOutput(task)
603 task = PatchTask()
604 # should not throw
605 assertValidInitOutput(task)
607 def testAssertValidInitOutputMissing(self):
608 class BadVisitTask(VisitTask):
609 def __init__(self, **kwargs):
610 PipelineTask.__init__(self, **kwargs) # Bypass VisitTask constructor
611 pass # do not set fields
613 task = BadVisitTask()
615 with self.assertRaises(AssertionError):
616 assertValidInitOutput(task)
618 def testAssertValidInitOutputSingle(self):
619 class BadVisitTask(VisitTask):
620 def __init__(self, **kwargs):
621 PipelineTask.__init__(self, **kwargs) # Bypass VisitTask constructor
622 self.initOut = butlerTests.MetricsExample(data=[1, 2])
624 task = BadVisitTask()
626 with self.assertRaises(AssertionError):
627 assertValidInitOutput(task)
629 def testAssertValidInitOutputMultiple(self):
630 class BadPatchTask(PatchTask):
631 def __init__(self, **kwargs):
632 PipelineTask.__init__(self, **kwargs) # Bypass PatchTask constructor
633 self.initOutA = [butlerTests.MetricsExample(data=[1, 2, 4])]
634 self.initOutB = butlerTests.MetricsExample(data=[1, 2, 3])
636 task = BadPatchTask()
638 with self.assertRaises(AssertionError):
639 assertValidInitOutput(task)
641 def testGetInitInputs(self):
642 dataId = {"instrument": "notACam", "visit": 102}
643 data = self._makeVisitTestData(dataId)
645 self.assertEqual(getInitInputs(self.butler, VisitConfig()), {})
647 config = VisitConfig()
648 config.doUseInitIn = True
649 self.assertEqual(
650 getInitInputs(self.butler, config),
651 {"initIn": butlerTests.MetricsExample(data=data["VisitInitIn"])},
652 )
654 def testSkypixHandling(self):
655 task = SkyPixTask()
657 dataId = {"htm7": 157227} # connection declares skypix, but Butler uses htm7
658 data = butlerTests.MetricsExample(data=[1, 2, 3])
659 self.butler.put(data, "PixA", dataId)
661 quantum = makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "out"}})
662 run = runTestQuantum(task, self.butler, quantum, mockRun=True)
664 # PixA dataset should have been retrieved by runTestQuantum
665 run.assert_called_once_with(a=data)
667 def testLintConnectionsOk(self):
668 lintConnections(VisitConnections)
669 lintConnections(PatchConnections)
670 lintConnections(SkyPixConnections)
672 def testLintConnectionsMissingMultiple(self):
673 class BadConnections(PipelineTaskConnections, dimensions={"tract", "patch", "skymap"}):
674 coadds = connectionTypes.Input(
675 name="coadd_calexp",
676 storageClass="ExposureF",
677 # Some authors use list rather than set; check that linter
678 # can handle it.
679 dimensions=["tract", "patch", "band", "skymap"],
680 )
682 with self.assertRaises(AssertionError):
683 lintConnections(BadConnections)
684 lintConnections(BadConnections, checkMissingMultiple=False)
686 def testLintConnectionsExtraMultiple(self):
687 class BadConnections(
688 PipelineTaskConnections,
689 # Some authors use list rather than set.
690 dimensions=["tract", "patch", "band", "skymap"],
691 ):
692 coadds = connectionTypes.Input(
693 name="coadd_calexp",
694 storageClass="ExposureF",
695 multiple=True,
696 dimensions={"tract", "patch", "band", "skymap"},
697 )
699 with self.assertRaises(AssertionError):
700 lintConnections(BadConnections)
701 lintConnections(BadConnections, checkUnnecessaryMultiple=False)
704class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
705 pass
708def setup_module(module):
709 lsst.utils.tests.init()
712if __name__ == "__main__": 712 ↛ 713line 712 didn't jump to line 713, because the condition on line 712 was never true
713 lsst.utils.tests.init()
714 unittest.main()