Coverage for python / lsst / analysis / tools / interfaces / _stages.py: 19%

156 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 18:53 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("BasePrep", "BaseProcess", "BaseMetricAction", "BaseProduce") 

24 

25import logging 

26from collections import abc 

27from typing import Any, Mapping, cast 

28 

29import astropy.units as apu 

30from healsparse import HealSparseMap 

31from lsst.pex.config import ListField 

32from lsst.pex.config.configurableActions import ConfigurableActionStructField 

33from lsst.pex.config.dictField import DictField 

34from lsst.pipe.base import AlgorithmError 

35from lsst.verify import Measurement 

36 

37from ._actions import ( 

38 AnalysisAction, 

39 JointAction, 

40 KeyedDataAction, 

41 MetricAction, 

42 MetricResultType, 

43 NoPlot, 

44 Tensor, 

45 VectorAction, 

46) 

47from ._interfaces import KeyedData, KeyedDataSchema, KeyedDataTypes, Scalar, Vector 

48 

49_LOG = logging.getLogger(__name__) 

50 

51 

52class MissingMetadataError(AlgorithmError): 

53 """Raised if a required metadata key is missing. 

54 

55 Parameters 

56 ---------- 

57 key : `str` 

58 The missing key. 

59 data_repr : `str` 

60 The string representation of the input data which was missing the key. 

61 """ 

62 

63 def __init__(self, key, data_repr) -> None: 

64 self._key = key 

65 self._data_repr = data_repr 

66 super().__init__(f"Key '{self._key}' could not be found in input data {self._data_repr}") 

67 

68 @property 

69 def metadata(self) -> dict: 

70 return { 

71 "metadata_key": self._key, 

72 "input_data_repr": self._data_repr, 

73 } 

74 

75 

76class BasePrep(KeyedDataAction): 

77 """Base class for actions which prepare data for processing.""" 

78 

79 keysToLoad = ListField[str](doc="Keys to extract from KeyedData and return", default=[]) 

80 

81 vectorKeys = ListField[str](doc="Keys from the input data which selectors will be applied", default=[]) 

82 

83 selectors = ConfigurableActionStructField[VectorAction]( 

84 doc="Selectors for selecting rows, will be AND together", 

85 ) 

86 

87 def getInputSchema(self) -> KeyedDataSchema: 

88 yield from ( 

89 (column, Vector | Scalar | HealSparseMap | Tensor) 

90 for column in sorted(set(self.keysToLoad).union(self.vectorKeys)) 

91 ) 

92 for action in self.selectors: 

93 yield from action.getInputSchema() 

94 

95 def getOutputSchema(self) -> KeyedDataSchema: 

96 return ( 

97 (column, Vector | Scalar | HealSparseMap | Tensor) 

98 for column in sorted(set(self.keysToLoad).union(self.vectorKeys)) 

99 ) 

100 

101 def __call__(self, data: KeyedData, **kwargs) -> KeyedData: 

102 mask: Vector | None = None 

103 for selector in self.selectors: 

104 subMask = selector(data, **kwargs) 

105 if mask is None: 

106 mask = subMask 

107 else: 

108 mask *= subMask # type: ignore 

109 result: dict[str, Any] = {} 

110 for key in sorted(set(self.keysToLoad).union(self.vectorKeys)): 

111 formattedKey = key.format_map(kwargs) 

112 result[formattedKey] = cast(Vector, data[formattedKey]) 

113 if mask is not None: 

114 # Convert to ordered set to avoid formatting duplicate keys 

115 # multiple times 

116 formattedKeys = list({key.format_map(kwargs): None for key in self.vectorKeys}.keys()) 

117 for tempFormat in formattedKeys: 

118 # ignore type since there is not fully proper mypy support for 

119 # vector type casting. In the future there will be, and this 

120 # makes it clearer now what type things should be. 

121 result[tempFormat] = cast(Vector, result[tempFormat])[mask] # type: ignore 

122 return result 

123 

124 def addInputSchema(self, inputSchema: KeyedDataSchema) -> None: 

125 existing = list(self.keysToLoad) 

126 existingVectors = list(self.vectorKeys) 

127 for name, typ in inputSchema: 

128 existing.append(name) 

129 if typ == Vector: 

130 existingVectors.append(name) 

131 self.keysToLoad = sorted(set(existing)) 

132 self.vectorKeys = sorted(set(existingVectors)) 

133 

134 

135class BaseProcess(KeyedDataAction): 

136 """Base class for actions which process data.""" 

137 

138 buildActions = ConfigurableActionStructField[VectorAction | KeyedDataAction]( 

139 doc="Actions which compute a Vector which will be added to results" 

140 ) 

141 filterActions = ConfigurableActionStructField[VectorAction | KeyedDataAction]( 

142 doc="Actions which filter one or more input or build Vectors into shorter vectors" 

143 ) 

144 calculateActions = ConfigurableActionStructField[AnalysisAction]( 

145 doc="Actions which compute quantities from the input or built data" 

146 ) 

147 

148 def getInputSchema(self) -> KeyedDataSchema: 

149 inputSchema: KeyedDataTypes = {} # type: ignore 

150 buildOutputSchema: KeyedDataTypes = {} # type: ignore 

151 filterOutputSchema: KeyedDataTypes = {} # type: ignore 

