Coverage for tests/test_propertyMapPlot.py: 16%

144 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-29 11:33 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21import os 

22import unittest 

23 

24import healsparse as hsp 

25import lsst.utils.tests 

26import matplotlib 

27import matplotlib.pyplot as plt 

28import numpy as np 

29import skyproj 

30from lsst.analysis.tools.atools.propertyMap import PropertyMapTool 

31from lsst.analysis.tools.tasks.propertyMapTractAnalysis import ( 

32 PropertyMapConfig, 

33 PropertyMapTractAnalysisConfig, 

34 PropertyMapTractAnalysisTask, 

35) 

36from lsst.daf.butler import Butler, DataCoordinate, DatasetType, DeferredDatasetHandle 

37from lsst.daf.butler.tests.utils import makeTestTempDir, removeTestTempDir 

38from lsst.skymap.discreteSkyMap import DiscreteSkyMap 

39from mpl_toolkits import axisartist 

40 

41# No display needed. 

42matplotlib.use("Agg") 

43 

44# Direcory where this file is located. 

45ROOT = os.path.abspath(os.path.dirname(__file__)) 

46 

47 

48class PropertyMapTractAnalysisTaskTestCase(lsst.utils.tests.TestCase): 

49 """PropertyMapTractAnalysisTask test case. 

50 

51 Notes 

52 ----- 

53 While definitive tests are conducted in `ci_hsc` and `ci_imsim` using real 

54 and simulated datasets to ensure thorough coverage, this test case is 

55 designed to catch foundational issues like syntax errors or logical 

56 inconsistencies in the way the plots are generated. 

57 """ 

58 

59 def setUp(self): 

60 # Create a temporary directory to test in. 

61 self.testDir = makeTestTempDir(ROOT) 

62 

63 # Create a butler in the test directory. 

64 Butler.makeRepo(self.testDir) 

65 butler = Butler(self.testDir, run="testrun") 

66 

67 # Make a dummy dataId. 

68 dataId = {"band": "i", "skymap": "hsc_rings_v1", "tract": 1915} 

69 dataId = DataCoordinate.standardize(dataId, universe=butler.dimensions) 

70 

71 # Configure the maps to be plotted. 

72 config = PropertyMapTractAnalysisConfig() 

73 config.zoomFactors = [3, 6] 

74 

75 # Set configurations for the first property. 

76 config.properties["prop1"] = PropertyMapConfig 

77 config.properties["prop1"].coaddName = "deep" 

78 config.properties["prop1"].operations = ["weighted_mean", "sum"] 

79 config.properties["prop1"].nBinsHist = 100 

80 

81 # Set configurations for the second property. 

82 config.properties["prop2"] = PropertyMapConfig 

83 config.properties["prop2"].coaddName = "goodSeeing" 

84 config.properties["prop2"].operations = ["min", "max", "mean"] 

85 config.properties["prop2"].nBinsHist = 40 

86 

87 # Generate dataset type names from the config and populate the 

88 # propertyNameLookup dictionary. 

89 names = [] 

90 self.propertyNameLookup = {} 

91 for propertyName, propConfig in config.properties.items(): 

92 coaddName = propConfig.coaddName 

93 for operationName in propConfig.operations: 

94 name = f"{coaddName}Coadd_{propertyName}_map_{operationName}" 

95 names.append(name) 

96 # The keys in propertyNameLookup are derived by removing "_map" 

97 # from the datasettype name and appending "_PropertyMapPlot". 

98 key = f"{name.replace('_map', '')}_PropertyMapPlot" 

99 self.propertyNameLookup[key] = propertyName 

100 

101 # Mock up corresponding HealSparseMaps and register them with the 

102 # butler. Multiple maps allow us to check that we can generate multiple 

103 # plots using a single tool. 

104 mapsDict = {} 

105 for name, value in zip(names, np.linspace(1, 10, len(names))): 

106 hspMap = hsp.HealSparseMap.make_empty(nside_coverage=32, nside_sparse=4096, dtype=np.float32) 

107 hspMap[0:10000] = value 

108 hspMap[100000:110000] = value + 1 

109 hspMap[500000:510000] = value + 2 

110 datasetType = DatasetType(name, [], "HealSparseMap", universe=butler.dimensions) 

111 butler.registry.registerDatasetType(datasetType) 

112 dataRef = butler.put(hspMap, datasetType) 

113 # Keys in mapsDict are designed to reflect the task's connection 

114 # names, which are akin to datasettype names minus "_map". 

115 mapsDict[name.replace("_map", "")] = DeferredDatasetHandle( 

116 butler=butler, ref=dataRef, parameters=None 

117 ) 

118 

119 # Mock up the skymap and tractInfo. 

120 skyMapConfig = DiscreteSkyMap.ConfigClass() 

121 coords = [ # From the PS1 Medium-Deep fields. 

122 (10.6750, 41.2667), # M31 

123 (36.2074, -04.5833), # XMM-LSS 

124 ] 

125 skyMapConfig.raList = [c[0] for c in coords] 

