Coverage for python/lsst/analysis/tools/tasks/base.py: 19%
129 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-20 09:05 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-20 09:05 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23"""Base class implementation for the classes needed in creating `PipelineTasks`
24which execute `AnalysisTools`.
26The classes defined in this module have all the required behaviors for
27defining, introspecting, and executing `AnalysisTools` against an input dataset
28type.
30Subclasses of these tasks should specify specific datasets to consume in their
31connection classes and should specify a unique name
32"""
34from collections import abc
35from typing import TYPE_CHECKING, Any, Iterable, Mapping, MutableMapping, cast
37if TYPE_CHECKING: 37 ↛ 38line 37 didn't jump to line 38, because the condition on line 37 was never true
38 from lsst.daf.butler import DeferredDatasetHandle
40from lsst.daf.butler import DataCoordinate
41from lsst.pex.config import ListField
42from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct
43from lsst.pipe.base import connectionTypes as ct
44from lsst.pipe.base.butlerQuantumContext import ButlerQuantumContext
45from lsst.pipe.base.connections import InputQuantizedConnection, OutputQuantizedConnection
46from lsst.pipe.tasks.configurableActions import ConfigurableActionStructField
48from ..analysisMetrics.metricMeasurementBundle import MetricMeasurementBundle
49from ..interfaces import AnalysisMetric, AnalysisPlot, KeyedData
52class AnalysisBaseConnections(
53 PipelineTaskConnections, dimensions={}, defaultTemplates={"outputName": "Placeholder"}
54):
55 r"""Base class for Connections used for AnalysisTools PipelineTasks.
57 This class has a pre-defined output connection for the
58 MetricMeasurementMapping. The dataset type name for this connection is
59 determined by the template ``outputName``.
61 Output connections for plots created by `AnalysisPlot`\ s are created
62 dynamically when an instance of the class is created. The init method
63 examines all the `AnalysisPlot` actions specified in the associated
64 `AnalysisBaseConfig` subclass accumulating all the info needed to
65 create the output connections.
67 The dimensions for all of the output connections (metric and plot) will
68 be the same as the dimensions specified for the AnalysisBaseConnections
69 subclass (i.e. quantum dimensions).
70 """
72 metrics = ct.Output(
73 doc="Metrics calculated on input dataset type",
74 name="{outputName}_metrics",
75 storageClass="MetricMeasurementBundle",
76 )
78 def __init__(self, *, config: AnalysisBaseConfig = None): # type: ignore
79 # Validate that the outputName template has been set in config. This
80 # should have been checked early with the configs validate method, but
81 # it is possible for someone to manually create everything in a script
82 # without running validate, so also check it late here.
83 if (outputName := config.connections.outputName) == "Placeholder": # type: ignore
84 raise RuntimeError(
85 "Subclasses must specify an alternative value for the defaultTemplate `outputName`"
86 )
87 super().__init__(config=config)
89 # All arguments must be passed by kw, but python has not method to do
90 # that without specifying a default, so None is used. Validate that
91 # it is not None. This is largely for typing reasons, as in the normal
92 # course of operation code execution paths ensure this will not be None
93 assert config is not None
95 # Set the dimensions for the metric
96 self.metrics = ct.Output(
97 name=self.metrics.name,
98 doc=self.metrics.doc,
99 storageClass=self.metrics.storageClass,
100 dimensions=self.dimensions,
101 multiple=False,
102 isCalibration=False,
103 )
105 # Look for any conflicting names, creating a set of them, as these
106 # will be added to the instance as well as recorded in the outputs
107 # set.
108 existingNames = set(dir(self))
110 # Accumulate all the names to be used from all of the defined
111 # AnalysisPlots.
112 names: list[str] = []
113 for plotAction in config.plots:
114 if plotAction.parameterizedBand:
115 for band in config.bands:
116 names.extend(name.format(band=band) for name in plotAction.getOutputNames())
117 else:
118 names.extend(plotAction.getOutputNames())
120 # For each of the names found, create output connections.
121 for name in names:
122 name = f"{outputName}_{name}"
123 if name in self.outputs or name in existingNames:
124 raise NameError(
125 f"Plot with name {name} conflicts with existing connection"
126 " are two plots named the same?"
127 )
128 outConnection = ct.Output(
129 name=name,
130 storageClass="Plot",
131 doc="Dynamic connection for plotting",
132 dimensions=self.dimensions,
133 )
134 object.__setattr__(self, name, outConnection)
135 self.outputs.add(name)
138class AnalysisBaseConfig(PipelineTaskConfig, pipelineConnections=AnalysisBaseConnections):
139 """Base class for all configs used to define an `AnalysisPipelineTask`
141 This base class defines three fields that should be used in all subclasses,
142 plots, metrics, and bands.
144 The ``plots`` field is where a user configures which `AnalysisPlots` will
145 be run in this `PipelineTask`.
147 Likewise ``metrics`` defines which `AnalysisMetrics` will be run.
149 The bands field specifies which bands will be looped over for
150 `AnalysisTools` which support parameterized bands. I.e. called once for
151 each band in the list.
152 """
154 plots = ConfigurableActionStructField[AnalysisPlot](doc="AnalysisPlots to run with this Task")
155 metrics = ConfigurableActionStructField[AnalysisMetric](doc="AnalysisMetrics to run with this Task")
156 bands = ListField[str](
157 doc="Filter bands on which to run all of the actions", default=["u", "g", "r", "i", "z", "y"]
158 )
160 def validate(self):
161 super().validate()
162 # Validate that the required connections template is set.
163 if self.connections.outputName == "Placeholder": # type: ignore
164 raise RuntimeError("Connections class 'outputName' must have a config explicitly set")
167class _StandinPlotInfo(dict):
168 """This class is an implementation detail to support plots in the instance
169 no PlotInfo object is present in the call to run.
170 """
172 def __missing__(self, key):
173 return ""
176class AnalysisPipelineTask(PipelineTask):
177 """Base class for `PipelineTasks` intended to run `AnalysisTools`.
179 The run method will run all of the `AnalysisMetrics` and `AnalysisPlots`
180 defined in the config class.
182 To support interactive investigations, the actual work is done in
183 ``runMetrics`` and ``runPlots`` methods. These can be called interactively
184 with the same arguments as ``run`` but only the corresponding outputs will
185 be produced.
186 """
188 # Typing config because type checkers dont know about our Task magic
189 config: AnalysisBaseConfig
190 ConfigClass = AnalysisBaseConfig
192 def runPlots(self, data: KeyedData, **kwargs) -> Struct:
193 results = Struct()
194 # allow not sending in plot info
195 if "plotInfo" not in kwargs:
196 kwargs["plotInfo"] = _StandinPlotInfo()
197 for name, action in self.config.plots.items():
198 for selector in action.prep.selectors:
199 if "threshold" in selector.keys():
200 kwargs["plotInfo"]["SN"] = selector.threshold
201 kwargs["plotInfo"]["plotName"] = name
202 match action(data, **kwargs):
203 case abc.Mapping() as val:
204 for n, v in val.items():
205 setattr(results, n, v)
206 case value:
207 setattr(results, name, value)
208 if "SN" not in kwargs["plotInfo"].keys():
209 kwargs["plotInfo"]["SN"] = "-"
210 return results
212 def runMetrics(self, data: KeyedData, **kwargs) -> Struct:
213 metricsMapping = MetricMeasurementBundle()
214 for name, action in self.config.metrics.items():
215 match action(data, **kwargs):
216 case abc.Mapping() as val:
217 results = list(val.values())
218 case val:
219 results = [val]
220 metricsMapping[name] = results # type: ignore
221 return Struct(metrics=metricsMapping)
223 def run(self, *, data: KeyedData | None = None, **kwargs) -> Struct:
224 """Produce the outputs associated with this `PipelineTask`
226 Parameters
227 ----------
228 data : `KeyedData`
229 The input data from which all `AnalysisTools` will run and produce
230 outputs. A side note, the python typing specifies that this can be
231 None, but this is only due to a limitation in python where in order
232 to specify that all arguments be passed only as keywords the
233 argument must be given a default. This argument most not actually
234 be None.
235 **kwargs
236 Additional arguments that are passed through to the `AnalysisTools`
237 specified in the configuration.
239 Returns
240 -------
241 results : `~lsst.pipe.base.Struct`
242 The accumulated results of all the plots and metrics produced by
243 this `PipelineTask`.
245 Raises
246 ------
247 ValueError
248 Raised if the supplied data argument is `None`
249 """
250 if data is None:
251 raise ValueError("data must not be none")
252 results = Struct()
253 plotKey = f"{self.config.connections.outputName}_{{name}}" # type: ignore
254 if "bands" not in kwargs:
255 kwargs["bands"] = list(self.config.bands)
256 kwargs["plotInfo"]["bands"] = kwargs["bands"]
257 for name, value in self.runPlots(data, **kwargs).getDict().items():
258 setattr(results, plotKey.format(name=name), value)
259 for name, value in self.runMetrics(data, **kwargs).getDict().items():
260 setattr(results, name, value)
262 return results
264 def runQuantum(
265 self,
266 butlerQC: ButlerQuantumContext,
267 inputRefs: InputQuantizedConnection,
268 outputRefs: OutputQuantizedConnection,
269 ) -> None:
270 """Override default runQuantum to load the minimal columns necessary
271 to complete the action.
273 Parameters
274 ----------
275 butlerQC : `ButlerQuantumContext`
276 A butler which is specialized to operate in the context of a
277 `lsst.daf.butler.Quantum`.
278 inputRefs : `InputQuantizedConnection`
279 Datastructure whose attribute names are the names that identify
280 connections defined in corresponding `PipelineTaskConnections`
281 class. The values of these attributes are the
282 `lsst.daf.butler.DatasetRef` objects associated with the defined
283 input/prerequisite connections.
284 outputRefs : `OutputQuantizedConnection`
285 Datastructure whose attribute names are the names that identify
286 connections defined in corresponding `PipelineTaskConnections`
287 class. The values of these attributes are the
288 `lsst.daf.butler.DatasetRef` objects associated with the defined
289 output connections.
290 """
291 inputs = butlerQC.get(inputRefs)
292 dataId = butlerQC.quantum.dataId
293 plotInfo = self.parsePlotInfo(inputs, dataId)
294 data = self.loadData(inputs["data"])
295 if "skymap" in inputs.keys():
296 skymap = inputs["skymap"]
297 else:
298 skymap = None
299 outputs = self.run(data=data, plotInfo=plotInfo, skymap=skymap)
300 butlerQC.put(outputs, outputRefs)
302 def _populatePlotInfoWithDataId(
303 self, plotInfo: MutableMapping[str, Any], dataId: DataCoordinate | None
304 ) -> None:
305 """Update the plotInfo with the dataId values.
307 Parameters
308 ----------
309 plotInfo : `dict`
310 The plotInfo dictionary to update.
311 dataId : `lsst.daf.butler.DataCoordinate`
312 The dataId to use to update the plotInfo.
313 """
314 if dataId is not None:
315 for dataInfo in dataId:
316 plotInfo[dataInfo.name] = dataId[dataInfo.name]
318 def parsePlotInfo(
319 self, inputs: Mapping[str, Any] | None, dataId: DataCoordinate | None, connectionName: str = "data"
320 ) -> Mapping[str, str]:
321 """Parse the inputs and dataId to get the information needed to
322 to add to the figure.
324 Parameters
325 ----------
326 inputs: `dict`
327 The inputs to the task
328 dataCoordinate: `lsst.daf.butler.DataCoordinate`
329 The dataId that the task is being run on.
330 connectionName: `str`, optional
331 Name of the input connection to use for determining table name.
333 Returns
334 -------
335 plotInfo : `dict`
336 """
338 if inputs is None:
339 tableName = ""
340 run = ""
341 else:
342 tableName = inputs[connectionName].ref.datasetType.name
343 run = inputs[connectionName].ref.run
345 # Initialize the plot info dictionary
346 plotInfo = {"tableName": tableName, "run": run}
348 self._populatePlotInfoWithDataId(plotInfo, dataId)
349 return plotInfo
351 def loadData(self, handle: DeferredDatasetHandle, names: Iterable[str] | None = None) -> KeyedData:
352 """Load the minimal set of keyed data from the input dataset.
354 Parameters
355 ----------
356 handle : `DeferredDatasetHandle`
357 Handle to load the dataset with only the specified columns.
358 names : `Iterable` of `str`
359 The names of keys to extract from the dataset.
360 If `names` is `None` then the `collectInputNames` method
361 is called to generate the names.
362 For most purposes these are the names of columns to load from
363 a catalog or data frame.
365 Returns
366 -------
367 result: `KeyedData`
368 The dataset with only the specified keys loaded.
369 """
370 if names is None:
371 names = self.collectInputNames()
372 return cast(KeyedData, handle.get(parameters={"columns": names}))
374 def collectInputNames(self) -> Iterable[str]:
375 """Get the names of the inputs.
377 If using the default `loadData` method this will gather the names
378 of the keys to be loaded from an input dataset.
380 Returns
381 -------
382 inputs : `Iterable` of `str`
383 The names of the keys in the `KeyedData` object to extract.
385 """
386 inputs = set()
387 for band in self.config.bands:
388 for name, action in self.config.plots.items():
389 for column, dataType in action.getFormattedInputSchema(band=band):
390 inputs.add(column)
391 for name, action in self.config.metrics.items():
392 for column, dataType in action.getFormattedInputSchema(band=band):
393 inputs.add(column)
394 return inputs