Coverage for python/lsst/analysis/tools/interfaces/_task.py: 21%
143 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-10 10:36 +0000
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-10 10:36 +0000
1# This file is part of analysis_tools. #
2# Developed for the LSST Data Management System.
3# This product includes software developed by the LSST Project
4# (https://www.lsst.org).
5# See the COPYRIGHT file at the top-level directory of this distribution
6# for details of code ownership.
7#
8# This program is free software: you can redistribute it and/or modify
9# it under the terms of the GNU General Public License as published by
10# the Free Software Foundation, either version 3 of the License, or
11# (at your option) any later version.
12#
13# This program is distributed in the hope that it will be useful,
14# but WITHOUT ANY WARRANTY; without even the implied warranty of
15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16# GNU General Public License for more details.
17#
18# You should have received a copy of the GNU General Public License
19# along with this program. If not, see <https://www.gnu.org/licenses/>.
20from __future__ import annotations
22"""Base class implementation for the classes needed in creating `PipelineTasks`
23which execute `AnalysisTools`.
25The classes defined in this module have all the required behaviors for
26defining, introspecting, and executing `AnalysisTools` against an input dataset
27type.
29Subclasses of these tasks should specify specific datasets to consume in their
30connection classes and should specify a unique name
31"""
33__all__ = ("AnalysisBaseConnections", "AnalysisBaseConfig", "AnalysisPipelineTask")
35from copy import deepcopy
36from typing import TYPE_CHECKING, Any, Iterable, Mapping, MutableMapping, cast
38from lsst.verify import Measurement
40if TYPE_CHECKING: 40 ↛ 41line 40 didn't jump to line 41, because the condition on line 40 was never true
41 from lsst.daf.butler import DeferredDatasetHandle
43from lsst.daf.butler import DataCoordinate
44from lsst.pex.config import Field, ListField
45from lsst.pex.config.configurableActions import ConfigurableActionStructField
46from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct
47from lsst.pipe.base import connectionTypes as ct
48from lsst.pipe.base.butlerQuantumContext import ButlerQuantumContext
49from lsst.pipe.base.connections import InputQuantizedConnection, OutputQuantizedConnection
51from ._actions import JointAction, MetricAction, NoMetric
52from ._analysisTools import AnalysisTool
53from ._interfaces import KeyedData, PlotTypes
54from ._metricMeasurementBundle import MetricMeasurementBundle
57class AnalysisBaseConnections(
58 PipelineTaskConnections, dimensions={}, defaultTemplates={"outputName": "Placeholder"}
59):
60 r"""Base class for Connections used for AnalysisTools PipelineTasks.
62 This class has a pre-defined output connection for the
63 MetricMeasurementMapping. The dataset type name for this connection is
64 determined by the template ``outputName``.
66 Output connections for plots created by `AnalysisPlot`\ s are created
67 dynamically when an instance of the class is created. The init method
68 examines all the `AnalysisPlot` actions specified in the associated
69 `AnalysisBaseConfig` subclass accumulating all the info needed to
70 create the output connections.
72 The dimensions for all of the output connections (metric and plot) will
73 be the same as the dimensions specified for the AnalysisBaseConnections
74 subclass (i.e. quantum dimensions).
75 """
77 metrics = ct.Output(
78 doc="Metrics calculated on input dataset type",
79 name="{outputName}_metrics",
80 storageClass="MetricMeasurementBundle",
81 )
83 def __init__(self, *, config: AnalysisBaseConfig = None): # type: ignore
84 # Validate that the outputName template has been set in config. This
85 # should have been checked early with the configs validate method, but
86 # it is possible for someone to manually create everything in a script
87 # without running validate, so also check it late here.
88 if (outputName := config.connections.outputName) == "Placeholder": # type: ignore
89 raise RuntimeError(
90 "Subclasses must specify an alternative value for the defaultTemplate `outputName`"
91 )
92 super().__init__(config=config)
94 # All arguments must be passed by kw, but python has not method to do
95 # that without specifying a default, so None is used. Validate that
96 # it is not None. This is largely for typing reasons, as in the normal
97 # course of operation code execution paths ensure this will not be None
98 assert config is not None
100 for tool in config.atools:
101 match tool.produce:
102 case JointAction():
103 if isinstance(tool.produce.metric, NoMetric):
104 continue
105 if len(tool.produce.metric.units) != 0:
106 hasMetrics = True
107 break
108 case MetricAction():
109 hasMetrics = True
110 break
111 else:
112 hasMetrics = False
114 # Set the dimensions for the metric
115 if hasMetrics:
116 self.metrics = ct.Output(
117 name=self.metrics.name,
118 doc=self.metrics.doc,
119 storageClass=self.metrics.storageClass,
120 dimensions=self.dimensions,
121 multiple=False,
122 isCalibration=False,
123 )
124 else:
125 # There are no metrics to produce, remove the output connection
126 self.outputs.remove("metrics")
128 # Look for any conflicting names, creating a set of them, as these
129 # will be added to the instance as well as recorded in the outputs
130 # set.
131 existingNames = set(dir(self))
133 # Accumulate all the names to be used from all of the defined
134 # AnalysisPlots.
135 names: list[str] = []
136 for action in config.atools:
137 if action.parameterizedBand:
138 for band in config.bands:
139 names.extend(name.format(band=band) for name in action.getOutputNames())
140 else:
141 names.extend(action.getOutputNames())
143 # For each of the names found, create output connections.
144 for name in names:
145 name = f"{outputName}_{name}"
146 if name in self.outputs or name in existingNames:
147 raise NameError(
148 f"Plot with name {name} conflicts with existing connection"
149 " are two plots named the same?"
150 )
151 outConnection = ct.Output(
152 name=name,
153 storageClass="Plot",
154 doc="Dynamic connection for plotting",
155 dimensions=self.dimensions,
156 )
157 object.__setattr__(self, name, outConnection)
158 self.outputs.add(name)
161class AnalysisBaseConfig(PipelineTaskConfig, pipelineConnections=AnalysisBaseConnections):
162 """Base class for all configs used to define an `AnalysisPipelineTask`.
164 This base class defines two fields that should be used in all subclasses,
165 atools, and bands.
167 The ``atools`` field is where the user configures which analysis tools will
168 be run as part of this `PipelineTask`.
170 The bands field specifies which bands will be looped over for
171 `AnalysisTools` which support parameterized bands. I.e. called once for
172 each band in the list.
173 """
175 atools = ConfigurableActionStructField[AnalysisTool](
176 doc="The analysis tools that are to be run by this task at execution"
177 )
178 # Temporarally alias these for backwards compatibility
179 plots = atools
180 metrics = atools
181 bands = ListField[str](
182 doc="Filter bands on which to run all of the actions", default=["u", "g", "r", "i", "z", "y"]
183 )
184 metric_tags = ListField[str](
185 doc="List of tags which will be added to all configurable actions", default=[]
186 )
187 dataset_identifier = Field[str](doc="An identifier to be associated with output Metrics", optional=True)
188 reference_package = Field[str](
189 doc="A package who's version, at the time of metric upload to a "
190 "time series database, will be converted to a timestamp of when "
191 "that version was produced",
192 default="lsst_distrib",
193 )
194 timestamp_version = Field[str]( 194 ↛ exitline 194 didn't jump to the function exit
195 doc="Which time stamp should be used as the reference timestamp for a "
196 "metric in a time series database, valid values are; "
197 "reference_package_timestamp, run_timestamp, current_timestamp, "
198 "and dataset_timestamp",
199 default="run_timestamp",
200 check=lambda x: x
201 in ("reference_package_timestamp", "run_timestamp", "current_timestamp", "dataset_timestamp"),
202 )
204 def freeze(self):
205 # Copy the meta configuration values to each of the configured tools
206 # only do this if the tool has not been further specialized
207 if not self._frozen:
208 for tool in self.atools:
209 for tag in self.metric_tags:
210 tool.metric_tags.insert(-1, tag)
211 super().freeze()
213 def validate(self):
214 super().validate()
215 # Validate that the required connections template is set.
216 if self.connections.outputName == "Placeholder": # type: ignore
217 raise RuntimeError("Connections class 'outputName' must have a config explicitly set")
220class _StandinPlotInfo(dict):
221 """This class is an implementation detail to support plots in the instance
222 no PlotInfo object is present in the call to run.
223 """
225 def __missing__(self, key):
226 return ""
229class AnalysisPipelineTask(PipelineTask):
230 """Base class for `PipelineTasks` intended to run `AnalysisTools`.
232 The run method will run all of the `AnalysisTools` defined in the config
233 class.
234 """
236 # Typing config because type checkers dont know about our Task magic
237 config: AnalysisBaseConfig
238 ConfigClass = AnalysisBaseConfig
240 def _runTools(self, data: KeyedData, **kwargs) -> Struct:
241 results = Struct()
242 results.metrics = MetricMeasurementBundle(
243 dataset_identifier=self.config.dataset_identifier,
244 reference_package=self.config.reference_package,
245 timestamp_version=self.config.timestamp_version,
246 )
247 # copy plot info to be sure each action sees its own copy
248 plotInfo = kwargs.get("plotInfo")
249 plotKey = f"{self.config.connections.outputName}_{{name}}"
250 for name, action in self.config.atools.items():
251 kwargs["plotInfo"] = deepcopy(plotInfo)
252 actionResult = action(data, **kwargs)
253 metricAccumulate = []
254 for resultName, value in actionResult.items():
255 match value:
256 case PlotTypes():
257 setattr(results, plotKey.format(name=resultName), value)
258 case Measurement():
259 metricAccumulate.append(value)
260 # only add the metrics if there are some
261 if metricAccumulate:
262 results.metrics[name] = metricAccumulate
263 return results
265 def run(self, *, data: KeyedData | None = None, **kwargs) -> Struct:
266 """Produce the outputs associated with this `PipelineTask`
268 Parameters
269 ----------
270 data : `KeyedData`
271 The input data from which all `AnalysisTools` will run and produce
272 outputs. A side note, the python typing specifies that this can be
273 None, but this is only due to a limitation in python where in order
274 to specify that all arguments be passed only as keywords the
275 argument must be given a default. This argument most not actually
276 be None.
277 **kwargs
278 Additional arguments that are passed through to the `AnalysisTools`
279 specified in the configuration.
281 Returns
282 -------
283 results : `~lsst.pipe.base.Struct`
284 The accumulated results of all the plots and metrics produced by
285 this `PipelineTask`.
287 Raises
288 ------
289 ValueError
290 Raised if the supplied data argument is `None`
291 """
292 if data is None:
293 raise ValueError("data must not be none")
294 if "bands" not in kwargs:
295 kwargs["bands"] = list(self.config.bands)
296 if "plotInfo" not in kwargs:
297 kwargs["plotInfo"] = _StandinPlotInfo()
298 kwargs["plotInfo"]["bands"] = kwargs["bands"]
299 if "SN" not in kwargs["plotInfo"].keys():
300 kwargs["plotInfo"]["SN"] = "-"
301 return self._runTools(data, **kwargs)
303 def runQuantum(
304 self,
305 butlerQC: ButlerQuantumContext,
306 inputRefs: InputQuantizedConnection,
307 outputRefs: OutputQuantizedConnection,
308 ) -> None:
309 """Override default runQuantum to load the minimal columns necessary
310 to complete the action.
312 Parameters
313 ----------
314 butlerQC : `ButlerQuantumContext`
315 A butler which is specialized to operate in the context of a
316 `lsst.daf.butler.Quantum`.
317 inputRefs : `InputQuantizedConnection`
318 Datastructure whose attribute names are the names that identify
319 connections defined in corresponding `PipelineTaskConnections`
320 class. The values of these attributes are the
321 `lsst.daf.butler.DatasetRef` objects associated with the defined
322 input/prerequisite connections.
323 outputRefs : `OutputQuantizedConnection`
324 Datastructure whose attribute names are the names that identify
325 connections defined in corresponding `PipelineTaskConnections`
326 class. The values of these attributes are the
327 `lsst.daf.butler.DatasetRef` objects associated with the defined
328 output connections.
329 """
330 inputs = butlerQC.get(inputRefs)
331 dataId = butlerQC.quantum.dataId
332 plotInfo = self.parsePlotInfo(inputs, dataId)
333 data = self.loadData(inputs["data"])
334 if "skymap" in inputs.keys():
335 skymap = inputs["skymap"]
336 else:
337 skymap = None
338 outputs = self.run(data=data, plotInfo=plotInfo, skymap=skymap)
339 butlerQC.put(outputs, outputRefs)
341 def _populatePlotInfoWithDataId(
342 self, plotInfo: MutableMapping[str, Any], dataId: DataCoordinate | None
343 ) -> None:
344 """Update the plotInfo with the dataId values.
346 Parameters
347 ----------
348 plotInfo : `dict`
349 The plotInfo dictionary to update.
350 dataId : `lsst.daf.butler.DataCoordinate`
351 The dataId to use to update the plotInfo.
352 """
353 if dataId is not None:
354 for dataInfo in dataId:
355 plotInfo[dataInfo.name] = dataId[dataInfo.name]
357 def parsePlotInfo(
358 self, inputs: Mapping[str, Any] | None, dataId: DataCoordinate | None, connectionName: str = "data"
359 ) -> Mapping[str, str]:
360 """Parse the inputs and dataId to get the information needed to
361 to add to the figure.
363 Parameters
364 ----------
365 inputs: `dict`
366 The inputs to the task
367 dataCoordinate: `lsst.daf.butler.DataCoordinate`
368 The dataId that the task is being run on.
369 connectionName: `str`, optional
370 Name of the input connection to use for determining table name.
372 Returns
373 -------
374 plotInfo : `dict`
375 """
377 if inputs is None:
378 tableName = ""
379 run = ""
380 else:
381 tableName = inputs[connectionName].ref.datasetType.name
382 run = inputs[connectionName].ref.run
384 # Initialize the plot info dictionary
385 plotInfo = {"tableName": tableName, "run": run}
387 self._populatePlotInfoWithDataId(plotInfo, dataId)
388 return plotInfo
390 def loadData(self, handle: DeferredDatasetHandle, names: Iterable[str] | None = None) -> KeyedData:
391 """Load the minimal set of keyed data from the input dataset.
393 Parameters
394 ----------
395 handle : `DeferredDatasetHandle`
396 Handle to load the dataset with only the specified columns.
397 names : `Iterable` of `str`
398 The names of keys to extract from the dataset.
399 If `names` is `None` then the `collectInputNames` method
400 is called to generate the names.
401 For most purposes these are the names of columns to load from
402 a catalog or data frame.
404 Returns
405 -------
406 result: `KeyedData`
407 The dataset with only the specified keys loaded.
408 """
409 if names is None:
410 names = self.collectInputNames()
411 return cast(KeyedData, handle.get(parameters={"columns": names}))
413 def collectInputNames(self) -> Iterable[str]:
414 """Get the names of the inputs.
416 If using the default `loadData` method this will gather the names
417 of the keys to be loaded from an input dataset.
419 Returns
420 -------
421 inputs : `Iterable` of `str`
422 The names of the keys in the `KeyedData` object to extract.
424 """
425 inputs = set()
426 for band in self.config.bands:
427 for action in self.config.atools:
428 for key, _ in action.getFormattedInputSchema(band=band):
429 inputs.add(key)
430 return inputs