Coverage for python / lsst / analysis / tools / interfaces / _stages.py: 19%
156 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 18:53 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 18:53 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("BasePrep", "BaseProcess", "BaseMetricAction", "BaseProduce")
25import logging
26from collections import abc
27from typing import Any, Mapping, cast
29import astropy.units as apu
30from healsparse import HealSparseMap
31from lsst.pex.config import ListField
32from lsst.pex.config.configurableActions import ConfigurableActionStructField
33from lsst.pex.config.dictField import DictField
34from lsst.pipe.base import AlgorithmError
35from lsst.verify import Measurement
37from ._actions import (
38 AnalysisAction,
39 JointAction,
40 KeyedDataAction,
41 MetricAction,
42 MetricResultType,
43 NoPlot,
44 Tensor,
45 VectorAction,
46)
47from ._interfaces import KeyedData, KeyedDataSchema, KeyedDataTypes, Scalar, Vector
49_LOG = logging.getLogger(__name__)
52class MissingMetadataError(AlgorithmError):
53 """Raised if a required metadata key is missing.
55 Parameters
56 ----------
57 key : `str`
58 The missing key.
59 data_repr : `str`
60 The string representation of the input data which was missing the key.
61 """
63 def __init__(self, key, data_repr) -> None:
64 self._key = key
65 self._data_repr = data_repr
66 super().__init__(f"Key '{self._key}' could not be found in input data {self._data_repr}")
68 @property
69 def metadata(self) -> dict:
70 return {
71 "metadata_key": self._key,
72 "input_data_repr": self._data_repr,
73 }
76class BasePrep(KeyedDataAction):
77 """Base class for actions which prepare data for processing."""
79 keysToLoad = ListField[str](doc="Keys to extract from KeyedData and return", default=[])
81 vectorKeys = ListField[str](doc="Keys from the input data which selectors will be applied", default=[])
83 selectors = ConfigurableActionStructField[VectorAction](
84 doc="Selectors for selecting rows, will be AND together",
85 )
87 def getInputSchema(self) -> KeyedDataSchema:
88 yield from (
89 (column, Vector | Scalar | HealSparseMap | Tensor)
90 for column in sorted(set(self.keysToLoad).union(self.vectorKeys))
91 )
92 for action in self.selectors:
93 yield from action.getInputSchema()
95 def getOutputSchema(self) -> KeyedDataSchema:
96 return (
97 (column, Vector | Scalar | HealSparseMap | Tensor)
98 for column in sorted(set(self.keysToLoad).union(self.vectorKeys))
99 )
101 def __call__(self, data: KeyedData, **kwargs) -> KeyedData:
102 mask: Vector | None = None
103 for selector in self.selectors:
104 subMask = selector(data, **kwargs)
105 if mask is None:
106 mask = subMask
107 else:
108 mask *= subMask # type: ignore
109 result: dict[str, Any] = {}
110 for key in sorted(set(self.keysToLoad).union(self.vectorKeys)):
111 formattedKey = key.format_map(kwargs)
112 result[formattedKey] = cast(Vector, data[formattedKey])
113 if mask is not None:
114 # Convert to ordered set to avoid formatting duplicate keys
115 # multiple times
116 formattedKeys = list({key.format_map(kwargs): None for key in self.vectorKeys}.keys())
117 for tempFormat in formattedKeys:
118 # ignore type since there is not fully proper mypy support for
119 # vector type casting. In the future there will be, and this
120 # makes it clearer now what type things should be.
121 result[tempFormat] = cast(Vector, result[tempFormat])[mask] # type: ignore
122 return result
124 def addInputSchema(self, inputSchema: KeyedDataSchema) -> None:
125 existing = list(self.keysToLoad)
126 existingVectors = list(self.vectorKeys)
127 for name, typ in inputSchema:
128 existing.append(name)
129 if typ == Vector:
130 existingVectors.append(name)
131 self.keysToLoad = sorted(set(existing))
132 self.vectorKeys = sorted(set(existingVectors))
135class BaseProcess(KeyedDataAction):
136 """Base class for actions which process data."""
138 buildActions = ConfigurableActionStructField[VectorAction | KeyedDataAction](
139 doc="Actions which compute a Vector which will be added to results"
140 )
141 filterActions = ConfigurableActionStructField[VectorAction | KeyedDataAction](
142 doc="Actions which filter one or more input or build Vectors into shorter vectors"
143 )
144 calculateActions = ConfigurableActionStructField[AnalysisAction](
145 doc="Actions which compute quantities from the input or built data"
146 )
148 def getInputSchema(self) -> KeyedDataSchema:
149 inputSchema: KeyedDataTypes = {} # type: ignore
150 buildOutputSchema: KeyedDataTypes = {} # type: ignore
151 filterOutputSchema: KeyedDataTypes = {} # type: ignore
152 action: AnalysisAction
154 for fieldName, action in self.buildActions.items():
155 for name, typ in action.getInputSchema():
156 inputSchema[name] = typ
157 if isinstance(action, KeyedDataAction):
158 buildOutputSchema.update(action.getOutputSchema() or {})
159 else:
160 buildOutputSchema[fieldName] = Vector
162 for fieldName, action in self.filterActions.items():
163 for name, typ in action.getInputSchema():
164 if name not in buildOutputSchema:
165 inputSchema[name] = typ
166 if isinstance(action, KeyedDataAction):
167 filterOutputSchema.update(action.getOutputSchema() or {})
168 else:
169 filterOutputSchema[fieldName] = Vector
171 for calcAction in self.calculateActions:
172 for name, typ in calcAction.getInputSchema():
173 if name not in buildOutputSchema and name not in filterOutputSchema:
174 inputSchema[name] = typ
175 return ((name, typ) for name, typ in inputSchema.items())
177 def getOutputSchema(self) -> KeyedDataSchema:
178 for action in self.buildActions:
179 if isinstance(action, KeyedDataAction):
180 outSchema = action.getOutputSchema()
181 if outSchema is not None:
182 yield from outSchema
184 def __call__(self, data: KeyedData, **kwargs) -> KeyedData:
185 action: AnalysisAction
186 results = {}
187 data = dict(data)
188 for name, action in self.buildActions.items():
189 match action(data, **kwargs):
190 case abc.Mapping() as item:
191 for key, result in item.items():
192 results[key] = result
193 case item:
194 results[name] = item
195 view1 = data | results
196 for name, action in self.filterActions.items():
197 match action(view1, **kwargs):
198 case abc.Mapping() as item:
199 for key, result in item.items():
200 results[key] = result
201 case item:
202 results[name] = item
204 view2 = data | results
205 for name, calcAction in self.calculateActions.items():
206 match calcAction(view2, **kwargs):
207 case abc.Mapping() as item:
208 for key, result in item.items():
209 results[key] = result
210 case item:
211 results[name] = item
212 return results
215def _newNameChecker(value: str) -> bool:
216 if "-" in value:
217 # Yes this should be a log here, as pex config provides no other
218 # useful way to get info into the exception.
219 _LOG.error("Remapped metric names must not have a - character in them.")
220 return False
221 return True
224class BaseMetricAction(MetricAction):
225 """Base class for actions which compute metrics."""
227 units = DictField[str, str](doc="Mapping of scalar key to astropy unit string", default={})
228 newNames = DictField[str, str](
229 doc=(
230 "Mapping of key to new name if needed prior to creating metric, "
231 "cannot contain a minus character in the name."
232 ),
233 default={},
234 itemCheck=_newNameChecker,
235 )
237 def getInputSchema(self) -> KeyedDataSchema:
238 # Something is wrong with the typing for DictField key iteration
239 return [(key, Scalar) for key in self.units] # type: ignore
241 @staticmethod
242 def _sanitizeMetadataName(metadata_name: str) -> str:
243 """Sanitize a metadata name to allow it to be pushed to Apache Avro.
244 Parameters
245 ----------
246 metadata_name : `str`
247 The metadata name to sanitize.
248 Returns
249 -------
250 sanitized_metadata_name : `str`
251 The sanitized metadata name.
252 """
253 return metadata_name.replace(" ", "_").replace("-", "_")
255 def __call__(self, data: KeyedData, **kwargs) -> MetricResultType:
256 results = {}
257 for key, unit in self.units.items():
258 formattedKey = key.format(**kwargs)
259 if formattedKey not in data:
260 raise MissingMetadataError(formattedKey, data.__repr__())
261 value = data[formattedKey]
262 if newName := self.newNames.get(key):
263 formattedKey = newName.format(**kwargs)
264 notes = {"metric_tags": kwargs.get("metric_tags", [])}
265 if isinstance(value, Mapping):
266 for k, v in value.items():
267 if not isinstance(v, Scalar):
268 raise ValueError(f"Data for subkey {k} of key {key} is not a Scalar type")
269 formattedSubKey = self._sanitizeMetadataName(f"{formattedKey}_{k}")
270 results[formattedSubKey] = Measurement(formattedSubKey, v * apu.Unit(unit), notes=notes)
271 else:
272 if not isinstance(value, Scalar):
273 raise ValueError(f"Data for key {key} is not a Scalar type")
274 formattedKey = self._sanitizeMetadataName(formattedKey)
275 results[formattedKey] = Measurement(formattedKey, value * apu.Unit(unit), notes=notes)
276 return results
279class BaseProduce(JointAction):
280 """Base class for actions which produce data."""
282 def setDefaults(self):
283 super().setDefaults()
284 self.metric = BaseMetricAction()
285 self.plot = NoPlot