Coverage for python/lsst/analysis/tools/interfaces/_stages.py: 17%

130 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-02-15 13:33 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("BasePrep", "BaseProcess", "BaseMetricAction", "BaseProduce") 

24 

25from collections import abc 

26from typing import Any, cast 

27 

28import astropy.units as apu 

29from healsparse import HealSparseMap 

30from lsst.pex.config import ListField 

31from lsst.pex.config.configurableActions import ConfigurableActionStructField 

32from lsst.pex.config.dictField import DictField 

33from lsst.verify import Measurement 

34 

35from ._actions import ( 

36 AnalysisAction, 

37 JointAction, 

38 KeyedDataAction, 

39 MetricAction, 

40 MetricResultType, 

41 NoPlot, 

42 Tensor, 

43 VectorAction, 

44) 

45from ._interfaces import KeyedData, KeyedDataSchema, KeyedDataTypes, Scalar, Vector 

46 

47 

48class BasePrep(KeyedDataAction): 

49 """Base class for actions which prepare data for processing.""" 

50 

51 keysToLoad = ListField[str](doc="Keys to extract from KeyedData and return", default=[]) 

52 

53 vectorKeys = ListField[str](doc="Keys from the input data which selectors will be applied", default=[]) 

54 

55 selectors = ConfigurableActionStructField[VectorAction]( 

56 doc="Selectors for selecting rows, will be AND together", 

57 ) 

58 

59 def getInputSchema(self) -> KeyedDataSchema: 

60 yield from ( 

61 (column, Vector | Scalar | HealSparseMap | Tensor) 

62 for column in set(self.keysToLoad).union(self.vectorKeys) 

63 ) 

64 for action in self.selectors: 

65 yield from action.getInputSchema() 

66 

67 def getOutputSchema(self) -> KeyedDataSchema: 

68 return ( 

69 (column, Vector | Scalar | HealSparseMap | Tensor) 

70 for column in set(self.keysToLoad).union(self.vectorKeys) 

71 ) 

72 

73 def __call__(self, data: KeyedData, **kwargs) -> KeyedData: 

74 mask: Vector | None = None 

75 for selector in self.selectors: 

76 subMask = selector(data, **kwargs) 

77 if mask is None: 

78 mask = subMask 

79 else: 

80 mask *= subMask # type: ignore 

81 result: dict[str, Any] = {} 

82 for key in set(self.keysToLoad).union(self.vectorKeys): 

83 formattedKey = key.format_map(kwargs) 

84 result[formattedKey] = cast(Vector, data[formattedKey]) 

85 if mask is not None: 

86 for key in self.vectorKeys: 

87 # ignore type since there is not fully proper mypy support for 

88 # vector type casting. In the future there will be, and this 

89 # makes it clearer now what type things should be. 

90 tempFormat = key.format_map(kwargs) 

91 result[tempFormat] = cast(Vector, result[tempFormat])[mask] # type: ignore 

92 return result 

93 

94 def addInputSchema(self, inputSchema: KeyedDataSchema) -> None: 

95 existing = list(self.keysToLoad) 

96 existingVectors = list(self.vectorKeys) 

97 for name, typ in inputSchema: 

98 existing.append(name) 

99 if typ == Vector: 

100 existingVectors.append(name) 

101 self.keysToLoad = set(existing) 

102 self.vectorKeys = set(existingVectors) 

103 

104 

105class BaseProcess(KeyedDataAction): 

106 """Base class for actions which process data.""" 

107 

108 buildActions = ConfigurableActionStructField[VectorAction | KeyedDataAction]( 

109 doc="Actions which compute a Vector which will be added to results" 

110 ) 

111 filterActions = ConfigurableActionStructField[VectorAction | KeyedDataAction]( 

112 doc="Actions which filter one or more input or build Vectors into shorter vectors" 

113 ) 

114 calculateActions = ConfigurableActionStructField[AnalysisAction]( 

115 doc="Actions which compute quantities from the input or built data" 

116 ) 

117 

118 def getInputSchema(self) -> KeyedDataSchema: 

