Coverage for python/lsst/analysis/tools/tasks/base.py: 18%

129 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-16 02:54 -0800

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23"""Base class implementation for the classes needed in creating `PipelineTasks` 

24which execute `AnalysisTools`. 

25 

26The classes defined in this module have all the required behaviors for 

27defining, introspecting, and executing `AnalysisTools` against an input dataset 

28type. 

29 

30Subclasses of these tasks should specify specific datasets to consume in their 

31connection classes and should specify a unique name 

32""" 

33 

34from collections import abc 

35from typing import TYPE_CHECKING, Any, Iterable, Mapping, cast 

36 

37if TYPE_CHECKING: 37 ↛ 38line 37 didn't jump to line 38, because the condition on line 37 was never true

38 from lsst.daf.butler import DeferredDatasetHandle 

39 

40from lsst.daf.butler import DataCoordinate 

41from lsst.pex.config import ListField 

42from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct 

43from lsst.pipe.base import connectionTypes as ct 

44from lsst.pipe.base.butlerQuantumContext import ButlerQuantumContext 

45from lsst.pipe.base.connections import InputQuantizedConnection, OutputQuantizedConnection 

46from lsst.pipe.tasks.configurableActions import ConfigurableActionStructField 

47 

48from ..analysisMetrics.metricMeasurementBundle import MetricMeasurementBundle 

49from ..interfaces import AnalysisMetric, AnalysisPlot, KeyedData 

50 

51 

52class AnalysisBaseConnections( 

53 PipelineTaskConnections, dimensions={}, defaultTemplates={"outputName": "Placeholder"} 

54): 

55 r"""Base class for Connections used for AnalysisTools PipelineTasks. 

56 

57 This class has a pre-defined output connection for the 

58 MetricMeasurementMapping. The dataset type name for this connection is 

59 determined by the template ``outputName``. 

60 

61 Output connections for plots created by `AnalysisPlot`\ s are created 

62 dynamically when an instance of the class is created. The init method 

63 examines all the `AnalysisPlot` actions specified in the associated 

64 `AnalysisBaseConfig` subclass accumulating all the info needed to 

65 create the output connections. 

66 

67 The dimensions for all of the output connections (metric and plot) will 

68 be the same as the dimensions specified for the AnalysisBaseConnections 

69 subclass (i.e. quantum dimensions). 

70 """ 

71 

72 metrics = ct.Output( 

73 doc="Metrics calculated on input dataset type", 

74 name="{outputName}_metrics", 

75 storageClass="MetricMeasurementBundle", 

76 ) 

77 

78 def __init__(self, *, config: AnalysisBaseConfig = None): # type: ignore 

79 # Validate that the outputName template has been set in config. This 

80 # should have been checked early with the configs validate method, but 

81 # it is possible for someone to manually create everything in a script 

82 # without running validate, so also check it late here. 

83 if (outputName := config.connections.outputName) == "Placeholder": # type: ignore 

84 raise RuntimeError( 

85 "Subclasses must specify an alternative value for the defaultTemplate `outputName`" 

86 ) 

87 super().__init__(config=config) 

88 

89 # All arguments must be passed by kw, but python has not method to do 

90 # that without specifying a default, so None is used. Validate that 

91 # it is not None. This is largely for typing reasons, as in the normal 

92 # course of operation code execution paths ensure this will not be None 

93 assert config is not None 

94 

95 # Set the dimensions for the metric 

96 self.metrics = ct.Output( 

97 name=self.metrics.name, 

98 doc=self.metrics.doc, 

99 storageClass=self.metrics.storageClass, 

100 dimensions=self.dimensions, 

101 multiple=False, 

102 isCalibration=False, 

103 ) 

104 

105 # Look for any conflicting names, creating a set of them, as these 

106 # will be added to the instance as well as recorded in the outputs 

107 # set. 

108 existingNames = set(dir(self)) 

109 

110 # Accumulate all the names to be used from all of the defined 

111 # AnalysisPlots. 

112 names: list[str] = [] 

113 for plotAction in config.plots: 

114 if plotAction.parameterizedBand: 

115 for band in config.bands: 

116 names.extend(name.format(band=band) for name in plotAction.getOutputNames()) 

117 else: 

118 names.extend(plotAction.getOutputNames()) 

119 

120 # For each of the names found, create output connections. 

121 for name in names: 

122 name = f"{outputName}_{name}" 

123 if name in self.outputs or name in existingNames: 

124 raise NameError( 

125 f"Plot with name {name} conflicts with existing connection" 

126 " are two plots named the same?" 

127 ) 

128 outConnection = ct.Output( 

129 name=name, 

130 storageClass="Plot", 

131 doc="Dynamic connection for plotting", 

132 dimensions=self.dimensions, 

133 ) 

134 object.__setattr__(self, name, outConnection) 

135 self.outputs.add(name) 

136 

137 

