Coverage for python/lsst/analysis/tools/interfaces/_task.py: 20%
149 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-28 05:16 -0700
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-28 05:16 -0700
1# This file is part of analysis_tools. #
2# Developed for the LSST Data Management System.
3# This product includes software developed by the LSST Project
4# (https://www.lsst.org).
5# See the COPYRIGHT file at the top-level directory of this distribution
6# for details of code ownership.
7#
8# This program is free software: you can redistribute it and/or modify
9# it under the terms of the GNU General Public License as published by
10# the Free Software Foundation, either version 3 of the License, or
11# (at your option) any later version.
12#
13# This program is distributed in the hope that it will be useful,
14# but WITHOUT ANY WARRANTY; without even the implied warranty of
15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16# GNU General Public License for more details.
17#
18# You should have received a copy of the GNU General Public License
19# along with this program. If not, see <https://www.gnu.org/licenses/>.
20from __future__ import annotations
22"""Base class implementation for the classes needed in creating `PipelineTasks`
23which execute `AnalysisTools`.
25The classes defined in this module have all the required behaviors for
26defining, introspecting, and executing `AnalysisTools` against an input dataset
27type.
29Subclasses of these tasks should specify specific datasets to consume in their
30connection classes and should specify a unique name
31"""
33__all__ = ("AnalysisBaseConnections", "AnalysisBaseConfig", "AnalysisPipelineTask")
35from copy import deepcopy
36from typing import TYPE_CHECKING, Any, Iterable, Mapping, MutableMapping, cast
38from lsst.verify import Measurement
40if TYPE_CHECKING: 40 ↛ 41line 40 didn't jump to line 41, because the condition on line 40 was never true
41 from lsst.daf.butler import DeferredDatasetHandle
43from lsst.daf.butler import DataCoordinate
44from lsst.pex.config import Field, ListField
45from lsst.pex.config.configurableActions import ConfigurableActionStructField
46from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct
47from lsst.pipe.base import connectionTypes as ct
48from lsst.pipe.base.butlerQuantumContext import ButlerQuantumContext
49from lsst.pipe.base.connections import InputQuantizedConnection, OutputQuantizedConnection
51from ._actions import JointAction, MetricAction, NoMetric
52from ._analysisTools import AnalysisTool
53from ._interfaces import KeyedData, PlotTypes
54from ._metricMeasurementBundle import MetricMeasurementBundle
57class AnalysisBaseConnections(
58 PipelineTaskConnections, dimensions={}, defaultTemplates={"outputName": "Placeholder"}
59):
60 r"""Base class for Connections used for AnalysisTools PipelineTasks.
62 This class has a pre-defined output connection for the
63 MetricMeasurementMapping. The dataset type name for this connection is
64 determined by the template ``outputName``.
66 Output connections for plots created by `AnalysisPlot`\ s are created
67 dynamically when an instance of the class is created. The init method
68 examines all the `AnalysisPlot` actions specified in the associated
69 `AnalysisBaseConfig` subclass accumulating all the info needed to
70 create the output connections.
72 The dimensions for all of the output connections (metric and plot) will
73 be the same as the dimensions specified for the AnalysisBaseConnections
74 subclass (i.e. quantum dimensions).
75 """
77 metrics = ct.Output(
78 doc="Metrics calculated on input dataset type",
79 name="{outputName}_metrics",
80 storageClass="MetricMeasurementBundle",
81 )
83 def __init__(self, *, config: AnalysisBaseConfig = None): # type: ignore
84 # Validate that the outputName template has been set in config. This
85 # should have been checked early with the configs validate method, but
86 # it is possible for someone to manually create everything in a script
87 # without running validate, so also check it late here.
88 if (outputName := config.connections.outputName) == "Placeholder": # type: ignore
89 raise RuntimeError(
90 "Subclasses must specify an alternative value for the defaultTemplate `outputName`"
91 )
92 super().__init__(config=config)
94 # All arguments must be passed by kw, but python has not method to do
95 # that without specifying a default, so None is used. Validate that
96 # it is not None. This is largely for typing reasons, as in the normal
97 # course of operation code execution paths ensure this will not be None
98 assert config is not None
100 for tool in config.atools:
101 match tool.produce:
102 case JointAction():
103 if isinstance(tool.produce.metric, NoMetric):
104 continue
105 if len(tool.produce.metric.units) != 0:
106 hasMetrics = True
107 break
108 case MetricAction():
109 hasMetrics = True
110 break
111 else:
112 hasMetrics = False
114 # Set the dimensions for the metric
115 if hasMetrics:
116 self.metrics = ct.Output(
117 name=self.metrics.name,
118 doc=self.metrics.doc,
119 storageClass=self.metrics.storageClass,
120 dimensions=self.dimensions,
121 multiple=False,
122 isCalibration=False,
123 )
124 else:
125 # There are no metrics to produce, remove the output connection
126 self.outputs.remove("metrics")
128 # Look for any conflicting names, creating a set of them, as these
129 # will be added to the instance as well as recorded in the outputs
130 # set.
131 existingNames = set(dir(self))
133 # Accumulate all the names to be used from all of the defined
134 # AnalysisPlots.
135 names: list[str] = []
136 for action in config.atools:
137 if action.parameterizedBand:
138 for band in config.bands:
139 names.extend(name.format(band=band) for name in action.getOutputNames())
140 else:
141 names.extend(action.getOutputNames())
143 # For each of the names found, create output connections.
144 for name in names:
145 name = f"{outputName}_{name}"
146 if name in self.outputs or name in existingNames:
147 raise NameError(
148 f"Plot with name {name} conflicts with existing connection"
149 " are two plots named the same?"
150 )
151 outConnection = ct.Output(
152 name=name,
153 storageClass="Plot",
154 doc="Dynamic connection for plotting",
155 dimensions=self.dimensions,
156 )
157 object.__setattr__(self, name, outConnection)
158 self.outputs.add(name)
161class AnalysisBaseConfig(PipelineTaskConfig, pipelineConnections=AnalysisBaseConnections):
162 """Base class for all configs used to define an `AnalysisPipelineTask`
164 This base class defines two fields that should be used in all subclasses,
165 atools, and bands.
167 The ``atools`` field is where the user configures which analysis tools will
168 be run as part of this `PipelineTask`.
170 The bands field specifies which bands will be looped over for
171 `AnalysisTools` which support parameterized bands. I.e. called once for
172 each band in the list.
173 """
175 atools = ConfigurableActionStructField[AnalysisTool](
176 doc="The analysis tools that are to be run by this task at execution"
177 )
178 # Temporarally alias these for backwards compatibility
179 plots = atools
180 metrics = atools
181 bands = ListField[str](
182 doc="Filter bands on which to run all of the actions", default=["u", "g", "r", "i", "z", "y"]
183 )
184 metric_tags = ListField[str](
185 doc="List of tags which will be added to all configurable actions", default=[]
186 )
187 dataset_identifier = Field[str](doc="An identifier to be associated with output Metrics", optional=True)
188 reference_package = Field[str](
189 doc="A package who's version, at the time of metric upload to a "
190 "time series database, will be converted to a timestamp of when "
191 "that version was produced",
192 default="lsst_distrib",
193 )
194 timestamp_version = Field[str]( 194 ↛ exitline 194 didn't jump to the function exit
195 doc="Which time stamp should be used as the reference timestamp for a "
196 "metric in a time series database, valid values are; "
197 "reference_package_timestamp, run_timestamp, current_timestamp, "
198 "and dataset_timestamp",
199 default="run_timestamp",
200 check=lambda x: x
201 in ("reference_package_timestamp", "run_timestamp", "current_timestamp", "dataset_timestamp"),
202 )
204 def freeze(self):
205 # Copy the meta configuration values to each of the configured tools
206 # only do this if the tool has not been further specialized
207 if not self._frozen:
208 for tool in self.atools:
209 for tag in self.metric_tags:
210 tool.metric_tags.insert(-1, tag)
211 if tool.dataset_identifier is None:
212 tool.dataset_identifier = self.dataset_identifier
213 if tool.reference_package == "lsst_distrib":
214 tool.reference_package = self.reference_package
215 if tool.timestamp_version == "run_timestamp":
216 tool.timestamp_version = self.timestamp_version
217 super().freeze()
219 def validate(self):
220 super().validate()
221 # Validate that the required connections template is set.
222 if self.connections.outputName == "Placeholder": # type: ignore
223 raise RuntimeError("Connections class 'outputName' must have a config explicitly set")
226class _StandinPlotInfo(dict):
227 """This class is an implementation detail to support plots in the instance
228 no PlotInfo object is present in the call to run.
229 """
231 def __missing__(self, key):
232 return ""
235class AnalysisPipelineTask(PipelineTask):
236 """Base class for `PipelineTasks` intended to run `AnalysisTools`.
238 The run method will run all of the `AnalysisTools` defined in the config
239 class.
240 """
242 # Typing config because type checkers dont know about our Task magic
243 config: AnalysisBaseConfig
244 ConfigClass = AnalysisBaseConfig
246 def _runTools(self, data: KeyedData, **kwargs) -> Struct:
247 results = Struct()
248 results.metrics = MetricMeasurementBundle() # type: ignore
249 # copy plot info to be sure each action sees its own copy
250 plotInfo = kwargs.get("plotInfo")
251 plotKey = f"{self.config.connections.outputName}_{{name}}"
252 for name, action in self.config.atools.items():
253 kwargs["plotInfo"] = deepcopy(plotInfo)
254 actionResult = action(data, **kwargs)
255 metricAccumulate = []
256 for resultName, value in actionResult.items():
257 match value:
258 case PlotTypes():
259 setattr(results, plotKey.format(name=resultName), value)
260 case Measurement():
261 metricAccumulate.append(value)
262 # only add the metrics if there are some
263 if metricAccumulate:
264 results.metrics[name] = metricAccumulate
265 return results
267 def run(self, *, data: KeyedData | None = None, **kwargs) -> Struct:
268 """Produce the outputs associated with this `PipelineTask`
270 Parameters
271 ----------
272 data : `KeyedData`
273 The input data from which all `AnalysisTools` will run and produce
274 outputs. A side note, the python typing specifies that this can be
275 None, but this is only due to a limitation in python where in order
276 to specify that all arguments be passed only as keywords the
277 argument must be given a default. This argument most not actually
278 be None.
279 **kwargs
280 Additional arguments that are passed through to the `AnalysisTools`
281 specified in the configuration.
283 Returns
284 -------
285 results : `~lsst.pipe.base.Struct`
286 The accumulated results of all the plots and metrics produced by
287 this `PipelineTask`.
289 Raises
290 ------
291 ValueError
292 Raised if the supplied data argument is `None`
293 """
294 if data is None:
295 raise ValueError("data must not be none")
296 if "bands" not in kwargs:
297 kwargs["bands"] = list(self.config.bands)
298 if "plotInfo" not in kwargs:
299 kwargs["plotInfo"] = _StandinPlotInfo()
300 kwargs["plotInfo"]["bands"] = kwargs["bands"]
301 if "SN" not in kwargs["plotInfo"].keys():
302 kwargs["plotInfo"]["SN"] = "-"
303 return self._runTools(data, **kwargs)
305 def runQuantum(
306 self,
307 butlerQC: ButlerQuantumContext,
308 inputRefs: InputQuantizedConnection,
309 outputRefs: OutputQuantizedConnection,
310 ) -> None:
311 """Override default runQuantum to load the minimal columns necessary
312 to complete the action.
314 Parameters
315 ----------
316 butlerQC : `ButlerQuantumContext`
317 A butler which is specialized to operate in the context of a
318 `lsst.daf.butler.Quantum`.
319 inputRefs : `InputQuantizedConnection`
320 Datastructure whose attribute names are the names that identify
321 connections defined in corresponding `PipelineTaskConnections`
322 class. The values of these attributes are the
323 `lsst.daf.butler.DatasetRef` objects associated with the defined
324 input/prerequisite connections.
325 outputRefs : `OutputQuantizedConnection`
326 Datastructure whose attribute names are the names that identify
327 connections defined in corresponding `PipelineTaskConnections`
328 class. The values of these attributes are the
329 `lsst.daf.butler.DatasetRef` objects associated with the defined
330 output connections.
331 """
332 inputs = butlerQC.get(inputRefs)
333 dataId = butlerQC.quantum.dataId
334 plotInfo = self.parsePlotInfo(inputs, dataId)
335 data = self.loadData(inputs["data"])
336 if "skymap" in inputs.keys():
337 skymap = inputs["skymap"]
338 else:
339 skymap = None
340 outputs = self.run(data=data, plotInfo=plotInfo, skymap=skymap)
341 butlerQC.put(outputs, outputRefs)
343 def _populatePlotInfoWithDataId(
344 self, plotInfo: MutableMapping[str, Any], dataId: DataCoordinate | None
345 ) -> None:
346 """Update the plotInfo with the dataId values.
348 Parameters
349 ----------
350 plotInfo : `dict`
351 The plotInfo dictionary to update.
352 dataId : `lsst.daf.butler.DataCoordinate`
353 The dataId to use to update the plotInfo.
354 """
355 if dataId is not None:
356 for dataInfo in dataId:
357 plotInfo[dataInfo.name] = dataId[dataInfo.name]
359 def parsePlotInfo(
360 self, inputs: Mapping[str, Any] | None, dataId: DataCoordinate | None, connectionName: str = "data"
361 ) -> Mapping[str, str]:
362 """Parse the inputs and dataId to get the information needed to
363 to add to the figure.
365 Parameters
366 ----------
367 inputs: `dict`
368 The inputs to the task
369 dataCoordinate: `lsst.daf.butler.DataCoordinate`
370 The dataId that the task is being run on.
371 connectionName: `str`, optional
372 Name of the input connection to use for determining table name.
374 Returns
375 -------
376 plotInfo : `dict`
377 """
379 if inputs is None:
380 tableName = ""
381 run = ""
382 else:
383 tableName = inputs[connectionName].ref.datasetType.name
384 run = inputs[connectionName].ref.run
386 # Initialize the plot info dictionary
387 plotInfo = {"tableName": tableName, "run": run}
389 self._populatePlotInfoWithDataId(plotInfo, dataId)
390 return plotInfo
392 def loadData(self, handle: DeferredDatasetHandle, names: Iterable[str] | None = None) -> KeyedData:
393 """Load the minimal set of keyed data from the input dataset.
395 Parameters
396 ----------
397 handle : `DeferredDatasetHandle`
398 Handle to load the dataset with only the specified columns.
399 names : `Iterable` of `str`
400 The names of keys to extract from the dataset.
401 If `names` is `None` then the `collectInputNames` method
402 is called to generate the names.
403 For most purposes these are the names of columns to load from
404 a catalog or data frame.
406 Returns
407 -------
408 result: `KeyedData`
409 The dataset with only the specified keys loaded.
410 """
411 if names is None:
412 names = self.collectInputNames()
413 return cast(KeyedData, handle.get(parameters={"columns": names}))
415 def collectInputNames(self) -> Iterable[str]:
416 """Get the names of the inputs.
418 If using the default `loadData` method this will gather the names
419 of the keys to be loaded from an input dataset.
421 Returns
422 -------
423 inputs : `Iterable` of `str`
424 The names of the keys in the `KeyedData` object to extract.
426 """
427 inputs = set()
428 for band in self.config.bands:
429 for action in self.config.atools:
430 for key, _ in action.getFormattedInputSchema(band=band):
431 inputs.add(key)
432 return inputs