Coverage for tests/qg_test_utils.py: 54%

68 statements  

« prev     ^ index     » next       coverage.py v7.3.3, created at 2023-12-20 17:34 +0000

1# This file is part of ctrl_bps. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27"""QuantumGraph-related utilities to support ctrl_bps testing. 

28""" 

29 

30# Not actually running Quantum so do not need to override 'run' Method 

31# pylint: disable=abstract-method 

32 

33# Many dummy classes for testing. 

34# pylint: disable=missing-class-docstring 

35 

36import lsst.pipe.base.connectionTypes as cT 

37from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum 

38from lsst.pex.config import Field 

39from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, QuantumGraph, TaskDef 

40from lsst.utils.introspection import get_full_type_name 

41 

42METADATA = {"D1": [1, 2, 3]} 

43 

44 

45# For each dummy task, create a Connections, Config, and PipelineTask 

46 

47 

48class Dummy1Connections(PipelineTaskConnections, dimensions=("D1", "D2")): 

49 """Connections class used for tests.""" 

50 

51 initOutput = cT.InitOutput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

52 input = cT.Input(name="Dummy1Input", storageClass="ExposureF", doc="n/a", dimensions=("D1", "D2")) 

53 output = cT.Output(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("D1", "D2")) 

54 

55 

56class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections): 

57 """Config class used for testing.""" 

58 

59 conf1 = Field(dtype=int, default=1, doc="dummy config") 

60 

61 

62class Dummy1PipelineTask(PipelineTask): 

63 """PipelineTask used for testing.""" 

64 

65 ConfigClass = Dummy1Config 

66 

67 

68class Dummy2Connections(PipelineTaskConnections, dimensions=("D1", "D2")): 

69 """Second connections class used for testing.""" 

70 

71 initInput = cT.InitInput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

72 initOutput = cT.InitOutput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

73 input = cT.Input(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("D1", "D2")) 

74 output = cT.Output(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("D1", "D2")) 

75 

76 

77class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections): 

78 """Config class used for second pipeline task.""" 

79 

80 conf1 = Field(dtype=int, default=1, doc="dummy config") 

81 

82 

83class Dummy2PipelineTask(PipelineTask): 

84 """Second test PipelineTask.""" 

85 

86 ConfigClass = Dummy2Config 

87 

88 

89class Dummy3Connections(PipelineTaskConnections, dimensions=("D1", "D2")): 

90 """Third connections class used for testing.""" 

91 

92 initInput = cT.InitInput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

93 initOutput = cT.InitOutput(name="Dummy3InitOutput", storageClass="ExposureF", doc="n/a") 

94 input = cT.Input(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("D1", "D2")) 

95 output = cT.Output(name="Dummy3Output", storageClass="ExposureF", doc="n/a", dimensions=("D1", "D2")) 

96 

97 

98class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections): 

99 """Third config used for testing.""" 

100 

101 conf1 = Field(dtype=int, default=1, doc="dummy config") 

102 

103 

104class Dummy3PipelineTask(PipelineTask): 

105 """Third test PipelineTask.""" 

106 

107 ConfigClass = Dummy3Config 

108 

109 

110# Test if a Task that does not interact with the other Tasks works fine in 

111# the graph. 

112class Dummy4Connections(PipelineTaskConnections, dimensions=("D1", "D2")): 

113 """Fourth connections class used for testing.""" 

114 

115 input = cT.Input(name="Dummy4Input", storageClass="ExposureF", doc="n/a", dimensions=("D1", "D2")) 

116 output = cT.Output(name="Dummy4Output", storageClass="ExposureF", doc="n/a", dimensions=("D1", "D2")) 

117 

118 

119class Dummy4Config(PipelineTaskConfig, pipelineConnections=Dummy4Connections): 

120 """Fourth config used for testing.""" 

121 

122 conf1 = Field(dtype=int, default=1, doc="dummy config") 

123 

124 

125class Dummy4PipelineTask(PipelineTask): 

126 """Fourth test PipelineTask.""" 

127 

128 ConfigClass = Dummy4Config 

129 

130 

131def make_test_quantum_graph(run: str = "run"): 

132 """Create a QuantumGraph for unit tests. 

