Coverage for tests / test_adjust_all_quanta.py: 32%

65 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-17 08:59 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = () 

31 

32import operator 

33import unittest 

34from collections import defaultdict 

35 

36import lsst.pipe.base.connectionTypes as cT 

37from lsst.daf.butler import Butler, DataCoordinate 

38from lsst.daf.butler.tests.utils import create_populated_sqlite_registry 

39from lsst.pipe.base import ( 

40 PipelineGraph, 

41 PipelineTask, 

42 PipelineTaskConfig, 

43 PipelineTaskConnections, 

44 QuantaAdjuster, 

45) 

46from lsst.pipe.base.all_dimensions_quantum_graph_builder import AllDimensionsQuantumGraphBuilder 

47from lsst.resources import ResourcePath 

48 

49 

50class GroupTestConnections(PipelineTaskConnections, dimensions=("detector",)): 

51 """Connections for a task whose quanta read in all of the biases for all 

52 detectors with the same purpose, use flat-field exposures as prerequisites, 

53 and (theoretically) writes both a single summary output for all inputs and 

54 another output for each input. The data IDs of the quanta and the summary 

55 outputs using the detector with the lowest ID in that group. 

56 """ 

57 

58 input_group = cT.Input( 

59 "bias", 

60 "Exposure", 

61 multiple=True, 

62 dimensions=("detector",), 

63 isCalibration=True, 

64 ) 

65 prereq_input_group = cT.PrerequisiteInput( 

66 "flat", 

67 "Exposure", 

68 multiple=True, 

69 dimensions=("detector", "physical_filter", "band"), 

70 isCalibration=True, 

71 ) 

72 output_group = cT.Output( 

73 "bias_stuff", 

74 "StructuredDataDict", 

75 multiple=True, 

76 dimensions=("detector",), 

77 ) 

78 single_output = cT.Output( 

79 "bias_summary", 

80 "StructuredDataDict", 

81 multiple=False, 

82 dimensions=("detector",), 

83 ) 

84 

85 def adjust_all_quanta(self, adjuster: QuantaAdjuster) -> None: 

86 # Group the quanta by their detector's purpose. 

87 quanta_by_detector_purpose: defaultdict[str, list[DataCoordinate]] = defaultdict(list) 

88 for quantum_data_id in adjuster.iter_data_ids(): 

89 quantum_data_id = adjuster.expand_quantum_data_id(quantum_data_id) 

90 quanta_by_detector_purpose[quantum_data_id.records["detector"].purpose].append(quantum_data_id) 

91 # Within each group, keep only the one with the lowest detector ID, 

92 # while transferring the inputs and outputs of the others to that 

93 # quantum. 

94 for data_id_group in quanta_by_detector_purpose.values(): 

95 data_id_group.sort(key=operator.itemgetter("detector")) 

96 keep, *drop = data_id_group 

97 for drop_data_id in drop: 

98 for input_data_id in adjuster.get_inputs(drop_data_id)["input_group"]: 

99 adjuster.add_input(keep, "input_group", input_data_id) 

100 for input_uuid in adjuster.get_prerequisite_inputs(drop_data_id)["prereq_input_group"]: 

101 adjuster.add_prerequisite_input(keep, "prereq_input_group", input_uuid) 

102 for output_data_id in adjuster.get_outputs(drop_data_id)["output_group"]: 

103 adjuster.move_output(keep, "output_group", output_data_id) 

104 adjuster.remove_quantum(drop_data_id) 

105 

106 

107class GroupTestConfig(PipelineTaskConfig, pipelineConnections=GroupTestConnections): 

108 pass 

109 

110 

111class GroupTestTask(PipelineTask): 

112 ConfigClass = GroupTestConfig 

113 

114 

115class AdjustAllQuantaTestCase(unittest.TestCase): 

116 """Tests for the `PipelineTaskConnections.adjust_all_quanta` hook in 

117 quantum-graph generation. 

118 """ 

119 

120 @staticmethod 

121 def make_butler() -> Butler: 

122 DATA_ROOT = ResourcePath("resource://lsst.daf.butler/tests/registry_data", forceDirectory=True) 

123 return create_populated_sqlite_registry( 

124 *[DATA_ROOT.join(filename) for filename in ("base.yaml", "datasets.yaml")] 

125 ) 

126 

127 def test_adjust_all_quanta(self) -> None: 

128 """Build a quantum graph for a task that implements the 

129 adjust_all_quanta hook, and check that it works as expected. 

130 """ 

131 butler = self.make_butler() 

132 self.enterContext(butler) 

133 pipeline_graph = PipelineGraph(universe=butler.dimensions) 

134 pipeline_graph.add_task("grouper", GroupTestTask) 

135 collections = ["imported_g", "imported_r"] 

136 qgb = AllDimensionsQuantumGraphBuilder( 

137 pipeline_graph, 

138 butler, 

139 input_collections=collections, 

140 output_run="irrelevant", 

141 ) 

142 qg = qgb.finish(attach_datastore_records=False).assemble() 

143 quanta = { 

144 quantum.dataId["detector"]: quantum 

145 for quantum in qg.build_execution_quanta(task_label="grouper").values() 

146 } 

147 # This test camera (defined in daf_butler test data) has 4 detectors; 

148 # 1-3 have purpose=SCIENCE, and 4 has purpose=WAVEFRONT. 

149 self.assertEqual(quanta.keys(), {1, 4}) 

150 self.assertEqual(len(quanta[1].inputs["bias"]), 3) 

151 self.assertEqual(len(quanta[1].inputs["flat"]), 5) 

152 self.assertEqual(len(quanta[1].outputs["bias_stuff"]), 3) 

153 self.assertCountEqual( 

154 quanta[1].inputs["bias"], 

155 butler.query_datasets("bias", collections=collections, where="detector.purpose = 'SCIENCE'"), 

156 ) 

157 self.assertCountEqual( 

158 quanta[1].inputs["flat"], 

159 butler.query_datasets("flat", collections=collections, where="detector.purpose = 'SCIENCE'"), 

160 ) 

161 self.assertEqual(len(quanta[1].outputs["bias_summary"]), 1) 

162 self.assertEqual(quanta[1].outputs["bias_summary"][0].dataId["detector"], 1) 

163 self.assertEqual(len(quanta[4].inputs["bias"]), 1) 

164 self.assertEqual(len(quanta[4].inputs["flat"]), 2) 

165 self.assertEqual(len(quanta[4].outputs["bias_stuff"]), 1) 

166 self.assertCountEqual( 

167 quanta[4].inputs["bias"], 

168 butler.query_datasets("bias", collections=collections, where="detector.purpose = 'WAVEFRONT'"), 

169 ) 

170 self.assertCountEqual( 

171 quanta[4].inputs["flat"], 

172 butler.query_datasets("flat", collections=collections, where="detector.purpose = 'WAVEFRONT'"), 

173 ) 

174 self.assertEqual(len(quanta[4].outputs["bias_summary"]), 1) 

175 self.assertEqual(quanta[4].outputs["bias_summary"][0].dataId["detector"], 4)