Coverage for python/lsst/analysis/tools/interfaces/_task.py: 20%
171 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-03 10:03 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-03 10:03 +0000
1# This file is part of analysis_tools. #
2# Developed for the LSST Data Management System.
3# This product includes software developed by the LSST Project
4# (https://www.lsst.org).
5# See the COPYRIGHT file at the top-level directory of this distribution
6# for details of code ownership.
7#
8# This program is free software: you can redistribute it and/or modify
9# it under the terms of the GNU General Public License as published by
10# the Free Software Foundation, either version 3 of the License, or
11# (at your option) any later version.
12#
13# This program is distributed in the hope that it will be useful,
14# but WITHOUT ANY WARRANTY; without even the implied warranty of
15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16# GNU General Public License for more details.
17#
18# You should have received a copy of the GNU General Public License
19# along with this program. If not, see <https://www.gnu.org/licenses/>.
20from __future__ import annotations
22"""Base class implementation for the classes needed in creating `PipelineTasks`
23which execute `AnalysisTools`.
25The classes defined in this module have all the required behaviors for
26defining, introspecting, and executing `AnalysisTools` against an input dataset
27type.
29Subclasses of these tasks should specify specific datasets to consume in their
30connection classes and should specify a unique name
31"""
33__all__ = ("AnalysisBaseConnections", "AnalysisBaseConfig", "AnalysisPipelineTask")
35import weakref
36from collections.abc import Iterable
37from copy import deepcopy
38from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, cast
40import matplotlib.pyplot as plt
41from lsst.verify import Measurement
43if TYPE_CHECKING: 43 ↛ 44line 43 didn't jump to line 44, because the condition on line 43 was never true
44 from lsst.daf.butler import DeferredDatasetHandle
45 from lsst.pipe.base import QuantumContext
47from lsst.daf.butler import DataCoordinate
48from lsst.pex.config import Field, ListField
49from lsst.pex.config.configurableActions import ConfigurableActionStructField
50from lsst.pipe.base import Instrument, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct
51from lsst.pipe.base import connectionTypes as ct
52from lsst.pipe.base.connections import InputQuantizedConnection, OutputQuantizedConnection
53from lsst.pipe.base.pipelineIR import ConfigIR, ParametersIR
55from ._actions import JointAction, MetricAction, NoMetric
56from ._analysisTools import AnalysisTool
57from ._interfaces import KeyedData, PlotTypes
58from ._metricMeasurementBundle import MetricMeasurementBundle
61# TODO: This _plotCloser function assists in closing all open plots at the
62# conclusion of a PipelineTask. When DM-39114 is implemented, this function and
63# all associated usage thereof should be removed.
64def _plotCloser(*args):
65 """Close all the plots in the given list."""
66 for plot in args:
67 plt.close(plot)
70class AnalysisBaseConnections(
71 PipelineTaskConnections, dimensions={}, defaultTemplates={"outputName": "Placeholder"}
72):
73 r"""Base class for Connections used for AnalysisTools PipelineTasks.
75 This class has a pre-defined output connection for the
76 MetricMeasurementMapping. The dataset type name for this connection is
77 determined by the template ``outputName``.
79 Output connections for plots created by `AnalysisPlot`\ s are created
80 dynamically when an instance of the class is created. The init method
81 examines all the `AnalysisPlot` actions specified in the associated
82 `AnalysisBaseConfig` subclass accumulating all the info needed to
83 create the output connections.
85 The dimensions for all of the output connections (metric and plot) will
86 be the same as the dimensions specified for the AnalysisBaseConnections
87 subclass (i.e. quantum dimensions).
88 """
90 metrics = ct.Output(
91 doc="Metrics calculated on input dataset type",
92 name="{outputName}_metrics",
93 storageClass="MetricMeasurementBundle",
94 )
96 def __init__(self, *, config: AnalysisBaseConfig = None): # type: ignore
97 # Validate that the outputName template has been set in config. This
98 # should have been checked early with the configs validate method, but
99 # it is possible for someone to manually create everything in a script
100 # without running validate, so also check it late here.
101 if (outputName := config.connections.outputName) == "Placeholder": # type: ignore
102 raise RuntimeError(
103 "Subclasses must specify an alternative value for the defaultTemplate `outputName`"
104 )
105 super().__init__(config=config)
107 # All arguments must be passed by kw, but python has not method to do
108 # that without specifying a default, so None is used. Validate that
109 # it is not None. This is largely for typing reasons, as in the normal
110 # course of operation code execution paths ensure this will not be None
111 assert config is not None
113 for tool in config.atools:
114 match tool.produce:
115 case JointAction():
116 if isinstance(tool.produce.metric, NoMetric):
117 continue
118 if len(tool.produce.metric.units) != 0:
119 hasMetrics = True
120 break
121 case MetricAction():
122 hasMetrics = True
123 break
124 else:
125 hasMetrics = False
127 # Set the dimensions for the metric
128 if hasMetrics:
129 self.metrics = ct.Output(
130 name=self.metrics.name,
131 doc=self.metrics.doc,
132 storageClass=self.metrics.storageClass,
133 dimensions=self.dimensions,
134 multiple=False,
135 isCalibration=False,
136 )
137 else:
138 # There are no metrics to produce, remove the output connection
139 self.outputs.remove("metrics")
141 # Look for any conflicting names, creating a set of them, as these
142 # will be added to the instance as well as recorded in the outputs
143 # set.
144 existingNames = set(dir(self))
146 # Accumulate all the names to be used from all of the defined
147 # AnalysisPlots.
148 names: list[str] = []
149 for action in config.atools:
150 if action.dynamicOutputNames:
151 outNames = action.getOutputNames(config=config)
152 else:
153 outNames = action.getOutputNames()
154 if action.parameterizedBand:
155 for band in config.bands:
156 names.extend(name.format(band=band) for name in outNames)
157 else:
158 names.extend(outNames)
160 # For each of the names found, create output connections.
161 for name in names:
162 name = f"{outputName}_{name}"
163 if name in self.outputs or name in existingNames:
164 raise NameError(
165 f"Plot with name {name} conflicts with existing connection"
166 " are two plots named the same?"
167 )
168 outConnection = ct.Output(
169 name=name,
170 storageClass="Plot",
171 doc="Dynamic connection for plotting",
172 dimensions=self.dimensions,
173 )
174 object.__setattr__(self, name, outConnection)
175 self.outputs.add(name)
178class AnalysisBaseConfig(PipelineTaskConfig, pipelineConnections=AnalysisBaseConnections):
179 """Base class for all configs used to define an `AnalysisPipelineTask`.
181 This base class defines two fields that should be used in all subclasses,
182 atools, and bands.
184 The ``atools`` field is where the user configures which analysis tools will
185 be run as part of this `PipelineTask`.
187 The bands field specifies which bands will be looped over for
188 `AnalysisTools` which support parameterized bands. I.e. called once for
189 each band in the list.
190 """
192 atools = ConfigurableActionStructField[AnalysisTool](
193 doc="The analysis tools that are to be run by this task at execution"
194 )
195 # Temporarally alias these for backwards compatibility
196 plots = atools
197 metrics = atools
198 bands = ListField[str](
199 doc="Filter bands on which to run all of the actions", default=["u", "g", "r", "i", "z", "y"]
200 )
201 metric_tags = ListField[str](
202 doc="List of tags which will be added to all configurable actions", default=[]
203 )
204 dataset_identifier = Field[str](doc="An identifier to be associated with output Metrics", optional=True)
205 reference_package = Field[str](
206 doc="A package who's version, at the time of metric upload to a "
207 "time series database, will be converted to a timestamp of when "
208 "that version was produced",
209 default="lsst_distrib",
210 )
211 timestamp_version = Field[str]( 211 ↛ exitline 211 didn't jump to the function exit
212 doc="Which time stamp should be used as the reference timestamp for a "
213 "metric in a time series database, valid values are; "
214 "reference_package_timestamp, run_timestamp, current_timestamp, "
215 "and dataset_timestamp",
216 default="run_timestamp",
217 check=lambda x: x
218 in ("reference_package_timestamp", "run_timestamp", "current_timestamp", "dataset_timestamp"),
219 )
221 def applyConfigOverrides(
222 self,
223 instrument: Instrument | None,
224 taskDefaultName: str,
225 pipelineConfigs: Iterable[ConfigIR] | None,
226 parameters: ParametersIR,
227 label: str,
228 ) -> None:
229 extraConfig = {}
230 if (value := parameters.mapping.get("sasquatch_dataset_identifier", None)) is not None:
231 extraConfig["dataset_identifier"] = value
232 if (value := parameters.mapping.get("sasquatch_reference_package", None)) is not None:
233 extraConfig["reference_package"] = value
234 if (value := parameters.mapping.get("sasquatch_timestamp_version", None)) is not None:
235 extraConfig["timestamp_version"] = value
236 if extraConfig:
237 newPipelineConfigs = [ConfigIR(rest=extraConfig)]
238 if pipelineConfigs is not None:
239 newPipelineConfigs.extend(pipelineConfigs)
240 pipelineConfigs = newPipelineConfigs
241 return super().applyConfigOverrides(instrument, taskDefaultName, pipelineConfigs, parameters, label)
243 def freeze(self):
244 # Copy the meta configuration values to each of the configured tools
245 # only do this if the tool has not been further specialized
246 if not self._frozen:
247 for tool in self.atools:
248 for tag in self.metric_tags:
249 tool.metric_tags.insert(-1, tag)
250 super().freeze()
252 def validate(self):
253 super().validate()
254 # Validate that the required connections template is set.
255 if self.connections.outputName == "Placeholder": # type: ignore
256 raise RuntimeError("Connections class 'outputName' must have a config explicitly set")
259class _StandinPlotInfo(dict):
260 """This class is an implementation detail to support plots in the instance
261 no PlotInfo object is present in the call to run.
262 """
264 def __missing__(self, key):
265 return ""
268class AnalysisPipelineTask(PipelineTask):
269 """Base class for `PipelineTasks` intended to run `AnalysisTools`.
271 The run method will run all of the `AnalysisTools` defined in the config
272 class.
273 """
275 # Typing config because type checkers dont know about our Task magic
276 config: AnalysisBaseConfig
277 ConfigClass = AnalysisBaseConfig
279 def _runTools(self, data: KeyedData, **kwargs) -> Struct:
280 results = Struct()
281 results.metrics = MetricMeasurementBundle(
282 dataset_identifier=self.config.dataset_identifier,
283 reference_package=self.config.reference_package,
284 timestamp_version=self.config.timestamp_version,
285 )
286 # copy plot info to be sure each action sees its own copy
287 plotInfo = kwargs.get("plotInfo")
288 plotKey = f"{self.config.connections.outputName}_{{name}}"
289 weakrefArgs = []
290 for name, action in self.config.atools.items():
291 kwargs["plotInfo"] = deepcopy(plotInfo)
292 actionResult = action(data, **kwargs)
293 metricAccumulate = []
294 for resultName, value in actionResult.items():
295 match value:
296 case PlotTypes():
297 setattr(results, plotKey.format(name=resultName), value)
298 weakrefArgs.append(value)
299 case Measurement():
300 metricAccumulate.append(value)
301 # only add the metrics if there are some
302 if metricAccumulate:
303 results.metrics[name] = metricAccumulate
304 # Wrap the return struct in a finalizer so that when results is
305 # garbage collected the plots will be closed.
306 # TODO: This finalize step closes all open plots at the conclusion of
307 # a task. When DM-39114 is implemented, this step should no longer be
308 # required and may be removed.
309 weakref.finalize(results, _plotCloser, *weakrefArgs)
310 return results
312 def run(self, *, data: KeyedData | None = None, **kwargs) -> Struct:
313 """Produce the outputs associated with this `PipelineTask`.
315 Parameters
316 ----------
317 data : `KeyedData`
318 The input data from which all `AnalysisTools` will run and produce
319 outputs. A side note, the python typing specifies that this can be
320 None, but this is only due to a limitation in python where in order
321 to specify that all arguments be passed only as keywords the
322 argument must be given a default. This argument most not actually
323 be None.
324 **kwargs
325 Additional arguments that are passed through to the `AnalysisTools`
326 specified in the configuration.
328 Returns
329 -------
330 results : `~lsst.pipe.base.Struct`
331 The accumulated results of all the plots and metrics produced by
332 this `PipelineTask`.
334 Raises
335 ------
336 ValueError
337 Raised if the supplied data argument is `None`
338 """
339 if data is None:
340 raise ValueError("data must not be none")
341 if "bands" not in kwargs:
342 kwargs["bands"] = list(self.config.bands)
343 if "plotInfo" not in kwargs:
344 kwargs["plotInfo"] = _StandinPlotInfo()
345 kwargs["plotInfo"]["bands"] = kwargs["bands"]
346 if "SN" not in kwargs["plotInfo"].keys():
347 kwargs["plotInfo"]["SN"] = "-"
348 return self._runTools(data, **kwargs)
350 def runQuantum(
351 self,
352 butlerQC: QuantumContext,
353 inputRefs: InputQuantizedConnection,
354 outputRefs: OutputQuantizedConnection,
355 ) -> None:
356 """Override default runQuantum to load the minimal columns necessary
357 to complete the action.
359 Parameters
360 ----------
361 butlerQC : `~lsst.pipe.base.QuantumContext`
362 A butler which is specialized to operate in the context of a
363 `lsst.daf.butler.Quantum`.
364 inputRefs : `InputQuantizedConnection`
365 Datastructure whose attribute names are the names that identify
366 connections defined in corresponding `PipelineTaskConnections`
367 class. The values of these attributes are the
368 `lsst.daf.butler.DatasetRef` objects associated with the defined
369 input/prerequisite connections.
370 outputRefs : `OutputQuantizedConnection`
371 Datastructure whose attribute names are the names that identify
372 connections defined in corresponding `PipelineTaskConnections`
373 class. The values of these attributes are the
374 `lsst.daf.butler.DatasetRef` objects associated with the defined
375 output connections.
376 """
377 # TODO: This rcParams modification is a temporary solution, hiding
378 # a matplotlib warning indicating too many figures have been opened.
379 # When DM-39114 is implemented, this should be removed.
380 plt.rcParams.update({"figure.max_open_warning": 0})
381 inputs = butlerQC.get(inputRefs)
382 dataId = butlerQC.quantum.dataId
383 plotInfo = self.parsePlotInfo(inputs, dataId)
384 data = self.loadData(inputs["data"])
385 if "skymap" in inputs.keys():
386 skymap = inputs["skymap"]
387 else:
388 skymap = None
389 outputs = self.run(data=data, plotInfo=plotInfo, skymap=skymap)
390 butlerQC.put(outputs, outputRefs)
392 def _populatePlotInfoWithDataId(
393 self, plotInfo: MutableMapping[str, Any], dataId: DataCoordinate | None
394 ) -> None:
395 """Update the plotInfo with the dataId values.
397 Parameters
398 ----------
399 plotInfo : `dict`
400 The plotInfo dictionary to update.
401 dataId : `lsst.daf.butler.DataCoordinate`
402 The dataId to use to update the plotInfo.
403 """
404 if dataId is not None:
405 for dataInfo in dataId:
406 plotInfo[dataInfo.name] = dataId[dataInfo.name]
408 def parsePlotInfo(
409 self, inputs: Mapping[str, Any] | None, dataId: DataCoordinate | None, connectionName: str = "data"
410 ) -> Mapping[str, str]:
411 """Parse the inputs and dataId to get the information needed to
412 to add to the figure.
414 Parameters
415 ----------
416 inputs: `dict`
417 The inputs to the task
418 dataCoordinate: `lsst.daf.butler.DataCoordinate`
419 The dataId that the task is being run on.
420 connectionName: `str`, optional
421 Name of the input connection to use for determining table name.
423 Returns
424 -------
425 plotInfo : `dict`
426 """
428 if inputs is None:
429 tableName = ""
430 run = ""
431 else:
432 tableName = inputs[connectionName].ref.datasetType.name
433 run = inputs[connectionName].ref.run
435 # Initialize the plot info dictionary
436 plotInfo = {"tableName": tableName, "run": run}
438 self._populatePlotInfoWithDataId(plotInfo, dataId)
439 return plotInfo
441 def loadData(self, handle: DeferredDatasetHandle, names: Iterable[str] | None = None) -> KeyedData:
442 """Load the minimal set of keyed data from the input dataset.
444 Parameters
445 ----------
446 handle : `DeferredDatasetHandle`
447 Handle to load the dataset with only the specified columns.
448 names : `Iterable` of `str`
449 The names of keys to extract from the dataset.
450 If `names` is `None` then the `collectInputNames` method
451 is called to generate the names.
452 For most purposes these are the names of columns to load from
453 a catalog or data frame.
455 Returns
456 -------
457 result: `KeyedData`
458 The dataset with only the specified keys loaded.
459 """
460 if names is None:
461 names = self.collectInputNames()
462 return cast(KeyedData, handle.get(parameters={"columns": names}))
464 def collectInputNames(self) -> Iterable[str]:
465 """Get the names of the inputs.
467 If using the default `loadData` method this will gather the names
468 of the keys to be loaded from an input dataset.
470 Returns
471 -------
472 inputs : `Iterable` of `str`
473 The names of the keys in the `KeyedData` object to extract.
475 """
476 inputs = set()
477 for band in self.config.bands:
478 for action in self.config.atools:
479 for key, _ in action.getFormattedInputSchema(band=band):
480 inputs.add(key)
481 return inputs