Coverage for tests/qg_test_utils.py: 54%

68 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-10-02 08:09 +0000

1# This file is part of ctrl_bps. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27"""QuantumGraph-related utilities to support ctrl_bps testing. 

28""" 

29 

30# Not actually running Quantum so do not need to override 'run' Method 

31# pylint: disable=abstract-method 

32 

33# Many dummy classes for testing. 

34# pylint: disable=missing-class-docstring 

35 

36import lsst.pipe.base.connectionTypes as cT 

37from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum 

38from lsst.pex.config import Field 

39from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, QuantumGraph, TaskDef 

40from lsst.utils.introspection import get_full_type_name 

41 

42METADATA = {"D1": [1, 2, 3]} 

43 

44 

45# For each dummy task, create a Connections, Config, and PipelineTask 

46 

47 

48class Dummy1Connections(PipelineTaskConnections, dimensions=("D1", "D2")): 

49 """Connections class used for tests.""" 

50 

51 initOutput = cT.InitOutput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

52 input = cT.Input(name="Dummy1Input", storageClass="ExposureF", doc="n/a", dimensions=("D1", "D2")) 

53 output = cT.Output(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("D1", "D2")) 

54 

55 

56class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections): 

57 """Config class used for testing.""" 

58 

59 conf1 = Field(dtype=int, default=1, doc="dummy config") 

60 

61 

62class Dummy1PipelineTask(PipelineTask): 

63 """PipelineTask used for testing.""" 

64 

65 ConfigClass = Dummy1Config 

66 

67 

68class Dummy2Connections(PipelineTaskConnections, dimensions=("D1", "D2")): 

69 """Second connections class used for testing.""" 

70 

71 initInput = cT.InitInput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

72 initOutput = cT.InitOutput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

73 input = cT.Input(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("D1", "D2")) 

74 output = cT.Output(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("D1", "D2")) 

75 

76 

77class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections): 

78 """Config class used for second pipeline task.""" 

79 

80 conf1 = Field(dtype=int, default=1, doc="dummy config") 

81 

82 

83class Dummy2PipelineTask(PipelineTask): 

84 """Second test PipelineTask.""" 

85 

86 ConfigClass = Dummy2Config 

87 

88 

89class Dummy3Connections(PipelineTaskConnections, dimensions=("D1", "D2")): 

90 """Third connections class used for testing.""" 

91 

92 initInput = cT.InitInput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

93 initOutput = cT.InitOutput(name="Dummy3InitOutput", storageClass="ExposureF", doc="n/a") 

94 input = cT.Input(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("D1", "D2")) 

95 output = cT.Output(name="Dummy3Output", storageClass="ExposureF", doc="n/a", dimensions=("D1", "D2")) 

96 

97 

98class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections): 

99 """Third config used for testing.""" 

100 

101 conf1 = Field(dtype=int, default=1, doc="dummy config") 

102 

103 

104class Dummy3PipelineTask(PipelineTask): 

105 """Third test PipelineTask.""" 

106 

107 ConfigClass = Dummy3Config 

108 

109 

110# Test if a Task that does not interact with the other Tasks works fine in 

111# the graph. 

112class Dummy4Connections(PipelineTaskConnections, dimensions=("D1", "D2")): 

113 """Fourth connections class used for testing.""" 

114 

115 input = cT.Input(name="Dummy4Input", storageClass="ExposureF", doc="n/a", dimensions=("D1", "D2")) 

116 output = cT.Output(name="Dummy4Output", storageClass="ExposureF", doc="n/a", dimensions=("D1", "D2")) 

117 

118 

119class Dummy4Config(PipelineTaskConfig, pipelineConnections=Dummy4Connections): 

120 """Fourth config used for testing.""" 

121 

122 conf1 = Field(dtype=int, default=1, doc="dummy config") 

123 

124 

125class Dummy4PipelineTask(PipelineTask): 

126 """Fourth test PipelineTask.""" 

127 

128 ConfigClass = Dummy4Config 

129 

130 

131def make_test_quantum_graph(run: str = "run"): 

132 """Create a QuantumGraph for unit tests. 

