Coverage for python / lsst / analysis / tools / interfaces / _stages.py: 19%

157 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-01 08:55 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("BasePrep", "BaseProcess", "BaseMetricAction", "BaseProduce") 

24 

25import logging 

26from collections import abc 

27from collections.abc import Mapping 

28from typing import Any, cast 

29 

30import astropy.units as apu 

31from healsparse import HealSparseMap 

32 

33from lsst.pex.config import ListField 

34from lsst.pex.config.configurableActions import ConfigurableActionStructField 

35from lsst.pex.config.dictField import DictField 

36from lsst.pipe.base import AlgorithmError 

37from lsst.verify import Measurement 

38 

39from ._actions import ( 

40 AnalysisAction, 

41 JointAction, 

42 KeyedDataAction, 

43 MetricAction, 

44 MetricResultType, 

45 NoPlot, 

46 Tensor, 

47 VectorAction, 

48) 

49from ._interfaces import KeyedData, KeyedDataSchema, KeyedDataTypes, Scalar, Vector 

50 

51_LOG = logging.getLogger(__name__) 

52 

53 

54class MissingMetadataError(AlgorithmError): 

55 """Raised if a required metadata key is missing. 

56 

57 Parameters 

58 ---------- 

59 key : `str` 

60 The missing key. 

61 data_repr : `str` 

62 The string representation of the input data which was missing the key. 

63 """ 

64 

65 def __init__(self, key, data_repr) -> None: 

66 self._key = key 

67 self._data_repr = data_repr 

68 super().__init__(f"Key '{self._key}' could not be found in input data {self._data_repr}") 

69 

70 @property 

71 def metadata(self) -> dict: 

72 return { 

73 "metadata_key": self._key, 

74 "input_data_repr": self._data_repr, 

75 } 

76 

77 

78class BasePrep(KeyedDataAction): 

79 """Base class for actions which prepare data for processing.""" 

80 

81 keysToLoad = ListField[str](doc="Keys to extract from KeyedData and return", default=[]) 

82 

83 vectorKeys = ListField[str](doc="Keys from the input data which selectors will be applied", default=[]) 

84 

85 selectors = ConfigurableActionStructField[VectorAction]( 

86 doc="Selectors for selecting rows, will be AND together", 

87 ) 

88 

89 def getInputSchema(self) -> KeyedDataSchema: 

90 yield from ( 

91 (column, Vector | Scalar | HealSparseMap | Tensor) 

92 for column in sorted(set(self.keysToLoad).union(self.vectorKeys)) 

93 ) 

94 for action in self.selectors: 

95 yield from action.getInputSchema() 

96 

97 def getOutputSchema(self) -> KeyedDataSchema: 

98 return ( 

99 (column, Vector | Scalar | HealSparseMap | Tensor) 

100 for column in sorted(set(self.keysToLoad).union(self.vectorKeys)) 

101 ) 

102 

103 def __call__(self, data: KeyedData, **kwargs) -> KeyedData: 

104 mask: Vector | None = None 

105 for selector in self.selectors: 

106 subMask = selector(data, **kwargs) 

107 if mask is None: 

108 mask = subMask 

109 else: 

110 mask *= subMask # type: ignore 

111 result: dict[str, Any] = {} 

112 for key in sorted(set(self.keysToLoad).union(self.vectorKeys)): 

113 formattedKey = key.format_map(kwargs) 

114 result[formattedKey] = cast(Vector, data[formattedKey]) 

115 if mask is not None: 

116 # Convert to ordered set to avoid formatting duplicate keys 

117 # multiple times 

118 formattedKeys = list({key.format_map(kwargs): None for key in self.vectorKeys}.keys()) 

119 for tempFormat in formattedKeys: 

120 # ignore type since there is not fully proper mypy support for 

121 # vector type casting. In the future there will be, and this 

122 # makes it clearer now what type things should be. 

123 result[tempFormat] = cast(Vector, result[tempFormat])[mask] # type: ignore 

124 return result 

125 

126 def addInputSchema(self, inputSchema: KeyedDataSchema) -> None: 

127 existing = list(self.keysToLoad) 

128 existingVectors = list(self.vectorKeys) 

129 for name, typ in inputSchema: 

130 existing.append(name) 

131 if typ == Vector: 

132 existingVectors.append(name) 

133 self.keysToLoad = sorted(set(existing)) 

134 self.vectorKeys = sorted(set(existingVectors)) 

135 

136 

137class BaseProcess(KeyedDataAction): 

138 """Base class for actions which process data.""" 

139 

140 buildActions = ConfigurableActionStructField[VectorAction | KeyedDataAction]( 

141 doc="Actions which compute a Vector which will be added to results" 

142 ) 

143 filterActions = ConfigurableActionStructField[VectorAction | KeyedDataAction]( 

144 doc="Actions which filter one or more input or build Vectors into shorter vectors" 

145 ) 

146 calculateActions = ConfigurableActionStructField[AnalysisAction]( 

147 doc="Actions which compute quantities from the input or built data" 

148 ) 

149 

150 def getInputSchema(self) -> KeyedDataSchema: 

151 inputSchema: KeyedDataTypes = {} # type: ignore 

152 buildOutputSchema: KeyedDataTypes = {} # type: ignore 

153 filterOutputSchema: KeyedDataTypes = {} # type: ignore 

