Coverage for tests / test_adjust_all_quanta.py: 32%
65 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 08:20 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 08:20 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <https://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = ()
32import operator
33import unittest
34from collections import defaultdict
36import lsst.pipe.base.connectionTypes as cT
37from lsst.daf.butler import Butler, DataCoordinate
38from lsst.daf.butler.tests.utils import create_populated_sqlite_registry
39from lsst.pipe.base import (
40 PipelineGraph,
41 PipelineTask,
42 PipelineTaskConfig,
43 PipelineTaskConnections,
44 QuantaAdjuster,
45)
46from lsst.pipe.base.all_dimensions_quantum_graph_builder import AllDimensionsQuantumGraphBuilder
47from lsst.resources import ResourcePath
50class GroupTestConnections(PipelineTaskConnections, dimensions=("detector",)):
51 """Connections for a task whose quanta read in all of the biases for all
52 detectors with the same purpose, use flat-field exposures as prerequisites,
53 and (theoretically) writes both a single summary output for all inputs and
54 another output for each input. The data IDs of the quanta and the summary
55 outputs using the detector with the lowest ID in that group.
56 """
58 input_group = cT.Input(
59 "bias",
60 "Exposure",
61 multiple=True,
62 dimensions=("detector",),
63 isCalibration=True,
64 )
65 prereq_input_group = cT.PrerequisiteInput(
66 "flat",
67 "Exposure",
68 multiple=True,
69 dimensions=("detector", "physical_filter", "band"),
70 isCalibration=True,
71 )
72 output_group = cT.Output(
73 "bias_stuff",
74 "StructuredDataDict",
75 multiple=True,
76 dimensions=("detector",),
77 )
78 single_output = cT.Output(
79 "bias_summary",
80 "StructuredDataDict",
81 multiple=False,
82 dimensions=("detector",),
83 )
85 def adjust_all_quanta(self, adjuster: QuantaAdjuster) -> None:
86 # Group the quanta by their detector's purpose.
87 quanta_by_detector_purpose: defaultdict[str, list[DataCoordinate]] = defaultdict(list)
88 for quantum_data_id in adjuster.iter_data_ids():
89 quantum_data_id = adjuster.expand_quantum_data_id(quantum_data_id)
90 quanta_by_detector_purpose[quantum_data_id.records["detector"].purpose].append(quantum_data_id)
91 # Within each group, keep only the one with the lowest detector ID,
92 # while transferring the inputs and outputs of the others to that
93 # quantum.
94 for data_id_group in quanta_by_detector_purpose.values():
95 data_id_group.sort(key=operator.itemgetter("detector"))
96 keep, *drop = data_id_group
97 for drop_data_id in drop:
98 for input_data_id in adjuster.get_inputs(drop_data_id)["input_group"]:
99 adjuster.add_input(keep, "input_group", input_data_id)
100 for input_uuid in adjuster.get_prerequisite_inputs(drop_data_id)["prereq_input_group"]:
101 adjuster.add_prerequisite_input(keep, "prereq_input_group", input_uuid)
102 for output_data_id in adjuster.get_outputs(drop_data_id)["output_group"]:
103 adjuster.move_output(keep, "output_group", output_data_id)
104 adjuster.remove_quantum(drop_data_id)
107class GroupTestConfig(PipelineTaskConfig, pipelineConnections=GroupTestConnections):
108 pass
111class GroupTestTask(PipelineTask):
112 ConfigClass = GroupTestConfig
115class AdjustAllQuantaTestCase(unittest.TestCase):
116 """Tests for the `PipelineTaskConnections.adjust_all_quanta` hook in
117 quantum-graph generation.
118 """
120 @staticmethod
121 def make_butler() -> Butler:
122 DATA_ROOT = ResourcePath("resource://lsst.daf.butler/tests/registry_data", forceDirectory=True)
123 return create_populated_sqlite_registry(
124 *[DATA_ROOT.join(filename) for filename in ("base.yaml", "datasets.yaml")]
125 )
127 def test_adjust_all_quanta(self) -> None:
128 """Build a quantum graph for a task that implements the
129 adjust_all_quanta hook, and check that it works as expected.
130 """
131 butler = self.make_butler()
132 self.enterContext(butler)
133 pipeline_graph = PipelineGraph(universe=butler.dimensions)
134 pipeline_graph.add_task("grouper", GroupTestTask)
135 collections = ["imported_g", "imported_r"]
136 qgb = AllDimensionsQuantumGraphBuilder(
137 pipeline_graph,
138 butler,
139 input_collections=collections,
140 output_run="irrelevant",
141 )
142 qg = qgb.finish(attach_datastore_records=False).assemble()
143 quanta = {
144 quantum.dataId["detector"]: quantum
145 for quantum in qg.build_execution_quanta(task_label="grouper").values()
146 }
147 # This test camera (defined in daf_butler test data) has 4 detectors;
148 # 1-3 have purpose=SCIENCE, and 4 has purpose=WAVEFRONT.
149 self.assertEqual(quanta.keys(), {1, 4})
150 self.assertEqual(len(quanta[1].inputs["bias"]), 3)
151 self.assertEqual(len(quanta[1].inputs["flat"]), 5)
152 self.assertEqual(len(quanta[1].outputs["bias_stuff"]), 3)
153 self.assertCountEqual(
154 quanta[1].inputs["bias"],
155 butler.query_datasets("bias", collections=collections, where="detector.purpose = 'SCIENCE'"),
156 )
157 self.assertCountEqual(
158 quanta[1].inputs["flat"],
159 butler.query_datasets("flat", collections=collections, where="detector.purpose = 'SCIENCE'"),
160 )
161 self.assertEqual(len(quanta[1].outputs["bias_summary"]), 1)
162 self.assertEqual(quanta[1].outputs["bias_summary"][0].dataId["detector"], 1)
163 self.assertEqual(len(quanta[4].inputs["bias"]), 1)
164 self.assertEqual(len(quanta[4].inputs["flat"]), 2)
165 self.assertEqual(len(quanta[4].outputs["bias_stuff"]), 1)
166 self.assertCountEqual(
167 quanta[4].inputs["bias"],
168 butler.query_datasets("bias", collections=collections, where="detector.purpose = 'WAVEFRONT'"),
169 )
170 self.assertCountEqual(
171 quanta[4].inputs["flat"],
172 butler.query_datasets("flat", collections=collections, where="detector.purpose = 'WAVEFRONT'"),
173 )
174 self.assertEqual(len(quanta[4].outputs["bias_summary"]), 1)
175 self.assertEqual(quanta[4].outputs["bias_summary"][0].dataId["detector"], 4)