Coverage for python/lsst/source/injection/utils/make_injection_pipeline.py: 5%

103 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-01 14:09 +0000

1# This file is part of source_injection. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["make_injection_pipeline"] 

25 

26import logging 

27 

28from lsst.analysis.tools.interfaces import AnalysisPipelineTask 

29from lsst.pipe.base import Pipeline 

30 

31 

32def _get_dataset_type_names(conns, fields): 

33 """Return the name of a connection's dataset type.""" 

34 dataset_type_names = set() 

35 for field in fields: 

36 dataset_type_names.add(getattr(conns, field).name) 

37 return dataset_type_names 

38 

39 

40def _parse_config_override(config_override: str) -> tuple[str, str, str]: 

41 """Parse a config override string into a label, a key and a value. 

42 

43 Parameters 

44 ---------- 

45 config_override : `str` 

46 Config override string to parse. 

47 

48 Returns 

49 ------- 

50 label : `str` 

51 Label to override. 

52 key : `str` 

53 Key to override. 

54 value : `str` 

55 Value to override. 

56 

57 Raises 

58 ------ 

59 TypeError 

60 If the config override string cannot be parsed. 

61 """ 

62 try: 

63 label, keyvalue = config_override.split(":", 1) 

64 except ValueError: 

65 raise TypeError( 

66 f"Unrecognized syntax for option 'config': '{config_override}' (does not match pattern " 

67 "(?P<label>.+):(?P<value>.+=.+))" 

68 ) from None 

69 try: 

70 key, value = keyvalue.split("=", 1) 

71 except ValueError as e: 

72 raise TypeError( 

73 f"Could not parse key-value pair '{config_override}' using separator '=', with multiple values " 

74 f"not allowed: {e}" 

75 ) from None 

76 return label, key, value 

77 

78 

79def make_injection_pipeline( 

80 dataset_type_name: str, 

81 reference_pipeline: Pipeline | str, 

82 injection_pipeline: Pipeline | str | None = None, 

83 exclude_subsets: bool = False, 

84 excluded_tasks: set[str] 

85 | str = { 

86 "jointcal", 

87 "gbdesAstrometricFit", 

88 "fgcmBuildFromIsolatedStars", 

89 "fgcmFitCycle", 

90 "fgcmOutputProducts", 

91 }, 

92 prefix: str = "injected_", 

93 instrument: str | None = None, 

94 config: str | list[str] | None = None, 

95 log_level: int = logging.INFO, 

96) -> Pipeline: 

97 """Make an expanded source injection pipeline. 

98 

99 This function takes a reference pipeline definition file in YAML format and 

100 prefixes all post-injection dataset type names with the injected prefix. If 

101 an optional injection pipeline definition YAML file is also provided, the 

102 injection task will be merged into the pipeline. 

103 

104 Unless explicitly excluded, all subsets from the reference pipeline which 

105 contain the task which generates the injection dataset type will also be 

106 updated to include the injection task. 

107 

108 Parameters 

109 ---------- 

110 dataset_type_name : `str` 

111 Name of the dataset type being injected into. 

112 reference_pipeline : Pipeline | `str` 

113 Location of a reference pipeline definition YAML file. 

114 injection_pipeline : Pipeline | `str`, optional 

115 Location of an injection pipeline definition YAML file stub. If not 

116 provided, an attempt to infer the injection pipeline will be made based 

117 on the injected dataset type name. 

118 exclude_subsets : `bool`, optional 

119 If True, do not update pipeline subsets to include the injection task. 

120 excluded_tasks : `set` [`str`] | `str` 

121 Set or comma-separated string of task labels to exclude from the 

122 injection pipeline. 

123 prefix : `str`, optional 

124 Prefix to prepend to each affected post-injection dataset type name. 

125 instrument : `str`, optional 

126 Add instrument overrides. Must be a fully qualified class name. 

127 log_level : `int`, optional 

128 The log level to use for logging. 

129 

130 Returns 

131 ------- 

132 pipeline : `lsst.pipe.base.Pipeline` 

133 An expanded source injection pipeline. 

134 """ 

135 # Instantiate logger. 

136 logger = logging.getLogger(__name__) 

137 logger.setLevel(log_level) 

138 

139 # Load the pipeline and apply config overrides, if supplied. 

140 if isinstance(reference_pipeline, str): 

141 pipeline = Pipeline.fromFile(reference_pipeline) 

142 else: 

143 pipeline = reference_pipeline 

144 if config: 

145 if isinstance(config, str): 

146 config = [config] 

147 for conf in config: 

148 config_label, config_key, config_value = _parse_config_override(conf) 

149 pipeline.addConfigOverride(config_label, config_key, config_value) 

150 

151 # Add an instrument override, if provided. 

152 if instrument: 

153 pipeline.addInstrument(instrument) 

154 

155 # Remove all tasks which are not to be included in the injection pipeline. 

156 if isinstance(excluded_tasks, str): 

157 excluded_tasks = set(excluded_tasks.split(",")) 

158 not_excluded_tasks = set() 

159 for task_label in excluded_tasks: 

160 # First remove tasks from their host subsets, if present. 

161 try: 

162 host_subsets = pipeline.findSubsetsWithLabel(task_label) 

163 except ValueError: 

164 pass 

165 else: 

166 for host_subset in host_subsets: 

167 pipeline.removeLabelFromSubset(host_subset, task_label) 

168 # Then remove the task from the pipeline. 

169 try: 

170 pipeline.removeTask(task_label) 

171 except KeyError: 

172 not_excluded_tasks.add(task_label) 

173 if len(not_excluded_tasks) > 0: 