154 action: AnalysisAction 

155 

156 for fieldName, action in self.buildActions.items(): 

157 for name, typ in action.getInputSchema(): 

158 inputSchema[name] = typ 

159 if isinstance(action, KeyedDataAction): 

160 buildOutputSchema.update(action.getOutputSchema() or {}) 

161 else: 

162 buildOutputSchema[fieldName] = Vector 

163 

164 for fieldName, action in self.filterActions.items(): 

165 for name, typ in action.getInputSchema(): 

166 if name not in buildOutputSchema: 

167 inputSchema[name] = typ 

168 if isinstance(action, KeyedDataAction): 

169 filterOutputSchema.update(action.getOutputSchema() or {}) 

170 else: 

171 filterOutputSchema[fieldName] = Vector 

172 

173 for calcAction in self.calculateActions: 

174 for name, typ in calcAction.getInputSchema(): 

175 if name not in buildOutputSchema and name not in filterOutputSchema: 

176 inputSchema[name] = typ 

177 return ((name, typ) for name, typ in inputSchema.items()) 

178 

179 def getOutputSchema(self) -> KeyedDataSchema: 

180 for action in self.buildActions: 

181 if isinstance(action, KeyedDataAction): 

182 outSchema = action.getOutputSchema() 

183 if outSchema is not None: 

184 yield from outSchema 

185 

186 def __call__(self, data: KeyedData, **kwargs) -> KeyedData: 

187 action: AnalysisAction 

188 results = {} 

189 data = dict(data) 

190 for name, action in self.buildActions.items(): 

191 match action(data, **kwargs): 

192 case abc.Mapping() as item: 

193 for key, result in item.items(): 

194 results[key] = result 

195 case item: 

196 results[name] = item 

197 view1 = data | results 

198 for name, action in self.filterActions.items(): 

199 match action(view1, **kwargs): 

200 case abc.Mapping() as item: 

201 for key, result in item.items(): 

202 results[key] = result 

203 case item: 

204 results[name] = item 

205 

206 view2 = data | results 

207 for name, calcAction in self.calculateActions.items(): 

208 match calcAction(view2, **kwargs): 

209 case abc.Mapping() as item: 

210 for key, result in item.items(): 

211 results[key] = result 

212 case item: 

213 results[name] = item 

214 return results 

215 

216 

217def _newNameChecker(value: str) -> bool: 

218 if "-" in value: 

219 # Yes this should be a log here, as pex config provides no other 

220 # useful way to get info into the exception. 

221 _LOG.error("Remapped metric names must not have a - character in them.") 

222 return False 

223 return True 

224 

225 

226class BaseMetricAction(MetricAction): 

227 """Base class for actions which compute metrics.""" 

228 

229 units = DictField[str, str](doc="Mapping of scalar key to astropy unit string", default={}) 

230 newNames = DictField[str, str]( 

231 doc=( 

232 "Mapping of key to new name if needed prior to creating metric, " 

233 "cannot contain a minus character in the name." 

234 ), 

235 default={}, 

236 itemCheck=_newNameChecker, 

237 ) 

238 

239 def getInputSchema(self) -> KeyedDataSchema: 

240 # Something is wrong with the typing for DictField key iteration 

241 return [(key, Scalar) for key in self.units] # type: ignore 

242 

243 @staticmethod 

244 def _sanitizeMetadataName(metadata_name: str) -> str: 

245 """Sanitize a metadata name to allow it to be pushed to Apache Avro. 

246 Parameters 

247 ---------- 

248 metadata_name : `str` 

249 The metadata name to sanitize. 

250 Returns 

251 ------- 

252 sanitized_metadata_name : `str` 

253 The sanitized metadata name. 

254 """ 

255 return metadata_name.replace(" ", "_").replace("-", "_") 

256 

257 def __call__(self, data: KeyedData, **kwargs) -> MetricResultType: 

258 results = {} 

259 for key, unit in self.units.items(): 

260 formattedKey = key.format(**kwargs) 

261 if formattedKey not in data: 

262 raise MissingMetadataError(formattedKey, data.__repr__()) 

263 value = data[formattedKey] 

264 if newName := self.newNames.get(key): 

265 formattedKey = newName.format(**kwargs) 

266 notes = {"metric_tags": kwargs.get("metric_tags", [])} 

267 if isinstance(value, Mapping): 

268 for k, v in value.items(): 

269 if not isinstance(v, Scalar): 

270 raise ValueError(f"Data for subkey {k} of key {key} is not a Scalar type") 

271 formattedSubKey = self._sanitizeMetadataName(f"{formattedKey}_{k}") 

272 results[formattedSubKey] = Measurement(formattedSubKey, v * apu.Unit(unit), notes=notes) 

273 else: 

274 if not isinstance(value, Scalar): 

275 raise ValueError(f"Data for key {key} is not a Scalar type") 

276 formattedKey = self._sanitizeMetadataName(formattedKey) 

277 results[formattedKey] = Measurement(formattedKey, value * apu.Unit(unit), notes=notes) 

278 return results 

279 

280 

281class BaseProduce(JointAction): 

282 """Base class for actions which produce data.""" 

283 

284 def setDefaults(self): 

285 super().setDefaults() 

286 self.metric = BaseMetricAction() 

287 self.plot = NoPlot