Coverage for python/lsst/analysis/tools/interfaces/_analysisTools.py: 21%
136 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-01 04:04 -0700
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-01 04:04 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("AnalysisTool",)
26from collections.abc import Mapping
27from functools import wraps
28from typing import Callable, Iterable, Protocol, runtime_checkable
30from lsst.pex.config import Field, ListField
31from lsst.pex.config.configurableActions import ConfigurableActionField
32from lsst.verify import Measurement
34from ._actions import AnalysisAction, JointAction, JointResults, NoPlot, PlotAction
35from ._interfaces import KeyedData, KeyedDataSchema, KeyedResults, PlotTypes
36from ._stages import BasePrep, BaseProcess, BaseProduce
39@runtime_checkable
40class _HasOutputNames(Protocol):
41 def getOutputNames(self) -> Iterable[str]:
42 ...
45def _finalizeWrapper(
46 f: Callable[[AnalysisTool], None], cls: type[AnalysisTool]
47) -> Callable[[AnalysisTool], None]:
48 """Wrap a classes finalize function to ensure the base classes special
49 finalize method only fires after the most derived finalize method.
51 Parameters
52 ----------
53 f : `Callable`
54 Function that is being wrapped
55 cls : `type` of `AnalysisTool`
56 The class which is having its function wrapped
58 Returns
59 -------
60 function : `Callable`
61 The new function which wraps the old
62 """
64 @wraps(f)
65 def wrapper(self: AnalysisTool) -> None:
66 # call the wrapped finalize function
67 f(self)
68 # get the method resolution order for the self variable
69 mro = self.__class__.mro()
71 # Find which class in the mro that last defines a finalize method
72 # note that this is in the reverse order from the mro, because the
73 # last class in an inheritance stack is the first in the mro (aka you
74 # walk from the furthest child first.
75 #
76 # Also note that the most derived finalize method need not be the same
77 # as the type of self, as that might inherit from a parent somewhere
78 # between it and the furthest parent.
79 mostDerived: type | None = None
80 for klass in mro:
81 # inspect the classes dictionary to see if it specifically defines
82 # finalize. This is needed because normal lookup will go through
83 # the mro, but this needs to be restricted to each class.
84 if "finalize" in vars(klass):
85 mostDerived = klass
86 break
88 # Find what stage in the MRO walking process the recursive function
89 # call is in.
90 this = super(cls, self).__thisclass__
92 # If the current place in the MRO walking is also the class that
93 # defines the most derived instance of finalize, then call the base
94 # classes private finalize that must be called after everything else.
95 if mostDerived is not None and this == mostDerived:
96 self._baseFinalize()
98 return wrapper
101class AnalysisTool(AnalysisAction):
102 r"""A tool which which calculates a single type of analysis on input data,
103 though it may return more than one result.
105 Although `AnalysisTool`\ s are considered a single type of analysis, the
106 classes themselves can be thought of as a container. `AnalysisTool`\ s
107 are aggregations of `AnalysisAction`\ s to form prep, process, and
108 produce stages. These stages allow better reuse of individual
109 `AnalysisActions` and easier introspection in contexts such as a notebook
110 or interpreter.
112 An `AnalysisTool` can be thought of an an individual configuration that
113 specifies which `AnalysisAction` should run for each stage.
115 The stages themselves are also configurable, allowing control over various
116 aspects of the individual `AnalysisAction`\ s.
117 """
118 prep = ConfigurableActionField[AnalysisAction](doc="Action to run to prepare inputs", default=BasePrep)
119 process = ConfigurableActionField[AnalysisAction](
120 doc="Action to process data into intended form", default=BaseProcess
121 )
122 produce = ConfigurableActionField[AnalysisAction](
123 doc="Action to perform any finalization steps", default=BaseProduce
124 )
125 metric_tags = ListField[str](
126 doc="List of tags which will be associated with metric measurement(s)", default=[]
127 )
129 def __init_subclass__(cls: type[AnalysisTool], **kwargs):
130 super().__init_subclass__(**kwargs)
131 # Wrap all definitions of the finalize method in a special wrapper that
132 # ensures that the bases classes private finalize is called last.
133 if "finalize" in vars(cls):
134 cls.finalize = _finalizeWrapper(cls.finalize, cls)
136 parameterizedBand: bool | Field[bool] = True
137 """Specifies if an `AnalysisTool` may parameterize a band within any field
138 in any stage, or if the set of bands is already uniquely determined though
139 configuration. I.e. can this `AnalysisTool` be automatically looped over to
140 produce a result for multiple bands.
141 """
143 def __call__(self, data: KeyedData, **kwargs) -> KeyedResults:
144 bands = kwargs.pop("bands", None)
145 if "plotInfo" in kwargs and kwargs.get("plotInfo") is not None:
146 kwargs["plotInfo"]["plotName"] = self.identity
147 if not self.parameterizedBand or bands is None:
148 if "band" not in kwargs:
149 # Some tasks require a "band" key for naming. This shouldn't
150 # affect the results. DM-35813 should make this unnecessary.
151 kwargs["band"] = "analysisTools"
152 return self._call_single(data, **kwargs)
153 results: KeyedResults = {}
154 for band in bands:
155 kwargs["band"] = band
156 if "plotInfo" in kwargs:
157 kwargs["plotInfo"]["bands"] = band
158 subResult = self._call_single(data, **kwargs)
159 for key, value in subResult.items():
160 match value:
161 case PlotTypes():
162 results[f"{band}_{key}"] = value
163 case Measurement():
164 results[key] = value
165 return results
167 def _call_single(self, data: KeyedData, **kwargs) -> KeyedResults:
168 # create a shallow copy of kwargs
169 kwargs = dict(**kwargs)
170 kwargs["metric_tags"] = list(self.metric_tags or ())
171 prepped: KeyedData = self.prep(data, **kwargs) # type: ignore
172 processed: KeyedData = self.process(prepped, **kwargs) # type: ignore
173 finalized: Mapping[str, PlotTypes] | PlotTypes | Mapping[
174 str, Measurement
175 ] | Measurement | JointResults = self.produce(
176 processed, **kwargs
177 ) # type: ignore
178 return self._process_single_results(finalized)
180 def _getPlotType(self) -> str:
181 match self.produce:
182 case PlotAction():
183 return type(self.produce).__name__
184 case JointAction(plot=NoPlot()):
185 pass
186 case JointAction(plot=plotter):
187 return type(plotter).__name__
189 return ""
191 def _process_single_results(
192 self,
193 results: Mapping[str, PlotTypes] | PlotTypes | Mapping[str, Measurement] | Measurement | JointResults,
194 ) -> KeyedResults:
195 accumulation = {}
196 suffix = self._getPlotType()
197 predicate = f"{self.identity}" if self.identity else ""
198 match results:
199 case Mapping():
200 for key, value in results.items():
201 match value:
202 case PlotTypes():
203 iterable = (predicate, key, suffix)
204 case Measurement():
205 iterable = (predicate, key)
206 refKey = "_".join(x for x in iterable if x)
207 accumulation[refKey] = value
208 case PlotTypes():
209 refKey = "_".join(x for x in (predicate, suffix) if x)
210 accumulation[refKey] = results
211 case Measurement():
212 accumulation[f"{predicate}"] = results
213 case JointResults(plot=plotResults, metric=metricResults):
214 if plotResults is not None:
215 subResult = self._process_single_results(plotResults)
216 accumulation.update(subResult)
217 if metricResults is not None:
218 subResult = self._process_single_results(metricResults)
219 accumulation.update(subResult)
220 return accumulation
222 def getInputSchema(self) -> KeyedDataSchema:
223 return self.prep.getInputSchema()
225 def populatePrepFromProcess(self):
226 """Add additional inputs to the prep stage if supported.
228 If the configured prep action supports adding to it's input schema,
229 attempt to add the required inputs schema from the process stage to the
230 prep stage.
232 This method will be a no-op if the prep action does not support this
233 feature.
234 """
235 self.prep.addInputSchema(self.process.getInputSchema())
237 def getOutputNames(self) -> Iterable[str]:
238 """Return the names of the plots produced by this analysis tool.
240 If there is a `PlotAction` defined in the produce action, these names
241 will either come from the `PlotAction` if it defines a
242 ``getOutputNames`` method (likely if it returns a mapping of figures),
243 or a default value is used and a single figure is assumed.
245 Returns
246 -------
247 result : `tuple` of `str`
248 Names for each plot produced by this action.
249 """
250 match self.produce:
251 case JointAction(plot=NoPlot()):
252 return tuple()
253 case _HasOutputNames():
254 outNames = tuple(self.produce.getOutputNames())
255 case _:
256 raise ValueError(f"Unsupported Action type {type(self.produce)} for getting output names")
258 results = []
259 suffix = self._getPlotType()
260 if self.parameterizedBand:
261 prefix = "_".join(x for x in ("{band}", self.identity) if x)
262 else:
263 prefix = f"{self.identity}" if self.identity else ""
265 if outNames:
266 for name in outNames:
267 results.append("_".join(x for x in (prefix, name, suffix) if x))
268 else:
269 results.append("_".join(x for x in (prefix, suffix) if x))
270 return results
272 def finalize(self) -> None:
273 """Run any finalization code that depends on configuration being
274 complete.
275 """
276 pass
278 def _baseFinalize(self) -> None:
279 self.populatePrepFromProcess()
281 def freeze(self):
282 if not self.__dict__.get("_finalizeRun"):
283 self.finalize()
284 self.__dict__["_finalizeRun"] = True
285 super().freeze()
288# explicitly wrap the finalize of the base class
289AnalysisTool.finalize = _finalizeWrapper(AnalysisTool.finalize, AnalysisTool)