Coverage for python/lsst/analysis/tools/interfaces/_task.py: 19%
200 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-28 12:44 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-28 12:44 +0000
1# This file is part of analysis_tools. #
2# Developed for the LSST Data Management System.
3# This product includes software developed by the LSST Project
4# (https://www.lsst.org).
5# See the COPYRIGHT file at the top-level directory of this distribution
6# for details of code ownership.
7#
8# This program is free software: you can redistribute it and/or modify
9# it under the terms of the GNU General Public License as published by
10# the Free Software Foundation, either version 3 of the License, or
11# (at your option) any later version.
12#
13# This program is distributed in the hope that it will be useful,
14# but WITHOUT ANY WARRANTY; without even the implied warranty of
15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16# GNU General Public License for more details.
17#
18# You should have received a copy of the GNU General Public License
19# along with this program. If not, see <https://www.gnu.org/licenses/>.
20from __future__ import annotations
22"""Base class implementation for the classes needed in creating `PipelineTasks`
23which execute `AnalysisTools`.
25The classes defined in this module have all the required behaviors for
26defining, introspecting, and executing `AnalysisTools` against an input dataset
27type.
29Subclasses of these tasks should specify specific datasets to consume in their
30connection classes and should specify a unique name
31"""
33__all__ = ("AnalysisBaseConnections", "AnalysisBaseConfig", "AnalysisPipelineTask")
35import datetime
36import logging
37import warnings
38import weakref
39from collections.abc import Iterable
40from copy import deepcopy
41from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, cast
43import matplotlib.pyplot as plt
44from lsst.verify import Measurement
46if TYPE_CHECKING: 46 ↛ 47line 46 didn't jump to line 47, because the condition on line 46 was never true
47 from lsst.daf.butler import DeferredDatasetHandle
48 from lsst.pipe.base import QuantumContext
50from lsst.daf.butler import DataCoordinate
51from lsst.pex.config import Field, ListField
52from lsst.pex.config.configurableActions import ConfigurableActionStructField
53from lsst.pipe.base import Instrument, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct
54from lsst.pipe.base import connectionTypes as ct
55from lsst.pipe.base.connections import InputQuantizedConnection, OutputQuantizedConnection
56from lsst.pipe.base.pipelineIR import ConfigIR, ParametersIR
58from ._actions import JointAction, MetricAction, NoMetric
59from ._analysisTools import AnalysisTool
60from ._interfaces import KeyedData, PlotTypes
61from ._metricMeasurementBundle import MetricMeasurementBundle
64# TODO: This _plotCloser function assists in closing all open plots at the
65# conclusion of a PipelineTask. When DM-39114 is implemented, this function and
66# all associated usage thereof should be removed.
67def _plotCloser(*args):
68 """Close all the plots in the given list."""
69 for plot in args:
70 plt.close(plot)
73class AnalysisBaseConnections(
74 PipelineTaskConnections, dimensions={}, defaultTemplates={"outputName": "Placeholder"}
75):
76 r"""Base class for Connections used for AnalysisTools PipelineTasks.
78 This class has a pre-defined output connection for the
79 MetricMeasurementMapping. The dataset type name for this connection is
80 determined by the template ``outputName``.
82 Output connections for plots created by `AnalysisPlot`\ s are created
83 dynamically when an instance of the class is created. The init method
84 examines all the `AnalysisPlot` actions specified in the associated
85 `AnalysisBaseConfig` subclass accumulating all the info needed to
86 create the output connections.
88 The dimensions for all of the output connections (metric and plot) will
89 be the same as the dimensions specified for the AnalysisBaseConnections
90 subclass (i.e. quantum dimensions).
91 """
93 metrics = ct.Output(
94 doc="Metrics calculated on input dataset type",
95 name="{outputName}_metrics",
96 storageClass="MetricMeasurementBundle",
97 )
99 def __init__(self, *, config: AnalysisBaseConfig = None): # type: ignore
100 # Validate that the outputName template has been set in config. This
101 # should have been checked early with the configs validate method, but
102 # it is possible for someone to manually create everything in a script
103 # without running validate, so also check it late here.
104 if (outputName := config.connections.outputName) == "Placeholder": # type: ignore
105 raise RuntimeError(
106 "Subclasses must specify an alternative value for the defaultTemplate `outputName`"
107 )
108 super().__init__(config=config)
110 # All arguments must be passed by kw, but python has not method to do
111 # that without specifying a default, so None is used. Validate that
112 # it is not None. This is largely for typing reasons, as in the normal
113 # course of operation code execution paths ensure this will not be None
114 assert config is not None
116 for tool in config.atools:
117 match tool.produce:
118 case JointAction():
119 if isinstance(tool.produce.metric, NoMetric):
120 continue
121 if len(tool.produce.metric.units) != 0:
122 hasMetrics = True
123 break
124 case MetricAction():
125 hasMetrics = True
126 break
127 else:
128 hasMetrics = False
130 # Set the dimensions for the metric
131 if hasMetrics:
132 self.metrics = ct.Output(
133 name=self.metrics.name,
134 doc=self.metrics.doc,
135 storageClass=self.metrics.storageClass,
136 dimensions=self.dimensions,
137 multiple=False,
138 isCalibration=False,
139 )
140 else:
141 # There are no metrics to produce, remove the output connection
142 self.outputs.remove("metrics")
144 # Look for any conflicting names, creating a set of them, as these
145 # will be added to the instance as well as recorded in the outputs
146 # set.
147 existingNames = set(dir(self))
149 # Accumulate all the names to be used from all of the defined
150 # AnalysisPlots.
151 names: list[str] = []
152 for action in config.atools:
153 if action.dynamicOutputNames:
154 outNames = action.getOutputNames(config=config)
155 else:
156 outNames = action.getOutputNames()
157 if action.parameterizedBand:
158 for band in config.bands:
159 names.extend(name.format(band=band) for name in outNames)
160 else:
161 names.extend(outNames)
163 # For each of the names found, create output connections.
164 for name in names:
165 name = f"{outputName}_{name}"
166 if name in self.outputs or name in existingNames:
167 raise NameError(
168 f"Plot with name {name} conflicts with existing connection"
169 " are two plots named the same?"
170 )
171 outConnection = ct.Output(
172 name=name,
173 storageClass="Plot",
174 doc="Dynamic connection for plotting",
175 dimensions=self.dimensions,
176 )
177 object.__setattr__(self, name, outConnection)
178 self.outputs.add(name)
181def _timestampValidator(value: str) -> bool:
182 if value in ("reference_package_timestamp", "run_timestamp", "current_timestamp", "dataset_timestamp"):
183 return True
184 elif "explicit_timestamp" in value:
185 try:
186 _, splitTime = value.split(":")
187 except ValueError:
188 logging.error(
189 "Explicit timestamp must be given in the format 'explicit_timestamp:datetime', "
190 r"where datetime is given in the form '%Y%m%dT%H%M%S%z"
191 )
192 return False
193 try:
194 datetime.datetime.strptime(splitTime, r"%Y%m%dT%H%M%S%z")
195 except ValueError:
196 # This is explicitly chosen to be an f string as the string
197 # contains control characters.
198 logging.error(
199 f"The supplied datetime {splitTime} could not be parsed correctly into "
200 r"%Y%m%dT%H%M%S%z format"
201 )
202 return False
203 return True
204 else:
205 return False
208class AnalysisBaseConfig(PipelineTaskConfig, pipelineConnections=AnalysisBaseConnections):
209 """Base class for all configs used to define an `AnalysisPipelineTask`.
211 This base class defines two fields that should be used in all subclasses,
212 atools, and bands.
214 The ``atools`` field is where the user configures which analysis tools will
215 be run as part of this `PipelineTask`.
217 The bands field specifies which bands will be looped over for
218 `AnalysisTools` which support parameterized bands. I.e. called once for
219 each band in the list.
220 """
222 atools = ConfigurableActionStructField[AnalysisTool](
223 doc="The analysis tools that are to be run by this task at execution"
224 )
225 # Temporarally alias these for backwards compatibility
226 plots = atools
227 metrics = atools
228 bands = ListField[str](
229 doc="Filter bands on which to run all of the actions", default=["u", "g", "r", "i", "z", "y"]
230 )
231 metric_tags = ListField[str](
232 doc="List of tags which will be added to all configurable actions", default=[]
233 )
234 dataset_identifier = Field[str](doc="An identifier to be associated with output Metrics", optional=True)
235 reference_package = Field[str](
236 doc="A package who's version, at the time of metric upload to a "
237 "time series database, will be converted to a timestamp of when "
238 "that version was produced",
239 default="lsst_distrib",
240 )
241 timestamp_version = Field[str](
242 doc="Which time stamp should be used as the reference timestamp for a "
243 "metric in a time series database, valid values are; "
244 "reference_package_timestamp, run_timestamp, current_timestamp, "
245 "dataset_timestamp and explicit_timestamp:datetime where datetime is "
246 "given in the form %Y%m%dT%H%M%S%z",
247 default="run_timestamp",
248 check=_timestampValidator,
249 )
251 def applyConfigOverrides(
252 self,
253 instrument: Instrument | None,
254 taskDefaultName: str,
255 pipelineConfigs: Iterable[ConfigIR] | None,
256 parameters: ParametersIR,
257 label: str,
258 ) -> None:
259 extraConfig = {}
260 if (value := parameters.mapping.get("sasquatch_dataset_identifier", None)) is not None:
261 extraConfig["dataset_identifier"] = value
262 if (value := parameters.mapping.get("sasquatch_reference_package", None)) is not None:
263 extraConfig["reference_package"] = value
264 if (value := parameters.mapping.get("sasquatch_timestamp_version", None)) is not None:
265 if "explicit_timestamp" in value:
266 try:
267 _, splitTime = value.split(":")
268 except ValueError as excpt:
269 raise ValueError(
270 "Explicit timestamp must be given in the format 'explicit_timestamp:datetime', "
271 "where datetime is given in the form '%Y%m%dT%H%M%S%z"
272 ) from excpt
273 try:
274 datetime.datetime.strptime(splitTime, r"%Y%m%dT%H%M%S%z")
275 except ValueError as excpt:
276 raise ValueError(
277 f"The supplied datetime {splitTime} could not be parsed correctly into "
278 "%Y%m%dT%H%M%S%z format"
279 ) from excpt
280 extraConfig["timestamp_version"] = value
281 if extraConfig:
282 newPipelineConfigs = [ConfigIR(rest=extraConfig)]
283 if pipelineConfigs is not None:
284 newPipelineConfigs.extend(pipelineConfigs)
285 pipelineConfigs = newPipelineConfigs
286 return super().applyConfigOverrides(instrument, taskDefaultName, pipelineConfigs, parameters, label)
288 def freeze(self):
289 # Copy the meta configuration values to each of the configured tools
290 # only do this if the tool has not been further specialized
291 if not self._frozen:
292 for tool in self.atools:
293 for tag in self.metric_tags:
294 tool.metric_tags.insert(-1, tag)
295 super().freeze()
297 def validate(self):
298 super().validate()
299 # Validate that the required connections template is set.
300 if self.connections.outputName == "Placeholder": # type: ignore
301 raise RuntimeError("Connections class 'outputName' must have a config explicitly set")
304class _StandinPlotInfo(dict):
305 """This class is an implementation detail to support plots in the instance
306 no PlotInfo object is present in the call to run.
307 """
309 def __missing__(self, key):
310 return ""
313class AnalysisPipelineTask(PipelineTask):
314 """Base class for `PipelineTasks` intended to run `AnalysisTools`.
316 The run method will run all of the `AnalysisTools` defined in the config
317 class.
318 """
320 # Typing config because type checkers dont know about our Task magic
321 config: AnalysisBaseConfig
322 ConfigClass = AnalysisBaseConfig
324 warnings_all = (
325 "divide by zero encountered in divide",
326 "invalid value encountered in arcsin",
327 "invalid value encountered in cos",
328 "invalid value encountered in divide",
329 "invalid value encountered in log10",
330 "invalid value encountered in scalar divide",
331 "invalid value encountered in sin",
332 "invalid value encountered in sqrt",
333 "invalid value encountered in true_divide",
334 "Mean of empty slice",
335 )
337 def _runTools(self, data: KeyedData, **kwargs) -> Struct:
338 with warnings.catch_warnings():
339 # Change below to "in self.warnings_all" to find otherwise
340 # unfiltered numpy warnings.
341 for warning in ():
342 warnings.filterwarnings("error", warning, RuntimeWarning)
343 results = Struct()
344 results.metrics = MetricMeasurementBundle(
345 dataset_identifier=self.config.dataset_identifier,
346 reference_package=self.config.reference_package,
347 timestamp_version=self.config.timestamp_version,
348 )
349 # copy plot info to be sure each action sees its own copy
350 plotInfo = kwargs.get("plotInfo")
351 plotKey = f"{self.config.connections.outputName}_{{name}}"
352 weakrefArgs = []
353 for name, action in self.config.atools.items():
354 kwargs["plotInfo"] = deepcopy(plotInfo)
355 actionResult = action(data, **kwargs)
356 metricAccumulate = []
357 for resultName, value in actionResult.items():
358 match value:
359 case PlotTypes():
360 setattr(results, plotKey.format(name=resultName), value)
361 weakrefArgs.append(value)
362 case Measurement():
363 metricAccumulate.append(value)
364 # only add the metrics if there are some
365 if metricAccumulate:
366 results.metrics[name] = metricAccumulate
367 # Wrap the return struct in a finalizer so that when results is
368 # garbage collected the plots will be closed.
369 # TODO: This finalize step closes all open plots at the conclusion
370 # of a task. When DM-39114 is implemented, this step should not
371 # be required and may be removed.
372 weakref.finalize(results, _plotCloser, *weakrefArgs)
373 return results
375 def run(self, *, data: KeyedData | None = None, **kwargs) -> Struct:
376 """Produce the outputs associated with this `PipelineTask`.
378 Parameters
379 ----------
380 data : `KeyedData`
381 The input data from which all `AnalysisTools` will run and produce
382 outputs. A side note, the python typing specifies that this can be
383 None, but this is only due to a limitation in python where in order
384 to specify that all arguments be passed only as keywords the
385 argument must be given a default. This argument most not actually
386 be None.
387 **kwargs
388 Additional arguments that are passed through to the `AnalysisTools`
389 specified in the configuration.
391 Returns
392 -------
393 results : `~lsst.pipe.base.Struct`
394 The accumulated results of all the plots and metrics produced by
395 this `PipelineTask`.
397 Raises
398 ------
399 ValueError
400 Raised if the supplied data argument is `None`
401 """
402 if data is None:
403 raise ValueError("data must not be none")
404 if "bands" not in kwargs:
405 kwargs["bands"] = list(self.config.bands)
406 if "plotInfo" not in kwargs:
407 kwargs["plotInfo"] = _StandinPlotInfo()
408 kwargs["plotInfo"]["bands"] = kwargs["bands"]
409 return self._runTools(data, **kwargs)
411 def runQuantum(
412 self,
413 butlerQC: QuantumContext,
414 inputRefs: InputQuantizedConnection,
415 outputRefs: OutputQuantizedConnection,
416 ) -> None:
417 """Override default runQuantum to load the minimal columns necessary
418 to complete the action.
420 Parameters
421 ----------
422 butlerQC : `~lsst.pipe.base.QuantumContext`
423 A butler which is specialized to operate in the context of a
424 `lsst.daf.butler.Quantum`.
425 inputRefs : `InputQuantizedConnection`
426 Datastructure whose attribute names are the names that identify
427 connections defined in corresponding `PipelineTaskConnections`
428 class. The values of these attributes are the
429 `lsst.daf.butler.DatasetRef` objects associated with the defined
430 input/prerequisite connections.
431 outputRefs : `OutputQuantizedConnection`
432 Datastructure whose attribute names are the names that identify
433 connections defined in corresponding `PipelineTaskConnections`
434 class. The values of these attributes are the
435 `lsst.daf.butler.DatasetRef` objects associated with the defined
436 output connections.
437 """
438 # TODO: This rcParams modification is a temporary solution, hiding
439 # a matplotlib warning indicating too many figures have been opened.
440 # When DM-39114 is implemented, this should be removed.
441 plt.rcParams.update({"figure.max_open_warning": 0})
442 inputs = butlerQC.get(inputRefs)
443 dataId = butlerQC.quantum.dataId
444 plotInfo = self.parsePlotInfo(inputs, dataId)
445 data = self.loadData(inputs["data"])
446 if "skymap" in inputs.keys():
447 skymap = inputs["skymap"]
448 else:
449 skymap = None
450 outputs = self.run(data=data, plotInfo=plotInfo, skymap=skymap)
451 butlerQC.put(outputs, outputRefs)
453 def _populatePlotInfoWithDataId(
454 self, plotInfo: MutableMapping[str, Any], dataId: DataCoordinate | None
455 ) -> None:
456 """Update the plotInfo with the dataId values.
458 Parameters
459 ----------
460 plotInfo : `dict`
461 The plotInfo dictionary to update.
462 dataId : `lsst.daf.butler.DataCoordinate`
463 The dataId to use to update the plotInfo.
464 """
465 if dataId is not None:
466 plotInfo.update(dataId.mapping)
468 def parsePlotInfo(
469 self, inputs: Mapping[str, Any] | None, dataId: DataCoordinate | None, connectionName: str = "data"
470 ) -> Mapping[str, str]:
471 """Parse the inputs and dataId to get the information needed to
472 to add to the figure.
474 Parameters
475 ----------
476 inputs: `dict`
477 The inputs to the task
478 dataCoordinate: `lsst.daf.butler.DataCoordinate`
479 The dataId that the task is being run on.
480 connectionName: `str`, optional
481 Name of the input connection to use for determining table name.
483 Returns
484 -------
485 plotInfo : `dict`
486 """
488 if inputs is None:
489 tableName = ""
490 run = ""
491 else:
492 tableName = inputs[connectionName].ref.datasetType.name
493 run = inputs[connectionName].ref.run
495 # Initialize the plot info dictionary
496 plotInfo = {"tableName": tableName, "run": run}
498 self._populatePlotInfoWithDataId(plotInfo, dataId)
499 return plotInfo
501 def loadData(self, handle: DeferredDatasetHandle, names: Iterable[str] | None = None) -> KeyedData:
502 """Load the minimal set of keyed data from the input dataset.
504 Parameters
505 ----------
506 handle : `DeferredDatasetHandle`
507 Handle to load the dataset with only the specified columns.
508 names : `Iterable` of `str`
509 The names of keys to extract from the dataset.
510 If `names` is `None` then the `collectInputNames` method
511 is called to generate the names.
512 For most purposes these are the names of columns to load from
513 a catalog or data frame.
515 Returns
516 -------
517 result: `KeyedData`
518 The dataset with only the specified keys loaded.
519 """
520 if names is None:
521 names = self.collectInputNames()
522 return cast(KeyedData, handle.get(parameters={"columns": names}))
524 def collectInputNames(self) -> Iterable[str]:
525 """Get the names of the inputs.
527 If using the default `loadData` method this will gather the names
528 of the keys to be loaded from an input dataset.
530 Returns
531 -------
532 inputs : `Iterable` of `str`
533 The names of the keys in the `KeyedData` object to extract.
535 """
536 inputs = set()
537 for band in self.config.bands:
538 for action in self.config.atools:
539 for key, _ in action.getFormattedInputSchema(band=band):
540 inputs.add(key)
541 return inputs