174 grammar = "Task" if len(not_excluded_tasks) == 1 else "Tasks" 

175 logger.warning( 

176 "%s marked for exclusion not found in the reference pipeline: %s.", 

177 grammar, 

178 ", ".join(sorted(not_excluded_tasks)), 

179 ) 

180 

181 # Determine the set of dataset type names affected by source injection. 

182 all_connection_type_names = set() 

183 injected_types = {dataset_type_name} 

184 precursor_injection_task_labels = set() 

185 # Loop over all tasks in the pipeline. 

186 for taskDef in pipeline.toExpandedPipeline(): 

187 # Add override for Analysis Tools taskDefs. Connections in Analysis 

188 # Tools are dynamically assigned, and so are not able to be modified in 

189 # the same way as a static connection. Instead, we add a config 

190 # override here to the connections.outputName field. This field is 

191 # prepended to all Analysis Tools connections, and so will prepend the 

192 # injection prefix to all plot/metric outputs. Further processing of 

193 # this taskDef will be skipped thereafter. 

194 if issubclass(taskDef.taskClass, AnalysisPipelineTask): 

195 pipeline.addConfigOverride( 

196 taskDef.label, "connections.outputName", prefix + taskDef.config.connections.outputName 

197 ) 

198 continue 

199 

200 conns = taskDef.connections 

201 input_types = _get_dataset_type_names(conns, conns.initInputs | conns.inputs) 

202 output_types = _get_dataset_type_names(conns, conns.initOutputs | conns.outputs) 

203 all_connection_type_names |= input_types | output_types 

204 # Identify the precursor task: allows appending inject task to subset. 

205 if dataset_type_name in output_types: 

206 precursor_injection_task_labels.add(taskDef.label) 

207 # If the task has any injected dataset type names as inputs, add all of 

208 # its outputs to the set of injected types. 

209 if len(input_types & injected_types) > 0: 

210 injected_types |= output_types 

211 # Add the injection prefix to all affected dataset type names. 

212 for field in conns.initInputs | conns.inputs | conns.initOutputs | conns.outputs: 

213 if hasattr(taskDef.config.connections.ConnectionsClass, field): 

214 # If the connection type is not dynamic, modify as usual. 

215 if (conn_type := getattr(conns, field).name) in injected_types: 

216 pipeline.addConfigOverride(taskDef.label, "connections." + field, prefix + conn_type) 

217 else: 

218 # Add log warning if the connection type is dynamic. 

219 logger.warning( 

220 "Dynamic connection %s in task %s is not supported here. This connection will " 

221 "neither be modified nor merged into the output injection pipeline.", 

222 field, 

223 taskDef.label, 

224 ) 

225 # Raise if the injected dataset type does not exist in the pipeline. 

226 if dataset_type_name not in all_connection_type_names: 

227 raise RuntimeError( 

228 f"Dataset type '{dataset_type_name}' not found in the reference pipeline; " 

229 "no connection type edits to be made." 

230 ) 

231 

232 # Attempt to infer the injection pipeline from the dataset type name. 

233 if not injection_pipeline: 

234 match dataset_type_name: 

235 case "postISRCCD": 

236 injection_pipeline = "$SOURCE_INJECTION_DIR/pipelines/inject_exposure.yaml" 

237 case "icExp" | "calexp": 

238 injection_pipeline = "$SOURCE_INJECTION_DIR/pipelines/inject_visit.yaml" 

239 case "deepCoadd" | "deepCoadd_calexp" | "goodSeeingCoadd": 

240 injection_pipeline = "$SOURCE_INJECTION_DIR/pipelines/inject_coadd.yaml" 

241 case _: 

242 # Print a warning rather than a raise, as the user may wish to 

243 # edit connection names without merging an injection pipeline. 

244 logger.warning( 

245 "Unable to infer injection pipeline stub from dataset type name '%s' and none was " 

246 "provided. No injection pipeline will be merged into the output pipeline.", 

247 dataset_type_name, 

248 ) 

249 if injection_pipeline: 

250 logger.info( 

251 "Injected dataset type '%s' used to infer injection pipeline: %s", 

252 dataset_type_name, 

253 injection_pipeline, 

254 ) 

255 

256 # Merge the injection pipeline to the modified pipeline, if provided. 

257 if injection_pipeline: 

258 if isinstance(injection_pipeline, str): 

259 pipeline2 = Pipeline.fromFile(injection_pipeline) 

260 else: 

261 pipeline2 = injection_pipeline 

262 if len(pipeline2) != 1: 

263 raise RuntimeError( 

264 f"The injection pipeline contains {len(pipeline2)} tasks; only one task is allowed." 

265 ) 

266 pipeline.mergePipeline(pipeline2) 

267 # Loop over all injection tasks and modify the connection names. 

268 for injection_taskDef in pipeline2.toExpandedPipeline(): 

269 conns = injection_taskDef.connections 

270 pipeline.addConfigOverride( 

271 injection_taskDef.label, "connections.input_exposure", dataset_type_name 

272 ) 

273 pipeline.addConfigOverride( 

274 injection_taskDef.label, "connections.output_exposure", prefix + dataset_type_name 

275 ) 

276 # Optionally update subsets to include the injection task. 

277 if not exclude_subsets: 

278 for label in precursor_injection_task_labels: 

279 precursor_subsets = pipeline.findSubsetsWithLabel(label) 

280 for subset in precursor_subsets: 

281 pipeline.addLabelToSubset(subset, injection_taskDef.label) 

282 

283 logger.info("Made an injection pipeline containing %d tasks.", len(pipeline)) 

284 return pipeline