126 skyMapConfig.decList = [c[1] for c in coords] 

127 skyMapConfig.radiusList = [2] * len(coords) 

128 skyMapConfig.validate() 

129 skymap = DiscreteSkyMap(config=skyMapConfig) 

130 self.tractInfo = skymap.generateTract(0) 

131 

132 # Initialize the task and set class attributes for subsequent use. 

133 task = PropertyMapTractAnalysisTask() 

134 self.plotConfig = config 

135 self.plotInfo = task.parsePlotInfo(mapsDict, dataId, list(mapsDict.keys())) 

136 self.data = {"maps": mapsDict} 

137 self.atool = PropertyMapTool() 

138 self.atool.produce.plot.plotName = "test" 

139 self.atool.finalize() 

140 

141 def tearDown(self): 

142 del self.propertyNameLookup 

143 del self.atool 

144 del self.data 

145 del self.tractInfo 

146 del self.plotConfig 

147 del self.plotInfo 

148 removeTestTempDir(self.testDir) 

149 del self.testDir 

150 

151 def test_PropertyMapTractAnalysisTask(self): 

152 plt.rcParams.update(plt.rcParamsDefault) 

153 result = self.atool( 

154 data=self.data, 

155 tractInfo=self.tractInfo, 

156 plotConfig=self.plotConfig, 

157 plotInfo=self.plotInfo, 

158 ) 

159 

160 # Previously computed reference RGB fractions for the plots. 

161 expectedRGBFractions = [ 

162 (0.3195957485702613, 0.3214485564746734, 0.3205870925245099), 

163 (0.3196810334967318, 0.3215274948937909, 0.32067869434232027), 

164 (0.31861285794526134, 0.3206199550653596, 0.3196505213439543), 

165 (0.3185880596405229, 0.3206015032679739, 0.319619406147876), 

166 (0.31843414266748354, 0.32044758629493475, 0.3194654891748366), 

167 ] 

168 

169 # Unpack the figures from the dictionary and run some checks. 

170 for (name, fig), expectedRGBFraction in zip(result.items(), expectedRGBFractions): 

171 propertyName = self.propertyNameLookup[name] 

172 binsCount = self.plotConfig.properties[propertyName].nBinsHist 

173 xlabel = propertyName.title().replace("Psf", "PSF") 

174 zoomFactors = self.plotConfig.zoomFactors 

175 

176 # Check that the object is a matplotlib figure. 

177 self.assertTrue(isinstance(fig, plt.Figure), msg=f"Figure {name} is not a matplotlib figure.") 

178 

179 # Validate the structure of the figure. 

180 self._validateFigureStructure(fig, binsCount, xlabel, zoomFactors) 

181 

182 # Validate the RGB fractions of the figure. The tolerance is set 

183 # empirically. 

184 self._validateRGBFractions(fig, expectedRGBFraction, rtol=5e-3) 

185 

186 @staticmethod 

187 def _isHistogramAxis(ax, binsCount, legendLabels, errors): 

188 """Checks if a given axis is a histogram axis based on specified 

189 parameters. 

190 

191 Parameters 

192 ---------- 

193 ax : `~matplotlib.axes.Axes` 

194 The axis to be checked. 

195 binsCount : `int` 

196 The expected number of bins in the histogram. 

197 legendLabels : `List` [`str`] 

198 The expected labels in the histogram legend. 

199 errors : `List` [`str`] 

200 A list to append any errors found during the checks. 

201 

202 Returns 

203 ------- 

204 None 

205 Errors are appended to the provided `errors` list. 

206 """ 

207 

208 # Count rectangle and polygon patches. 

209 nRectanglePatches = sum(1 for patch in ax.patches if isinstance(patch, matplotlib.patches.Rectangle)) 

210 nPolygonPatches = sum(1 for patch in ax.patches if isinstance(patch, matplotlib.patches.Polygon)) 

211 

212 # Check for the number of rectangle patches for the filled histogram. 

213 if nRectanglePatches != binsCount: 

214 errors.append( 

215 f"Expected {binsCount} rectangle patches in histogram, but found {nRectanglePatches}." 

216 ) 

217 

218 # Check for the number of polygon patches, i.e. the step histograms. 

219 if nPolygonPatches != 2: 

220 errors.append(f"Expected 2 polygon patches, but found {nPolygonPatches}.") 

221 

222 # Check for `fill_between` regions, represented by `PolyCollection` 

223 # objects. 

224 if len(ax.collections) != 2: 

225 errors.append(f"Expected 2 `fill_between` regions but found {len(ax.collections)}.") 

226 

227 # Verify legend labels. 

228 legend = ax.get_legend() 

229 if not legend: 

230 errors.append("Legend is missing in the histogram.") 

231 else: 

232 labels = [text.get_text() for text in legend.get_texts()] 

233 if set(labels) != set(legendLabels): 

234 errors.append(f"Expected legend labels {legendLabels} but found {labels}.") 

235 

236 def _validateFigureStructure(self, fig, binsCount, xlabel, zoomFactors): 