138class AnalysisBaseConfig(PipelineTaskConfig, pipelineConnections=AnalysisBaseConnections): 

139 """Base class for all configs used to define an `AnalysisPipelineTask` 

140 

141 This base class defines three fields that should be used in all subclasses, 

142 plots, metrics, and bands. 

143 

144 The ``plots`` field is where a user configures which `AnalysisPlots` will 

145 be run in this `PipelineTask`. 

146 

147 Likewise ``metrics`` defines which `AnalysisMetrics` will be run. 

148 

149 The bands field specifies which bands will be looped over for 

150 `AnalysisTools` which support parameterized bands. I.e. called once for 

151 each band in the list. 

152 """ 

153 

154 plots = ConfigurableActionStructField[AnalysisPlot](doc="AnalysisPlots to run with this Task") 

155 metrics = ConfigurableActionStructField[AnalysisMetric](doc="AnalysisMetrics to run with this Task") 

156 bands = ListField[str]( 

157 doc="Filter bands on which to run all of the actions", default=["u", "g", "r", "i", "z", "y"] 

158 ) 

159 

160 def validate(self): 

161 super().validate() 

162 # Validate that the required connections template is set. 

163 if self.connections.outputName == "Placeholder": # type: ignore 

164 raise RuntimeError("Connections class 'outputName' must have a config explicitly set") 

165 

166 

167class _StandinPlotInfo(dict): 

168 """This class is an implementation detail to support plots in the instance 

169 no PlotInfo object is present in the call to run. 

170 """ 

171 

172 def __missing__(self, key): 

173 return "" 

174 

175 

176class AnalysisPipelineTask(PipelineTask): 

177 """Base class for `PipelineTasks` intended to run `AnalysisTools`. 

178 

179 The run method will run all of the `AnalysisMetrics` and `AnalysisPlots` 

180 defined in the config class. 

181 

182 To support interactive investigations, the actual work is done in 

183 ``runMetrics`` and ``runPlots`` methods. These can be called interactively 

184 with the same arguments as ``run`` but only the corresponding outputs will 

185 be produced. 

186 """ 

187 

188 # Typing config because type checkers dont know about our Task magic 

189 config: AnalysisBaseConfig 

190 ConfigClass = AnalysisBaseConfig 

191 

192 def runPlots(self, data: KeyedData, **kwargs) -> Struct: 

193 results = Struct() 

194 # allow not sending in plot info 

195 if "plotInfo" not in kwargs: 

196 kwargs["plotInfo"] = _StandinPlotInfo() 

197 for name, action in self.config.plots.items(): 

198 for selector in action.prep.selectors: 

199 if "threshold" in selector.keys(): 

200 kwargs["plotInfo"]["SN"] = selector.threshold 

201 kwargs["plotInfo"]["plotName"] = name 

202 match action(data, **kwargs): 

203 case abc.Mapping() as val: 

204 for n, v in val.items(): 

205 setattr(results, n, v) 

206 case value: 

207 setattr(results, name, value) 

208 if "SN" not in kwargs["plotInfo"].keys(): 

209 kwargs["plotInfo"]["SN"] = "-" 

210 return results 

211 

212 def runMetrics(self, data: KeyedData, **kwargs) -> Struct: 

213 metricsMapping = MetricMeasurementBundle() 

214 for name, action in self.config.metrics.items(): 

215 match action(data, **kwargs): 

216 case abc.Mapping() as val: 

217 results = list(val.values()) 

218 case val: 

219 results = [val] 

220 metricsMapping[name] = results # type: ignore 

221 return Struct(metrics=metricsMapping) 

222 

223 def run(self, *, data: KeyedData | None = None, **kwargs) -> Struct: 

224 """Produce the outputs associated with this `PipelineTask` 

225 

226 Parameters 

227 ---------- 

228 data : `KeyedData` 

229 The input data from which all `AnalysisTools` will run and produce 

230 outputs. A side note, the python typing specifies that this can be 

231 None, but this is only due to a limitation in python where in order 

232 to specify that all arguments be passed only as keywords the 

233 argument must be given a default. This argument most not actually 

234 be None. 

235 **kwargs 

236 Additional arguments that are passed through to the `AnalysisTools` 

237 specified in the configuration. 

238 

239 Returns 

240 ------- 

241 results : `~lsst.pipe.base.Struct` 

242 The accumulated results of all the plots and metrics produced by 

243 this `PipelineTask`. 

244 

245 Raises 

246 ------ 

247 ValueError 

248 Raised if the supplied data argument is `None` 

249 """ 

250 if data is None: 

251 raise ValueError("data must not be none") 

252 results = Struct() 

253 plotKey = f"{self.config.connections.outputName}_{{name}}" # type: ignore 

254 if "bands" not in kwargs: 

255 kwargs["bands"] = list(self.config.bands) 

256 kwargs["plotInfo"]["bands"] = kwargs["bands"] 