152 action: AnalysisAction 

153 

154 for fieldName, action in self.buildActions.items(): 

155 for name, typ in action.getInputSchema(): 

156 inputSchema[name] = typ 

157 if isinstance(action, KeyedDataAction): 

158 buildOutputSchema.update(action.getOutputSchema() or {}) 

159 else: 

160 buildOutputSchema[fieldName] = Vector 

161 

162 for fieldName, action in self.filterActions.items(): 

163 for name, typ in action.getInputSchema(): 

164 if name not in buildOutputSchema: 

165 inputSchema[name] = typ 

166 if isinstance(action, KeyedDataAction): 

167 filterOutputSchema.update(action.getOutputSchema() or {}) 

168 else: 

169 filterOutputSchema[fieldName] = Vector 

170 

171 for calcAction in self.calculateActions: 

172 for name, typ in calcAction.getInputSchema(): 

173 if name not in buildOutputSchema and name not in filterOutputSchema: 

174 inputSchema[name] = typ 

175 return ((name, typ) for name, typ in inputSchema.items()) 

176 

177 def getOutputSchema(self) -> KeyedDataSchema: 

178 for action in self.buildActions: 

179 if isinstance(action, KeyedDataAction): 

180 outSchema = action.getOutputSchema() 

181 if outSchema is not None: 

182 yield from outSchema 

183 

184 def __call__(self, data: KeyedData, **kwargs) -> KeyedData: 

185 action: AnalysisAction 

186 results = {} 

187 data = dict(data) 

188 for name, action in self.buildActions.items(): 

189 match action(data, **kwargs): 

190 case abc.Mapping() as item: 

191 for key, result in item.items(): 

192 results[key] = result 

193 case item: 

194 results[name] = item 

195 view1 = data | results 

196 for name, action in self.filterActions.items(): 

197 match action(view1, **kwargs): 

198 case abc.Mapping() as item: 

199 for key, result in item.items(): 

200 results[key] = result 

201 case item: 

202 results[name] = item 

203 

204 view2 = data | results 

205 for name, calcAction in self.calculateActions.items(): 

206 match calcAction(view2, **kwargs): 

207 case abc.Mapping() as item: 

208 for key, result in item.items(): 

209 results[key] = result 

210 case item: 

211 results[name] = item 

212 return results 

213 

214 

215def _newNameChecker(value: str) -> bool: 

216 if "-" in value: 

217 # Yes this should be a log here, as pex config provides no other 

218 # useful way to get info into the exception. 

219 _LOG.error("Remapped metric names must not have a - character in them.") 

220 return False 

221 return True 

222 

223 

224class BaseMetricAction(MetricAction): 

225 """Base class for actions which compute metrics.""" 

226 

227 units = DictField[str, str](doc="Mapping of scalar key to astropy unit string", default={}) 

228 newNames = DictField[str, str]( 

229 doc=( 

230 "Mapping of key to new name if needed prior to creating metric, " 

231 "cannot contain a minus character in the name." 

232 ), 

233 default={}, 

234 itemCheck=_newNameChecker, 

235 ) 

236 

237 def getInputSchema(self) -> KeyedDataSchema: 

238 # Something is wrong with the typing for DictField key iteration 

239 return [(key, Scalar) for key in self.units] # type: ignore 

240 

241 @staticmethod 

242 def _sanitizeMetadataName(metadata_name: str) -> str: 

243 """Sanitize a metadata name to allow it to be pushed to Apache Avro. 

244 Parameters 

245 ---------- 

246 metadata_name : `str` 

247 The metadata name to sanitize. 

248 Returns 

249 ------- 

250 sanitized_metadata_name : `str` 

251 The sanitized metadata name. 

252 """ 

253 return metadata_name.replace(" ", "_").replace("-", "_") 

254 

255 def __call__(self, data: KeyedData, **kwargs) -> MetricResultType: 

256 results = {} 

257 for key, unit in self.units.items(): 

258 formattedKey = key.format(**kwargs) 

259 if formattedKey not in data: 

260 raise MissingMetadataError(formattedKey, data.__repr__()) 

261 value = data[formattedKey] 

262 if newName := self.newNames.get(key): 

263 formattedKey = newName.format(**kwargs) 

264 notes = {"metric_tags": kwargs.get("metric_tags", [])} 

265 if isinstance(value, Mapping): 

266 for k, v in value.items(): 

267 if not isinstance(v, Scalar): 

268 raise ValueError(f"Data for subkey {k} of key {key} is not a Scalar type") 

269 formattedSubKey = self._sanitizeMetadataName(f"{formattedKey}_{k}") 

270 results[formattedSubKey] = Measurement(formattedSubKey, v * apu.Unit(unit), notes=notes) 

271 else: 

272 if not isinstance(value, Scalar): 

273 raise ValueError(f"Data for key {key} is not a Scalar type") 

274 formattedKey = self._sanitizeMetadataName(formattedKey) 

275 results[formattedKey] = Measurement(formattedKey, value * apu.Unit(unit), notes=notes) 

276 return results 

277 

278 

279class BaseProduce(JointAction): 

280 """Base class for actions which produce data.""" 

281 

282 def setDefaults(self): 

283 super().setDefaults() 

284 self.metric = BaseMetricAction() 

285 self.plot = NoPlot