Coverage for tests/test_testUtils.py : 24%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""Unit tests for `lsst.pipe.base.tests`, a library for testing
23PipelineTask subclasses.
24"""
26import shutil
27import tempfile
28import unittest
30import lsst.utils.tests
31import lsst.pex.config
32import lsst.daf.butler
33import lsst.daf.butler.tests as butlerTests
35from lsst.pipe.base import Struct, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, connectionTypes
36from lsst.pipe.base.testUtils import runTestQuantum, makeQuantum, assertValidOutput, \
37 assertValidInitOutput, getInitInputs
40class VisitConnections(PipelineTaskConnections, dimensions={"instrument", "visit"}):
41 initIn = connectionTypes.InitInput(
42 name="VisitInitIn",
43 storageClass="StructuredData",
44 multiple=False,
45 )
46 a = connectionTypes.Input(
47 name="VisitA",
48 storageClass="StructuredData",
49 multiple=False,
50 dimensions={"instrument", "visit"},
51 )
52 b = connectionTypes.Input(
53 name="VisitB",
54 storageClass="StructuredData",
55 multiple=False,
56 dimensions={"instrument", "visit"},
57 )
58 initOut = connectionTypes.InitOutput(
59 name="VisitInitOut",
60 storageClass="StructuredData",
61 multiple=True,
62 )
63 outA = connectionTypes.Output(
64 name="VisitOutA",
65 storageClass="StructuredData",
66 multiple=False,
67 dimensions={"instrument", "visit"},
68 )
69 outB = connectionTypes.Output(
70 name="VisitOutB",
71 storageClass="StructuredData",
72 multiple=False,
73 dimensions={"instrument", "visit"},
74 )
76 def __init__(self, *, config=None):
77 super().__init__(config=config)
79 if not config.doUseInitIn:
80 self.initInputs.remove("initIn")
83class PatchConnections(PipelineTaskConnections, dimensions={"skymap", "tract"}):
84 a = connectionTypes.Input(
85 name="PatchA",
86 storageClass="StructuredData",
87 multiple=True,
88 dimensions={"skymap", "tract", "patch"},
89 )
90 b = connectionTypes.PrerequisiteInput(
91 name="PatchB",
92 storageClass="StructuredData",
93 multiple=False,
94 dimensions={"skymap", "tract"},
95 )
96 initOutA = connectionTypes.InitOutput(
97 name="PatchInitOutA",
98 storageClass="StructuredData",
99 multiple=False,
100 )
101 initOutB = connectionTypes.InitOutput(
102 name="PatchInitOutB",
103 storageClass="StructuredData",
104 multiple=False,
105 )
106 out = connectionTypes.Output(
107 name="PatchOut",
108 storageClass="StructuredData",
109 multiple=True,
110 dimensions={"skymap", "tract", "patch"},
111 )
113 def __init__(self, *, config=None):
114 super().__init__(config=config)
116 if not config.doUseB:
117 self.prerequisiteInputs.remove("b")
120class SkyPixConnections(PipelineTaskConnections, dimensions={"skypix"}):
121 a = connectionTypes.Input(
122 name="PixA",
123 storageClass="StructuredData",
124 dimensions={"skypix"},
125 )
126 out = connectionTypes.Output(
127 name="PixOut",
128 storageClass="StructuredData",
129 dimensions={"skypix"},
130 )
133class VisitConfig(PipelineTaskConfig, pipelineConnections=VisitConnections):
134 doUseInitIn = lsst.pex.config.Field(default=False, dtype=bool, doc="")
137class PatchConfig(PipelineTaskConfig, pipelineConnections=PatchConnections):
138 doUseB = lsst.pex.config.Field(default=True, dtype=bool, doc="")
141class SkyPixConfig(PipelineTaskConfig, pipelineConnections=SkyPixConnections):
142 pass
145class VisitTask(PipelineTask):
146 ConfigClass = VisitConfig
147 _DefaultName = "visit"
149 def __init__(self, initInputs=None, **kwargs):
150 super().__init__(initInputs=initInputs, **kwargs)
151 self.initOut = [
152 butlerTests.MetricsExample(data=[1, 2]),
153 butlerTests.MetricsExample(data=[3, 4]),
154 ]
156 def run(self, a, b):
157 outA = butlerTests.MetricsExample(data=(a.data + b.data))
158 outB = butlerTests.MetricsExample(data=(a.data * max(b.data)))
159 return Struct(outA=outA, outB=outB)
162class PatchTask(PipelineTask):
163 ConfigClass = PatchConfig
164 _DefaultName = "patch"
166 def __init__(self, **kwargs):
167 super().__init__(**kwargs)
168 self.initOutA = butlerTests.MetricsExample(data=[1, 2, 4])
169 self.initOutB = butlerTests.MetricsExample(data=[1, 2, 3])
171 def run(self, a, b=None):
172 if self.config.doUseB:
173 out = [butlerTests.MetricsExample(data=(oneA.data + b.data)) for oneA in a]
174 else:
175 out = a
176 return Struct(out=out)
179class SkyPixTask(PipelineTask):
180 ConfigClass = SkyPixConfig
181 _DefaultName = "skypix"
183 def run(self, a):
184 return Struct(out=a)
187class PipelineTaskTestSuite(lsst.utils.tests.TestCase):
188 @classmethod
189 def setUpClass(cls):
190 super().setUpClass()
191 # Repository should be re-created for each test case, but
192 # this has a prohibitive run-time cost at present
193 cls.root = tempfile.mkdtemp()
195 cls.repo = butlerTests.makeTestRepo(cls.root)
196 butlerTests.addDataIdValue(cls.repo, "instrument", "notACam")
197 butlerTests.addDataIdValue(cls.repo, "visit", 101)
198 butlerTests.addDataIdValue(cls.repo, "visit", 102)
199 butlerTests.addDataIdValue(cls.repo, "skymap", "sky")
200 butlerTests.addDataIdValue(cls.repo, "tract", 42)
201 butlerTests.addDataIdValue(cls.repo, "patch", 0)
202 butlerTests.addDataIdValue(cls.repo, "patch", 1)
203 butlerTests.registerMetricsExample(cls.repo)
205 for typeName in {"VisitA", "VisitB", "VisitOutA", "VisitOutB"}:
206 butlerTests.addDatasetType(cls.repo, typeName, {"instrument", "visit"}, "StructuredData")
207 for typeName in {"PatchA", "PatchOut"}:
208 butlerTests.addDatasetType(cls.repo, typeName, {"skymap", "tract", "patch"}, "StructuredData")
209 butlerTests.addDatasetType(cls.repo, "PatchB", {"skymap", "tract"}, "StructuredData")
210 for typeName in {"PixA", "PixOut"}:
211 butlerTests.addDatasetType(cls.repo, typeName, {"htm7"}, "StructuredData")
212 butlerTests.addDatasetType(cls.repo, "VisitInitIn", set(), "StructuredData")
214 @classmethod
215 def tearDownClass(cls):
216 shutil.rmtree(cls.root, ignore_errors=True)
217 super().tearDownClass()
219 def setUp(self):
220 super().setUp()
221 self.butler = butlerTests.makeTestCollection(self.repo)
223 def _makeVisitTestData(self, dataId):
224 """Create dummy datasets suitable for VisitTask.
226 This method updates ``self.butler`` directly.
228 Parameters
229 ----------
230 dataId : any data ID type
231 The (shared) ID for the datasets to create.
233 Returns
234 -------
235 datasets : `dict` [`str`, `list`]
236 A dictionary keyed by dataset type. Its values are the list of
237 integers used to create each dataset. The datasets stored in the
238 butler are `lsst.daf.butler.tests.MetricsExample` objects with
239 these lists as their ``data`` argument, but the lists are easier
240 to manipulate in test code.
241 """
242 inInit = [4, 2]
243 inA = [1, 2, 3]
244 inB = [4, 0, 1]
245 self.butler.put(butlerTests.MetricsExample(data=inA), "VisitA", dataId)
246 self.butler.put(butlerTests.MetricsExample(data=inB), "VisitB", dataId)
247 self.butler.put(butlerTests.MetricsExample(data=inInit), "VisitInitIn", set())
248 return {"VisitA": inA, "VisitB": inB, "VisitInitIn": inInit, }
250 def _makePatchTestData(self, dataId):
251 """Create dummy datasets suitable for PatchTask.
253 This method updates ``self.butler`` directly.
255 Parameters
256 ----------
257 dataId : any data ID type
258 The (shared) ID for the datasets to create. Any patch ID is
259 overridden to create multiple datasets.
261 Returns
262 -------
263 datasets : `dict` [`str`, `list` [`tuple` [data ID, `list`]]]
264 A dictionary keyed by dataset type. Its values are the data ID
265 of each dataset and the list of integers used to create each. The
266 datasets stored in the butler are
267 `lsst.daf.butler.tests.MetricsExample` objects with these lists as
268 their ``data`` argument, but the lists are easier to manipulate
269 in test code.
270 """
271 inA = [1, 2, 3]
272 inB = [4, 0, 1]
273 datasets = {"PatchA": [], "PatchB": []}
274 for patch in {0, 1}:
275 self.butler.put(butlerTests.MetricsExample(data=(inA + [patch])), "PatchA", dataId, patch=patch)
276 datasets["PatchA"].append((dict(dataId, patch=patch), inA + [patch]))
277 self.butler.put(butlerTests.MetricsExample(data=inB), "PatchB", dataId)
278 datasets["PatchB"].append((dataId, inB))
279 return datasets
281 def testMakeQuantumNoSuchDatatype(self):
282 config = VisitConfig()
283 config.connections.a = "Visit"
284 task = VisitTask(config=config)
286 dataId = {"instrument": "notACam", "visit": 102}
287 self._makeVisitTestData(dataId)
289 with self.assertRaises(ValueError):
290 makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "b", "outA", "outB"}})
292 def testMakeQuantumInvalidDimension(self):
293 config = VisitConfig()
294 config.connections.a = "PatchA"
295 task = VisitTask(config=config)
296 dataIdV = {"instrument": "notACam", "visit": 102}
297 dataIdP = {"skymap": "sky", "tract": 42, "patch": 0}
299 inA = [1, 2, 3]
300 inB = [4, 0, 1]
301 self.butler.put(butlerTests.MetricsExample(data=inA), "VisitA", dataIdV)
302 self.butler.put(butlerTests.MetricsExample(data=inA), "PatchA", dataIdP)
303 self.butler.put(butlerTests.MetricsExample(data=inB), "VisitB", dataIdV)
305 with self.assertRaises(ValueError):
306 makeQuantum(task, self.butler, dataIdV, {
307 "a": dataIdP,
308 "b": dataIdV,
309 "outA": dataIdV,
310 "outB": dataIdV,
311 })
313 def testMakeQuantumMissingMultiple(self):
314 task = PatchTask()
316 dataId = {"skymap": "sky", "tract": 42}
317 self._makePatchTestData(dataId)
319 with self.assertRaises(ValueError):
320 makeQuantum(task, self.butler, dataId, {
321 "a": dict(dataId, patch=0),
322 "b": dataId,
323 "out": [dict(dataId, patch=patch) for patch in {0, 1}],
324 })
326 def testMakeQuantumExtraMultiple(self):
327 task = PatchTask()
329 dataId = {"skymap": "sky", "tract": 42}
330 self._makePatchTestData(dataId)
332 with self.assertRaises(ValueError):
333 makeQuantum(task, self.butler, dataId, {
334 "a": [dict(dataId, patch=patch) for patch in {0, 1}],
335 "b": [dataId],
336 "out": [dict(dataId, patch=patch) for patch in {0, 1}],
337 })
339 def testMakeQuantumMissingDataId(self):
340 task = VisitTask()
342 dataId = {"instrument": "notACam", "visit": 102}
343 self._makeVisitTestData(dataId)
345 with self.assertRaises(ValueError):
346 makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "outA", "outB"}})
347 with self.assertRaises(ValueError):
348 makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "b", "outB"}})
350 def testMakeQuantumCorruptedDataId(self):
351 task = VisitTask()
353 dataId = {"instrument": "notACam", "visit": 102}
354 self._makeVisitTestData(dataId)
356 with self.assertRaises(ValueError):
357 # fourth argument should be a mapping keyed by component name
358 makeQuantum(task, self.butler, dataId, dataId)
360 def testRunTestQuantumVisitWithRun(self):
361 task = VisitTask()
363 dataId = {"instrument": "notACam", "visit": 102}
364 data = self._makeVisitTestData(dataId)
366 quantum = makeQuantum(task, self.butler, dataId,
367 {key: dataId for key in {"a", "b", "outA", "outB"}})
368 runTestQuantum(task, self.butler, quantum, mockRun=False)
370 # Can we use runTestQuantum to verify that task.run got called with
371 # correct inputs/outputs?
372 self.assertTrue(self.butler.datasetExists("VisitOutA", dataId))
373 self.assertEqual(self.butler.get("VisitOutA", dataId),
374 butlerTests.MetricsExample(data=(data["VisitA"] + data["VisitB"])))
375 self.assertTrue(self.butler.datasetExists("VisitOutB", dataId))
376 self.assertEqual(self.butler.get("VisitOutB", dataId),
377 butlerTests.MetricsExample(data=(data["VisitA"] * max(data["VisitB"]))))
379 def testRunTestQuantumPatchWithRun(self):
380 task = PatchTask()
382 dataId = {"skymap": "sky", "tract": 42}
383 data = self._makePatchTestData(dataId)
385 quantum = makeQuantum(task, self.butler, dataId, {
386 "a": [dataset[0] for dataset in data["PatchA"]],
387 "b": dataId,
388 "out": [dataset[0] for dataset in data["PatchA"]],
389 })
390 runTestQuantum(task, self.butler, quantum, mockRun=False)
392 # Can we use runTestQuantum to verify that task.run got called with
393 # correct inputs/outputs?
394 inB = data["PatchB"][0][1]
395 for dataset in data["PatchA"]:
396 patchId = dataset[0]
397 self.assertTrue(self.butler.datasetExists("PatchOut", patchId))
398 self.assertEqual(self.butler.get("PatchOut", patchId),
399 butlerTests.MetricsExample(data=(dataset[1] + inB)))
401 def testRunTestQuantumVisitMockRun(self):
402 task = VisitTask()
404 dataId = {"instrument": "notACam", "visit": 102}
405 data = self._makeVisitTestData(dataId)
407 quantum = makeQuantum(task, self.butler, dataId,
408 {key: dataId for key in {"a", "b", "outA", "outB"}})
409 run = runTestQuantum(task, self.butler, quantum, mockRun=True)
411 # Can we use the mock to verify that task.run got called with the
412 # correct inputs?
413 run.assert_called_once_with(a=butlerTests.MetricsExample(data=data["VisitA"]),
414 b=butlerTests.MetricsExample(data=data["VisitB"]))
416 def testRunTestQuantumPatchMockRun(self):
417 task = PatchTask()
419 dataId = {"skymap": "sky", "tract": 42}
420 data = self._makePatchTestData(dataId)
422 quantum = makeQuantum(task, self.butler, dataId, {
423 # Use lists, not sets, to ensure order agrees with test assertion
424 "a": [dataset[0] for dataset in data["PatchA"]],
425 "b": dataId,
426 "out": [dataset[0] for dataset in data["PatchA"]],
427 })
428 run = runTestQuantum(task, self.butler, quantum, mockRun=True)
430 # Can we use the mock to verify that task.run got called with the
431 # correct inputs?
432 run.assert_called_once_with(
433 a=[butlerTests.MetricsExample(data=dataset[1]) for dataset in data["PatchA"]],
434 b=butlerTests.MetricsExample(data=data["PatchB"][0][1])
435 )
437 def testRunTestQuantumPatchOptionalInput(self):
438 config = PatchConfig()
439 config.doUseB = False
440 task = PatchTask(config=config)
442 dataId = {"skymap": "sky", "tract": 42}
443 data = self._makePatchTestData(dataId)
445 quantum = makeQuantum(task, self.butler, dataId, {
446 # Use lists, not sets, to ensure order agrees with test assertion
447 "a": [dataset[0] for dataset in data["PatchA"]],
448 "out": [dataset[0] for dataset in data["PatchA"]],
449 })
450 run = runTestQuantum(task, self.butler, quantum, mockRun=True)
452 # Can we use the mock to verify that task.run got called with the
453 # correct inputs?
454 run.assert_called_once_with(
455 a=[butlerTests.MetricsExample(data=dataset[1]) for dataset in data["PatchA"]]
456 )
458 def testAssertValidOutputPass(self):
459 task = VisitTask()
461 inA = butlerTests.MetricsExample(data=[1, 2, 3])
462 inB = butlerTests.MetricsExample(data=[4, 0, 1])
463 result = task.run(inA, inB)
465 # should not throw
466 assertValidOutput(task, result)
468 def testAssertValidOutputMissing(self):
469 task = VisitTask()
471 def run(a, b):
472 return Struct(outA=a)
473 task.run = run
475 inA = butlerTests.MetricsExample(data=[1, 2, 3])
476 inB = butlerTests.MetricsExample(data=[4, 0, 1])
477 result = task.run(inA, inB)
479 with self.assertRaises(AssertionError):
480 assertValidOutput(task, result)
482 def testAssertValidOutputSingle(self):
483 task = PatchTask()
485 def run(a, b):
486 return Struct(out=b)
487 task.run = run
489 inA = butlerTests.MetricsExample(data=[1, 2, 3])
490 inB = butlerTests.MetricsExample(data=[4, 0, 1])
491 result = task.run([inA], inB)
493 with self.assertRaises(AssertionError):
494 assertValidOutput(task, result)
496 def testAssertValidOutputMultiple(self):
497 task = VisitTask()
499 def run(a, b):
500 return Struct(outA=[a], outB=b)
501 task.run = run
503 inA = butlerTests.MetricsExample(data=[1, 2, 3])
504 inB = butlerTests.MetricsExample(data=[4, 0, 1])
505 result = task.run(inA, inB)
507 with self.assertRaises(AssertionError):
508 assertValidOutput(task, result)
510 def testAssertValidInitOutputPass(self):
511 task = VisitTask()
512 # should not throw
513 assertValidInitOutput(task)
515 task = PatchTask()
516 # should not throw
517 assertValidInitOutput(task)
519 def testAssertValidInitOutputMissing(self):
520 class BadVisitTask(VisitTask):
521 def __init__(self, **kwargs):
522 PipelineTask.__init__(self, **kwargs) # Bypass VisitTask constructor
523 pass # do not set fields
525 task = BadVisitTask()
527 with self.assertRaises(AssertionError):
528 assertValidInitOutput(task)
530 def testAssertValidInitOutputSingle(self):
531 class BadVisitTask(VisitTask):
532 def __init__(self, **kwargs):
533 PipelineTask.__init__(self, **kwargs) # Bypass VisitTask constructor
534 self.initOut = butlerTests.MetricsExample(data=[1, 2])
536 task = BadVisitTask()
538 with self.assertRaises(AssertionError):
539 assertValidInitOutput(task)
541 def testAssertValidInitOutputMultiple(self):
542 class BadPatchTask(PatchTask):
543 def __init__(self, **kwargs):
544 PipelineTask.__init__(self, **kwargs) # Bypass PatchTask constructor
545 self.initOutA = [butlerTests.MetricsExample(data=[1, 2, 4])]
546 self.initOutB = butlerTests.MetricsExample(data=[1, 2, 3])
548 task = BadPatchTask()
550 with self.assertRaises(AssertionError):
551 assertValidInitOutput(task)
553 def testGetInitInputs(self):
554 dataId = {"instrument": "notACam", "visit": 102}
555 data = self._makeVisitTestData(dataId)
557 self.assertEqual(getInitInputs(self.butler, VisitConfig()), {})
559 config = VisitConfig()
560 config.doUseInitIn = True
561 self.assertEqual(getInitInputs(self.butler, config),
562 {"initIn": butlerTests.MetricsExample(data=data["VisitInitIn"])})
564 def testSkypixHandling(self):
565 task = SkyPixTask()
567 dataId = {"htm7": 157227} # connection declares skypix, but Butler uses htm7
568 data = butlerTests.MetricsExample(data=[1, 2, 3])
569 self.butler.put(data, "PixA", dataId)
571 quantum = makeQuantum(task, self.butler, dataId, {key: dataId for key in {"a", "out"}})
572 run = runTestQuantum(task, self.butler, quantum, mockRun=True)
574 # PixA dataset should have been retrieved by runTestQuantum
575 run.assert_called_once_with(a=data)
578class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
579 pass
582def setup_module(module):
583 lsst.utils.tests.init()
586if __name__ == "__main__": 586 ↛ 587line 586 didn't jump to line 587, because the condition on line 586 was never true
587 lsst.utils.tests.init()
588 unittest.main()