133 

134 Parameters 

135 ---------- 

136 run : `str`, optional 

137 Name of the RUN collection for output datasets. 

138 

139 Returns 

140 ------- 

141 qgraph : `lsst.pipe.base.QuantumGraph` 

142 A test QuantumGraph looking like the following: 

143 (DummyTask4 is completely independent). 

144 

145 Numbers in parens are the values for the two dimensions (D1, D2). 

146 

147 .. code-block:: 

148 

149 T1(1,2) T1(3,4) T4(1,2) T4(3,4) 

150 | | 

151 T2(1,2) T2(3,4) 

152 | | 

153 T3(1,2) T3(3,4) 

154 """ 

155 config = Config( 

156 { 

157 "version": 1, 

158 "skypix": { 

159 "common": "htm7", 

160 "htm": { 

161 "class": "lsst.sphgeom.HtmPixelization", 

162 "max_level": 24, 

163 }, 

164 }, 

165 "elements": { 

166 "D1": { 

167 "keys": [ 

168 { 

169 "name": "id", 

170 "type": "int", 

171 } 

172 ], 

173 "storage": { 

174 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

175 }, 

176 }, 

177 "D2": { 

178 "keys": [ 

179 { 

180 "name": "id", 

181 "type": "int", 

182 } 

183 ], 

184 "storage": { 

185 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

186 }, 

187 }, 

188 }, 

189 "packers": {}, 

190 } 

191 ) 

192 

193 universe = DimensionUniverse(config=config) 

194 # need to make a mapping of TaskDef to set of quantum 

195 quantum_map = {} 

196 tasks = [] 

197 # Map to keep output/intermediate refs. 

198 intermediate_refs: dict[tuple[DatasetType, DataCoordinate], DatasetRef] = {} 

199 for task, label in ( 

200 (Dummy1PipelineTask, "T1"), 

201 (Dummy2PipelineTask, "T2"), 

202 (Dummy3PipelineTask, "T3"), 

203 (Dummy4PipelineTask, "T4"), 

204 ): 

205 task_def = TaskDef(get_full_type_name(task), task.ConfigClass(), task, label) 

206 tasks.append(task_def) 

207 quantum_set = set() 

208 for dim1, dim2 in ((1, 2), (3, 4)): 

209 if task_def.connections.initInputs: 

210 init_init_ds_type = DatasetType( 

211 task_def.connections.initInput.name, 

212 (), 

213 storageClass=task_def.connections.initInput.storageClass, 

214 universe=universe, 

215 ) 

216 init_refs = [DatasetRef(init_init_ds_type, DataCoordinate.make_empty(universe), run=run)] 

217 else: 

218 init_refs = None 

219 input_ds_type = DatasetType( 

220 task_def.connections.input.name, 

221 task_def.connections.input.dimensions, 

222 storageClass=task_def.connections.input.storageClass, 

223 universe=universe, 

224 ) 

225 data_id = DataCoordinate.standardize({"D1": dim1, "D2": dim2}, universe=universe) 

226 if ref := intermediate_refs.get((input_ds_type, data_id)): 

227 input_refs = [ref] 

228 else: 

229 input_refs = [DatasetRef(input_ds_type, data_id, run=run)] 

230 output_ds_type = DatasetType( 

231 task_def.connections.output.name, 

232 task_def.connections.output.dimensions, 

233 storageClass=task_def.connections.output.storageClass, 

234 universe=universe, 

235 ) 

236 ref = DatasetRef(output_ds_type, data_id, run=run) 

237 intermediate_refs[(output_ds_type, data_id)] = ref 

238 output_refs = [ref] 

239 quantum_set.add( 

240 Quantum( 

241 taskName=task.__qualname__, 

242 dataId=data_id, 

243 taskClass=task, 

244 initInputs=init_refs, 

245 inputs={input_ds_type: input_refs}, 

246 outputs={output_ds_type: output_refs}, 

247 ) 

248 ) 

249 quantum_map[task_def] = quantum_set 

250 qgraph = QuantumGraph(quantum_map, metadata=METADATA) 

251 

252 return qgraph