237 """Validates the structure of a given matplotlib figure generated by 

238 the tool that is being tested. 

239 

240 Parameters 

241 ---------- 

242 fig : `~matplotlib.figure.Figure` 

243 The figure to be validated. 

244 binsCount : `int` 

245 The expected number of bins in the histogram. 

246 xlabel : `str` 

247 The expected x-axis label of the histogram. 

248 zoomFactors : `List` [`float`] 

249 A list of zoom factors used for the zoomed-in plots. 

250 

251 Raises 

252 ------ 

253 AssertionError 

254 If any of the criteria for figure structure is not met. The error 

255 message will list all criteria that were not satisfied. 

256 """ 

257 errors = [] 

258 axes = fig.get_axes() 

259 

260 # Check the total number of each axis type. 

261 totalSkyAxes = sum(isinstance(ax, skyproj.skyaxes.SkyAxes) for ax in axes) 

262 totalAxisArtistAxes = sum(isinstance(ax, axisartist.axislines.Axes) for ax in axes) 

263 totalColorbarAxes = sum(isinstance(ax, plt.Axes) and ax.get_label() == "<colorbar>" for ax in axes) 

264 

265 if totalSkyAxes != 3: 

266 errors.append(f"Expected 3 SkyAxes but got {totalSkyAxes}.") 

267 if totalAxisArtistAxes != 3: 

268 errors.append(f"Expected 3 AxisArtist Axes but got {totalAxisArtistAxes}.") 

269 if totalColorbarAxes != 3: 

270 errors.append(f"Expected 3 colorbar Axes but got {totalColorbarAxes}.") 

271 

272 # Check histogram axis. 

273 self._isHistogramAxis( 

274 axes[0], 

275 binsCount, 

276 ["Full Tract"] 

277 + [f"{self.atool.produce.plot.prettyPrintFloat(factor)}x Zoom" for factor in zoomFactors], 

278 errors, 

279 ) 

280 

281 # Verify x and y labels for histogram. 

282 if axes[0].get_xlabel() != xlabel: 

283 errors.append(f"Expected x-label '{xlabel}' for histogram but found '{axes[0].get_xlabel()}'.") 

284 if axes[0].get_ylabel() != "Normalized Count": 

285 errors.append( 

286 f"Expected y-label 'Normalized Count' for histogram but found '{axes[0].get_ylabel()}'." 

287 ) 

288 

289 self.assertTrue(len(errors) == 0, msg="\n" + "\n".join(errors)) 

290 

291 def _validateRGBFractions(self, fig, RGBFraction, rtol=1e-7): 

292 """Checks if a matplotlib figure has specified fractions of R, G, and B 

293 colors. 

294 

295 Parameters 

296 ---------- 

297 fig : `~matplotlib.figure.Figure` 

298 The figure to check. 

299 RGBFraction : `tuple` 

300 Tuple containing the desired fractions for red, green, and blue in 

301 the image, respectively. 

302 rtol : `float`, optional 

303 The relative tolerance allowed for the fractions. Default is 1e-7. 

304 

305 Raises 

306 ------ 

307 AssertionError 

308 If the actual fractions of the RGB colors in the image do not match 

309 the expected fractions within the given tolerance. 

310 """ 

311 

312 # Unpack the desired fractions. 

313 rFraction, gFraction, bFraction = RGBFraction 

314 

315 # Draw the figure so the renderer can grab the pixel buffer. 

316 fig.canvas.draw() 

317 

318 # Convert figure to data array. 

319 data = np.array(fig.canvas.renderer.buffer_rgba())[:, :, :3] / 255.0 

320 

321 # Calculate fractions. 

322 rActualFraction = np.sum(data[:, :, 0]) / data.size 

323 gActualFraction = np.sum(data[:, :, 1]) / data.size 

324 bActualFraction = np.sum(data[:, :, 2]) / data.size 

325 

326 # Check if the actual fractions meet the expected fractions within the 

327 # given tolerance. 

328 errors = [] 

329 if not np.abs(rActualFraction - rFraction) <= rtol: 

330 errors.append( 

331 f"Calculated red fraction {rActualFraction} does not match {rFraction} within rtol {rtol}." 

332 ) 

333 

334 if not np.abs(gActualFraction - gFraction) <= rtol: 

335 errors.append( 

336 f"Calculated green fraction {gActualFraction} does not match {gFraction} within rtol {rtol}." 

337 ) 

338 

339 if not np.abs(bActualFraction - bFraction) <= rtol: 

340 errors.append( 

341 f"Calculated blue fraction {bActualFraction} does not match {bFraction} within rtol {rtol}." 

342 ) 

343 

344 self.assertTrue(len(errors) == 0, msg="\n" + "\n".join(errors)) 

345 

346 

347class MemoryTester(lsst.utils.tests.MemoryTestCase): 

348 pass 

349 

350 

351def setup_module(module): 

352 lsst.utils.tests.init() 

353 

354 

355if __name__ == "__main__": 355 ↛ 356line 355 didn't jump to line 356, because the condition on line 355 was never true

356 lsst.utils.tests.init() 

357 unittest.main()