Coverage for python/lsst/analysis/tools/interfaces/_stages.py: 17%

119 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-10 10:36 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("BasePrep", "BaseProcess", "BaseMetricAction", "BaseProduce") 

24 

25from collections import abc 

26from typing import Any, cast 

27 

28import astropy.units as apu 

29from lsst.pex.config import ListField 

30from lsst.pex.config.configurableActions import ConfigurableActionStructField 

31from lsst.pex.config.dictField import DictField 

32from lsst.verify import Measurement 

33 

34from ._actions import ( 

35 AnalysisAction, 

36 JointAction, 

37 KeyedDataAction, 

38 MetricAction, 

39 MetricResultType, 

40 NoPlot, 

41 VectorAction, 

42) 

43from ._interfaces import KeyedData, KeyedDataSchema, KeyedDataTypes, Scalar, Vector 

44 

45 

46class BasePrep(KeyedDataAction): 

47 """Base class for actions which prepare data for processing.""" 

48 

49 vectorKeys = ListField[str](doc="Keys to extract from KeyedData and return", default=[]) 

50 

51 selectors = ConfigurableActionStructField[VectorAction]( 

52 doc="Selectors for selecting rows, will be AND together", 

53 ) 

54 

55 def getInputSchema(self) -> KeyedDataSchema: 

56 yield from ((column, Vector | Scalar) for column in self.vectorKeys) # type: ignore 

57 for action in self.selectors: 

58 yield from action.getInputSchema() 

59 

60 def getOutputSchema(self) -> KeyedDataSchema: 

61 return ((column, Vector | Scalar) for column in self.vectorKeys) # type: ignore 

62 

63 def __call__(self, data: KeyedData, **kwargs) -> KeyedData: 

64 mask: Vector | None = None 

65 for selector in self.selectors: 

66 subMask = selector(data, **kwargs) 

67 if mask is None: 

68 mask = subMask 

69 else: 

70 mask *= subMask # type: ignore 

71 result: dict[str, Any] = {} 

72 for key in self.vectorKeys: 

73 formattedKey = key.format_map(kwargs) 

74 result[formattedKey] = cast(Vector, data[formattedKey]) 

75 if mask is not None: 

76 return {key: cast(Vector, col)[mask] for key, col in result.items()} 

77 else: 

78 return result 

79 

80 def addInputSchema(self, inputSchema: KeyedDataSchema) -> None: 

81 self.vectorKeys = [name for name, _ in inputSchema] 

82 

83 

84class BaseProcess(KeyedDataAction): 

85 """Base class for actions which process data.""" 

86 

87 buildActions = ConfigurableActionStructField[VectorAction | KeyedDataAction]( 

88 doc="Actions which compute a Vector which will be added to results" 

89 ) 

90 filterActions = ConfigurableActionStructField[VectorAction | KeyedDataAction]( 

91 doc="Actions which filter one or more input or build Vectors into shorter vectors" 

92 ) 

93 calculateActions = ConfigurableActionStructField[AnalysisAction]( 

94 doc="Actions which compute quantities from the input or built data" 

95 ) 

96 

97 def getInputSchema(self) -> KeyedDataSchema: 

98 inputSchema: KeyedDataTypes = {} # type: ignore 

99 buildOutputSchema: KeyedDataTypes = {} # type: ignore 

100 filterOutputSchema: KeyedDataTypes = {} # type: ignore 

101 action: AnalysisAction 

102 

103 for fieldName, action in self.buildActions.items(): 

104 for name, typ in action.getInputSchema(): 

105 inputSchema[name] = typ 

106 if isinstance(action, KeyedDataAction): 

107 buildOutputSchema.update(action.getOutputSchema() or {}) 

108 else: 

109 buildOutputSchema[fieldName] = Vector 

110 

111 for fieldName, action in self.filterActions.items(): 

112 for name, typ in action.getInputSchema(): 

113 if name not in buildOutputSchema: 

114 inputSchema[name] = typ 

115 if isinstance(action, KeyedDataAction): 

116 filterOutputSchema.update(action.getOutputSchema() or {}) 

117 else: 

118 filterOutputSchema[fieldName] = Vector 

119 

120 for calcAction in self.calculateActions: 

121 for name, typ in calcAction.getInputSchema(): 

122 if name not in buildOutputSchema and name not in filterOutputSchema: 

123 inputSchema[name] = typ 

124 return ((name, typ) for name, typ in inputSchema.items()) 

125 

126 def getOutputSchema(self) -> KeyedDataSchema: 

127 for action in self.buildActions: 

128 if isinstance(action, KeyedDataAction): 

129 outSchema = action.getOutputSchema() 

130 if outSchema is not None: 

131 yield from outSchema 

132 

133 def __call__(self, data: KeyedData, **kwargs) -> KeyedData: 

134 action: AnalysisAction 

135 results = {} 

136 data = dict(data) 

137 for name, action in self.buildActions.items(): 

138 match action(data, **kwargs): 

139 case abc.Mapping() as item: 

140 for key, result in item.items(): 

141 results[key] = result 

142 case item: 

143 results[name] = item 

144 view1 = data | results 

145 for name, action in self.filterActions.items(): 

146 match action(view1, **kwargs): 

147 case abc.Mapping() as item: 

148 for key, result in item.items(): 

149 results[key] = result 

150 case item: 

151 results[name] = item 

152 

153 view2 = data | results 

154 for name, calcAction in self.calculateActions.items(): 

155 match calcAction(view2, **kwargs): 

156 case abc.Mapping() as item: 

157 for key, result in item.items(): 

158 results[key] = result 

159 case item: 

160 results[name] = item 

161 return results 

162 

163 

164class BaseMetricAction(MetricAction): 

165 """Base class for actions which compute metrics.""" 

166 

167 units = DictField[str, str](doc="Mapping of scalar key to astropy unit string", default={}) 

168 newNames = DictField[str, str]( 

169 doc="Mapping of key to new name if needed prior to creating metric", 

170 default={}, 

171 ) 

172 

173 def getInputSchema(self) -> KeyedDataSchema: 

174 # Something is wrong with the typing for DictField key iteration 

175 return [(key, Scalar) for key in self.units] # type: ignore 

176 

177 def __call__(self, data: KeyedData, **kwargs) -> MetricResultType: 

178 results = {} 

179 for key, unit in self.units.items(): 

180 formattedKey = key.format(**kwargs) 

181 if formattedKey not in data: 

182 raise ValueError(f"Key: {formattedKey} could not be found input data") 

183 value = data[formattedKey] 

184 if not isinstance(value, Scalar): 

185 raise ValueError(f"Data for key {key} is not a Scalar type") 

186 if newName := self.newNames.get(key): 

187 formattedKey = newName.format(**kwargs) 

188 notes = {"metric_tags": kwargs.get("metric_tags", [])} 

189 results[formattedKey] = Measurement(formattedKey, value * apu.Unit(unit), notes=notes) 

190 return results 

191 

192 

193class BaseProduce(JointAction): 

194 """Base class for actions which produce data.""" 

195 

196 def setDefaults(self): 

197 super().setDefaults() 

198 self.metric = BaseMetricAction() 

199 self.plot = NoPlot