119 inputSchema: KeyedDataTypes = {} # type: ignore 

120 buildOutputSchema: KeyedDataTypes = {} # type: ignore 

121 filterOutputSchema: KeyedDataTypes = {} # type: ignore 

122 action: AnalysisAction 

123 

124 for fieldName, action in self.buildActions.items(): 

125 for name, typ in action.getInputSchema(): 

126 inputSchema[name] = typ 

127 if isinstance(action, KeyedDataAction): 

128 buildOutputSchema.update(action.getOutputSchema() or {}) 

129 else: 

130 buildOutputSchema[fieldName] = Vector 

131 

132 for fieldName, action in self.filterActions.items(): 

133 for name, typ in action.getInputSchema(): 

134 if name not in buildOutputSchema: 

135 inputSchema[name] = typ 

136 if isinstance(action, KeyedDataAction): 

137 filterOutputSchema.update(action.getOutputSchema() or {}) 

138 else: 

139 filterOutputSchema[fieldName] = Vector 

140 

141 for calcAction in self.calculateActions: 

142 for name, typ in calcAction.getInputSchema(): 

143 if name not in buildOutputSchema and name not in filterOutputSchema: 

144 inputSchema[name] = typ 

145 return ((name, typ) for name, typ in inputSchema.items()) 

146 

147 def getOutputSchema(self) -> KeyedDataSchema: 

148 for action in self.buildActions: 

149 if isinstance(action, KeyedDataAction): 

150 outSchema = action.getOutputSchema() 

151 if outSchema is not None: 

152 yield from outSchema 

153 

154 def __call__(self, data: KeyedData, **kwargs) -> KeyedData: 

155 action: AnalysisAction 

156 results = {} 

157 data = dict(data) 

158 for name, action in self.buildActions.items(): 

159 match action(data, **kwargs): 

160 case abc.Mapping() as item: 

161 for key, result in item.items(): 

162 results[key] = result 

163 case item: 

164 results[name] = item 

165 view1 = data | results 

166 for name, action in self.filterActions.items(): 

167 match action(view1, **kwargs): 

168 case abc.Mapping() as item: 

169 for key, result in item.items(): 

170 results[key] = result 

171 case item: 

172 results[name] = item 

173 

174 view2 = data | results 

175 for name, calcAction in self.calculateActions.items(): 

176 match calcAction(view2, **kwargs): 

177 case abc.Mapping() as item: 

178 for key, result in item.items(): 

179 results[key] = result 

180 case item: 

181 results[name] = item 

182 return results 

183 

184 

185class BaseMetricAction(MetricAction): 

186 """Base class for actions which compute metrics.""" 

187 

188 units = DictField[str, str](doc="Mapping of scalar key to astropy unit string", default={}) 

189 newNames = DictField[str, str]( 

190 doc="Mapping of key to new name if needed prior to creating metric", 

191 default={}, 

192 ) 

193 

194 def getInputSchema(self) -> KeyedDataSchema: 

195 # Something is wrong with the typing for DictField key iteration 

196 return [(key, Scalar) for key in self.units] # type: ignore 

197 

198 def __call__(self, data: KeyedData, **kwargs) -> MetricResultType: 

199 results = {} 

200 for key, unit in self.units.items(): 

201 formattedKey = key.format(**kwargs) 

202 if formattedKey not in data: 

203 raise ValueError(f"Key: {formattedKey} could not be found input data") 

204 value = data[formattedKey] 

205 if not isinstance(value, Scalar): 

206 raise ValueError(f"Data for key {key} is not a Scalar type") 

207 if newName := self.newNames.get(key): 

208 formattedKey = newName.format(**kwargs) 

209 notes = {"metric_tags": kwargs.get("metric_tags", [])} 

210 results[formattedKey] = Measurement(formattedKey, value * apu.Unit(unit), notes=notes) 

211 return results 

212 

213 

214class BaseProduce(JointAction): 

215 """Base class for actions which produce data.""" 

216 

217 def setDefaults(self): 

218 super().setDefaults() 

219 self.metric = BaseMetricAction() 

220 self.plot = NoPlot