Coverage for python/lsst/analysis/tools/interfaces/_task.py: 19%
201 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-15 02:38 -0700
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-15 02:38 -0700
1# This file is part of analysis_tools. #
2# Developed for the LSST Data Management System.
3# This product includes software developed by the LSST Project
4# (https://www.lsst.org).
5# See the COPYRIGHT file at the top-level directory of this distribution
6# for details of code ownership.
7#
8# This program is free software: you can redistribute it and/or modify
9# it under the terms of the GNU General Public License as published by
10# the Free Software Foundation, either version 3 of the License, or
11# (at your option) any later version.
12#
13# This program is distributed in the hope that it will be useful,
14# but WITHOUT ANY WARRANTY; without even the implied warranty of
15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16# GNU General Public License for more details.
17#
18# You should have received a copy of the GNU General Public License
19# along with this program. If not, see <https://www.gnu.org/licenses/>.
20from __future__ import annotations
22"""Base class implementation for the classes needed in creating `PipelineTasks`
23which execute `AnalysisTools`.
25The classes defined in this module have all the required behaviors for
26defining, introspecting, and executing `AnalysisTools` against an input dataset
27type.
29Subclasses of these tasks should specify specific datasets to consume in their
30connection classes and should specify a unique name
31"""
33__all__ = ("AnalysisBaseConnections", "AnalysisBaseConfig", "AnalysisPipelineTask")
35import datetime
36import logging
37import warnings
38import weakref
39from collections.abc import Iterable
40from copy import deepcopy
41from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, cast
43import matplotlib.pyplot as plt
44from lsst.verify import Measurement
46if TYPE_CHECKING: 46 ↛ 47line 46 didn't jump to line 47, because the condition on line 46 was never true
47 from lsst.daf.butler import DeferredDatasetHandle
48 from lsst.pipe.base import QuantumContext
50from lsst.daf.butler import DataCoordinate
51from lsst.pex.config import Field, ListField
52from lsst.pex.config.configurableActions import ConfigurableActionStructField
53from lsst.pipe.base import Instrument, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct
54from lsst.pipe.base import connectionTypes as ct
55from lsst.pipe.base.connections import InputQuantizedConnection, OutputQuantizedConnection
56from lsst.pipe.base.pipelineIR import ConfigIR, ParametersIR
58from ._actions import JointAction, MetricAction, NoMetric
59from ._analysisTools import AnalysisTool
60from ._interfaces import KeyedData, PlotTypes
61from ._metricMeasurementBundle import MetricMeasurementBundle
63# TODO: This rcParams modification is a temporary solution, hiding
64# a matplotlib warning indicating too many figures have been opened.
65# When DM-39114 is implemented, this should be removed.
66plt.rcParams.update({"figure.max_open_warning": 0})
69# TODO: This _plotCloser function assists in closing all open plots at the
70# conclusion of a PipelineTask. When DM-39114 is implemented, this function and
71# all associated usage thereof should be removed.
72def _plotCloser(*args):
73 """Close all the plots in the given list."""
74 for plot in args:
75 plt.close(plot)
78class AnalysisBaseConnections(
79 PipelineTaskConnections, dimensions={}, defaultTemplates={"outputName": "Placeholder"}
80):
81 r"""Base class for Connections used for AnalysisTools PipelineTasks.
83 This class has a pre-defined output connection for the
84 MetricMeasurementMapping. The dataset type name for this connection is
85 determined by the template ``outputName``.
87 Output connections for plots created by `AnalysisPlot`\ s are created
88 dynamically when an instance of the class is created. The init method
89 examines all the `AnalysisPlot` actions specified in the associated
90 `AnalysisBaseConfig` subclass accumulating all the info needed to
91 create the output connections.
93 The dimensions for all of the output connections (metric and plot) will
94 be the same as the dimensions specified for the AnalysisBaseConnections
95 subclass (i.e. quantum dimensions).
96 """
98 metrics = ct.Output(
99 doc="Metrics calculated on input dataset type",
100 name="{outputName}_metrics",
101 storageClass="MetricMeasurementBundle",
102 )
104 def __init__(self, *, config: AnalysisBaseConfig = None): # type: ignore
105 # Validate that the outputName template has been set in config. This
106 # should have been checked early with the configs validate method, but
107 # it is possible for someone to manually create everything in a script
108 # without running validate, so also check it late here.
109 if (outputName := config.connections.outputName) == "Placeholder": # type: ignore
110 raise RuntimeError(
111 "Subclasses must specify an alternative value for the defaultTemplate `outputName`"
112 )
113 super().__init__(config=config)
115 # All arguments must be passed by kw, but python has not method to do
116 # that without specifying a default, so None is used. Validate that
117 # it is not None. This is largely for typing reasons, as in the normal
118 # course of operation code execution paths ensure this will not be None
119 assert config is not None
121 for tool in config.atools:
122 match tool.produce:
123 case JointAction():
124 if isinstance(tool.produce.metric, NoMetric):
125 continue
126 if len(tool.produce.metric.units) != 0:
127 hasMetrics = True
128 break
129 case MetricAction():
130 hasMetrics = True
131 break
132 else:
133 hasMetrics = False
135 # Set the dimensions for the metric
136 if hasMetrics:
137 self.metrics = ct.Output(
138 name=self.metrics.name,
139 doc=self.metrics.doc,
140 storageClass=self.metrics.storageClass,
141 dimensions=self.dimensions,
142 multiple=False,
143 isCalibration=False,
144 )
145 else:
146 # There are no metrics to produce, remove the output connection
147 self.outputs.remove("metrics")
149 # Look for any conflicting names, creating a set of them, as these
150 # will be added to the instance as well as recorded in the outputs
151 # set.
152 existingNames = set(dir(self))
154 # Accumulate all the names to be used from all of the defined
155 # AnalysisPlots.
156 names: list[str] = []
157 for action in config.atools:
158 if action.dynamicOutputNames:
159 outNames = action.getOutputNames(config=config)
160 else:
161 outNames = action.getOutputNames()
162 if action.parameterizedBand:
163 for band in config.bands:
164 names.extend(name.format(band=band) for name in outNames)
165 else:
166 names.extend(outNames)
168 # For each of the names found, create output connections.
169 for name in names:
170 name = f"{outputName}_{name}"
171 if name in self.outputs or name in existingNames:
172 raise NameError(
173 f"Plot with name {name} conflicts with existing connection"
174 " are two plots named the same?"
175 )
176 outConnection = ct.Output(
177 name=name,
178 storageClass="Plot",
179 doc="Dynamic connection for plotting",
180 dimensions=self.dimensions,
181 )
182 object.__setattr__(self, name, outConnection)
183 self.outputs.add(name)
186def _timestampValidator(value: str) -> bool:
187 if value in ("reference_package_timestamp", "run_timestamp", "current_timestamp", "dataset_timestamp"):
188 return True
189 elif "explicit_timestamp" in value:
190 try:
191 _, splitTime = value.split(":")
192 except ValueError:
193 logging.error(
194 "Explicit timestamp must be given in the format 'explicit_timestamp:datetime', "
195 r"where datetime is given in the form '%Y%m%dT%H%M%S%z"
196 )
197 return False
198 try:
199 datetime.datetime.strptime(splitTime, r"%Y%m%dT%H%M%S%z")
200 except ValueError:
201 # This is explicitly chosen to be an f string as the string
202 # contains control characters.
203 logging.error(
204 f"The supplied datetime {splitTime} could not be parsed correctly into "
205 r"%Y%m%dT%H%M%S%z format"
206 )
207 return False
208 return True
209 else:
210 return False
213class AnalysisBaseConfig(PipelineTaskConfig, pipelineConnections=AnalysisBaseConnections):
214 """Base class for all configs used to define an `AnalysisPipelineTask`.
216 This base class defines two fields that should be used in all subclasses,
217 atools, and bands.
219 The ``atools`` field is where the user configures which analysis tools will
220 be run as part of this `PipelineTask`.
222 The bands field specifies which bands will be looped over for
223 `AnalysisTools` which support parameterized bands. I.e. called once for
224 each band in the list.
225 """
227 atools = ConfigurableActionStructField[AnalysisTool](
228 doc="The analysis tools that are to be run by this task at execution"
229 )
230 # Temporarally alias these for backwards compatibility
231 plots = atools
232 metrics = atools
233 bands = ListField[str](
234 doc="Filter bands on which to run all of the actions", default=["u", "g", "r", "i", "z", "y"]
235 )
236 metric_tags = ListField[str](
237 doc="List of tags which will be added to all configurable actions", default=[]
238 )
239 dataset_identifier = Field[str](doc="An identifier to be associated with output Metrics", optional=True)
240 reference_package = Field[str](
241 doc="A package who's version, at the time of metric upload to a "
242 "time series database, will be converted to a timestamp of when "
243 "that version was produced",
244 default="lsst_distrib",
245 )
246 timestamp_version = Field[str](
247 doc="Which time stamp should be used as the reference timestamp for a "
248 "metric in a time series database, valid values are; "
249 "reference_package_timestamp, run_timestamp, current_timestamp, "
250 "dataset_timestamp and explicit_timestamp:datetime where datetime is "
251 "given in the form %Y%m%dT%H%M%S%z",
252 default="run_timestamp",
253 check=_timestampValidator,
254 )
256 def applyConfigOverrides(
257 self,
258 instrument: Instrument | None,
259 taskDefaultName: str,
260 pipelineConfigs: Iterable[ConfigIR] | None,
261 parameters: ParametersIR,
262 label: str,
263 ) -> None:
264 extraConfig = {}
265 if (value := parameters.mapping.get("sasquatch_dataset_identifier", None)) is not None:
266 extraConfig["dataset_identifier"] = value
267 if (value := parameters.mapping.get("sasquatch_reference_package", None)) is not None:
268 extraConfig["reference_package"] = value
269 if (value := parameters.mapping.get("sasquatch_timestamp_version", None)) is not None:
270 if "explicit_timestamp" in value:
271 try:
272 _, splitTime = value.split(":")
273 except ValueError as excpt:
274 raise ValueError(
275 "Explicit timestamp must be given in the format 'explicit_timestamp:datetime', "
276 "where datetime is given in the form '%Y%m%dT%H%M%S%z"
277 ) from excpt
278 try:
279 datetime.datetime.strptime(splitTime, r"%Y%m%dT%H%M%S%z")
280 except ValueError as excpt:
281 raise ValueError(
282 f"The supplied datetime {splitTime} could not be parsed correctly into "
283 "%Y%m%dT%H%M%S%z format"
284 ) from excpt
285 extraConfig["timestamp_version"] = value
286 if extraConfig:
287 newPipelineConfigs = [ConfigIR(rest=extraConfig)]
288 if pipelineConfigs is not None:
289 newPipelineConfigs.extend(pipelineConfigs)
290 pipelineConfigs = newPipelineConfigs
291 return super().applyConfigOverrides(instrument, taskDefaultName, pipelineConfigs, parameters, label)
293 def freeze(self):
294 # Copy the meta configuration values to each of the configured tools
295 # only do this if the tool has not been further specialized
296 if not self._frozen:
297 for tool in self.atools:
298 for tag in self.metric_tags:
299 tool.metric_tags.insert(-1, tag)
300 super().freeze()
302 def validate(self):
303 super().validate()
304 # Validate that the required connections template is set.
305 if self.connections.outputName == "Placeholder": # type: ignore
306 raise RuntimeError("Connections class 'outputName' must have a config explicitly set")
309class _StandinPlotInfo(dict):
310 """This class is an implementation detail to support plots in the instance
311 no PlotInfo object is present in the call to run.
312 """
314 def __missing__(self, key):
315 return ""
318class AnalysisPipelineTask(PipelineTask):
319 """Base class for `PipelineTasks` intended to run `AnalysisTools`.
321 The run method will run all of the `AnalysisTools` defined in the config
322 class.
323 """
325 # Typing config because type checkers dont know about our Task magic
326 config: AnalysisBaseConfig
327 ConfigClass = AnalysisBaseConfig
329 warnings_all = (
330 "divide by zero encountered in divide",
331 "invalid value encountered in arcsin",
332 "invalid value encountered in cos",
333 "invalid value encountered in divide",
334 "invalid value encountered in log10",
335 "invalid value encountered in scalar divide",
336 "invalid value encountered in sin",
337 "invalid value encountered in sqrt",
338 "invalid value encountered in true_divide",
339 "Mean of empty slice",
340 )
342 def _runTools(self, data: KeyedData, **kwargs) -> Struct:
343 with warnings.catch_warnings():
344 # Change below to "in self.warnings_all" to find otherwise
345 # unfiltered numpy warnings.
346 for warning in ():
347 warnings.filterwarnings("error", warning, RuntimeWarning)
348 results = Struct()
349 results.metrics = MetricMeasurementBundle(
350 dataset_identifier=self.config.dataset_identifier,
351 reference_package=self.config.reference_package,
352 timestamp_version=self.config.timestamp_version,
353 )
354 # copy plot info to be sure each action sees its own copy
355 plotInfo = kwargs.get("plotInfo")
356 plotKey = f"{self.config.connections.outputName}_{{name}}"
357 weakrefArgs = []
358 for name, action in self.config.atools.items():
359 kwargs["plotInfo"] = deepcopy(plotInfo)
360 actionResult = action(data, **kwargs)
361 metricAccumulate = []
362 for resultName, value in actionResult.items():
363 match value:
364 case PlotTypes():
365 setattr(results, plotKey.format(name=resultName), value)
366 weakrefArgs.append(value)
367 case Measurement():
368 metricAccumulate.append(value)
369 # only add the metrics if there are some
370 if metricAccumulate:
371 results.metrics[name] = metricAccumulate
372 # Wrap the return struct in a finalizer so that when results is
373 # garbage collected the plots will be closed.
374 # TODO: This finalize step closes all open plots at the conclusion
375 # of a task. When DM-39114 is implemented, this step should not
376 # be required and may be removed.
377 weakref.finalize(results, _plotCloser, *weakrefArgs)
378 return results
380 def run(self, *, data: KeyedData | None = None, **kwargs) -> Struct:
381 """Produce the outputs associated with this `PipelineTask`.
383 Parameters
384 ----------
385 data : `KeyedData`
386 The input data from which all `AnalysisTools` will run and produce
387 outputs. A side note, the python typing specifies that this can be
388 None, but this is only due to a limitation in python where in order
389 to specify that all arguments be passed only as keywords the
390 argument must be given a default. This argument most not actually
391 be None.
392 **kwargs
393 Additional arguments that are passed through to the `AnalysisTools`
394 specified in the configuration.
396 Returns
397 -------
398 results : `~lsst.pipe.base.Struct`
399 The accumulated results of all the plots and metrics produced by
400 this `PipelineTask`.
402 Raises
403 ------
404 ValueError
405 Raised if the supplied data argument is `None`
406 """
407 if data is None:
408 raise ValueError("data must not be none")
409 if "bands" not in kwargs:
410 kwargs["bands"] = list(self.config.bands)
411 if "plotInfo" not in kwargs:
412 kwargs["plotInfo"] = _StandinPlotInfo()
413 kwargs["plotInfo"]["bands"] = kwargs["bands"]
414 return self._runTools(data, **kwargs)
416 def runQuantum(
417 self,
418 butlerQC: QuantumContext,
419 inputRefs: InputQuantizedConnection,
420 outputRefs: OutputQuantizedConnection,
421 ) -> None:
422 """Override default runQuantum to load the minimal columns necessary
423 to complete the action.
425 Parameters
426 ----------
427 butlerQC : `~lsst.pipe.base.QuantumContext`
428 A butler which is specialized to operate in the context of a
429 `lsst.daf.butler.Quantum`.
430 inputRefs : `InputQuantizedConnection`
431 Datastructure whose attribute names are the names that identify
432 connections defined in corresponding `PipelineTaskConnections`
433 class. The values of these attributes are the
434 `lsst.daf.butler.DatasetRef` objects associated with the defined
435 input/prerequisite connections.
436 outputRefs : `OutputQuantizedConnection`
437 Datastructure whose attribute names are the names that identify
438 connections defined in corresponding `PipelineTaskConnections`
439 class. The values of these attributes are the
440 `lsst.daf.butler.DatasetRef` objects associated with the defined
441 output connections.
442 """
443 inputs = butlerQC.get(inputRefs)
444 dataId = butlerQC.quantum.dataId
445 plotInfo = self.parsePlotInfo(inputs, dataId)
446 # We implicitly assume that 'data' has been defined, but do not have a
447 # corresponding input connection in the base class. Thus, we capture
448 # and re-raise the error with a more helpful message.
449 try:
450 # data has to be popped out to avoid duplication in the call to the
451 # `run` method.
452 inputData = inputs.pop("data")
453 except KeyError:
454 raise RuntimeError("'data' is a required input connection, but is not defined.")
455 data = self.loadData(inputData)
456 outputs = self.run(data=data, plotInfo=plotInfo, **inputs)
457 butlerQC.put(outputs, outputRefs)
459 def _populatePlotInfoWithDataId(
460 self, plotInfo: MutableMapping[str, Any], dataId: DataCoordinate | None
461 ) -> None:
462 """Update the plotInfo with the dataId values.
464 Parameters
465 ----------
466 plotInfo : `dict`
467 The plotInfo dictionary to update.
468 dataId : `lsst.daf.butler.DataCoordinate`
469 The dataId to use to update the plotInfo.
470 """
471 if dataId is not None:
472 plotInfo.update(dataId.mapping)
474 def parsePlotInfo(
475 self, inputs: Mapping[str, Any] | None, dataId: DataCoordinate | None, connectionName: str = "data"
476 ) -> Mapping[str, str]:
477 """Parse the inputs and dataId to get the information needed to
478 to add to the figure.
480 Parameters
481 ----------
482 inputs: `dict`
483 The inputs to the task
484 dataCoordinate: `lsst.daf.butler.DataCoordinate`
485 The dataId that the task is being run on.
486 connectionName: `str`, optional
487 Name of the input connection to use for determining table name.
489 Returns
490 -------
491 plotInfo : `dict`
492 """
494 if inputs is None:
495 tableName = ""
496 run = ""
497 else:
498 tableName = inputs[connectionName].ref.datasetType.name
499 run = inputs[connectionName].ref.run
501 # Initialize the plot info dictionary
502 plotInfo = {"tableName": tableName, "run": run}
504 self._populatePlotInfoWithDataId(plotInfo, dataId)
505 return plotInfo
507 def loadData(self, handle: DeferredDatasetHandle, names: Iterable[str] | None = None) -> KeyedData:
508 """Load the minimal set of keyed data from the input dataset.
510 Parameters
511 ----------
512 handle : `DeferredDatasetHandle`
513 Handle to load the dataset with only the specified columns.
514 names : `Iterable` of `str`
515 The names of keys to extract from the dataset.
516 If `names` is `None` then the `collectInputNames` method
517 is called to generate the names.
518 For most purposes these are the names of columns to load from
519 a catalog or data frame.
521 Returns
522 -------
523 result: `KeyedData`
524 The dataset with only the specified keys loaded.
525 """
526 if names is None:
527 names = self.collectInputNames()
528 return cast(KeyedData, handle.get(parameters={"columns": names}))
530 def collectInputNames(self) -> Iterable[str]:
531 """Get the names of the inputs.
533 If using the default `loadData` method this will gather the names
534 of the keys to be loaded from an input dataset.
536 Returns
537 -------
538 inputs : `Iterable` of `str`
539 The names of the keys in the `KeyedData` object to extract.
541 """
542 inputs = set()
543 for band in self.config.bands:
544 for action in self.config.atools:
545 for key, _ in action.getFormattedInputSchema(band=band):
546 inputs.add(key)
547 return inputs