Coverage for python / lsst / analysis / tools / interfaces / _stages.py: 19%
157 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 08:45 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 08:45 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("BasePrep", "BaseProcess", "BaseMetricAction", "BaseProduce")
25import logging
26from collections import abc
27from collections.abc import Mapping
28from typing import Any, cast
30import astropy.units as apu
31from healsparse import HealSparseMap
33from lsst.pex.config import ListField
34from lsst.pex.config.configurableActions import ConfigurableActionStructField
35from lsst.pex.config.dictField import DictField
36from lsst.pipe.base import AlgorithmError
37from lsst.verify import Measurement
39from ._actions import (
40 AnalysisAction,
41 JointAction,
42 KeyedDataAction,
43 MetricAction,
44 MetricResultType,
45 NoPlot,
46 Tensor,
47 VectorAction,
48)
49from ._interfaces import KeyedData, KeyedDataSchema, KeyedDataTypes, Scalar, Vector
51_LOG = logging.getLogger(__name__)
54class MissingMetadataError(AlgorithmError):
55 """Raised if a required metadata key is missing.
57 Parameters
58 ----------
59 key : `str`
60 The missing key.
61 data_repr : `str`
62 The string representation of the input data which was missing the key.
63 """
65 def __init__(self, key, data_repr) -> None:
66 self._key = key
67 self._data_repr = data_repr
68 super().__init__(f"Key '{self._key}' could not be found in input data {self._data_repr}")
70 @property
71 def metadata(self) -> dict:
72 return {
73 "metadata_key": self._key,
74 "input_data_repr": self._data_repr,
75 }
78class BasePrep(KeyedDataAction):
79 """Base class for actions which prepare data for processing."""
81 keysToLoad = ListField[str](doc="Keys to extract from KeyedData and return", default=[])
83 vectorKeys = ListField[str](doc="Keys from the input data which selectors will be applied", default=[])
85 selectors = ConfigurableActionStructField[VectorAction](
86 doc="Selectors for selecting rows, will be AND together",
87 )
89 def getInputSchema(self) -> KeyedDataSchema:
90 yield from (
91 (column, Vector | Scalar | HealSparseMap | Tensor)
92 for column in sorted(set(self.keysToLoad).union(self.vectorKeys))
93 )
94 for action in self.selectors:
95 yield from action.getInputSchema()
97 def getOutputSchema(self) -> KeyedDataSchema:
98 return (
99 (column, Vector | Scalar | HealSparseMap | Tensor)
100 for column in sorted(set(self.keysToLoad).union(self.vectorKeys))
101 )
103 def __call__(self, data: KeyedData, **kwargs) -> KeyedData:
104 mask: Vector | None = None
105 for selector in self.selectors:
106 subMask = selector(data, **kwargs)
107 if mask is None:
108 mask = subMask
109 else:
110 mask *= subMask # type: ignore
111 result: dict[str, Any] = {}
112 for key in sorted(set(self.keysToLoad).union(self.vectorKeys)):
113 formattedKey = key.format_map(kwargs)
114 result[formattedKey] = cast(Vector, data[formattedKey])
115 if mask is not None:
116 # Convert to ordered set to avoid formatting duplicate keys
117 # multiple times
118 formattedKeys = list({key.format_map(kwargs): None for key in self.vectorKeys}.keys())
119 for tempFormat in formattedKeys:
120 # ignore type since there is not fully proper mypy support for
121 # vector type casting. In the future there will be, and this
122 # makes it clearer now what type things should be.
123 result[tempFormat] = cast(Vector, result[tempFormat])[mask] # type: ignore
124 return result
126 def addInputSchema(self, inputSchema: KeyedDataSchema) -> None:
127 existing = list(self.keysToLoad)
128 existingVectors = list(self.vectorKeys)
129 for name, typ in inputSchema:
130 existing.append(name)
131 if typ == Vector:
132 existingVectors.append(name)
133 self.keysToLoad = sorted(set(existing))
134 self.vectorKeys = sorted(set(existingVectors))
137class BaseProcess(KeyedDataAction):
138 """Base class for actions which process data."""
140 buildActions = ConfigurableActionStructField[VectorAction | KeyedDataAction](
141 doc="Actions which compute a Vector which will be added to results"
142 )
143 filterActions = ConfigurableActionStructField[VectorAction | KeyedDataAction](
144 doc="Actions which filter one or more input or build Vectors into shorter vectors"
145 )
146 calculateActions = ConfigurableActionStructField[AnalysisAction](
147 doc="Actions which compute quantities from the input or built data"
148 )
150 def getInputSchema(self) -> KeyedDataSchema:
151 inputSchema: KeyedDataTypes = {} # type: ignore
152 buildOutputSchema: KeyedDataTypes = {} # type: ignore
153 filterOutputSchema: KeyedDataTypes = {} # type: ignore
154 action: AnalysisAction
156 for fieldName, action in self.buildActions.items():
157 for name, typ in action.getInputSchema():
158 inputSchema[name] = typ
159 if isinstance(action, KeyedDataAction):
160 buildOutputSchema.update(action.getOutputSchema() or {})
161 else:
162 buildOutputSchema[fieldName] = Vector
164 for fieldName, action in self.filterActions.items():
165 for name, typ in action.getInputSchema():
166 if name not in buildOutputSchema:
167 inputSchema[name] = typ
168 if isinstance(action, KeyedDataAction):
169 filterOutputSchema.update(action.getOutputSchema() or {})
170 else:
171 filterOutputSchema[fieldName] = Vector
173 for calcAction in self.calculateActions:
174 for name, typ in calcAction.getInputSchema():
175 if name not in buildOutputSchema and name not in filterOutputSchema:
176 inputSchema[name] = typ
177 return ((name, typ) for name, typ in inputSchema.items())
179 def getOutputSchema(self) -> KeyedDataSchema:
180 for action in self.buildActions:
181 if isinstance(action, KeyedDataAction):
182 outSchema = action.getOutputSchema()
183 if outSchema is not None:
184 yield from outSchema
186 def __call__(self, data: KeyedData, **kwargs) -> KeyedData:
187 action: AnalysisAction
188 results = {}
189 data = dict(data)
190 for name, action in self.buildActions.items():
191 match action(data, **kwargs):
192 case abc.Mapping() as item:
193 for key, result in item.items():
194 results[key] = result
195 case item:
196 results[name] = item
197 view1 = data | results
198 for name, action in self.filterActions.items():
199 match action(view1, **kwargs):
200 case abc.Mapping() as item:
201 for key, result in item.items():
202 results[key] = result
203 case item:
204 results[name] = item
206 view2 = data | results
207 for name, calcAction in self.calculateActions.items():
208 match calcAction(view2, **kwargs):
209 case abc.Mapping() as item:
210 for key, result in item.items():
211 results[key] = result
212 case item:
213 results[name] = item
214 return results
217def _newNameChecker(value: str) -> bool:
218 if "-" in value:
219 # Yes this should be a log here, as pex config provides no other
220 # useful way to get info into the exception.
221 _LOG.error("Remapped metric names must not have a - character in them.")
222 return False
223 return True
226class BaseMetricAction(MetricAction):
227 """Base class for actions which compute metrics."""
229 units = DictField[str, str](doc="Mapping of scalar key to astropy unit string", default={})
230 newNames = DictField[str, str](
231 doc=(
232 "Mapping of key to new name if needed prior to creating metric, "
233 "cannot contain a minus character in the name."
234 ),
235 default={},
236 itemCheck=_newNameChecker,
237 )
239 def getInputSchema(self) -> KeyedDataSchema:
240 # Something is wrong with the typing for DictField key iteration
241 return [(key, Scalar) for key in self.units] # type: ignore
243 @staticmethod
244 def _sanitizeMetadataName(metadata_name: str) -> str:
245 """Sanitize a metadata name to allow it to be pushed to Apache Avro.
246 Parameters
247 ----------
248 metadata_name : `str`
249 The metadata name to sanitize.
250 Returns
251 -------
252 sanitized_metadata_name : `str`
253 The sanitized metadata name.
254 """
255 return metadata_name.replace(" ", "_").replace("-", "_")
257 def __call__(self, data: KeyedData, **kwargs) -> MetricResultType:
258 results = {}
259 for key, unit in self.units.items():
260 formattedKey = key.format(**kwargs)
261 if formattedKey not in data:
262 raise MissingMetadataError(formattedKey, data.__repr__())
263 value = data[formattedKey]
264 if newName := self.newNames.get(key):
265 formattedKey = newName.format(**kwargs)
266 notes = {"metric_tags": kwargs.get("metric_tags", [])}
267 if isinstance(value, Mapping):
268 for k, v in value.items():
269 if not isinstance(v, Scalar):
270 raise ValueError(f"Data for subkey {k} of key {key} is not a Scalar type")
271 formattedSubKey = self._sanitizeMetadataName(f"{formattedKey}_{k}")
272 results[formattedSubKey] = Measurement(formattedSubKey, v * apu.Unit(unit), notes=notes)
273 else:
274 if not isinstance(value, Scalar):
275 raise ValueError(f"Data for key {key} is not a Scalar type")
276 formattedKey = self._sanitizeMetadataName(formattedKey)
277 results[formattedKey] = Measurement(formattedKey, value * apu.Unit(unit), notes=notes)
278 return results
281class BaseProduce(JointAction):
282 """Base class for actions which produce data."""
284 def setDefaults(self):
285 super().setDefaults()
286 self.metric = BaseMetricAction()
287 self.plot = NoPlot