133 

134 Parameters 

135 ---------- 

136 run : `str`, optional 

137 Name of the RUN collection for output datasets. 

138 

139 Returns 

140 ------- 

141 qgraph : `lsst.pipe.base.QuantumGraph` 

142 A test QuantumGraph looking like the following: 

143 (DummyTask4 is completely independent.) 

144 

145 Numbers in parens are the values for the two dimensions (D1, D2). 

146 

147 T1(1,2) T1(3,4) T4(1,2) T4(3,4) 

148 | | 

149 T2(1,2) T2(3,4) 

150 | | 

151 T3(1,2) T3(3,4) 

152 """ 

153 config = Config( 

154 { 

155 "version": 1, 

156 "skypix": { 

157 "common": "htm7", 

158 "htm": { 

159 "class": "lsst.sphgeom.HtmPixelization", 

160 "max_level": 24, 

161 }, 

162 }, 

163 "elements": { 

164 "D1": { 

165 "keys": [ 

166 { 

167 "name": "id", 

168 "type": "int", 

169 } 

170 ], 

171 "storage": { 

172 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

173 }, 

174 }, 

175 "D2": { 

176 "keys": [ 

177 { 

178 "name": "id", 

179 "type": "int", 

180 } 

181 ], 

182 "storage": { 

183 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

184 }, 

185 }, 

186 }, 

187 "packers": {}, 

188 } 

189 ) 

190 

191 universe = DimensionUniverse(config=config) 

192 # need to make a mapping of TaskDef to set of quantum 

193 quantum_map = {} 

194 tasks = [] 

195 # Map to keep output/intermediate refs. 

196 intermediate_refs: dict[tuple[DatasetType, DataCoordinate], DatasetRef] = {} 

197 for task, label in ( 

198 (Dummy1PipelineTask, "T1"), 

199 (Dummy2PipelineTask, "T2"), 

200 (Dummy3PipelineTask, "T3"), 

201 (Dummy4PipelineTask, "T4"), 

202 ): 

203 task_def = TaskDef(get_full_type_name(task), task.ConfigClass(), task, label) 

204 tasks.append(task_def) 

205 quantum_set = set() 

206 for dim1, dim2 in ((1, 2), (3, 4)): 

207 if task_def.connections.initInputs: 

208 init_init_ds_type = DatasetType( 

209 task_def.connections.initInput.name, 

210 (), 

211 storageClass=task_def.connections.initInput.storageClass, 

212 universe=universe, 

213 ) 

214 init_refs = [DatasetRef(init_init_ds_type, DataCoordinate.makeEmpty(universe), run=run)] 

215 else: 

216 init_refs = None 

217 input_ds_type = DatasetType( 

218 task_def.connections.input.name, 

219 task_def.connections.input.dimensions, 

220 storageClass=task_def.connections.input.storageClass, 

221 universe=universe, 

222 ) 

223 data_id = DataCoordinate.standardize({"D1": dim1, "D2": dim2}, universe=universe) 

224 if ref := intermediate_refs.get((input_ds_type, data_id)): 

225 input_refs = [ref] 

226 else: 

227 input_refs = [DatasetRef(input_ds_type, data_id, run=run)] 

228 output_ds_type = DatasetType( 

229 task_def.connections.output.name, 

230 task_def.connections.output.dimensions, 

231 storageClass=task_def.connections.output.storageClass, 

232 universe=universe, 

233 ) 

234 ref = DatasetRef(output_ds_type, data_id, run=run) 

235 intermediate_refs[(output_ds_type, data_id)] = ref 

236 output_refs = [ref] 

237 quantum_set.add( 

238 Quantum( 

239 taskName=task.__qualname__, 

240 dataId=data_id, 

241 taskClass=task, 

242 initInputs=init_refs, 

243 inputs={input_ds_type: input_refs}, 

244 outputs={output_ds_type: output_refs}, 

245 ) 

246 ) 

247 quantum_map[task_def] = quantum_set 

248 qgraph = QuantumGraph(quantum_map, metadata=METADATA) 

249 

250 return qgraph