Coverage for python/lsst/analysis/tools/interfaces/_analysisTools.py: 22%
173 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-05 04:13 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-05 04:13 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("AnalysisTool",)
26from collections import ChainMap
27from collections.abc import Mapping
28from functools import wraps
29from operator import attrgetter
30from typing import Callable, Iterable, Protocol, runtime_checkable
32import lsst.pex.config as pexConfig
33from lsst.obs.base import Instrument
34from lsst.pex.config import Field, FieldValidationError, ListField
35from lsst.pex.config.configurableActions import ConfigurableActionField
36from lsst.pipe.base import Pipeline
37from lsst.verify import Measurement
39from ._actions import AnalysisAction, JointAction, JointResults, NoPlot, PlotAction
40from ._interfaces import KeyedData, KeyedDataSchema, KeyedResults, PlotTypes
41from ._stages import BasePrep, BaseProcess, BaseProduce
44@runtime_checkable
45class _HasOutputNames(Protocol):
46 def getOutputNames(self, config: pexConfig.Config | None = None) -> Iterable[str]: ... 46 ↛ exitline 46 didn't jump to line 46, because
49def _finalizeWrapper(
50 f: Callable[[AnalysisTool], None], cls: type[AnalysisTool]
51) -> Callable[[AnalysisTool], None]:
52 """Wrap a classes finalize function to ensure the base classes special
53 finalize method only fires after the most derived finalize method.
55 Parameters
56 ----------
57 f : `Callable`
58 Function that is being wrapped
59 cls : `type` of `AnalysisTool`
60 The class which is having its function wrapped
62 Returns
63 -------
64 function : `Callable`
65 The new function which wraps the old
66 """
68 @wraps(f)
69 def wrapper(self: AnalysisTool) -> None:
70 # call the wrapped finalize function
71 f(self)
72 # get the method resolution order for the self variable
73 mro = self.__class__.mro()
75 # Find which class in the mro that last defines a finalize method
76 # note that this is in the reverse order from the mro, because the
77 # last class in an inheritance stack is the first in the mro (aka you
78 # walk from the furthest child first.
79 #
80 # Also note that the most derived finalize method need not be the same
81 # as the type of self, as that might inherit from a parent somewhere
82 # between it and the furthest parent.
83 mostDerived: type | None = None
84 for klass in mro:
85 # inspect the classes dictionary to see if it specifically defines
86 # finalize. This is needed because normal lookup will go through
87 # the mro, but this needs to be restricted to each class.
88 if "finalize" in vars(klass):
89 mostDerived = klass
90 break
92 # Find what stage in the MRO walking process the recursive function
93 # call is in.
94 this = super(cls, self).__thisclass__
96 # If the current place in the MRO walking is also the class that
97 # defines the most derived instance of finalize, then call the base
98 # classes private finalize that must be called after everything else.
99 if mostDerived is not None and this == mostDerived:
100 self._baseFinalize()
102 return wrapper
105class AnalysisTool(AnalysisAction):
106 r"""A tool which which calculates a single type of analysis on input data,
107 though it may return more than one result.
109 Although `AnalysisTool`\ s are considered a single type of analysis, the
110 classes themselves can be thought of as a container. `AnalysisTool`\ s
111 are aggregations of `AnalysisAction`\ s to form prep, process, and
112 produce stages. These stages allow better reuse of individual
113 `AnalysisActions` and easier introspection in contexts such as a notebook
114 or interpreter.
116 An `AnalysisTool` can be thought of an an individual configuration that
117 specifies which `AnalysisAction` should run for each stage.
119 The stages themselves are also configurable, allowing control over various
120 aspects of the individual `AnalysisAction`\ s.
121 """
123 prep = ConfigurableActionField[AnalysisAction](doc="Action to run to prepare inputs", default=BasePrep)
124 process = ConfigurableActionField[AnalysisAction](
125 doc="Action to process data into intended form", default=BaseProcess
126 )
127 produce = ConfigurableActionField[AnalysisAction](
128 doc="Action to perform any finalization steps", default=BaseProduce
129 )
130 metric_tags = ListField[str](
131 doc="List of tags which will be associated with metric measurement(s)", default=[]
132 )
134 def __init_subclass__(cls: type[AnalysisTool], **kwargs):
135 super().__init_subclass__(**kwargs)
136 # Wrap all definitions of the finalize method in a special wrapper that
137 # ensures that the bases classes private finalize is called last.
138 if "finalize" in vars(cls):
139 cls.finalize = _finalizeWrapper(cls.finalize, cls)
141 dynamicOutputNames: bool | Field[bool] = False
142 """Determines whether to grant the ``getOutputNames`` method access to
143 config parameters.
144 """
146 parameterizedBand: bool | Field[bool] = True
147 """Specifies if an `AnalysisTool` may parameterize a band within any field
148 in any stage, or if the set of bands is already uniquely determined though
149 configuration. I.e. can this `AnalysisTool` be automatically looped over to
150 produce a result for multiple bands.
151 """
153 propagateData: bool | Field[bool] = False
154 """If this value is set to True, the input data `KeyedData` will be passed
155 to each stage in addition to the ``prep`` stage. Any keys created in a
156 stage with the same key that exists in the input ``data`` will shadow that
157 key/value.
158 """
160 def __call__(self, data: KeyedData, **kwargs) -> KeyedResults:
161 bands = kwargs.pop("bands", None)
162 if "plotInfo" in kwargs and kwargs.get("plotInfo") is not None:
163 if "plotName" not in kwargs["plotInfo"] or kwargs["plotInfo"]["plotName"] is None:
164 kwargs["plotInfo"]["plotName"] = self.identity
165 if not self.parameterizedBand or bands is None:
166 if "band" not in kwargs:
167 # Some tasks require a "band" key for naming. This shouldn't
168 # affect the results. DM-35813 should make this unnecessary.
169 kwargs["band"] = "analysisTools"
170 return self._call_single(data, **kwargs)
171 results: KeyedResults = {}
172 for band in bands:
173 kwargs["band"] = band
174 if "plotInfo" in kwargs:
175 kwargs["plotInfo"]["bands"] = band
176 subResult = self._call_single(data, **kwargs)
177 for key, value in subResult.items():
178 match value:
179 case PlotTypes():
180 results[f"{band}_{key}"] = value
181 case Measurement():
182 results[key] = value
183 return results
185 def _call_single(self, data: KeyedData, **kwargs) -> KeyedResults:
186 # create a shallow copy of kwargs
187 kwargs = dict(**kwargs)
188 kwargs["metric_tags"] = list(self.metric_tags or ())
189 prepped: KeyedData = self.prep(data, **kwargs) # type: ignore
190 if self.propagateData:
191 prepped = ChainMap(data, prepped)
192 processed: KeyedData = self.process(prepped, **kwargs) # type: ignore
193 if self.propagateData:
194 processed = ChainMap(data, processed)
195 finalized: (
196 Mapping[str, PlotTypes] | PlotTypes | Mapping[str, Measurement] | Measurement | JointResults
197 ) = self.produce(
198 processed, **kwargs
199 ) # type: ignore
200 return self._process_single_results(finalized)
202 def _getPlotType(self) -> str:
203 match self.produce:
204 case PlotAction():
205 return type(self.produce).__name__
206 case JointAction(plot=NoPlot()):
207 pass
208 case JointAction(plot=plotter):
209 return type(plotter).__name__
211 return ""
213 def _process_single_results(
214 self,
215 results: Mapping[str, PlotTypes] | PlotTypes | Mapping[str, Measurement] | Measurement | JointResults,
216 ) -> KeyedResults:
217 accumulation = {}
218 suffix = self._getPlotType()
219 predicate = f"{self.identity}" if self.identity else ""
220 match results:
221 case Mapping():
222 for key, value in results.items():
223 match value:
224 case PlotTypes():
225 iterable = (predicate, key, suffix)
226 case Measurement():
227 iterable = (predicate, key)
228 refKey = "_".join(x for x in iterable if x)
229 accumulation[refKey] = value
230 case PlotTypes():
231 refKey = "_".join(x for x in (predicate, suffix) if x)
232 accumulation[refKey] = results
233 case Measurement():
234 accumulation[f"{predicate}"] = results
235 case JointResults(plot=plotResults, metric=metricResults):
236 if plotResults is not None:
237 subResult = self._process_single_results(plotResults)
238 accumulation.update(subResult)
239 if metricResults is not None:
240 subResult = self._process_single_results(metricResults)
241 accumulation.update(subResult)
242 return accumulation
244 def getInputSchema(self) -> KeyedDataSchema:
245 return self.prep.getInputSchema()
247 def populatePrepFromProcess(self):
248 """Add additional inputs to the prep stage if supported.
250 If the configured prep action supports adding to it's input schema,
251 attempt to add the required inputs schema from the process stage to the
252 prep stage.
254 This method will be a no-op if the prep action does not support this
255 feature.
256 """
257 self.prep.addInputSchema(self.process.getInputSchema())
259 def getOutputNames(self, config: pexConfig.Config | None = None) -> Iterable[str]:
260 """Return the names of the plots produced by this analysis tool.
262 If there is a `PlotAction` defined in the produce action, these names
263 will either come from the `PlotAction` if it defines a
264 ``getOutputNames`` method (likely if it returns a mapping of figures),
265 or a default value is used and a single figure is assumed.
267 Parameters
268 ----------
269 config : `lsst.pex.config.Config`, optional
270 Configuration of the task. This is only used if the output naming
271 needs to be config-aware.
273 Returns
274 -------
275 result : `tuple` of `str`
276 Names for each plot produced by this action.
277 """
278 match self.produce:
279 case JointAction(plot=NoPlot()):
280 return tuple()
281 case _HasOutputNames():
282 outNames = tuple(self.produce.getOutputNames(config=config))
283 case _:
284 raise ValueError(f"Unsupported Action type {type(self.produce)} for getting output names")
286 results = []
287 suffix = self._getPlotType()
288 if self.parameterizedBand:
289 prefix = "_".join(x for x in ("{band}", self.identity) if x)
290 else:
291 prefix = f"{self.identity}" if self.identity else ""
293 if outNames:
294 for name in outNames:
295 results.append("_".join(x for x in (prefix, name, suffix) if x))
296 else:
297 results.append("_".join(x for x in (prefix, suffix) if x))
298 return results
300 @classmethod
301 def fromPipeline(
302 cls,
303 pipeline: str | Pipeline,
304 name: str,
305 fullpath: bool = False,
306 instrument: Instrument | str | None = None,
307 ) -> AnalysisTool | None:
308 """Construct an `AnalysisTool` from a definition written in a
309 `~lsst.pipe.base.Pipeline`.
311 Parameters
312 ----------
313 pipeline : `str` or `~lsst.pipe.base.Pipeline`
314 The pipeline to load the `AnalysisTool` from.
315 name : `str`
316 The name of the analysis tool to run. This can either be just the
317 name assigned to the tool, or an absolute name in a config
318 hierarchy.
319 fullpath : `bool`
320 Determines if the name is interpreted as an absolute path in a
321 config hierarchy, or is relative to an `AnalysisTool` ``atools``
322 `~lsst.pex.config.configurableActions.ConfigurableActionStructField`
323 .
324 instrument : `~lsst.daf.butler.instrument.Instrument` or `str` or\
325 `None`
326 Either a derived class object of a `lsst.daf.butler.instrument` or
327 a string corresponding to a fully qualified
328 `lsst.daf.butler.instrument` name or None if no instrument needs
329 specified or if the pipeline contains the instrument in it.
330 Defaults to None.
332 Returns
333 -------
334 tool : `AnalysisTool`
335 The loaded `AnalysisTool` as configured in the pipeline.
337 Raises
338 ------
339 ValueError
340 Raised if the config field specified does not point to an
341 `AnalysisTool`.
342 Raised if an instrument is specified and it conflicts with the
343 pipelines instrument.
344 """
345 if not isinstance(pipeline, Pipeline):
346 pipeline = Pipeline.fromFile(pipeline)
347 # If the caller specified an instrument, verify it does not conflict
348 # with the pipelines instrument, and add it to the pipeline
349 if instrument is not None:
350 if (pipeInstrument := pipeline.getInstrument()) and pipeInstrument != instrument:
351 raise ValueError(
352 f"The supplied instrument {instrument} conflicts with the pipelines instrument "
353 f"{pipeInstrument}."
354 )
355 else:
356 pipeline.addInstrument(instrument)
357 try:
358 pipelineGraph = pipeline.to_graph()
359 except (FieldValidationError, ValueError) as err:
360 raise ValueError(
361 "There was an error instantiating the pipeline, do you need to specify an instrument?"
362 ) from err
363 if not fullpath:
364 name = f"atools.{name}"
365 for task in pipelineGraph.tasks.values():
366 config = task.config
367 try:
368 attr = attrgetter(name)(config)
369 except AttributeError:
370 continue
371 if not isinstance(attr, AnalysisTool):
372 raise ValueError("The requested name did not refer to an analysisTool")
373 return attr
374 return None
376 def finalize(self) -> None:
377 """Run any finalization code that depends on configuration being
378 complete.
379 """
380 pass
382 def _baseFinalize(self) -> None:
383 self.populatePrepFromProcess()
385 def freeze(self):
386 if not self.__dict__.get("_finalizeRun"):
387 self.finalize()
388 self.__dict__["_finalizeRun"] = True
389 super().freeze()
392# explicitly wrap the finalize of the base class
393AnalysisTool.finalize = _finalizeWrapper(AnalysisTool.finalize, AnalysisTool)