Coverage for tests/test_scatterPlot.py: 19%

100 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-01-11 03:43 -0800

1# This file is part of analysis_drp. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22 

23import unittest 

24import lsst.utils.tests 

25 

26from lsst.analysis.drp.calcFunctors import MagDiff 

27from lsst.analysis.drp.dataSelectors import GalaxyIdentifier 

28from lsst.analysis.drp.plotUtils import get_and_remove_figure_text 

29from lsst.analysis.drp.scatterPlot import ScatterPlotWithTwoHistsTask, ScatterPlotWithTwoHistsTaskConfig 

30 

31import matplotlib 

32import matplotlib.pyplot as plt 

33import numpy as np 

34from numpy.random import default_rng 

35import os 

36import pandas as pd 

37import shutil 

38import tempfile 

39 

40matplotlib.use("Agg") 

41 

42ROOT = os.path.abspath(os.path.dirname(__file__)) 

43filename_texts_ref = os.path.join(ROOT, "data", "test_scatterPlot_texts.txt") 

44path_lines_ref = os.path.join(ROOT, "data", "test_scatterPlot_lines") 

45 

46 

47class ScatterPlotWithTwoHistsTaskTestCase(lsst.utils.tests.TestCase): 

48 """ScatterPlotWithTwoHistsTask test case.""" 

49 def setUp(self): 

50 self.testDir = tempfile.mkdtemp(dir=ROOT, prefix="test_output") 

51 

52 # Set up a quasi-plausible measurement catalog 

53 mag = 12.5 + 2.5*np.log10(np.arange(10, 100000)) 

54 flux = 10**(-0.4*(mag - (mag[-1] + 1))) 

55 rng = default_rng(0) 

56 extendedness = 0. + (rng.uniform(size=len(mag)) < 0.99*(mag - mag[0])/(mag[-1] - mag[0])) 

57 flux_meas = flux + rng.normal(scale=np.sqrt(flux*(1 + extendedness))) 

58 flux_err = np.sqrt(flux_meas * (1 + extendedness)) 

59 good = (flux_meas/np.sqrt(flux * (1 + extendedness))) > 3 

60 extendedness = extendedness[good] 

61 flux = flux[good] 

62 flux_meas = flux_meas[good] 

63 flux_err = flux_err[good] 

64 sky_object = np.full(len(flux), False) 

65 

66 # Configure the plot to show observed vs true mags 

67 config = ScatterPlotWithTwoHistsTaskConfig( 

68 axisLabels={"x": "mag", "y": "mag meas - ref", "mag": "mag"}, 

69 ) 

70 config.selectorActions.flagSelector.bands = ["i"] 

71 config.axisActions.yAction = MagDiff(col1="refcat_flux", col2="refcat_flux") 

72 config.nonBandColumnPrefixes.append("refcat") 

73 config.selectorActions.catSnSelector.threshold = -1e12 

74 config.sourceSelectorActions.galaxySelector = GalaxyIdentifier 

75 config.highSnStatisticSelectorActions.statSelector.threshold = 50 

76 config.lowSnStatisticSelectorActions.statSelector.threshold = 20 

77 self.task = ScatterPlotWithTwoHistsTask(config=config) 

78 

79 n = len(flux) 

80 self.bands, columns = config.get_requirements() 

81 data = { 

82 "refcat_flux": flux, 

83 "patch": np.zeros(n, dtype=int), 

84 } 

85 

86 # Assign values to columns based on their unchanged default names 

87 for column in columns: 

88 if column not in data: 

89 if column.startswith("detect"): 

90 data[column] = np.ones(n, dtype=bool) 

91 elif column.endswith("_flag") or "Flag" in column: 

92 data[column] = np.zeros(n, dtype=bool) 

93 elif column.endswith("Flux"): 

94 config.axisActions.yAction.col1 = column 

95 data[column] = flux_meas 

96 elif column.endswith("FluxErr"): 

97 data[column] = flux_err 

98 elif column.endswith("_extendedness"): 

99 data[column] = extendedness 

100 elif column.startswith("sky_"): 

101 data[column] = sky_object 

102 

103 self.data = pd.DataFrame(data) 

104 

105 def tearDown(self): 

106 if os.path.exists(self.testDir): 

107 shutil.rmtree(self.testDir, True) 

108 del self.bands 

109 del self.data 

110 del self.task 

111 

112 def test_ScatterPlotWithTwoHistsTask(self): 

113 plt.rcParams.update(plt.rcParamsDefault) 

114 result = self.task.run(self.data, 

115 dataId={}, 

116 runName="test", 

117 skymap=None, 

118 tableName="test", 

119 bands=self.bands, 

120 plotName="test") 

121 

122 self.assertTrue(isinstance(result.scatterPlot, plt.Figure)) 

123 

124 # Set to true to save plots as PNGs 

125 # Use matplotlib.testing.compare.compare_images if needed 

126 save_images = False 

127 if save_images: 

128 result.scatterPlot.savefig(os.path.join(ROOT, "data", "test_scatterPlot.png")) 

129 

130 texts, lines = get_and_remove_figure_text(result.scatterPlot) 

131 if save_images: 

132 result.scatterPlot.savefig(os.path.join(ROOT, "data", "test_scatterPlot_unlabeled.png")) 

133 

134 # Set to true to re-generate reference data 

135 resave = False 

136 

137 # Compare line values 

138 for idx, line in enumerate(lines): 

139 filename = os.path.join(path_lines_ref, f"line_{idx}.txt") 

140 if resave: 

141 np.savetxt(filename, line) 

142 arr = np.loadtxt(filename) 

143 # Differences of order 1e-12 possible between MacOS and Linux 

144 # Plots are generally not expected to be that precise 

145 # Differences to 1e-3 should not be visible with this test data 

146 self.assertFloatsAlmostEqual(arr, line, atol=1e-3, rtol=1e-4) 

147 

148 # Ensure that newlines within labels are replaced by a sentinel 

149 newline = '\n' 

150 newline_replace = "[newline]" 

151 # Compare text labels 

152 if resave: 

153 with open(filename_texts_ref, 'w') as f: 

154 f.writelines(f'{text.strip().replace(newline, newline_replace)}\n' for text in texts) 

155 

156 with open(filename_texts_ref, 'r') as f: 

157 texts_ref = set(x.strip() for x in f.readlines()) 

158 texts_set = set(x.strip().replace(newline, newline_replace) for x in texts) 

159 

160 self.assertTrue(texts_set.issuperset(texts_ref)) 

161 

162 

163class MemoryTester(lsst.utils.tests.MemoryTestCase): 

164 pass 

165 

166 

167def setup_module(module): 

168 lsst.utils.tests.init() 

169 

170 

171if __name__ == "__main__": 171 ↛ 172line 171 didn't jump to line 172, because the condition on line 171 was never true

172 lsst.utils.tests.init() 

173 unittest.main()