Coverage for python/lsst/analysis/tools/interfaces/_task.py: 20%
168 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-07 11:44 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-07 11:44 +0000
1# This file is part of analysis_tools. #
2# Developed for the LSST Data Management System.
3# This product includes software developed by the LSST Project
4# (https://www.lsst.org).
5# See the COPYRIGHT file at the top-level directory of this distribution
6# for details of code ownership.
7#
8# This program is free software: you can redistribute it and/or modify
9# it under the terms of the GNU General Public License as published by
10# the Free Software Foundation, either version 3 of the License, or
11# (at your option) any later version.
12#
13# This program is distributed in the hope that it will be useful,
14# but WITHOUT ANY WARRANTY; without even the implied warranty of
15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16# GNU General Public License for more details.
17#
18# You should have received a copy of the GNU General Public License
19# along with this program. If not, see <https://www.gnu.org/licenses/>.
20from __future__ import annotations
22"""Base class implementation for the classes needed in creating `PipelineTasks`
23which execute `AnalysisTools`.
25The classes defined in this module have all the required behaviors for
26defining, introspecting, and executing `AnalysisTools` against an input dataset
27type.
29Subclasses of these tasks should specify specific datasets to consume in their
30connection classes and should specify a unique name
31"""
33__all__ = ("AnalysisBaseConnections", "AnalysisBaseConfig", "AnalysisPipelineTask")
35import weakref
36from collections.abc import Iterable
37from copy import deepcopy
38from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, cast
40import matplotlib.pyplot as plt
41from lsst.verify import Measurement
43if TYPE_CHECKING: 43 ↛ 44line 43 didn't jump to line 44, because the condition on line 43 was never true
44 from lsst.daf.butler import DeferredDatasetHandle
45 from lsst.pipe.base import QuantumContext
47from lsst.daf.butler import DataCoordinate
48from lsst.pex.config import Field, ListField
49from lsst.pex.config.configurableActions import ConfigurableActionStructField
50from lsst.pipe.base import Instrument, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct
51from lsst.pipe.base import connectionTypes as ct
52from lsst.pipe.base.connections import InputQuantizedConnection, OutputQuantizedConnection
53from lsst.pipe.base.pipelineIR import ConfigIR, ParametersIR
55from ._actions import JointAction, MetricAction, NoMetric
56from ._analysisTools import AnalysisTool
57from ._interfaces import KeyedData, PlotTypes
58from ._metricMeasurementBundle import MetricMeasurementBundle
61# TODO: This _plotCloser function assists in closing all open plots at the
62# conclusion of a PipelineTask. When DM-39114 is implemented, this function and
63# all associated usage thereof should be removed.
64def _plotCloser(*args):
65 """Close all the plots in the given list."""
66 for plot in args:
67 plt.close(plot)
70class AnalysisBaseConnections(
71 PipelineTaskConnections, dimensions={}, defaultTemplates={"outputName": "Placeholder"}
72):
73 r"""Base class for Connections used for AnalysisTools PipelineTasks.
75 This class has a pre-defined output connection for the
76 MetricMeasurementMapping. The dataset type name for this connection is
77 determined by the template ``outputName``.
79 Output connections for plots created by `AnalysisPlot`\ s are created
80 dynamically when an instance of the class is created. The init method
81 examines all the `AnalysisPlot` actions specified in the associated
82 `AnalysisBaseConfig` subclass accumulating all the info needed to
83 create the output connections.
85 The dimensions for all of the output connections (metric and plot) will
86 be the same as the dimensions specified for the AnalysisBaseConnections
87 subclass (i.e. quantum dimensions).
88 """
90 metrics = ct.Output(
91 doc="Metrics calculated on input dataset type",
92 name="{outputName}_metrics",
93 storageClass="MetricMeasurementBundle",
94 )
96 def __init__(self, *, config: AnalysisBaseConfig = None): # type: ignore
97 # Validate that the outputName template has been set in config. This
98 # should have been checked early with the configs validate method, but
99 # it is possible for someone to manually create everything in a script
100 # without running validate, so also check it late here.
101 if (outputName := config.connections.outputName) == "Placeholder": # type: ignore
102 raise RuntimeError(
103 "Subclasses must specify an alternative value for the defaultTemplate `outputName`"
104 )
105 super().__init__(config=config)
107 # All arguments must be passed by kw, but python has not method to do
108 # that without specifying a default, so None is used. Validate that
109 # it is not None. This is largely for typing reasons, as in the normal
110 # course of operation code execution paths ensure this will not be None
111 assert config is not None
113 for tool in config.atools:
114 match tool.produce:
115 case JointAction():
116 if isinstance(tool.produce.metric, NoMetric):
117 continue
118 if len(tool.produce.metric.units) != 0:
119 hasMetrics = True
120 break
121 case MetricAction():
122 hasMetrics = True
123 break
124 else:
125 hasMetrics = False
127 # Set the dimensions for the metric
128 if hasMetrics:
129 self.metrics = ct.Output(
130 name=self.metrics.name,
131 doc=self.metrics.doc,
132 storageClass=self.metrics.storageClass,
133 dimensions=self.dimensions,
134 multiple=False,
135 isCalibration=False,
136 )
137 else:
138 # There are no metrics to produce, remove the output connection
139 self.outputs.remove("metrics")
141 # Look for any conflicting names, creating a set of them, as these
142 # will be added to the instance as well as recorded in the outputs
143 # set.
144 existingNames = set(dir(self))
146 # Accumulate all the names to be used from all of the defined
147 # AnalysisPlots.
148 names: list[str] = []
149 for action in config.atools:
150 if action.parameterizedBand:
151 for band in config.bands:
152 names.extend(name.format(band=band) for name in action.getOutputNames())
153 else:
154 names.extend(action.getOutputNames())
156 # For each of the names found, create output connections.
157 for name in names:
158 name = f"{outputName}_{name}"
159 if name in self.outputs or name in existingNames:
160 raise NameError(
161 f"Plot with name {name} conflicts with existing connection"
162 " are two plots named the same?"
163 )
164 outConnection = ct.Output(
165 name=name,
166 storageClass="Plot",
167 doc="Dynamic connection for plotting",
168 dimensions=self.dimensions,
169 )
170 object.__setattr__(self, name, outConnection)
171 self.outputs.add(name)
174class AnalysisBaseConfig(PipelineTaskConfig, pipelineConnections=AnalysisBaseConnections):
175 """Base class for all configs used to define an `AnalysisPipelineTask`.
177 This base class defines two fields that should be used in all subclasses,
178 atools, and bands.
180 The ``atools`` field is where the user configures which analysis tools will
181 be run as part of this `PipelineTask`.
183 The bands field specifies which bands will be looped over for
184 `AnalysisTools` which support parameterized bands. I.e. called once for
185 each band in the list.
186 """
188 atools = ConfigurableActionStructField[AnalysisTool](
189 doc="The analysis tools that are to be run by this task at execution"
190 )
191 # Temporarally alias these for backwards compatibility
192 plots = atools
193 metrics = atools
194 bands = ListField[str](
195 doc="Filter bands on which to run all of the actions", default=["u", "g", "r", "i", "z", "y"]
196 )
197 metric_tags = ListField[str](
198 doc="List of tags which will be added to all configurable actions", default=[]
199 )
200 dataset_identifier = Field[str](doc="An identifier to be associated with output Metrics", optional=True)
201 reference_package = Field[str](
202 doc="A package who's version, at the time of metric upload to a "
203 "time series database, will be converted to a timestamp of when "
204 "that version was produced",
205 default="lsst_distrib",
206 )
207 timestamp_version = Field[str]( 207 ↛ exitline 207 didn't jump to the function exit
208 doc="Which time stamp should be used as the reference timestamp for a "
209 "metric in a time series database, valid values are; "
210 "reference_package_timestamp, run_timestamp, current_timestamp, "
211 "and dataset_timestamp",
212 default="run_timestamp",
213 check=lambda x: x
214 in ("reference_package_timestamp", "run_timestamp", "current_timestamp", "dataset_timestamp"),
215 )
217 def applyConfigOverrides(
218 self,
219 instrument: Instrument | None,
220 taskDefaultName: str,
221 pipelineConfigs: Iterable[ConfigIR] | None,
222 parameters: ParametersIR,
223 label: str,
224 ) -> None:
225 extraConfig = {}
226 if (value := parameters.mapping.get("sasquatch_dataset_identifier", None)) is not None:
227 extraConfig["dataset_identifier"] = value
228 if (value := parameters.mapping.get("sasquatch_reference_package", None)) is not None:
229 extraConfig["reference_package"] = value
230 if (value := parameters.mapping.get("sasquatch_timestamp_version", None)) is not None:
231 extraConfig["timestamp_version"] = value
232 if extraConfig:
233 newPipelineConfigs = [ConfigIR(rest=extraConfig)]
234 if pipelineConfigs is not None:
235 newPipelineConfigs.extend(pipelineConfigs)
236 pipelineConfigs = newPipelineConfigs
237 return super().applyConfigOverrides(instrument, taskDefaultName, pipelineConfigs, parameters, label)
239 def freeze(self):
240 # Copy the meta configuration values to each of the configured tools
241 # only do this if the tool has not been further specialized
242 if not self._frozen:
243 for tool in self.atools:
244 for tag in self.metric_tags:
245 tool.metric_tags.insert(-1, tag)
246 super().freeze()
248 def validate(self):
249 super().validate()
250 # Validate that the required connections template is set.
251 if self.connections.outputName == "Placeholder": # type: ignore
252 raise RuntimeError("Connections class 'outputName' must have a config explicitly set")
255class _StandinPlotInfo(dict):
256 """This class is an implementation detail to support plots in the instance
257 no PlotInfo object is present in the call to run.
258 """
260 def __missing__(self, key):
261 return ""
264class AnalysisPipelineTask(PipelineTask):
265 """Base class for `PipelineTasks` intended to run `AnalysisTools`.
267 The run method will run all of the `AnalysisTools` defined in the config
268 class.
269 """
271 # Typing config because type checkers dont know about our Task magic
272 config: AnalysisBaseConfig
273 ConfigClass = AnalysisBaseConfig
275 def _runTools(self, data: KeyedData, **kwargs) -> Struct:
276 results = Struct()
277 results.metrics = MetricMeasurementBundle(
278 dataset_identifier=self.config.dataset_identifier,
279 reference_package=self.config.reference_package,
280 timestamp_version=self.config.timestamp_version,
281 )
282 # copy plot info to be sure each action sees its own copy
283 plotInfo = kwargs.get("plotInfo")
284 plotKey = f"{self.config.connections.outputName}_{{name}}"
285 weakrefArgs = []
286 for name, action in self.config.atools.items():
287 kwargs["plotInfo"] = deepcopy(plotInfo)
288 actionResult = action(data, **kwargs)
289 metricAccumulate = []
290 for resultName, value in actionResult.items():
291 match value:
292 case PlotTypes():
293 setattr(results, plotKey.format(name=resultName), value)
294 weakrefArgs.append(value)
295 case Measurement():
296 metricAccumulate.append(value)
297 # only add the metrics if there are some
298 if metricAccumulate:
299 results.metrics[name] = metricAccumulate
300 # Wrap the return struct in a finalizer so that when results is
301 # garbage collected the plots will be closed.
302 # TODO: This finalize step closes all open plots at the conclusion of
303 # a task. When DM-39114 is implemented, this step should no longer be
304 # required and may be removed.
305 weakref.finalize(results, _plotCloser, *weakrefArgs)
306 return results
308 def run(self, *, data: KeyedData | None = None, **kwargs) -> Struct:
309 """Produce the outputs associated with this `PipelineTask`.
311 Parameters
312 ----------
313 data : `KeyedData`
314 The input data from which all `AnalysisTools` will run and produce
315 outputs. A side note, the python typing specifies that this can be
316 None, but this is only due to a limitation in python where in order
317 to specify that all arguments be passed only as keywords the
318 argument must be given a default. This argument most not actually
319 be None.
320 **kwargs
321 Additional arguments that are passed through to the `AnalysisTools`
322 specified in the configuration.
324 Returns
325 -------
326 results : `~lsst.pipe.base.Struct`
327 The accumulated results of all the plots and metrics produced by
328 this `PipelineTask`.
330 Raises
331 ------
332 ValueError
333 Raised if the supplied data argument is `None`
334 """
335 if data is None:
336 raise ValueError("data must not be none")
337 if "bands" not in kwargs:
338 kwargs["bands"] = list(self.config.bands)
339 if "plotInfo" not in kwargs:
340 kwargs["plotInfo"] = _StandinPlotInfo()
341 kwargs["plotInfo"]["bands"] = kwargs["bands"]
342 if "SN" not in kwargs["plotInfo"].keys():
343 kwargs["plotInfo"]["SN"] = "-"
344 return self._runTools(data, **kwargs)
346 def runQuantum(
347 self,
348 butlerQC: QuantumContext,
349 inputRefs: InputQuantizedConnection,
350 outputRefs: OutputQuantizedConnection,
351 ) -> None:
352 """Override default runQuantum to load the minimal columns necessary
353 to complete the action.
355 Parameters
356 ----------
357 butlerQC : `~lsst.pipe.base.QuantumContext`
358 A butler which is specialized to operate in the context of a
359 `lsst.daf.butler.Quantum`.
360 inputRefs : `InputQuantizedConnection`
361 Datastructure whose attribute names are the names that identify
362 connections defined in corresponding `PipelineTaskConnections`
363 class. The values of these attributes are the
364 `lsst.daf.butler.DatasetRef` objects associated with the defined
365 input/prerequisite connections.
366 outputRefs : `OutputQuantizedConnection`
367 Datastructure whose attribute names are the names that identify
368 connections defined in corresponding `PipelineTaskConnections`
369 class. The values of these attributes are the
370 `lsst.daf.butler.DatasetRef` objects associated with the defined
371 output connections.
372 """
373 # TODO: This rcParams modification is a temporary solution, hiding
374 # a matplotlib warning indicating too many figures have been opened.
375 # When DM-39114 is implemented, this should be removed.
376 plt.rcParams.update({"figure.max_open_warning": 0})
377 inputs = butlerQC.get(inputRefs)
378 dataId = butlerQC.quantum.dataId
379 plotInfo = self.parsePlotInfo(inputs, dataId)
380 data = self.loadData(inputs["data"])
381 if "skymap" in inputs.keys():
382 skymap = inputs["skymap"]
383 else:
384 skymap = None
385 outputs = self.run(data=data, plotInfo=plotInfo, skymap=skymap)
386 butlerQC.put(outputs, outputRefs)
388 def _populatePlotInfoWithDataId(
389 self, plotInfo: MutableMapping[str, Any], dataId: DataCoordinate | None
390 ) -> None:
391 """Update the plotInfo with the dataId values.
393 Parameters
394 ----------
395 plotInfo : `dict`
396 The plotInfo dictionary to update.
397 dataId : `lsst.daf.butler.DataCoordinate`
398 The dataId to use to update the plotInfo.
399 """
400 if dataId is not None:
401 for dataInfo in dataId:
402 plotInfo[dataInfo.name] = dataId[dataInfo.name]
404 def parsePlotInfo(
405 self, inputs: Mapping[str, Any] | None, dataId: DataCoordinate | None, connectionName: str = "data"
406 ) -> Mapping[str, str]:
407 """Parse the inputs and dataId to get the information needed to
408 to add to the figure.
410 Parameters
411 ----------
412 inputs: `dict`
413 The inputs to the task
414 dataCoordinate: `lsst.daf.butler.DataCoordinate`
415 The dataId that the task is being run on.
416 connectionName: `str`, optional
417 Name of the input connection to use for determining table name.
419 Returns
420 -------
421 plotInfo : `dict`
422 """
424 if inputs is None:
425 tableName = ""
426 run = ""
427 else:
428 tableName = inputs[connectionName].ref.datasetType.name
429 run = inputs[connectionName].ref.run
431 # Initialize the plot info dictionary
432 plotInfo = {"tableName": tableName, "run": run}
434 self._populatePlotInfoWithDataId(plotInfo, dataId)
435 return plotInfo
437 def loadData(self, handle: DeferredDatasetHandle, names: Iterable[str] | None = None) -> KeyedData:
438 """Load the minimal set of keyed data from the input dataset.
440 Parameters
441 ----------
442 handle : `DeferredDatasetHandle`
443 Handle to load the dataset with only the specified columns.
444 names : `Iterable` of `str`
445 The names of keys to extract from the dataset.
446 If `names` is `None` then the `collectInputNames` method
447 is called to generate the names.
448 For most purposes these are the names of columns to load from
449 a catalog or data frame.
451 Returns
452 -------
453 result: `KeyedData`
454 The dataset with only the specified keys loaded.
455 """
456 if names is None:
457 names = self.collectInputNames()
458 return cast(KeyedData, handle.get(parameters={"columns": names}))
460 def collectInputNames(self) -> Iterable[str]:
461 """Get the names of the inputs.
463 If using the default `loadData` method this will gather the names
464 of the keys to be loaded from an input dataset.
466 Returns
467 -------
468 inputs : `Iterable` of `str`
469 The names of the keys in the `KeyedData` object to extract.
471 """
472 inputs = set()
473 for band in self.config.bands:
474 for action in self.config.atools:
475 for key, _ in action.getFormattedInputSchema(band=band):
476 inputs.add(key)
477 return inputs