Coverage for tests/test_scatterPlot.py: 22%
97 statements
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-28 11:26 +0000
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-28 11:26 +0000
1# This file is part of analysis_drp.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
23import unittest
24import lsst.utils.tests
26from lsst.analysis.drp.calcFunctors import MagDiff
27from lsst.analysis.drp.dataSelectors import GalaxyIdentifier
28from lsst.analysis.drp.plotUtils import get_and_remove_figure_text
29from lsst.analysis.drp.scatterPlot import ScatterPlotWithTwoHistsTask, ScatterPlotWithTwoHistsTaskConfig
31import matplotlib
32import matplotlib.pyplot as plt
33import numpy as np
34from numpy.random import default_rng
35import os
36import pandas as pd
37import shutil
38import tempfile
40matplotlib.use("Agg")
42ROOT = os.path.abspath(os.path.dirname(__file__))
43filename_texts_ref = os.path.join(ROOT, "data", "test_scatterPlot_texts.txt")
44path_lines_ref = os.path.join(ROOT, "data", "test_scatterPlot_lines")
47class ScatterPlotWithTwoHistsTaskTestCase(lsst.utils.tests.TestCase):
48 """ScatterPlotWithTwoHistsTask test case."""
49 def setUp(self):
50 self.testDir = tempfile.mkdtemp(dir=ROOT, prefix="test_output")
52 # Set up a quasi-plausible measurement catalog
53 mag = 12.5 + 2.5*np.log10(np.arange(10, 100000))
54 flux = 10**(-0.4*(mag - (mag[-1] + 1)))
55 rng = default_rng(0)
56 extendedness = 0. + (rng.uniform(size=len(mag)) < 0.99*(mag - mag[0])/(mag[-1] - mag[0]))
57 flux_meas = flux + rng.normal(scale=np.sqrt(flux*(1 + extendedness)))
58 flux_err = np.sqrt(flux_meas * (1 + extendedness))
59 good = (flux_meas/np.sqrt(flux * (1 + extendedness))) > 3
60 extendedness = extendedness[good]
61 flux = flux[good]
62 flux_meas = flux_meas[good]
63 flux_err = flux_err[good]
65 # Configure the plot to show observed vs true mags
66 config = ScatterPlotWithTwoHistsTaskConfig(
67 axisLabels={"x": "mag", "y": "mag meas - ref", "mag": "mag"},
68 )
69 config.selectorActions.flagSelector.bands = ["i"]
70 config.axisActions.yAction = MagDiff(col1="refcat_flux", col2="refcat_flux")
71 config.nonBandColumnPrefixes.append("refcat")
72 config.sourceSelectorActions.galaxySelector = GalaxyIdentifier
73 config.highSnStatisticSelectorActions.statSelector.threshold = 50
74 config.lowSnStatisticSelectorActions.statSelector.threshold = 20
75 self.task = ScatterPlotWithTwoHistsTask(config=config)
77 n = len(flux)
78 self.bands, columns = config.get_requirements()
79 data = {
80 "refcat_flux": flux,
81 "patch": np.zeros(n, dtype=int),
82 }
84 # Assign values to columns based on their unchanged default names
85 for column in columns:
86 if column not in data:
87 if column.startswith("detect"):
88 data[column] = np.ones(n, dtype=bool)
89 elif column.endswith("_flag") or "Flag" in column:
90 data[column] = np.zeros(n, dtype=bool)
91 elif column.endswith("Flux"):
92 config.axisActions.yAction.col1 = column
93 data[column] = flux_meas
94 elif column.endswith("FluxErr"):
95 data[column] = flux_err
96 elif column.endswith("_extendedness"):
97 data[column] = extendedness
98 else:
99 raise RuntimeError(f"Unexpected column {column} in ScatterPlotWithTwoHistsTaskConfig")
101 self.data = pd.DataFrame(data)
103 def tearDown(self):
104 if os.path.exists(self.testDir):
105 shutil.rmtree(self.testDir, True)
106 del self.bands
107 del self.data
108 del self.task
110 def test_ScatterPlotWithTwoHistsTask(self):
111 plt.rcParams.update(plt.rcParamsDefault)
112 result = self.task.run(self.data,
113 dataId={},
114 runName="test",
115 skymap=None,
116 tableName="test",
117 bands=self.bands,
118 plotName="test")
120 self.assertTrue(isinstance(result.scatterPlot, plt.Figure))
122 # Set to true to save plots as PNGs
123 # Use matplotlib.testing.compare.compare_images if needed
124 save_images = False
125 if save_images:
126 result.scatterPlot.savefig(os.path.join(ROOT, "data", "test_scatterPlot.png"))
128 texts, lines = get_and_remove_figure_text(result.scatterPlot)
129 if save_images:
130 result.scatterPlot.savefig(os.path.join(ROOT, "data", "test_scatterPlot_unlabeled.png"))
132 # Set to true to re-generate reference data
133 resave = False
135 # Compare line values
136 for idx, line in enumerate(lines):
137 filename = os.path.join(path_lines_ref, f"line_{idx}.txt")
138 if resave:
139 np.savetxt(filename, line)
140 arr = np.loadtxt(filename)
141 # Differences of order 1e-12 possible between MacOS and Linux
142 # Plots are generally not expected to be that precise
143 # Differences to 1e-3 should not be visible with this test data
144 self.assertFloatsAlmostEqual(arr, line, atol=1e-3, rtol=1e-4)
146 # Ensure that newlines within labels are replaced by a sentinel
147 newline = '\n'
148 newline_replace = "[newline]"
149 # Compare text labels
150 if resave:
151 with open(filename_texts_ref, 'w') as f:
152 f.writelines(f'{text.strip().replace(newline, newline_replace)}\n' for text in texts)
154 with open(filename_texts_ref, 'r') as f:
155 texts_ref = set(x.strip() for x in f.readlines())
156 texts_set = set(x.strip().replace(newline, newline_replace) for x in texts)
158 self.assertTrue(texts_set.issuperset(texts_ref))
161class MemoryTester(lsst.utils.tests.MemoryTestCase):
162 pass
165def setup_module(module):
166 lsst.utils.tests.init()
169if __name__ == "__main__": 169 ↛ 170line 169 didn't jump to line 170, because the condition on line 169 was never true
170 lsst.utils.tests.init()
171 unittest.main()