Coverage for python/lsst/analysis/tools/interfaces/_stages.py: 17%

137 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-16 04:38 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("BasePrep", "BaseProcess", "BaseMetricAction", "BaseProduce") 

24 

25import logging 

26from collections import abc 

27from typing import Any, cast 

28 

29import astropy.units as apu 

30from healsparse import HealSparseMap 

31from lsst.pex.config import ListField 

32from lsst.pex.config.configurableActions import ConfigurableActionStructField 

33from lsst.pex.config.dictField import DictField 

34from lsst.verify import Measurement 

35 

36from ._actions import ( 

37 AnalysisAction, 

38 JointAction, 

39 KeyedDataAction, 

40 MetricAction, 

41 MetricResultType, 

42 NoPlot, 

43 Tensor, 

44 VectorAction, 

45) 

46from ._interfaces import KeyedData, KeyedDataSchema, KeyedDataTypes, Scalar, Vector 

47 

48_LOG = logging.getLogger(__name__) 

49 

50 

51class BasePrep(KeyedDataAction): 

52 """Base class for actions which prepare data for processing.""" 

53 

54 keysToLoad = ListField[str](doc="Keys to extract from KeyedData and return", default=[]) 

55 

56 vectorKeys = ListField[str](doc="Keys from the input data which selectors will be applied", default=[]) 

57 

58 selectors = ConfigurableActionStructField[VectorAction]( 

59 doc="Selectors for selecting rows, will be AND together", 

60 ) 

61 

62 def getInputSchema(self) -> KeyedDataSchema: 

63 yield from ( 

64 (column, Vector | Scalar | HealSparseMap | Tensor) 

65 for column in set(self.keysToLoad).union(self.vectorKeys) 

66 ) 

67 for action in self.selectors: 

68 yield from action.getInputSchema() 

69 

70 def getOutputSchema(self) -> KeyedDataSchema: 

71 return ( 

72 (column, Vector | Scalar | HealSparseMap | Tensor) 

73 for column in set(self.keysToLoad).union(self.vectorKeys) 

74 ) 

75 

76 def __call__(self, data: KeyedData, **kwargs) -> KeyedData: 

77 mask: Vector | None = None 

78 for selector in self.selectors: 

79 subMask = selector(data, **kwargs) 

80 if mask is None: 

81 mask = subMask 

82 else: 

83 mask *= subMask # type: ignore 

84 result: dict[str, Any] = {} 

85 for key in set(self.keysToLoad).union(self.vectorKeys): 

86 formattedKey = key.format_map(kwargs) 

87 result[formattedKey] = cast(Vector, data[formattedKey]) 

88 if mask is not None: 

89 for key in self.vectorKeys: 

90 # ignore type since there is not fully proper mypy support for 

91 # vector type casting. In the future there will be, and this 

92 # makes it clearer now what type things should be. 

93 tempFormat = key.format_map(kwargs) 

94 result[tempFormat] = cast(Vector, result[tempFormat])[mask] # type: ignore 

95 return result 

96 

97 def addInputSchema(self, inputSchema: KeyedDataSchema) -> None: 

98 existing = list(self.keysToLoad) 

99 existingVectors = list(self.vectorKeys) 

100 for name, typ in inputSchema: 

101 existing.append(name) 

102 if typ == Vector: 

103 existingVectors.append(name) 

104 self.keysToLoad = set(existing) 

105 self.vectorKeys = set(existingVectors) 

106 

107 

108class BaseProcess(KeyedDataAction): 

109 """Base class for actions which process data.""" 

110 

111 buildActions = ConfigurableActionStructField[VectorAction | KeyedDataAction]( 

112 doc="Actions which compute a Vector which will be added to results" 

113 ) 

114 filterActions = ConfigurableActionStructField[VectorAction | KeyedDataAction]( 

115 doc="Actions which filter one or more input or build Vectors into shorter vectors" 

116 ) 

117 calculateActions = ConfigurableActionStructField[AnalysisAction]( 

118 doc="Actions which compute quantities from the input or built data" 

119 ) 

120 

121 def getInputSchema(self) -> KeyedDataSchema: 

122 inputSchema: KeyedDataTypes = {} # type: ignore 

123 buildOutputSchema: KeyedDataTypes = {} # type: ignore 

124 filterOutputSchema: KeyedDataTypes = {} # type: ignore 