257 for name, value in self.runPlots(data, **kwargs).getDict().items(): 

258 setattr(results, plotKey.format(name=name), value) 

259 for name, value in self.runMetrics(data, **kwargs).getDict().items(): 

260 setattr(results, name, value) 

261 

262 return results 

263 

264 def runQuantum( 

265 self, 

266 butlerQC: ButlerQuantumContext, 

267 inputRefs: InputQuantizedConnection, 

268 outputRefs: OutputQuantizedConnection, 

269 ) -> None: 

270 """Override default runQuantum to load the minimal columns necessary 

271 to complete the action. 

272 

273 Parameters 

274 ---------- 

275 butlerQC : `ButlerQuantumContext` 

276 A butler which is specialized to operate in the context of a 

277 `lsst.daf.butler.Quantum`. 

278 inputRefs : `InputQuantizedConnection` 

279 Datastructure whose attribute names are the names that identify 

280 connections defined in corresponding `PipelineTaskConnections` 

281 class. The values of these attributes are the 

282 `lsst.daf.butler.DatasetRef` objects associated with the defined 

283 input/prerequisite connections. 

284 outputRefs : `OutputQuantizedConnection` 

285 Datastructure whose attribute names are the names that identify 

286 connections defined in corresponding `PipelineTaskConnections` 

287 class. The values of these attributes are the 

288 `lsst.daf.butler.DatasetRef` objects associated with the defined 

289 output connections. 

290 """ 

291 inputs = butlerQC.get(inputRefs) 

292 dataId = butlerQC.quantum.dataId 

293 if dataId is not None: 

294 dataId = DataCoordinate.standardize(dataId, universe=butlerQC.registry.dimensions) 

295 plotInfo = self.parsePlotInfo(inputs, dataId) 

296 data = self.loadData(inputs["data"]) 

297 if "skymap" in inputs.keys(): 

298 skymap = inputs["skymap"] 

299 else: 

300 skymap = None 

301 outputs = self.run(data=data, plotInfo=plotInfo, skymap=skymap) 

302 butlerQC.put(outputs, outputRefs) 

303 

304 def parsePlotInfo( 

305 self, inputs: Mapping[str, Any], dataId: DataCoordinate | None, connectionName: str = "data" 

306 ) -> Mapping[str, str]: 

307 """Parse the inputs and dataId to get the information needed to 

308 to add to the figure. 

309 

310 Parameters 

311 ---------- 

312 inputs: `dict` 

313 The inputs to the task 

314 dataCoordinate: `lsst.daf.butler.DataCoordinate` 

315 The dataId that the task is being run on. 

316 connectionName: `str` 

317 Name of the input connection to use for determining table name. 

318 

319 Returns 

320 ------- 

321 plotInfo : `dict` 

322 """ 

323 

324 if inputs is None: 

325 tableName = "" 

326 run = "" 

327 else: 

328 tableName = inputs[connectionName].ref.datasetType.name 

329 run = inputs[connectionName].ref.run 

330 

331 plotInfo = {"tableName": tableName, "run": run} 

332 

333 if dataId is not None: 

334 for dataInfo in dataId: 

335 plotInfo[dataInfo.name] = dataId[dataInfo.name] 

336 

337 return plotInfo 

338 

339 def loadData(self, handle: DeferredDatasetHandle, names: Iterable[str] | None = None) -> KeyedData: 

340 """Load the minimal set of keyed data from the input dataset. 

341 

342 Parameters 

343 ---------- 

344 handle : `DeferredDatasetHandle` 

345 Handle to load the dataset with only the specified columns. 

346 names : `Iterable` of `str` 

347 The names of keys to extract from the dataset. 

348 If `names` is `None` then the `collectInputNames` method 

349 is called to generate the names. 

350 For most purposes these are the names of columns to load from 

351 a catalog or data frame. 

352 

353 Returns 

354 ------- 

355 result: `KeyedData` 

356 The dataset with only the specified keys loaded. 

357 """ 

358 if names is None: 

359 names = self.collectInputNames() 

360 return cast(KeyedData, handle.get(parameters={"columns": names})) 

361 

362 def collectInputNames(self) -> Iterable[str]: 

363 """Get the names of the inputs. 

364 

365 If using the default `loadData` method this will gather the names 

366 of the keys to be loaded from an input dataset. 

367 

368 Returns 

369 ------- 

370 inputs : `Iterable` of `str` 

371 The names of the keys in the `KeyedData` object to extract. 

372 

373 """ 

374 inputs = set() 

375 for band in self.config.bands: 

376 for name, action in self.config.plots.items(): 

377 for column, dataType in action.getFormattedInputSchema(band=band): 

378 inputs.add(column) 

379 for name, action in self.config.metrics.items(): 

380 for column, dataType in action.getFormattedInputSchema(band=band): 

381 inputs.add(column) 

382 return inputs