Coverage for tests / test_scatterPlot.py: 25%

102 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-18 09:19 +0000

1# This file is part of analysis_drp. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22 

23import os 

24import shutil 

25import tempfile 

26import unittest 

27 

28import matplotlib 

29import matplotlib.pyplot as plt 

30import numpy as np 

31import pandas as pd 

32 

33import lsst.utils.tests 

34from lsst.analysis.tools.actions.plot.plotUtils import get_and_remove_figure_text 

35from lsst.analysis.tools.actions.plot.scatterplotWithTwoHists import ( 

36 ScatterPlotStatsAction, 

37 ScatterPlotWithTwoHists, 

38) 

39from lsst.analysis.tools.actions.vector.mathActions import ConstantValue, DivideVector, SubtractVector 

40from lsst.analysis.tools.actions.vector.selectors import ( 

41 GalaxySelector, 

42 SnSelector, 

43 StarSelector, 

44 VectorSelector, 

45) 

46from lsst.analysis.tools.actions.vector.vectorActions import ConvertFluxToMag, DownselectVector, LoadVector 

47from lsst.analysis.tools.interfaces import AnalysisTool 

48from lsst.analysis.tools.math import sqrt 

49 

50matplotlib.use("Agg") 

51 

52ROOT = os.path.abspath(os.path.dirname(__file__)) 

53filename_texts_ref = os.path.join(ROOT, "data", "test_scatterPlot_texts.txt") 

54path_lines_ref = os.path.join(ROOT, "data", "test_scatterPlot_lines") 

55 

56 

57class ScatterPlotWithTwoHistsTaskTestCase(lsst.utils.tests.TestCase): 

58 """ScatterPlotWithTwoHistsTask test case.""" 

59 

60 def setUp(self): 

61 self.testDir = tempfile.mkdtemp(dir=ROOT, prefix="test_output") 

62 

63 # Set up a quasi-plausible measurement catalog 

64 mag = 12.5 + 2.5 * np.log10(np.arange(10, 100000)) 

65 flux = 10 ** (-0.4 * (mag - (mag[-1] + 1))) 

66 rng = np.random.default_rng(0) 

67 extendedness = 0.0 + (rng.uniform(size=len(mag)) < 0.99 * (mag - mag[0]) / (mag[-1] - mag[0])) 

68 flux_meas = flux + rng.normal(scale=np.sqrt(flux * (1 + extendedness))) 

69 flux_err = sqrt(flux_meas * (1 + extendedness)) 

70 good = (flux_meas / sqrt(flux * (1 + extendedness))) > 3 

71 extendedness = extendedness[good] 

72 flux = flux[good] 

73 flux_meas = flux_meas[good] 

74 flux_err = flux_err[good] 

75 

76 suffix_x, suffix_y, suffix_stat = "_x", "_y", "_stat" 

77 

78 # Configure the plot to show observed vs true mags 

79 action = ScatterPlotWithTwoHists( 

80 xAxisLabel="mag", 

81 yAxisLabel="mag meas - ref", 

82 magLabel="mag", 

83 plotTypes=[ 

84 "galaxies", 

85 "stars", 

86 ], 

87 xLims=(20, 30), 

88 yLims=(-1000, 1000), 

89 addSummaryPlot=False, 

90 # Make sure adding a suffix works to produce multiple plots 

91 suffix_x=suffix_x, 

92 suffix_y=suffix_y, 

93 suffix_stat=suffix_stat, 

94 ) 

95 plot = AnalysisTool() 

96 plot.produce.plot = action 

97 

98 # Load the relevant columns 

99 key_flux = "meas_Flux" 

100 plot.process.buildActions.fluxes_meas = LoadVector(vectorKey=key_flux) 

101 plot.process.buildActions.fluxes_err = LoadVector(vectorKey=f"{key_flux}Err") 

102 plot.process.buildActions.fluxes_ref = LoadVector(vectorKey="ref_Flux") 

103 plot.process.buildActions.mags_ref = ConvertFluxToMag( 

104 vectorKey=plot.process.buildActions.fluxes_ref.vectorKey 

105 ) 

106 

107 # Compute the y-axis quantity 

108 plot.process.buildActions.diff = SubtractVector( 

109 actionA=ConvertFluxToMag( 

110 vectorKey=plot.process.buildActions.fluxes_meas.vectorKey, returnMillimags=True 

111 ), 

112 actionB=DivideVector( 

113 actionA=plot.process.buildActions.mags_ref, 

114 actionB=ConstantValue(value=1e-3), 

115 ), 

116 ) 

117 

118 # Filter stars/galaxies, storing quantities separately 

119 plot.process.buildActions.galaxySelector = GalaxySelector(vectorKey="refExtendedness") 

120 plot.process.buildActions.starSelector = StarSelector(vectorKey="refExtendedness") 

121 for singular, plural in (("galaxy", "Galaxies"), ("star", "Stars")): 

