Coverage for python/lsst/analysis/tools/tasks/base.py: 19%
130 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-30 11:37 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-30 11:37 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23"""Base class implementation for the classes needed in creating `PipelineTasks`
24which execute `AnalysisTools`.
26The classes defined in this module have all the required behaviors for
27defining, introspecting, and executing `AnalysisTools` against an input dataset
28type.
30Subclasses of these tasks should specify specific datasets to consume in their
31connection classes and should specify a unique name
32"""
34__all__ = ("AnalysisBaseConfig", "AnalysisPipelineTask")
36from collections import abc
37from typing import TYPE_CHECKING, Any, Iterable, Mapping, MutableMapping, cast
39if TYPE_CHECKING: 39 ↛ 40line 39 didn't jump to line 40, because the condition on line 39 was never true
40 from lsst.daf.butler import DeferredDatasetHandle
42from lsst.daf.butler import DataCoordinate
43from lsst.pex.config import ListField
44from lsst.pex.config.configurableActions import ConfigurableActionStructField
45from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct
46from lsst.pipe.base import connectionTypes as ct
47from lsst.pipe.base.butlerQuantumContext import ButlerQuantumContext
48from lsst.pipe.base.connections import InputQuantizedConnection, OutputQuantizedConnection
50from ..analysisMetrics.metricMeasurementBundle import MetricMeasurementBundle
51from ..interfaces import AnalysisMetric, AnalysisPlot, KeyedData
54class AnalysisBaseConnections(
55 PipelineTaskConnections, dimensions={}, defaultTemplates={"outputName": "Placeholder"}
56):
57 r"""Base class for Connections used for AnalysisTools PipelineTasks.
59 This class has a pre-defined output connection for the
60 MetricMeasurementMapping. The dataset type name for this connection is
61 determined by the template ``outputName``.
63 Output connections for plots created by `AnalysisPlot`\ s are created
64 dynamically when an instance of the class is created. The init method
65 examines all the `AnalysisPlot` actions specified in the associated
66 `AnalysisBaseConfig` subclass accumulating all the info needed to
67 create the output connections.
69 The dimensions for all of the output connections (metric and plot) will
70 be the same as the dimensions specified for the AnalysisBaseConnections
71 subclass (i.e. quantum dimensions).
72 """
74 metrics = ct.Output(
75 doc="Metrics calculated on input dataset type",
76 name="{outputName}_metrics",
77 storageClass="MetricMeasurementBundle",
78 )
80 def __init__(self, *, config: AnalysisBaseConfig = None): # type: ignore
81 # Validate that the outputName template has been set in config. This
82 # should have been checked early with the configs validate method, but
83 # it is possible for someone to manually create everything in a script
84 # without running validate, so also check it late here.
85 if (outputName := config.connections.outputName) == "Placeholder": # type: ignore
86 raise RuntimeError(
87 "Subclasses must specify an alternative value for the defaultTemplate `outputName`"
88 )
89 super().__init__(config=config)
91 # All arguments must be passed by kw, but python has not method to do
92 # that without specifying a default, so None is used. Validate that
93 # it is not None. This is largely for typing reasons, as in the normal
94 # course of operation code execution paths ensure this will not be None
95 assert config is not None
97 # Set the dimensions for the metric
98 self.metrics = ct.Output(
99 name=self.metrics.name,
100 doc=self.metrics.doc,
101 storageClass=self.metrics.storageClass,
102 dimensions=self.dimensions,
103 multiple=False,
104 isCalibration=False,
105 )
107 # Look for any conflicting names, creating a set of them, as these
108 # will be added to the instance as well as recorded in the outputs
109 # set.
110 existingNames = set(dir(self))
112 # Accumulate all the names to be used from all of the defined
113 # AnalysisPlots.
114 names: list[str] = []
115 for plotAction in config.plots:
116 if plotAction.parameterizedBand:
117 for band in config.bands:
118 names.extend(name.format(band=band) for name in plotAction.getOutputNames())
119 else:
120 names.extend(plotAction.getOutputNames())
122 # For each of the names found, create output connections.
123 for name in names:
124 name = f"{outputName}_{name}"
125 if name in self.outputs or name in existingNames:
126 raise NameError(
127 f"Plot with name {name} conflicts with existing connection"
128 " are two plots named the same?"
129 )
130 outConnection = ct.Output(
131 name=name,
132 storageClass="Plot",
133 doc="Dynamic connection for plotting",
134 dimensions=self.dimensions,
135 )
136 object.__setattr__(self, name, outConnection)
137 self.outputs.add(name)
140class AnalysisBaseConfig(PipelineTaskConfig, pipelineConnections=AnalysisBaseConnections):
141 """Base class for all configs used to define an `AnalysisPipelineTask`
143 This base class defines three fields that should be used in all subclasses,
144 plots, metrics, and bands.
146 The ``plots`` field is where a user configures which `AnalysisPlots` will
147 be run in this `PipelineTask`.
149 Likewise ``metrics`` defines which `AnalysisMetrics` will be run.
151 The bands field specifies which bands will be looped over for
152 `AnalysisTools` which support parameterized bands. I.e. called once for
153 each band in the list.
154 """
156 plots = ConfigurableActionStructField[AnalysisPlot](doc="AnalysisPlots to run with this Task")
157 metrics = ConfigurableActionStructField[AnalysisMetric](doc="AnalysisMetrics to run with this Task")
158 bands = ListField[str](
159 doc="Filter bands on which to run all of the actions", default=["u", "g", "r", "i", "z", "y"]
160 )
162 def validate(self):
163 super().validate()
164 # Validate that the required connections template is set.
165 if self.connections.outputName == "Placeholder": # type: ignore
166 raise RuntimeError("Connections class 'outputName' must have a config explicitly set")
169class _StandinPlotInfo(dict):
170 """This class is an implementation detail to support plots in the instance
171 no PlotInfo object is present in the call to run.
172 """
174 def __missing__(self, key):
175 return ""
178class AnalysisPipelineTask(PipelineTask):
179 """Base class for `PipelineTasks` intended to run `AnalysisTools`.
181 The run method will run all of the `AnalysisMetrics` and `AnalysisPlots`
182 defined in the config class.
184 To support interactive investigations, the actual work is done in
185 ``runMetrics`` and ``runPlots`` methods. These can be called interactively
186 with the same arguments as ``run`` but only the corresponding outputs will
187 be produced.
188 """
190 # Typing config because type checkers dont know about our Task magic
191 config: AnalysisBaseConfig
192 ConfigClass = AnalysisBaseConfig
194 def runPlots(self, data: KeyedData, **kwargs) -> Struct:
195 results = Struct()
196 # allow not sending in plot info
197 if "plotInfo" not in kwargs:
198 kwargs["plotInfo"] = _StandinPlotInfo()
199 for name, action in self.config.plots.items():
200 for selector in action.prep.selectors:
201 if "threshold" in selector.keys():
202 kwargs["plotInfo"]["SN"] = selector.threshold
203 kwargs["plotInfo"]["plotName"] = name
204 match action(data, **kwargs):
205 case abc.Mapping() as val:
206 for n, v in val.items():
207 setattr(results, n, v)
208 case value:
209 setattr(results, name, value)
210 if "SN" not in kwargs["plotInfo"].keys():
211 kwargs["plotInfo"]["SN"] = "-"
212 return results
214 def runMetrics(self, data: KeyedData, **kwargs) -> Struct:
215 metricsMapping = MetricMeasurementBundle()
216 for name, action in self.config.metrics.items():
217 match action(data, **kwargs):
218 case abc.Mapping() as val:
219 results = list(val.values())
220 case val:
221 results = [val]
222 metricsMapping[name] = results # type: ignore
223 return Struct(metrics=metricsMapping)
225 def run(self, *, data: KeyedData | None = None, **kwargs) -> Struct:
226 """Produce the outputs associated with this `PipelineTask`
228 Parameters
229 ----------
230 data : `KeyedData`
231 The input data from which all `AnalysisTools` will run and produce
232 outputs. A side note, the python typing specifies that this can be
233 None, but this is only due to a limitation in python where in order
234 to specify that all arguments be passed only as keywords the
235 argument must be given a default. This argument most not actually
236 be None.
237 **kwargs
238 Additional arguments that are passed through to the `AnalysisTools`
239 specified in the configuration.
241 Returns
242 -------
243 results : `~lsst.pipe.base.Struct`
244 The accumulated results of all the plots and metrics produced by
245 this `PipelineTask`.
247 Raises
248 ------
249 ValueError
250 Raised if the supplied data argument is `None`
251 """
252 if data is None:
253 raise ValueError("data must not be none")
254 results = Struct()
255 plotKey = f"{self.config.connections.outputName}_{{name}}" # type: ignore
256 if "bands" not in kwargs:
257 kwargs["bands"] = list(self.config.bands)
258 kwargs["plotInfo"]["bands"] = kwargs["bands"]
259 for name, value in self.runPlots(data, **kwargs).getDict().items():
260 setattr(results, plotKey.format(name=name), value)
261 for name, value in self.runMetrics(data, **kwargs).getDict().items():
262 setattr(results, name, value)
264 return results
266 def runQuantum(
267 self,
268 butlerQC: ButlerQuantumContext,
269 inputRefs: InputQuantizedConnection,
270 outputRefs: OutputQuantizedConnection,
271 ) -> None:
272 """Override default runQuantum to load the minimal columns necessary
273 to complete the action.
275 Parameters
276 ----------
277 butlerQC : `ButlerQuantumContext`
278 A butler which is specialized to operate in the context of a
279 `lsst.daf.butler.Quantum`.
280 inputRefs : `InputQuantizedConnection`
281 Datastructure whose attribute names are the names that identify
282 connections defined in corresponding `PipelineTaskConnections`
283 class. The values of these attributes are the
284 `lsst.daf.butler.DatasetRef` objects associated with the defined
285 input/prerequisite connections.
286 outputRefs : `OutputQuantizedConnection`
287 Datastructure whose attribute names are the names that identify
288 connections defined in corresponding `PipelineTaskConnections`
289 class. The values of these attributes are the
290 `lsst.daf.butler.DatasetRef` objects associated with the defined
291 output connections.
292 """
293 inputs = butlerQC.get(inputRefs)
294 dataId = butlerQC.quantum.dataId
295 plotInfo = self.parsePlotInfo(inputs, dataId)
296 data = self.loadData(inputs["data"])
297 if "skymap" in inputs.keys():
298 skymap = inputs["skymap"]
299 else:
300 skymap = None
301 outputs = self.run(data=data, plotInfo=plotInfo, skymap=skymap)
302 butlerQC.put(outputs, outputRefs)
304 def _populatePlotInfoWithDataId(
305 self, plotInfo: MutableMapping[str, Any], dataId: DataCoordinate | None
306 ) -> None:
307 """Update the plotInfo with the dataId values.
309 Parameters
310 ----------
311 plotInfo : `dict`
312 The plotInfo dictionary to update.
313 dataId : `lsst.daf.butler.DataCoordinate`
314 The dataId to use to update the plotInfo.
315 """
316 if dataId is not None:
317 for dataInfo in dataId:
318 plotInfo[dataInfo.name] = dataId[dataInfo.name]
320 def parsePlotInfo(
321 self, inputs: Mapping[str, Any] | None, dataId: DataCoordinate | None, connectionName: str = "data"
322 ) -> Mapping[str, str]:
323 """Parse the inputs and dataId to get the information needed to
324 to add to the figure.
326 Parameters
327 ----------
328 inputs: `dict`
329 The inputs to the task
330 dataCoordinate: `lsst.daf.butler.DataCoordinate`
331 The dataId that the task is being run on.
332 connectionName: `str`, optional
333 Name of the input connection to use for determining table name.
335 Returns
336 -------
337 plotInfo : `dict`
338 """
340 if inputs is None:
341 tableName = ""
342 run = ""
343 else:
344 tableName = inputs[connectionName].ref.datasetType.name
345 run = inputs[connectionName].ref.run
347 # Initialize the plot info dictionary
348 plotInfo = {"tableName": tableName, "run": run}
350 self._populatePlotInfoWithDataId(plotInfo, dataId)
351 return plotInfo
353 def loadData(self, handle: DeferredDatasetHandle, names: Iterable[str] | None = None) -> KeyedData:
354 """Load the minimal set of keyed data from the input dataset.
356 Parameters
357 ----------
358 handle : `DeferredDatasetHandle`
359 Handle to load the dataset with only the specified columns.
360 names : `Iterable` of `str`
361 The names of keys to extract from the dataset.
362 If `names` is `None` then the `collectInputNames` method
363 is called to generate the names.
364 For most purposes these are the names of columns to load from
365 a catalog or data frame.
367 Returns
368 -------
369 result: `KeyedData`
370 The dataset with only the specified keys loaded.
371 """
372 if names is None:
373 names = self.collectInputNames()
374 return cast(KeyedData, handle.get(parameters={"columns": names}))
376 def collectInputNames(self) -> Iterable[str]:
377 """Get the names of the inputs.
379 If using the default `loadData` method this will gather the names
380 of the keys to be loaded from an input dataset.
382 Returns
383 -------
384 inputs : `Iterable` of `str`
385 The names of the keys in the `KeyedData` object to extract.
387 """
388 inputs = set()
389 for band in self.config.bands:
390 for name, action in self.config.plots.items():
391 for column, dataType in action.getFormattedInputSchema(band=band):
392 inputs.add(column)
393 for name, action in self.config.metrics.items():
394 for column, dataType in action.getFormattedInputSchema(band=band):
395 inputs.add(column)
396 return inputs