Coverage for tests/test_scatterPlot.py: 19%
100 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-01 01:28 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-01 01:28 -0700
1# This file is part of analysis_drp.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
23import unittest
24import lsst.utils.tests
26from lsst.analysis.drp.calcFunctors import MagDiff
27from lsst.analysis.drp.dataSelectors import GalaxyIdentifier
28from lsst.analysis.drp.plotUtils import get_and_remove_figure_text
29from lsst.analysis.drp.scatterPlot import ScatterPlotWithTwoHistsTask, ScatterPlotWithTwoHistsTaskConfig
31import matplotlib
32import matplotlib.pyplot as plt
33import numpy as np
34from numpy.random import default_rng
35import os
36import pandas as pd
37import shutil
38import tempfile
40matplotlib.use("Agg")
42ROOT = os.path.abspath(os.path.dirname(__file__))
43filename_texts_ref = os.path.join(ROOT, "data", "test_scatterPlot_texts.txt")
44path_lines_ref = os.path.join(ROOT, "data", "test_scatterPlot_lines")
47class ScatterPlotWithTwoHistsTaskTestCase(lsst.utils.tests.TestCase):
48 """ScatterPlotWithTwoHistsTask test case."""
49 def setUp(self):
50 self.testDir = tempfile.mkdtemp(dir=ROOT, prefix="test_output")
52 # Set up a quasi-plausible measurement catalog
53 mag = 12.5 + 2.5*np.log10(np.arange(10, 100000))
54 flux = 10**(-0.4*(mag - (mag[-1] + 1)))
55 rng = default_rng(0)
56 extendedness = 0. + (rng.uniform(size=len(mag)) < 0.99*(mag - mag[0])/(mag[-1] - mag[0]))
57 flux_meas = flux + rng.normal(scale=np.sqrt(flux*(1 + extendedness)))
58 flux_err = np.sqrt(flux_meas * (1 + extendedness))
59 good = (flux_meas/np.sqrt(flux * (1 + extendedness))) > 3
60 extendedness = extendedness[good]
61 flux = flux[good]
62 flux_meas = flux_meas[good]
63 flux_err = flux_err[good]
64 sky_object = np.full(len(flux), False)
66 # Configure the plot to show observed vs true mags
67 config = ScatterPlotWithTwoHistsTaskConfig(
68 axisLabels={"x": "mag", "y": "mag meas - ref", "mag": "mag"},
69 )
70 config.selectorActions.flagSelector.bands = ["i"]
71 config.axisActions.yAction = MagDiff(col1="refcat_flux", col2="refcat_flux")
72 config.nonBandColumnPrefixes.append("refcat")
73 config.selectorActions.catSnSelector.threshold = -1e12
74 config.sourceSelectorActions.galaxySelector = GalaxyIdentifier
75 config.highSnStatisticSelectorActions.statSelector.threshold = 50
76 config.lowSnStatisticSelectorActions.statSelector.threshold = 20
77 self.task = ScatterPlotWithTwoHistsTask(config=config)
79 n = len(flux)
80 self.bands, columns = config.get_requirements()
81 data = {
82 "refcat_flux": flux,
83 "patch": np.zeros(n, dtype=int),
84 }
86 # Assign values to columns based on their unchanged default names
87 for column in columns:
88 if column not in data:
89 if column.startswith("detect"):
90 data[column] = np.ones(n, dtype=bool)
91 elif column.endswith("_flag") or "Flag" in column:
92 data[column] = np.zeros(n, dtype=bool)
93 elif column.endswith("Flux"):
94 config.axisActions.yAction.col1 = column
95 data[column] = flux_meas
96 elif column.endswith("FluxErr"):
97 data[column] = flux_err
98 elif column.endswith("_extendedness"):
99 data[column] = extendedness
100 elif column.startswith("sky_"):
101 data[column] = sky_object
103 self.data = pd.DataFrame(data)
105 def tearDown(self):
106 if os.path.exists(self.testDir):
107 shutil.rmtree(self.testDir, True)
108 del self.bands
109 del self.data
110 del self.task
112 def test_ScatterPlotWithTwoHistsTask(self):
113 plt.rcParams.update(plt.rcParamsDefault)
114 result = self.task.run(self.data,
115 dataId={},
116 runName="test",
117 skymap=None,
118 tableName="test",
119 bands=self.bands,
120 plotName="test")
122 self.assertTrue(isinstance(result.scatterPlot, plt.Figure))
124 # Set to true to save plots as PNGs
125 # Use matplotlib.testing.compare.compare_images if needed
126 save_images = False
127 if save_images:
128 result.scatterPlot.savefig(os.path.join(ROOT, "data", "test_scatterPlot.png"))
130 texts, lines = get_and_remove_figure_text(result.scatterPlot)
131 if save_images:
132 result.scatterPlot.savefig(os.path.join(ROOT, "data", "test_scatterPlot_unlabeled.png"))
134 # Set to true to re-generate reference data
135 resave = False
137 # Compare line values
138 for idx, line in enumerate(lines):
139 filename = os.path.join(path_lines_ref, f"line_{idx}.txt")
140 if resave:
141 np.savetxt(filename, line)
142 arr = np.loadtxt(filename)
143 # Differences of order 1e-12 possible between MacOS and Linux
144 # Plots are generally not expected to be that precise
145 # Differences to 1e-3 should not be visible with this test data
146 self.assertFloatsAlmostEqual(arr, line, atol=1e-3, rtol=1e-4)
148 # Ensure that newlines within labels are replaced by a sentinel
149 newline = '\n'
150 newline_replace = "[newline]"
151 # Compare text labels
152 if resave:
153 with open(filename_texts_ref, 'w') as f:
154 f.writelines(f'{text.strip().replace(newline, newline_replace)}\n' for text in texts)
156 with open(filename_texts_ref, 'r') as f:
157 texts_ref = set(x.strip() for x in f.readlines())
158 texts_set = set(x.strip().replace(newline, newline_replace) for x in texts)
160 self.assertTrue(texts_set.issuperset(texts_ref))
163class MemoryTester(lsst.utils.tests.MemoryTestCase):
164 pass
167def setup_module(module):
168 lsst.utils.tests.init()
171if __name__ == "__main__": 171 ↛ 172line 171 didn't jump to line 172, because the condition on line 171 was never true
172 lsst.utils.tests.init()
173 unittest.main()