122 setattr( 

123 plot.process.filterActions, 

124 f"x{plural}{suffix_x}", 

125 DownselectVector( 

126 vectorKey="mags_ref", selector=VectorSelector(vectorKey=f"{singular}Selector") 

127 ), 

128 ) 

129 setattr( 

130 plot.process.filterActions, 

131 f"y{plural}{suffix_y}", 

132 DownselectVector(vectorKey="diff", selector=VectorSelector(vectorKey=f"{singular}Selector")), 

133 ) 

134 setattr( 

135 plot.process.filterActions, 

136 f"flux{plural}", 

137 DownselectVector( 

138 vectorKey="fluxes_meas", selector=VectorSelector(vectorKey=f"{singular}Selector") 

139 ), 

140 ) 

141 setattr( 

142 plot.process.filterActions, 

143 f"fluxErr{plural}", 

144 DownselectVector( 

145 vectorKey="fluxes_err", selector=VectorSelector(vectorKey=f"{singular}Selector") 

146 ), 

147 ) 

148 

149 # Compute low/high SN summary stats 

150 statAction = ScatterPlotStatsAction( 

151 vectorKey=f"y{plural}{suffix_y}", 

152 fluxType=f"flux{plural}", 

153 highSNSelector=SnSelector(fluxType=f"flux{plural}", threshold=50), 

154 lowSNSelector=SnSelector(fluxType=f"flux{plural}", threshold=20), 

155 suffix=suffix_stat, 

156 ) 

157 setattr(plot.process.calculateActions, plural.lower(), statAction) 

158 

159 data = { 

160 "ref_Flux": flux, 

161 key_flux: flux_meas, 

162 f"{key_flux}Err": flux_err, 

163 "refExtendedness": extendedness, 

164 } 

165 

166 self.data = pd.DataFrame(data) 

167 print(self.data.columns) 

168 self.plot = plot 

169 self.plot.finalize() 

170 plotInfo = {key: "test" for key in ("plotName", "run", "tableName")} 

171 plotInfo["bands"] = [] 

172 self.plotInfo = plotInfo 

173 

174 def tearDown(self): 

175 if os.path.exists(self.testDir): 

176 shutil.rmtree(self.testDir, True) 

177 del self.data 

178 del self.plot 

179 del self.plotInfo 

180 del self.testDir 

181 

182 def test_ScatterPlotWithTwoHistsTask(self): 

183 plt.rcParams.update(plt.rcParamsDefault) 

184 result = self.plot( 

185 data=self.data, 

186 skymap=None, 

187 plotInfo=self.plotInfo, 

188 ) 

189 # unpack the result from the dictionary 

190 result = result[type(self.plot.produce.plot).__name__] 

191 self.assertTrue(isinstance(result, plt.Figure)) 

192 

193 # Set to true to save plots as PNGs 

194 # Use matplotlib.testing.compare.compare_images if needed 

195 save_images = False 

196 if save_images: 

197 result.savefig(os.path.join(ROOT, "data", "test_scatterPlot.png")) 

198 

199 texts, lines = get_and_remove_figure_text(result) 

200 if save_images: 

201 result.savefig(os.path.join(ROOT, "data", "test_scatterPlot_unlabeled.png")) 

202 

203 # Set to true to re-generate reference data 

204 resave = False 

205 

206 # Compare line values 

207 for idx, line in enumerate(lines): 

208 filename = os.path.join(path_lines_ref, f"line_{idx}.txt") 

209 if resave: 

210 np.savetxt(filename, line) 

211 arr = np.loadtxt(filename) 

212 # Differences of order 1e-12 possible between MacOS and Linux 

213 # Plots are generally not expected to be that precise 

214 # Differences to 1e-3 should not be visible with this test data 

215 self.assertFloatsAlmostEqual(arr, line, atol=1e-3, rtol=1e-4) 

216 

217 # Ensure that newlines within labels are replaced by a sentinel 

218 newline = "\n" 

219 newline_replace = "[newline]" 

220 # Compare text labels 

221 if resave: 

222 with open(filename_texts_ref, "w") as f: 

223 f.writelines(f"{text.strip().replace(newline, newline_replace)}\n" for text in texts) 

224 

225 with open(filename_texts_ref) as f: 

226 texts_ref = set(x.strip() for x in f.readlines()) 

227 texts_set = set(x.strip().replace(newline, newline_replace) for x in texts) 

228 

229 self.assertEqual(texts_ref, texts_set) 

230 

231 

232class MemoryTester(lsst.utils.tests.MemoryTestCase): 

233 pass 

234 

235 

236def setup_module(module): 

237 lsst.utils.tests.init() 

238 

239 

240if __name__ == "__main__": 240 ↛ 241line 240 didn't jump to line 241 because the condition on line 240 was never true

241 lsst.utils.tests.init() 

242 unittest.main()