125 action: AnalysisAction 

126 

127 for fieldName, action in self.buildActions.items(): 

128 for name, typ in action.getInputSchema(): 

129 inputSchema[name] = typ 

130 if isinstance(action, KeyedDataAction): 

131 buildOutputSchema.update(action.getOutputSchema() or {}) 

132 else: 

133 buildOutputSchema[fieldName] = Vector 

134 

135 for fieldName, action in self.filterActions.items(): 

136 for name, typ in action.getInputSchema(): 

137 if name not in buildOutputSchema: 

138 inputSchema[name] = typ 

139 if isinstance(action, KeyedDataAction): 

140 filterOutputSchema.update(action.getOutputSchema() or {}) 

141 else: 

142 filterOutputSchema[fieldName] = Vector 

143 

144 for calcAction in self.calculateActions: 

145 for name, typ in calcAction.getInputSchema(): 

146 if name not in buildOutputSchema and name not in filterOutputSchema: 

147 inputSchema[name] = typ 

148 return ((name, typ) for name, typ in inputSchema.items()) 

149 

150 def getOutputSchema(self) -> KeyedDataSchema: 

151 for action in self.buildActions: 

152 if isinstance(action, KeyedDataAction): 

153 outSchema = action.getOutputSchema() 

154 if outSchema is not None: 

155 yield from outSchema 

156 

157 def __call__(self, data: KeyedData, **kwargs) -> KeyedData: 

158 action: AnalysisAction 

159 results = {} 

160 data = dict(data) 

161 for name, action in self.buildActions.items(): 

162 match action(data, **kwargs): 

163 case abc.Mapping() as item: 

164 for key, result in item.items(): 

165 results[key] = result 

166 case item: 

167 results[name] = item 

168 view1 = data | results 

169 for name, action in self.filterActions.items(): 

170 match action(view1, **kwargs): 

171 case abc.Mapping() as item: 

172 for key, result in item.items(): 

173 results[key] = result 

174 case item: 

175 results[name] = item 

176 

177 view2 = data | results 

178 for name, calcAction in self.calculateActions.items(): 

179 match calcAction(view2, **kwargs): 

180 case abc.Mapping() as item: 

181 for key, result in item.items(): 

182 results[key] = result 

183 case item: 

184 results[name] = item 

185 return results 

186 

187 

188def _newNameChecker(value: str) -> bool: 

189 if "-" in value: 

190 # Yes this should be a log here, as pex config provides no other 

191 # useful way to get info into the exception. 

192 _LOG.error("Remapped metric names must not have a - character in them.") 

193 return False 

194 return True 

195 

196 

197class BaseMetricAction(MetricAction): 

198 """Base class for actions which compute metrics.""" 

199 

200 units = DictField[str, str](doc="Mapping of scalar key to astropy unit string", default={}) 

201 newNames = DictField[str, str]( 

202 doc=( 

203 "Mapping of key to new name if needed prior to creating metric, " 

204 "cannot contain a minus character in the name." 

205 ), 

206 default={}, 

207 itemCheck=_newNameChecker, 

208 ) 

209 

210 def getInputSchema(self) -> KeyedDataSchema: 

211 # Something is wrong with the typing for DictField key iteration 

212 return [(key, Scalar) for key in self.units] # type: ignore 

213 

214 def __call__(self, data: KeyedData, **kwargs) -> MetricResultType: 

215 results = {} 

216 for key, unit in self.units.items(): 

217 formattedKey = key.format(**kwargs) 

218 if formattedKey not in data: 

219 raise ValueError(f"Key: {formattedKey} could not be found input data") 

220 value = data[formattedKey] 

221 if not isinstance(value, Scalar): 

222 raise ValueError(f"Data for key {key} is not a Scalar type") 

223 if newName := self.newNames.get(key): 

224 formattedKey = newName.format(**kwargs) 

225 notes = {"metric_tags": kwargs.get("metric_tags", [])} 

226 results[formattedKey] = Measurement(formattedKey, value * apu.Unit(unit), notes=notes) 

227 return results 

228 

229 

230class BaseProduce(JointAction): 

231 """Base class for actions which produce data.""" 

232 

233 def setDefaults(self): 

234 super().setDefaults() 

235 self.metric = BaseMetricAction() 

236 self.plot = NoPlot