Coverage for tests/test_sourceSelector.py: 15%
262 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-13 02:27 -0700
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-13 02:27 -0700
1#
2# LSST Data Management System
3#
4# Copyright 2008-2017 AURA/LSST.
5#
6# This product includes software developed by the
7# LSST Project (http://www.lsst.org/).
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the LSST License Statement and
20# the GNU General Public License along with this program. If not,
21# see <https://www.lsstcorp.org/LegalNotices/>.
22#
24import unittest
25import numpy as np
26import astropy.units as u
28import lsst.afw.table
29import lsst.meas.algorithms
30import lsst.meas.base.tests
31import lsst.pipe.base
32import lsst.utils.tests
34from lsst.meas.algorithms import ColorLimit
37class SourceSelectorTester:
38 """Mixin for testing
40 This provides a base class for doing tests common to the
41 ScienceSourceSelectorTask and ReferenceSourceSelectorTask.
42 """
43 Task = None
45 def setUp(self):
46 schema = lsst.afw.table.SourceTable.makeMinimalSchema()
47 schema.addField("flux", float, "Flux value")
48 schema.addField("flux_flag", "Flag", "Bad flux?")
49 schema.addField("other_flux", float, "Flux value 2")
50 schema.addField("other_flux_flag", "Flag", "Bad flux 2?")
51 schema.addField("other_fluxErr", float, "Flux error 2")
52 schema.addField("goodFlag", "Flag", "Flagged if good")
53 schema.addField("badFlag", "Flag", "Flagged if bad")
54 schema.addField("starGalaxy", float, "0=star, 1=galaxy")
55 schema.addField("nChild", np.int32, "Number of children")
56 self.catalog = lsst.afw.table.SourceCatalog(schema)
57 self.catalog.reserve(10)
58 self.config = self.Task.ConfigClass()
60 def tearDown(self):
61 del self.catalog
63 def check(self, expected):
64 task = self.Task(config=self.config)
65 results = task.run(self.catalog)
66 self.assertListEqual(results.selected.tolist(), expected)
67 self.assertListEqual([src.getId() for src in results.sourceCat],
68 [src.getId() for src, ok in zip(self.catalog, expected) if ok])
70 # Check with pandas.DataFrame version of catalog
71 results = task.run(self.catalog.asAstropy().to_pandas())
72 self.assertListEqual(results.selected.tolist(), expected)
73 self.assertListEqual(list(results.sourceCat['id']),
74 [src.getId() for src, ok in zip(self.catalog, expected) if ok])
76 # Check with astropy.table.Table version of catalog
77 results = task.run(self.catalog.asAstropy())
78 self.assertListEqual(results.selected.tolist(), expected)
79 self.assertListEqual(list(results.sourceCat['id']),
80 [src.getId() for src, ok in zip(self.catalog, expected) if ok])
82 def testFlags(self):
83 bad1 = self.catalog.addNew()
84 bad1.set("goodFlag", False)
85 bad1.set("badFlag", False)
86 bad2 = self.catalog.addNew()
87 bad2.set("goodFlag", True)
88 bad2.set("badFlag", True)
89 bad3 = self.catalog.addNew()
90 bad3.set("goodFlag", False)
91 bad3.set("badFlag", True)
92 good = self.catalog.addNew()
93 good.set("goodFlag", True)
94 good.set("badFlag", False)
95 self.catalog["flux"] = 1.0
96 self.catalog["other_flux"] = 1.0
97 self.config.flags.good = ["goodFlag"]
98 self.config.flags.bad = ["badFlag"]
99 self.check([False, False, False, True])
101 def testSignalToNoise(self):
102 low = self.catalog.addNew()
103 low.set("other_flux", 1.0)
104 low.set("other_fluxErr", 1.0)
105 good = self.catalog.addNew()
106 good.set("other_flux", 1.0)
107 good.set("other_fluxErr", 0.1)
108 high = self.catalog.addNew()
109 high.set("other_flux", 1.0)
110 high.set("other_fluxErr", 0.001)
111 self.config.doSignalToNoise = True
112 self.config.signalToNoise.fluxField = "other_flux"
113 self.config.signalToNoise.errField = "other_fluxErr"
114 self.config.signalToNoise.minimum = 5.0
115 self.config.signalToNoise.maximum = 100.0
116 self.check([False, True, False])
119class ScienceSourceSelectorTaskTest(SourceSelectorTester, lsst.utils.tests.TestCase):
120 Task = lsst.meas.algorithms.ScienceSourceSelectorTask
122 def setUp(self):
123 SourceSelectorTester.setUp(self)
124 self.config.fluxLimit.fluxField = "flux"
125 self.config.flags.bad = []
126 self.config.doFluxLimit = True
127 self.config.doFlags = True
128 self.config.doUnresolved = False
129 self.config.doIsolated = False
131 def testFluxLimit(self):
132 tooBright = self.catalog.addNew()
133 tooBright.set("flux", 1.0e10)
134 tooBright.set("flux_flag", False)
135 good = self.catalog.addNew()
136 good.set("flux", 1000.0)
137 good.set("flux_flag", False)
138 bad = self.catalog.addNew()
139 bad.set("flux", good.get("flux"))
140 bad.set("flux_flag", True)
141 tooFaint = self.catalog.addNew()
142 tooFaint.set("flux", 1.0)
143 tooFaint.set("flux_flag", False)
144 self.config.fluxLimit.minimum = 10.0
145 self.config.fluxLimit.maximum = 1.0e6
146 self.config.fluxLimit.fluxField = "flux"
147 self.check([False, True, False, False])
149 # Works with no maximum set?
150 self.config.fluxLimit.maximum = None
151 self.check([True, True, False, False])
153 # Works with no minimum set?
154 self.config.fluxLimit.minimum = None
155 self.check([True, True, False, True])
157 def testUnresolved(self):
158 num = 5
159 for _ in range(num):
160 self.catalog.addNew()
161 self.catalog["flux"] = 1.0
162 starGalaxy = np.linspace(0.0, 1.0, num, False)
163 self.catalog["starGalaxy"] = starGalaxy
164 self.config.doUnresolved = True
165 self.config.unresolved.name = "starGalaxy"
166 minimum, maximum = 0.3, 0.7
167 self.config.unresolved.minimum = minimum
168 self.config.unresolved.maximum = maximum
169 self.check(((starGalaxy > minimum) & (starGalaxy < maximum)).tolist())
171 # Works with no minimum set?
172 self.config.unresolved.minimum = None
173 self.check((starGalaxy < maximum).tolist())
175 # Works with no maximum set?
176 self.config.unresolved.minimum = minimum
177 self.config.unresolved.maximum = None
178 self.check((starGalaxy > minimum).tolist())
180 def testIsolated(self):
181 num = 5
182 for _ in range(num):
183 self.catalog.addNew()
184 self.catalog["flux"] = 1.0
185 parent = np.array([0, 0, 10, 0, 0], dtype=int)
186 nChild = np.array([2, 0, 0, 0, 0], dtype=int)
187 self.catalog["parent"] = parent
188 self.catalog["nChild"] = nChild
189 self.config.doIsolated = True
190 self.config.isolated.parentName = "parent"
191 self.config.isolated.nChildName = "nChild"
192 self.check(((parent == 0) & (nChild == 0)).tolist())
194 def testRequireFiniteRaDec(self):
195 num = 5
196 for _ in range(num):
197 self.catalog.addNew()
198 ra = np.array([np.nan, np.nan, 0, 0, 0], dtype=float)
199 dec = np.array([2, np.nan, 0, 0, np.nan], dtype=float)
200 self.catalog["coord_ra"] = ra
201 self.catalog["coord_dec"] = dec
202 self.config.doRequireFiniteRaDec = True
203 self.config.requireFiniteRaDec.raColName = "coord_ra"
204 self.config.requireFiniteRaDec.decColName = "coord_dec"
205 self.check((np.isfinite(ra) & np.isfinite(dec)).tolist())
208class ReferenceSourceSelectorTaskTest(SourceSelectorTester, lsst.utils.tests.TestCase):
209 Task = lsst.meas.algorithms.ReferenceSourceSelectorTask
211 def setUp(self):
212 SourceSelectorTester.setUp(self)
213 self.config.magLimit.fluxField = "flux"
214 self.config.doMagLimit = True
215 self.config.doFlags = True
216 self.config.doUnresolved = False
218 def testMagnitudeLimit(self):
219 tooBright = self.catalog.addNew()
220 tooBright.set("flux", 1.0e10)
221 tooBright.set("flux_flag", False)
222 good = self.catalog.addNew()
223 good.set("flux", 1000.0)
224 good.set("flux_flag", False)
225 bad = self.catalog.addNew()
226 bad.set("flux", good.get("flux"))
227 bad.set("flux_flag", True)
228 tooFaint = self.catalog.addNew()
229 tooFaint.set("flux", 1.0)
230 tooFaint.set("flux_flag", False)
231 # Note: magnitudes are backwards, so the minimum flux is the maximum magnitude
232 self.config.magLimit.minimum = (1.0e6*u.nJy).to_value(u.ABmag)
233 self.config.magLimit.maximum = (10.0*u.nJy).to_value(u.ABmag)
234 self.config.magLimit.fluxField = "flux"
235 self.check([False, True, False, False])
237 # Works with no minimum set?
238 self.config.magLimit.minimum = None
239 self.check([True, True, False, False])
241 # Works with no maximum set?
242 self.config.magLimit.maximum = None
243 self.check([True, True, False, True])
245 def testMagErrorLimit(self):
246 # Using an arbitrary field as if it was a magnitude error to save adding a new field
247 field = "other_fluxErr"
248 tooFaint = self.catalog.addNew()
249 tooFaint.set(field, 0.5)
250 tooBright = self.catalog.addNew()
251 tooBright.set(field, 0.00001)
252 good = self.catalog.addNew()
253 good.set(field, 0.2)
255 self.config.doMagError = True
256 self.config.magError.minimum = 0.01
257 self.config.magError.maximum = 0.3
258 self.config.magError.magErrField = field
259 self.check([False, False, True])
261 def testColorLimits(self):
262 num = 10
263 for _ in range(num):
264 self.catalog.addNew()
265 color = np.linspace(-0.5, 0.5, num, True)
266 flux = 1000.0*u.nJy
267 # Definition: color = mag(flux) - mag(otherFlux)
268 otherFlux = (flux.to(u.ABmag) - color*u.mag).to_value(u.nJy)
269 self.catalog["flux"] = flux.value
270 self.catalog["other_flux"] = otherFlux
271 minimum, maximum = -0.1, 0.2
272 self.config.colorLimits = {"test": ColorLimit(primary="flux", secondary="other_flux",
273 minimum=minimum, maximum=maximum)}
274 self.check(((color > minimum) & (color < maximum)).tolist())
276 # Works with no minimum set?
277 self.config.colorLimits["test"].minimum = None
278 self.check((color < maximum).tolist())
280 # Works with no maximum set?
281 self.config.colorLimits["test"].maximum = None
282 self.config.colorLimits["test"].minimum = minimum
283 self.check((color > minimum).tolist())
285 # Multiple limits
286 self.config.colorLimits = {"test": ColorLimit(primary="flux", secondary="other_flux",
287 minimum=minimum),
288 "other": ColorLimit(primary="flux", secondary="other_flux",
289 maximum=maximum)}
290 assert maximum > minimum # To be non-mutually-exclusive
291 self.check(((color > minimum) & (color < maximum)).tolist())
293 # Multiple mutually-exclusive limits
294 self.config.colorLimits["test"] = ColorLimit(primary="flux", secondary="other_flux", maximum=-0.1)
295 self.config.colorLimits["other"] = ColorLimit(primary="flux", secondary="other_flux", minimum=0.1)
296 self.check([False]*num)
298 def testUnresolved(self):
299 num = 5
300 for _ in range(num):
301 self.catalog.addNew()
302 self.catalog["flux"] = 1.0
303 starGalaxy = np.linspace(0.0, 1.0, num, False)
304 self.catalog["starGalaxy"] = starGalaxy
305 self.config.doUnresolved = True
306 self.config.unresolved.name = "starGalaxy"
307 minimum, maximum = 0.3, 0.7
308 self.config.unresolved.minimum = minimum
309 self.config.unresolved.maximum = maximum
310 self.check(((starGalaxy > minimum) & (starGalaxy < maximum)).tolist())
312 # Works with no minimum set?
313 self.config.unresolved.minimum = None
314 self.check((starGalaxy < maximum).tolist())
316 # Works with no maximum set?
317 self.config.unresolved.minimum = minimum
318 self.config.unresolved.maximum = None
319 self.check((starGalaxy > minimum).tolist())
322class TrivialSourceSelector(lsst.meas.algorithms.BaseSourceSelectorTask):
323 """Return true for every source. Purely for testing."""
324 def selectSources(self, sourceCat, matches=None, exposure=None):
325 return lsst.pipe.base.Struct(selected=np.ones(len(sourceCat), dtype=bool))
328class TestBaseSourceSelector(lsst.utils.tests.TestCase):
329 """Test the API of the Abstract Base Class with a trivial example."""
330 def setUp(self):
331 schema = lsst.meas.base.tests.TestDataset.makeMinimalSchema()
332 self.selectedKeyName = "is_selected"
333 schema.addField(self.selectedKeyName, type="Flag")
334 self.catalog = lsst.afw.table.SourceCatalog(schema)
335 for i in range(4):
336 self.catalog.addNew()
338 self.sourceSelector = TrivialSourceSelector()
340 def testRun(self):
341 """Test that run() returns a catalog and boolean selected array."""
342 result = self.sourceSelector.run(self.catalog)
343 for i, x in enumerate(self.catalog['id']):
344 self.assertIn(x, result.sourceCat['id'])
345 self.assertTrue(result.selected[i])
347 def testRunSourceSelectedField(self):
348 """Test that the selected flag is set in the original catalog."""
349 self.sourceSelector.run(self.catalog, sourceSelectedField=self.selectedKeyName)
350 np.testing.assert_array_equal(self.catalog[self.selectedKeyName], True)
352 def testRunNonContiguousRaises(self):
353 """Cannot do source selection on non-contiguous catalogs."""
354 del self.catalog[1] # take one out of the middle to make it non-contiguous.
355 self.assertFalse(self.catalog.isContiguous(), "Catalog is contiguous: the test won't work.")
357 with self.assertRaises(RuntimeError):
358 self.sourceSelector.run(self.catalog)
361class TestMemory(lsst.utils.tests.MemoryTestCase):
362 pass
365def setup_module(module):
366 lsst.utils.tests.init()
369if __name__ == "__main__": 369 ↛ 370line 369 didn't jump to line 370, because the condition on line 369 was never true
370 import sys
371 setup_module(sys.modules[__name__])
372 unittest.main()