Coverage for python/lsst/analysis/tools/interfaces/_stages.py: 17%

125 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-01-27 10:59 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("BasePrep", "BaseProcess", "BaseMetricAction", "BaseProduce") 

24 

25from collections import abc 

26from typing import Any, cast 

27 

28import astropy.units as apu 

29from healsparse import HealSparseMap 

30from lsst.pex.config import ListField 

31from lsst.pex.config.configurableActions import ConfigurableActionStructField 

32from lsst.pex.config.dictField import DictField 

33from lsst.verify import Measurement 

34 

35from ._actions import ( 

36 AnalysisAction, 

37 JointAction, 

38 KeyedDataAction, 

39 MetricAction, 

40 MetricResultType, 

41 NoPlot, 

42 Tensor, 

43 VectorAction, 

44) 

45from ._interfaces import KeyedData, KeyedDataSchema, KeyedDataTypes, Scalar, Vector 

46 

47 

48class BasePrep(KeyedDataAction): 

49 """Base class for actions which prepare data for processing.""" 

50 

51 keysToLoad = ListField[str](doc="Keys to extract from KeyedData and return", default=[]) 

52 

53 vectorKeys = ListField[str](doc="Keys from the input data which selectors will be applied", default=[]) 

54 

55 selectors = ConfigurableActionStructField[VectorAction]( 

56 doc="Selectors for selecting rows, will be AND together", 

57 ) 

58 

59 def getInputSchema(self) -> KeyedDataSchema: 

60 yield from ( 

61 (column, Vector | Scalar | HealSparseMap | Tensor) 

62 for column in set(self.keysToLoad).union(self.vectorKeys) 

63 ) 

64 for action in self.selectors: 

65 yield from action.getInputSchema() 

66 

67 def getOutputSchema(self) -> KeyedDataSchema: 

68 return ( 

69 (column, Vector | Scalar | HealSparseMap | Tensor) 

70 for column in set(self.keysToLoad).union(self.vectorKeys) 

71 ) 

72 

73 def __call__(self, data: KeyedData, **kwargs) -> KeyedData: 

74 mask: Vector | None = None 

75 for selector in self.selectors: 

76 subMask = selector(data, **kwargs) 

77 if mask is None: 

78 mask = subMask 

79 else: 

80 mask *= subMask # type: ignore 

81 result: dict[str, Any] = {} 

82 for key in set(self.keysToLoad).union(self.vectorKeys): 

83 formattedKey = key.format_map(kwargs) 

84 result[formattedKey] = cast(Vector, data[formattedKey]) 

85 if mask is not None: 

86 for key in self.vectorKeys: 

87 # ignore type since there is not fully proper mypy support for 

88 # vector type casting. In the future there will be, and this 

89 # makes it clearer now what type things should be. 

90 result[key] = cast(Vector, result[key])[mask] # type: ignore 

91 return result 

92 

93 def addInputSchema(self, inputSchema: KeyedDataSchema) -> None: 

94 existing = list(self.keysToLoad) 

95 for name, _ in inputSchema: 

96 existing.append(name) 

97 self.keysToLoad = existing 

98 

99 

100class BaseProcess(KeyedDataAction): 

101 """Base class for actions which process data.""" 

102 

103 buildActions = ConfigurableActionStructField[VectorAction | KeyedDataAction]( 

104 doc="Actions which compute a Vector which will be added to results" 

105 ) 

106 filterActions = ConfigurableActionStructField[VectorAction | KeyedDataAction]( 

107 doc="Actions which filter one or more input or build Vectors into shorter vectors" 

108 ) 

109 calculateActions = ConfigurableActionStructField[AnalysisAction]( 

110 doc="Actions which compute quantities from the input or built data" 

111 ) 

112 

113 def getInputSchema(self) -> KeyedDataSchema: 

114 inputSchema: KeyedDataTypes = {} # type: ignore 

115 buildOutputSchema: KeyedDataTypes = {} # type: ignore 

