Coverage for tests/test_scatterPlot.py: 20%
102 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-06 13:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-06 13:57 +0000
1# This file is part of analysis_drp.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
23import unittest
24import lsst.utils.tests
26from lsst.analysis.drp.calcFunctors import MagDiff
27from lsst.analysis.drp.dataSelectors import GalaxyIdentifier
28from lsst.analysis.drp.plotUtils import get_and_remove_figure_text
29from lsst.analysis.drp.scatterPlot import ScatterPlotWithTwoHistsTask, ScatterPlotWithTwoHistsTaskConfig
30from lsst.daf.butler import DataCoordinate, DimensionUniverse
32import matplotlib
33import matplotlib.pyplot as plt
34import numpy as np
35from numpy.random import default_rng
36import os
37import pandas as pd
38import shutil
39import tempfile
41matplotlib.use("Agg")
43ROOT = os.path.abspath(os.path.dirname(__file__))
44filename_texts_ref = os.path.join(ROOT, "data", "test_scatterPlot_texts.txt")
45path_lines_ref = os.path.join(ROOT, "data", "test_scatterPlot_lines")
48class ScatterPlotWithTwoHistsTaskTestCase(lsst.utils.tests.TestCase):
49 """ScatterPlotWithTwoHistsTask test case."""
50 def setUp(self):
51 self.testDir = tempfile.mkdtemp(dir=ROOT, prefix="test_output")
53 # Set up a quasi-plausible measurement catalog
54 mag = 12.5 + 2.5*np.log10(np.arange(10, 100000))
55 flux = 10**(-0.4*(mag - (mag[-1] + 1)))
56 rng = default_rng(0)
57 extendedness = 0. + (rng.uniform(size=len(mag)) < 0.99*(mag - mag[0])/(mag[-1] - mag[0]))
58 flux_meas = flux + rng.normal(scale=np.sqrt(flux*(1 + extendedness)))
59 flux_err = np.sqrt(flux_meas * (1 + extendedness))
60 good = (flux_meas/np.sqrt(flux * (1 + extendedness))) > 3
61 extendedness = extendedness[good]
62 flux = flux[good]
63 flux_meas = flux_meas[good]
64 flux_err = flux_err[good]
65 sky_object = np.full(len(flux), False)
67 # Configure the plot to show observed vs true mags
68 config = ScatterPlotWithTwoHistsTaskConfig(
69 axisLabels={"x": "mag", "y": "mag meas - ref", "mag": "mag"},
70 )
71 config.selectorActions.flagSelector.bands = ["i"]
72 config.axisActions.yAction = MagDiff(col1="refcat_flux", col2="refcat_flux")
73 config.nonBandColumnPrefixes.append("refcat")
74 config.selectorActions.catSnSelector.threshold = -1e12
75 config.sourceSelectorActions.galaxySelector = GalaxyIdentifier
76 config.highSnStatisticSelectorActions.statSelector.threshold = 50
77 config.lowSnStatisticSelectorActions.statSelector.threshold = 20
78 self.task = ScatterPlotWithTwoHistsTask(config=config)
80 n = len(flux)
81 self.bands, columns = config.get_requirements()
82 data = {
83 "refcat_flux": flux,
84 "patch": np.zeros(n, dtype=int),
85 }
87 # Assign values to columns based on their unchanged default names
88 for column in columns:
89 if column not in data:
90 if column.startswith("detect"):
91 data[column] = np.ones(n, dtype=bool)
92 elif column.endswith("_flag") or "Flag" in column:
93 data[column] = np.zeros(n, dtype=bool)
94 elif column.endswith("Flux"):
95 config.axisActions.yAction.col1 = column
96 data[column] = flux_meas
97 elif column.endswith("FluxErr"):
98 data[column] = flux_err
99 elif column.endswith("_extendedness"):
100 data[column] = extendedness
101 elif column.startswith("sky_"):
102 data[column] = sky_object
104 self.data = pd.DataFrame(data)
106 def tearDown(self):
107 if os.path.exists(self.testDir):
108 shutil.rmtree(self.testDir, True)
109 del self.bands
110 del self.data
111 del self.task
113 def test_ScatterPlotWithTwoHistsTask(self):
114 plt.rcParams.update(plt.rcParamsDefault)
115 universe = DimensionUniverse()
116 result = self.task.run(self.data,
117 dataId=DataCoordinate.make_empty(universe),
118 runName="test",
119 skymap=None,
120 tableName="test",
121 bands=self.bands,
122 plotName="test")
124 self.assertTrue(isinstance(result.scatterPlot, plt.Figure))
126 # Set to true to save plots as PNGs
127 # Use matplotlib.testing.compare.compare_images if needed
128 save_images = False
129 if save_images:
130 result.scatterPlot.savefig(os.path.join(ROOT, "data", "test_scatterPlot.png"))
132 texts, lines = get_and_remove_figure_text(result.scatterPlot)
133 if save_images:
134 result.scatterPlot.savefig(os.path.join(ROOT, "data", "test_scatterPlot_unlabeled.png"))
136 # Set to true to re-generate reference data
137 resave = False
139 # Compare line values
140 for idx, line in enumerate(lines):
141 filename = os.path.join(path_lines_ref, f"line_{idx}.txt")
142 if resave:
143 np.savetxt(filename, line)
144 arr = np.loadtxt(filename)
145 # Differences of order 1e-12 possible between MacOS and Linux
146 # Plots are generally not expected to be that precise
147 # Differences to 1e-3 should not be visible with this test data
148 self.assertFloatsAlmostEqual(arr, line, atol=1e-3, rtol=1e-4)
150 # Ensure that newlines within labels are replaced by a sentinel
151 newline = '\n'
152 newline_replace = "[newline]"
153 # Compare text labels
154 if resave:
155 with open(filename_texts_ref, 'w') as f:
156 f.writelines(f'{text.strip().replace(newline, newline_replace)}\n' for text in texts)
158 with open(filename_texts_ref, 'r') as f:
159 texts_ref = set(x.strip() for x in f.readlines())
160 texts_set = set(x.strip().replace(newline, newline_replace) for x in texts)
162 self.assertTrue(texts_set.issuperset(texts_ref))
165class MemoryTester(lsst.utils.tests.MemoryTestCase):
166 pass
169def setup_module(module):
170 lsst.utils.tests.init()
173if __name__ == "__main__": 173 ↛ 174line 173 didn't jump to line 174, because the condition on line 173 was never true
174 lsst.utils.tests.init()
175 unittest.main()