Coverage for tests/test_contexts.py: 48%

100 statements  

« prev     ^ index     » next       coverage.py v7.2.4, created at 2023-04-30 03:04 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23import warnings 

24from typing import cast 

25from unittest import TestCase, main 

26 

27import astropy.units as apu 

28import lsst.utils.tests 

29import numpy as np 

30from lsst.analysis.tools import ( 

31 AnalysisAction, 

32 AnalysisTool, 

33 KeyedData, 

34 KeyedDataAction, 

35 KeyedDataSchema, 

36 KeyedResults, 

37 Scalar, 

38 ScalarAction, 

39 Vector, 

40) 

41from lsst.analysis.tools.actions.scalar import MeanAction, MedianAction 

42from lsst.analysis.tools.contexts import Context 

43from lsst.pex.config import Field 

44from lsst.pex.config.configurableActions import ConfigurableActionField, ConfigurableActionStructField 

45from lsst.verify import Measurement 

46 

47 

48class MedianContext(Context): 

49 """Test Context for median""" 

50 

51 pass 

52 

53 

54class MeanContext(Context): 

55 """Test Context for mean""" 

56 

57 pass 

58 

59 

60class MultiplyContext(Context): 

61 """Test Context to multiply results""" 

62 

63 pass 

64 

65 

66class DivideContext(Context): 

67 """Test Context to divide result""" 

68 

69 pass 

70 

71 

72class TestAction1(KeyedDataAction): 

73 multiple = ConfigurableActionStructField[ScalarAction](doc="Multiple Actions") 

74 

75 def getInputSchema(self) -> KeyedDataSchema: 

76 return (("a", Vector),) 

77 

78 def medianContext(self) -> None: 

79 self.multiple.a = MedianAction(vectorKey="a") 

80 self.multiple.b = MedianAction(vectorKey="a") 

81 self.multiple.c = MedianAction(vectorKey="a") 

82 

83 def meanContext(self) -> None: 

84 self.multiple.a = MeanAction(vectorKey="a") 

85 self.multiple.b = MeanAction(vectorKey="a") 

86 

87 def __call__(self, data: KeyedData, **kwargs) -> KeyedData: 

88 result = np.array([action(data, **kwargs) for action in self.multiple]) 

89 return cast(KeyedData, {"b": result}) 

90 

91 

92class TestAction2(KeyedDataAction): 

93 single = ConfigurableActionField[ScalarAction](doc="Single Action") 

94 multiplier = ConfigurableActionField[AnalysisAction](doc="Multiplier") 

95 

96 def getInputSchema(self) -> KeyedDataSchema: 

97 return (("b", Vector),) 

98 

99 def setDefaults(self) -> None: 

100 super().setDefaults() 

101 # This action remains constant in every context, but it will be 

102 # recursively configured with any contexts 

103 self.multiplier = TestAction3() 

104 

105 def medianContext(self) -> None: 

106 self.single = MedianAction(vectorKey="b") 

107 

108 def meanContext(self) -> None: 

109 self.single = MeanAction(vectorKey="b") 

110 

111 def __call__(self, data: KeyedData, **kwargs) -> KeyedData: 

112 result = self.single(data, **kwargs) 

113 intermediate = self.multiplier(cast(KeyedData, {"c": result}), **kwargs)["d"] 

114 intermediate = cast(Measurement, intermediate) 

115 assert intermediate.quantity is not None 

116 result *= cast(Scalar, intermediate.quantity.value) 

117 return {"c": result} 

118 

119 

120class TestAction3(KeyedDataAction): 

121 multiplier = Field[float](doc="Number to multiply result by") 

122 

123 def getInputSchema(self) -> KeyedDataSchema: 

124 return (("c", Vector),) 

125 

126 def multiplyContext(self): 

127 self.multiplier = 2 

128 

129 def divideContext(self): 

130 self.multiplier = 0.5 

131 

132 def __call__(self, data: KeyedData, **kwargs) -> KeyedResults: 

133 result = Measurement("TestMeasurement", cast(Scalar, data["c"]) * self.multiplier * apu.Unit("count")) 

134 return {"d": result} 

135 

136 

137class TestAnalysisTool(AnalysisTool): 

138 def setDefaults(self) -> None: 

139 self.prep = TestAction1() 

140 self.process = TestAction2() 

141 self.produce = TestAction3() 

142 

143 

144class ContextTestCase(TestCase): 

145 def setUp(self) -> None: 

146 super().setUp() 

147 self.array = np.arange(20) 

148 self.input = cast(KeyedData, {"a": self.array}) 

149 

150 def testContext1(self): 

151 tester = TestAnalysisTool() 

152 # test applying contexts serially 

153 tester.applyContext(MultiplyContext()) 

154 

155 # verify assignment syntax works to support Yaml 

156 # normally this should be called as a function in python 

157 tester.applyContext = MedianContext 

158 with warnings.catch_warnings(): 

159 warnings.simplefilter("ignore") 

160 result = cast(Measurement, tester(self.input)["d"]) 

161 assert result.quantity is not None 

162 self.assertEqual(result.quantity.value, 361) 

163 

164 def testContext2(self): 

165 tester2 = TestAnalysisTool() 

166 compound = MeanContext | DivideContext 

167 tester2.applyContext(compound) 

168 with warnings.catch_warnings(): 

169 warnings.simplefilter("ignore") 

170 result = cast(Measurement, tester2(self.input)["d"]) 

171 assert result.quantity is not None 

172 self.assertEqual(result.quantity.value, 22.5625) 

173 

174 

175class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

176 pass 

177 

178 

179def setup_module(module): 

180 lsst.utils.tests.init() 

181 

182 

183if __name__ == "__main__": 183 ↛ 184line 183 didn't jump to line 184, because the condition on line 183 was never true

184 lsst.utils.tests.init() 

185 main()