Coverage for tests/test_testUtils.py : 23%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""Unit tests for `lsst.pipe.base.tests`, a library for testing
23PipelineTask subclasses.
24"""
26import shutil
27import tempfile
28import unittest
30import lsst.utils.tests
31import lsst.pex.config
32import lsst.daf.butler
33import lsst.daf.butler.tests as butlerTests
35from lsst.pipe.base import Struct, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, connectionTypes
36from lsst.pipe.base.testUtils import runTestQuantum, makeQuantum, assertValidOutput, \
37 assertValidInitOutput, getInitInputs, lintConnections
40class VisitConnections(PipelineTaskConnections, dimensions={"instrument", "visit"}):
41 initIn = connectionTypes.InitInput(
42 name="VisitInitIn",
43 storageClass="StructuredData",
44 multiple=False,
45 )
46 a = connectionTypes.Input(
47 name="VisitA",
48 storageClass="StructuredData",
49 multiple=False,
50 dimensions={"instrument", "visit"},
51 )
52 b = connectionTypes.Input(
53 name="VisitB",
54 storageClass="StructuredData",
55 multiple=False,
56 dimensions={"instrument", "visit"},
57 )
58 initOut = connectionTypes.InitOutput(
59 name="VisitInitOut",
60 storageClass="StructuredData",
61 multiple=True,
62 )
63 outA = connectionTypes.Output(
64 name="VisitOutA",
65 storageClass="StructuredData",
66 multiple=False,
67 dimensions={"instrument", "visit"},
68 )
69 outB = connectionTypes.Output(
70 name="VisitOutB",
71 storageClass="StructuredData",
72 multiple=False,
73 dimensions={"instrument", "visit"},
74 )
76 def __init__(self, *, config=None):
77 super().__init__(config=config)
79 if not config.doUseInitIn:
80 self.initInputs.remove("initIn")
83class PatchConnections(PipelineTaskConnections, dimensions={"skymap", "tract"}):
84 a = connectionTypes.Input(
85 name="PatchA",
86 storageClass="StructuredData",
87 multiple=True,
88 dimensions={"skymap", "tract", "patch"},
89 )
90 b = connectionTypes.PrerequisiteInput(
91 name="PatchB",
92 storageClass="StructuredData",
93 multiple=False,
94 dimensions={"skymap", "tract"},
95 )
96 initOutA = connectionTypes.InitOutput(
97 name="PatchInitOutA",
98 storageClass="StructuredData",
99 multiple=False,
100 )
101 initOutB = connectionTypes.InitOutput(
102 name="PatchInitOutB",
103 storageClass="StructuredData",
104 multiple=False,
105 )
106 out = connectionTypes.Output(
107 name="PatchOut",
108 storageClass="StructuredData",
109 multiple=True,
110 dimensions={"skymap", "tract", "patch"},
111 )
113 def __init__(self, *, config=None):
114 super().__init__(config=config)
116 if not config.doUseB:
117 self.prerequisiteInputs.remove("b")
120class SkyPixConnections(PipelineTaskConnections, dimensions={"skypix"}):
121 a = connectionTypes.Input(
122 name="PixA",
123 storageClass="StructuredData",
124 dimensions={"skypix"},
125 )
126 out = connectionTypes.Output(
127 name="PixOut",
128 storageClass="StructuredData",
129 dimensions={"skypix"},
130 )
133class VisitConfig(PipelineTaskConfig, pipelineConnections=VisitConnections):
134 doUseInitIn = lsst.pex.config.Field(default=False, dtype=bool, doc="")
137class PatchConfig(PipelineTaskConfig, pipelineConnections=PatchConnections):
138 doUseB = lsst.pex.config.Field(default=True, dtype=bool, doc="")
141class SkyPixConfig(PipelineTaskConfig, pipelineConnections=SkyPixConnections):
142 pass
145class VisitTask(PipelineTask):
146 ConfigClass = VisitConfig
147 _DefaultName = "visit"
149 def __init__(self, initInputs=None, **kwargs):
150 super().__init__(initInputs=initInputs, **kwargs)
151 self.initOut = [
152 butlerTests.MetricsExample(data=[1, 2]),
153 butlerTests.MetricsExample(data=[3, 4]),
154 ]
156 def run(self, a, b):
157 outA = butlerTests.MetricsExample(data=(a.data + b.data))
158 outB = butlerTests.MetricsExample(data=(a.data * max(b.data)))
159 return Struct(outA=outA, outB=outB)
162class PatchTask(PipelineTask):
163 ConfigClass = PatchConfig
164 _DefaultName = "patch"
166 def __init__(self, **kwargs):
167 super().__init__(**kwargs)
168 self.initOutA = butlerTests.MetricsExample(data=[1, 2, 4])
169 self.initOutB = butlerTests.MetricsExample(data=[1, 2, 3])
171 def run(self, a, b=None):
172 if self.config.doUseB:
173 out = [butlerTests.MetricsExample(data=(oneA.data + b.data)) for oneA in a]
174 else:
175 out = a
176 return Struct(out=out)
179class SkyPixTask(PipelineTask):
180 ConfigClass = SkyPixConfig
181 _DefaultName = "skypix"
183 def run(self, a):
184 return Struct(out=a)
187class PipelineTaskTestSuite(lsst.utils.tests.TestCase):
188 @classmethod
189 def setUpClass(cls):
190 super().setUpClass()
191 # Repository should be re-created for each test case, but
192 # this has a prohibitive run-time cost at present
193 cls.root = tempfile.mkdtemp()
195 cls.repo = butlerTests.makeTestRepo(cls.root)
196 butlerTests.addDataIdValue(cls.repo, "instrument", "notACam")
197 butlerTests.addDataIdValue(cls.repo, "visit", 101)
198 butlerTests.addDataIdValue(cls.repo, "visit", 102)
199 butlerTests.addDataIdValue(cls.repo, "skymap", "sky")
200 butlerTests.addDataIdValue(cls.repo, "tract", 42)
201 butlerTests.addDataIdValue(cls.repo, "patch", 0)
202 butlerTests.addDataIdValue(cls.repo, "patch", 1)
203 butlerTests.registerMetricsExample(cls.repo)
205 for typeName in {"VisitA", "VisitB", "VisitOutA", "VisitOutB"}:
206 butlerTests.addDatasetType(cls.repo, typeName, {"instrument", "visit"}, "StructuredData")
207 for typeName in {"PatchA", "PatchOut"}:
208 butlerTests.addDatasetType(cls.repo, typeName, {"skymap", "tract", "patch"}, "StructuredData")
209 butlerTests.addDatasetType(cls.repo, "PatchB", {"skymap", "tract"}, "StructuredData")
210 for typeName in {"PixA", "PixOut"}:
211 butlerTests.addDatasetType(cls.repo, typeName, {"htm7"}, "StructuredData")
212 butlerTests.addDatasetType(cls.repo, "VisitInitIn", set(), "StructuredData")
214 @classmethod
215 def tearDownClass(cls):
216 shutil.rmtree(cls.root, ignore_errors=True)
217 super().tearDownClass()
219 def setUp(self):
220 super().setUp()
221 self.butler = butlerTests.makeTestCollection(self.repo)
223 def _makeVisitTestData(self, dataId):
224 """Create dummy datasets suitable for VisitTask.
226 This method updates ``self.butler`` directly.
228 Parameters
229 ----------
230 dataId : any data ID type
231 The (shared) ID for the datasets to create.
233 Returns
234 -------
235 datasets : `dict` [`str`, `list`]
236 A dictionary keyed by dataset type. Its values are the list of
237 integers used to create each dataset. The datasets stored in the
238 butler are `lsst.daf.butler.tests.MetricsExample` objects with
239 these lists as their ``data`` argument, but the lists are easier
240 to manipulate in test code.
241 """
242 inInit = [4, 2]
243 inA = [1, 2, 3]
244 inB = [4, 0, 1]
245 self.butler.put(butlerTests.MetricsExample(data=inA), "VisitA", dataId)
246 self.butler.put(butlerTests.MetricsExample(data=inB), "VisitB", dataId)
247 self.butler.put(butlerTests.MetricsExample(data=inInit), "VisitInitIn", set())
248 return {"VisitA": inA, "VisitB": inB, "VisitInitIn": inInit, }
250 def _makePatchTestData(self, dataId):
251 """Create dummy datasets suitable for PatchTask.
253 This method updates ``self.butler`` directly.
255 Parameters
256 ----------
257 dataId : any data ID type
258 The (shared) ID for the datasets to create. Any patch ID is
259 overridden to create multiple datasets.
261 Returns
262 -------
263 datasets : `dict` [`str`, `list` [`tuple` [data ID, `list`]]]
264 A dictionary keyed by dataset type. Its values are the data ID
265 of each dataset and the list of integers used to create each. The
266 datasets stored in the butler are
267 `lsst.daf.butler.tests.MetricsExample` objects with these lists as
268 their ``data`` argument, but the lists are easier to manipulate
269 in test code.
270 """
271 inA = [1, 2, 3]
272 inB = [4, 0, 1]
273 datasets = {"PatchA": [], "PatchB": []}
274 for patch in {0, 1}:
275 self.butler.put(butlerTests.MetricsExample(data=(inA + [patch])), "PatchA", dataId, patch=patch)
276 datasets["PatchA"].append((dict(dataId, patch=patch), inA + [patch]))
277 self.butler.put(butlerTests.MetricsExample(data=inB), "PatchB", dataId)
278 datasets["PatchB"].append((dataId, inB))
279 return datasets
281 def testMakeQuantumNoSuchDatatype(self):
282 config = VisitConfig()
283 config.connections.a = "Visit"
284 task = VisitTask(config=config)
286 dataId = {"instrument": "notACam", "visit": 102}
287 self._makeVisitTestData(dataId)
289 with self.assertRaises(ValueError):
290 makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "b", "outA", "outB"}})
292 def testMakeQuantumInvalidDimension(self):
293 config = VisitConfig()
294 config.connections.a = "PatchA"
295 task = VisitTask(config=config)
296 dataIdV = {"instrument": "notACam", "visit": 102}
297 dataIdVExtra = {"instrument": "notACam", "visit": 102, "detector": 42}
298 dataIdP = {"skymap": "sky", "tract": 42, "patch": 0}
300 inA = [1, 2, 3]
301 inB = [4, 0, 1]
302 self.butler.put(butlerTests.MetricsExample(data=inA), "VisitA", dataIdV)
303 self.butler.put(butlerTests.MetricsExample(data=inA), "PatchA", dataIdP)
304 self.butler.put(butlerTests.MetricsExample(data=inB), "VisitB", dataIdV)
306 # dataIdV is correct everywhere, dataIdP should error
307 with self.assertRaises(ValueError):
308 makeQuantum(task, self.butler, dataIdV, {
309 "a": dataIdP,
310 "b": dataIdV,
311 "outA": dataIdV,
312 "outB": dataIdV,
313 })
314 with self.assertRaises(ValueError):
315 makeQuantum(task, self.butler, dataIdP, {
316 "a": dataIdV,
317 "b": dataIdV,
318 "outA": dataIdV,
319 "outB": dataIdV,
320 })
321 # should not accept small changes, either
322 with self.assertRaises(ValueError):
323 makeQuantum(task, self.butler, dataIdV, {
324 "a": dataIdV,
325 "b": dataIdV,
326 "outA": dataIdVExtra,
327 "outB": dataIdV,
328 })
329 with self.assertRaises(ValueError):
330 makeQuantum(task, self.butler, dataIdVExtra, {
331 "a": dataIdV,
332 "b": dataIdV,
333 "outA": dataIdV,
334 "outB": dataIdV,
335 })
337 def testMakeQuantumMissingMultiple(self):
338 task = PatchTask()
340 dataId = {"skymap": "sky", "tract": 42}
341 self._makePatchTestData(dataId)
343 with self.assertRaises(ValueError):
344 makeQuantum(task, self.butler, dataId, {
345 "a": dict(dataId, patch=0),
346 "b": dataId,
347 "out": [dict(dataId, patch=patch) for patch in {0, 1}],
348 })
350 def testMakeQuantumExtraMultiple(self):
351 task = PatchTask()
353 dataId = {"skymap": "sky", "tract": 42}
354 self._makePatchTestData(dataId)
356 with self.assertRaises(ValueError):
357 makeQuantum(task, self.butler, dataId, {
358 "a": [dict(dataId, patch=patch) for patch in {0, 1}],
359 "b": [dataId],
360 "out": [dict(dataId, patch=patch) for patch in {0, 1}],
361 })
363 def testMakeQuantumMissingDataId(self):
364 task = VisitTask()
366 dataId = {"instrument": "notACam", "visit": 102}
367 self._makeVisitTestData(dataId)
369 with self.assertRaises(ValueError):
370 makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "outA", "outB"}})
371 with self.assertRaises(ValueError):
372 makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "b", "outB"}})
374 def testMakeQuantumCorruptedDataId(self):
375 task = VisitTask()
377 dataId = {"instrument": "notACam", "visit": 102}
378 self._makeVisitTestData(dataId)
380 with self.assertRaises(ValueError):
381 # fourth argument should be a mapping keyed by component name
382 makeQuantum(task, self.butler, dataId, dataId)
384 def testRunTestQuantumVisitWithRun(self):
385 task = VisitTask()
387 dataId = {"instrument": "notACam", "visit": 102}
388 data = self._makeVisitTestData(dataId)
390 quantum = makeQuantum(task, self.butler, dataId,
391 {key: dataId for key in {"a", "b", "outA", "outB"}})
392 runTestQuantum(task, self.butler, quantum, mockRun=False)
394 # Can we use runTestQuantum to verify that task.run got called with
395 # correct inputs/outputs?
396 self.assertTrue(self.butler.datasetExists("VisitOutA", dataId))
397 self.assertEqual(self.butler.get("VisitOutA", dataId),
398 butlerTests.MetricsExample(data=(data["VisitA"] + data["VisitB"])))
399 self.assertTrue(self.butler.datasetExists("VisitOutB", dataId))
400 self.assertEqual(self.butler.get("VisitOutB", dataId),
401 butlerTests.MetricsExample(data=(data["VisitA"] * max(data["VisitB"]))))
403 def testRunTestQuantumPatchWithRun(self):
404 task = PatchTask()
406 dataId = {"skymap": "sky", "tract": 42}
407 data = self._makePatchTestData(dataId)
409 quantum = makeQuantum(task, self.butler, dataId, {
410 "a": [dataset[0] for dataset in data["PatchA"]],
411 "b": dataId,
412 "out": [dataset[0] for dataset in data["PatchA"]],
413 })
414 runTestQuantum(task, self.butler, quantum, mockRun=False)
416 # Can we use runTestQuantum to verify that task.run got called with
417 # correct inputs/outputs?
418 inB = data["PatchB"][0][1]
419 for dataset in data["PatchA"]:
420 patchId = dataset[0]
421 self.assertTrue(self.butler.datasetExists("PatchOut", patchId))
422 self.assertEqual(self.butler.get("PatchOut", patchId),
423 butlerTests.MetricsExample(data=(dataset[1] + inB)))
425 def testRunTestQuantumVisitMockRun(self):
426 task = VisitTask()
428 dataId = {"instrument": "notACam", "visit": 102}
429 data = self._makeVisitTestData(dataId)
431 quantum = makeQuantum(task, self.butler, dataId,
432 {key: dataId for key in {"a", "b", "outA", "outB"}})
433 run = runTestQuantum(task, self.butler, quantum, mockRun=True)
435 # Can we use the mock to verify that task.run got called with the
436 # correct inputs?
437 run.assert_called_once_with(a=butlerTests.MetricsExample(data=data["VisitA"]),
438 b=butlerTests.MetricsExample(data=data["VisitB"]))
440 def testRunTestQuantumPatchMockRun(self):
441 task = PatchTask()
443 dataId = {"skymap": "sky", "tract": 42}
444 data = self._makePatchTestData(dataId)
446 quantum = makeQuantum(task, self.butler, dataId, {
447 # Use lists, not sets, to ensure order agrees with test assertion
448 "a": [dataset[0] for dataset in data["PatchA"]],
449 "b": dataId,
450 "out": [dataset[0] for dataset in data["PatchA"]],
451 })
452 run = runTestQuantum(task, self.butler, quantum, mockRun=True)
454 # Can we use the mock to verify that task.run got called with the
455 # correct inputs?
456 run.assert_called_once_with(
457 a=[butlerTests.MetricsExample(data=dataset[1]) for dataset in data["PatchA"]],
458 b=butlerTests.MetricsExample(data=data["PatchB"][0][1])
459 )
461 def testRunTestQuantumPatchOptionalInput(self):
462 config = PatchConfig()
463 config.doUseB = False
464 task = PatchTask(config=config)
466 dataId = {"skymap": "sky", "tract": 42}
467 data = self._makePatchTestData(dataId)
469 quantum = makeQuantum(task, self.butler, dataId, {
470 # Use lists, not sets, to ensure order agrees with test assertion
471 "a": [dataset[0] for dataset in data["PatchA"]],
472 "out": [dataset[0] for dataset in data["PatchA"]],
473 })
474 run = runTestQuantum(task, self.butler, quantum, mockRun=True)
476 # Can we use the mock to verify that task.run got called with the
477 # correct inputs?
478 run.assert_called_once_with(
479 a=[butlerTests.MetricsExample(data=dataset[1]) for dataset in data["PatchA"]]
480 )
482 def testAssertValidOutputPass(self):
483 task = VisitTask()
485 inA = butlerTests.MetricsExample(data=[1, 2, 3])
486 inB = butlerTests.MetricsExample(data=[4, 0, 1])
487 result = task.run(inA, inB)
489 # should not throw
490 assertValidOutput(task, result)
492 def testAssertValidOutputMissing(self):
493 task = VisitTask()
495 def run(a, b):
496 return Struct(outA=a)
497 task.run = run
499 inA = butlerTests.MetricsExample(data=[1, 2, 3])
500 inB = butlerTests.MetricsExample(data=[4, 0, 1])
501 result = task.run(inA, inB)
503 with self.assertRaises(AssertionError):
504 assertValidOutput(task, result)
506 def testAssertValidOutputSingle(self):
507 task = PatchTask()
509 def run(a, b):
510 return Struct(out=b)
511 task.run = run
513 inA = butlerTests.MetricsExample(data=[1, 2, 3])
514 inB = butlerTests.MetricsExample(data=[4, 0, 1])
515 result = task.run([inA], inB)
517 with self.assertRaises(AssertionError):
518 assertValidOutput(task, result)
520 def testAssertValidOutputMultiple(self):
521 task = VisitTask()
523 def run(a, b):
524 return Struct(outA=[a], outB=b)
525 task.run = run
527 inA = butlerTests.MetricsExample(data=[1, 2, 3])
528 inB = butlerTests.MetricsExample(data=[4, 0, 1])
529 result = task.run(inA, inB)
531 with self.assertRaises(AssertionError):
532 assertValidOutput(task, result)
534 def testAssertValidInitOutputPass(self):
535 task = VisitTask()
536 # should not throw
537 assertValidInitOutput(task)
539 task = PatchTask()
540 # should not throw
541 assertValidInitOutput(task)
543 def testAssertValidInitOutputMissing(self):
544 class BadVisitTask(VisitTask):
545 def __init__(self, **kwargs):
546 PipelineTask.__init__(self, **kwargs) # Bypass VisitTask constructor
547 pass # do not set fields
549 task = BadVisitTask()
551 with self.assertRaises(AssertionError):
552 assertValidInitOutput(task)
554 def testAssertValidInitOutputSingle(self):
555 class BadVisitTask(VisitTask):
556 def __init__(self, **kwargs):
557 PipelineTask.__init__(self, **kwargs) # Bypass VisitTask constructor
558 self.initOut = butlerTests.MetricsExample(data=[1, 2])
560 task = BadVisitTask()
562 with self.assertRaises(AssertionError):
563 assertValidInitOutput(task)
565 def testAssertValidInitOutputMultiple(self):
566 class BadPatchTask(PatchTask):
567 def __init__(self, **kwargs):
568 PipelineTask.__init__(self, **kwargs) # Bypass PatchTask constructor
569 self.initOutA = [butlerTests.MetricsExample(data=[1, 2, 4])]
570 self.initOutB = butlerTests.MetricsExample(data=[1, 2, 3])
572 task = BadPatchTask()
574 with self.assertRaises(AssertionError):
575 assertValidInitOutput(task)
577 def testGetInitInputs(self):
578 dataId = {"instrument": "notACam", "visit": 102}
579 data = self._makeVisitTestData(dataId)
581 self.assertEqual(getInitInputs(self.butler, VisitConfig()), {})
583 config = VisitConfig()
584 config.doUseInitIn = True
585 self.assertEqual(getInitInputs(self.butler, config),
586 {"initIn": butlerTests.MetricsExample(data=data["VisitInitIn"])})
588 def testSkypixHandling(self):
589 task = SkyPixTask()
591 dataId = {"htm7": 157227} # connection declares skypix, but Butler uses htm7
592 data = butlerTests.MetricsExample(data=[1, 2, 3])
593 self.butler.put(data, "PixA", dataId)
595 quantum = makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "out"}})
596 run = runTestQuantum(task, self.butler, quantum, mockRun=True)
598 # PixA dataset should have been retrieved by runTestQuantum
599 run.assert_called_once_with(a=data)
601 def testLintConnectionsOk(self):
602 lintConnections(VisitConnections)
603 lintConnections(PatchConnections)
604 lintConnections(SkyPixConnections)
606 def testLintConnectionsMissingMultiple(self):
607 class BadConnections(PipelineTaskConnections,
608 dimensions={"tract", "patch", "skymap"}):
609 coadds = connectionTypes.Input(
610 name="coadd_calexp",
611 storageClass="ExposureF",
612 # Some authors use list rather than set; check that linter
613 # can handle it.
614 dimensions=["tract", "patch", "band", "skymap"],
615 )
617 with self.assertRaises(AssertionError):
618 lintConnections(BadConnections)
619 lintConnections(BadConnections, checkMissingMultiple=False)
621 def testLintConnectionsExtraMultiple(self):
622 class BadConnections(PipelineTaskConnections,
623 # Some authors use list rather than set.
624 dimensions=["tract", "patch", "band", "skymap"]):
625 coadds = connectionTypes.Input(
626 name="coadd_calexp",
627 storageClass="ExposureF",
628 multiple=True,
629 dimensions={"tract", "patch", "band", "skymap"},
630 )
632 with self.assertRaises(AssertionError):
633 lintConnections(BadConnections)
634 lintConnections(BadConnections, checkUnnecessaryMultiple=False)
637class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
638 pass
641def setup_module(module):
642 lsst.utils.tests.init()
645if __name__ == "__main__": 645 ↛ 646line 645 didn't jump to line 646, because the condition on line 645 was never true
646 lsst.utils.tests.init()
647 unittest.main()