116 filterOutputSchema: KeyedDataTypes = {} # type: ignore 

117 action: AnalysisAction 

118 

119 for fieldName, action in self.buildActions.items(): 

120 for name, typ in action.getInputSchema(): 

121 inputSchema[name] = typ 

122 if isinstance(action, KeyedDataAction): 

123 buildOutputSchema.update(action.getOutputSchema() or {}) 

124 else: 

125 buildOutputSchema[fieldName] = Vector 

126 

127 for fieldName, action in self.filterActions.items(): 

128 for name, typ in action.getInputSchema(): 

129 if name not in buildOutputSchema: 

130 inputSchema[name] = typ 

131 if isinstance(action, KeyedDataAction): 

132 filterOutputSchema.update(action.getOutputSchema() or {}) 

133 else: 

134 filterOutputSchema[fieldName] = Vector 

135 

136 for calcAction in self.calculateActions: 

137 for name, typ in calcAction.getInputSchema(): 

138 if name not in buildOutputSchema and name not in filterOutputSchema: 

139 inputSchema[name] = typ 

140 return ((name, typ) for name, typ in inputSchema.items()) 

141 

142 def getOutputSchema(self) -> KeyedDataSchema: 

143 for action in self.buildActions: 

144 if isinstance(action, KeyedDataAction): 

145 outSchema = action.getOutputSchema() 

146 if outSchema is not None: 

147 yield from outSchema 

148 

149 def __call__(self, data: KeyedData, **kwargs) -> KeyedData: 

150 action: AnalysisAction 

151 results = {} 

152 data = dict(data) 

153 for name, action in self.buildActions.items(): 

154 match action(data, **kwargs): 

155 case abc.Mapping() as item: 

156 for key, result in item.items(): 

157 results[key] = result 

158 case item: 

159 results[name] = item 

160 view1 = data | results 

161 for name, action in self.filterActions.items(): 

162 match action(view1, **kwargs): 

163 case abc.Mapping() as item: 

164 for key, result in item.items(): 

165 results[key] = result 

166 case item: 

167 results[name] = item 

168 

169 view2 = data | results 

170 for name, calcAction in self.calculateActions.items(): 

171 match calcAction(view2, **kwargs): 

172 case abc.Mapping() as item: 

173 for key, result in item.items(): 

174 results[key] = result 

175 case item: 

176 results[name] = item 

177 return results 

178 

179 

180class BaseMetricAction(MetricAction): 

181 """Base class for actions which compute metrics.""" 

182 

183 units = DictField[str, str](doc="Mapping of scalar key to astropy unit string", default={}) 

184 newNames = DictField[str, str]( 

185 doc="Mapping of key to new name if needed prior to creating metric", 

186 default={}, 

187 ) 

188 

189 def getInputSchema(self) -> KeyedDataSchema: 

190 # Something is wrong with the typing for DictField key iteration 

191 return [(key, Scalar) for key in self.units] # type: ignore 

192 

193 def __call__(self, data: KeyedData, **kwargs) -> MetricResultType: 

194 results = {} 

195 for key, unit in self.units.items(): 

196 formattedKey = key.format(**kwargs) 

197 if formattedKey not in data: 

198 raise ValueError(f"Key: {formattedKey} could not be found input data") 

199 value = data[formattedKey] 

200 if not isinstance(value, Scalar): 

201 raise ValueError(f"Data for key {key} is not a Scalar type") 

202 if newName := self.newNames.get(key): 

203 formattedKey = newName.format(**kwargs) 

204 notes = {"metric_tags": kwargs.get("metric_tags", [])} 

205 results[formattedKey] = Measurement(formattedKey, value * apu.Unit(unit), notes=notes) 

206 return results 

207 

208 

209class BaseProduce(JointAction): 

210 """Base class for actions which produce data.""" 

211 

212 def setDefaults(self): 

213 super().setDefaults() 

214 self.metric = BaseMetricAction() 

215 self.plot = NoPlot