Coverage for tests/test_scatterPlot.py: 20%

102 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-06 12:28 +0000

1# This file is part of analysis_drp. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22 

23import unittest 

24import lsst.utils.tests 

25 

26from lsst.analysis.drp.calcFunctors import MagDiff 

27from lsst.analysis.drp.dataSelectors import GalaxyIdentifier 

28from lsst.analysis.drp.plotUtils import get_and_remove_figure_text 

29from lsst.analysis.drp.scatterPlot import ScatterPlotWithTwoHistsTask, ScatterPlotWithTwoHistsTaskConfig 

30from lsst.daf.butler import DataCoordinate, DimensionUniverse 

31 

32import matplotlib 

33import matplotlib.pyplot as plt 

34import numpy as np 

35from numpy.random import default_rng 

36import os 

37import pandas as pd 

38import shutil 

39import tempfile 

40 

41matplotlib.use("Agg") 

42 

43ROOT = os.path.abspath(os.path.dirname(__file__)) 

44filename_texts_ref = os.path.join(ROOT, "data", "test_scatterPlot_texts.txt") 

45path_lines_ref = os.path.join(ROOT, "data", "test_scatterPlot_lines") 

46 

47 

48class ScatterPlotWithTwoHistsTaskTestCase(lsst.utils.tests.TestCase): 

49 """ScatterPlotWithTwoHistsTask test case.""" 

50 def setUp(self): 

51 self.testDir = tempfile.mkdtemp(dir=ROOT, prefix="test_output") 

52 

53 # Set up a quasi-plausible measurement catalog 

54 mag = 12.5 + 2.5*np.log10(np.arange(10, 100000)) 

55 flux = 10**(-0.4*(mag - (mag[-1] + 1))) 

56 rng = default_rng(0) 

57 extendedness = 0. + (rng.uniform(size=len(mag)) < 0.99*(mag - mag[0])/(mag[-1] - mag[0])) 

58 flux_meas = flux + rng.normal(scale=np.sqrt(flux*(1 + extendedness))) 

59 flux_err = np.sqrt(flux_meas * (1 + extendedness)) 

60 good = (flux_meas/np.sqrt(flux * (1 + extendedness))) > 3 

61 extendedness = extendedness[good] 

62 flux = flux[good] 

63 flux_meas = flux_meas[good] 

64 flux_err = flux_err[good] 

65 sky_object = np.full(len(flux), False) 

66 

67 # Configure the plot to show observed vs true mags 

68 config = ScatterPlotWithTwoHistsTaskConfig( 

69 axisLabels={"x": "mag", "y": "mag meas - ref", "mag": "mag"}, 

70 ) 

71 config.selectorActions.flagSelector.bands = ["i"] 

72 config.axisActions.yAction = MagDiff(col1="refcat_flux", col2="refcat_flux") 

73 config.nonBandColumnPrefixes.append("refcat") 

74 config.selectorActions.catSnSelector.threshold = -1e12 

75 config.sourceSelectorActions.galaxySelector = GalaxyIdentifier 

76 config.highSnStatisticSelectorActions.statSelector.threshold = 50 

77 config.lowSnStatisticSelectorActions.statSelector.threshold = 20 

78 self.task = ScatterPlotWithTwoHistsTask(config=config) 

79 

80 n = len(flux) 

81 self.bands, columns = config.get_requirements() 

82 data = { 

83 "refcat_flux": flux, 

84 "patch": np.zeros(n, dtype=int), 

85 } 

86 

87 # Assign values to columns based on their unchanged default names 

88 for column in columns: 

89 if column not in data: 

90 if column.startswith("detect"): 

91 data[column] = np.ones(n, dtype=bool) 

92 elif column.endswith("_flag") or "Flag" in column: 

93 data[column] = np.zeros(n, dtype=bool) 

94 elif column.endswith("Flux"): 

95 config.axisActions.yAction.col1 = column 

96 data[column] = flux_meas 

97 elif column.endswith("FluxErr"): 

98 data[column] = flux_err 

99 elif column.endswith("_extendedness"): 

100 data[column] = extendedness 

101 elif column.startswith("sky_"): 

102 data[column] = sky_object 

103 

104 self.data = pd.DataFrame(data) 

105 

106 def tearDown(self): 

107 if os.path.exists(self.testDir): 

108 shutil.rmtree(self.testDir, True) 

109 del self.bands 

110 del self.data 

111 del self.task 

112 

113 def test_ScatterPlotWithTwoHistsTask(self): 

114 plt.rcParams.update(plt.rcParamsDefault) 

115 universe = DimensionUniverse() 

116 result = self.task.run(self.data, 

117 dataId=DataCoordinate.make_empty(universe), 

118 runName="test", 

119 skymap=None, 

120 tableName="test", 

121 bands=self.bands, 

122 plotName="test") 

123 

124 self.assertTrue(isinstance(result.scatterPlot, plt.Figure)) 

125 

126 # Set to true to save plots as PNGs 

127 # Use matplotlib.testing.compare.compare_images if needed 

128 save_images = False 

129 if save_images: 

130 result.scatterPlot.savefig(os.path.join(ROOT, "data", "test_scatterPlot.png")) 

131 

132 texts, lines = get_and_remove_figure_text(result.scatterPlot) 

133 if save_images: 

134 result.scatterPlot.savefig(os.path.join(ROOT, "data", "test_scatterPlot_unlabeled.png")) 

135 

136 # Set to true to re-generate reference data 

137 resave = False 

138 

139 # Compare line values 

140 for idx, line in enumerate(lines): 

141 filename = os.path.join(path_lines_ref, f"line_{idx}.txt") 

142 if resave: 

143 np.savetxt(filename, line) 

144 arr = np.loadtxt(filename) 

145 # Differences of order 1e-12 possible between MacOS and Linux 

146 # Plots are generally not expected to be that precise 

147 # Differences to 1e-3 should not be visible with this test data 

148 self.assertFloatsAlmostEqual(arr, line, atol=1e-3, rtol=1e-4) 

149 

150 # Ensure that newlines within labels are replaced by a sentinel 

151 newline = '\n' 

152 newline_replace = "[newline]" 

153 # Compare text labels 

154 if resave: 

155 with open(filename_texts_ref, 'w') as f: 

156 f.writelines(f'{text.strip().replace(newline, newline_replace)}\n' for text in texts) 

157 

158 with open(filename_texts_ref, 'r') as f: 

159 texts_ref = set(x.strip() for x in f.readlines()) 

160 texts_set = set(x.strip().replace(newline, newline_replace) for x in texts) 

161 

162 self.assertTrue(texts_set.issuperset(texts_ref)) 

163 

164 

165class MemoryTester(lsst.utils.tests.MemoryTestCase): 

166 pass 

167 

168 

169def setup_module(module): 

170 lsst.utils.tests.init() 

171 

172 

173if __name__ == "__main__": 173 ↛ 174line 173 didn't jump to line 174, because the condition on line 173 was never true

174 lsst.utils.tests.init() 

175 unittest.main()