Coverage for tests / test_sourceSelector.py: 13%
334 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 08:52 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 08:52 +0000
1#
2# LSST Data Management System
3#
4# Copyright 2008-2017 AURA/LSST.
5#
6# This product includes software developed by the
7# LSST Project (http://www.lsst.org/).
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the LSST License Statement and
20# the GNU General Public License along with this program. If not,
21# see <https://www.lsstcorp.org/LegalNotices/>.
22#
24import unittest
25import numpy as np
26import astropy.units as u
27import warnings
29import lsst.afw.image
30import lsst.afw.table
31import lsst.geom
32import lsst.meas.algorithms
33import lsst.meas.base.tests
34import lsst.pipe.base
35import lsst.utils.tests
37from lsst.meas.algorithms import ColorLimit
40class SourceSelectorTester:
41 """Mixin for testing
43 This provides a base class for doing tests common to the
44 ScienceSourceSelectorTask and ReferenceSourceSelectorTask.
45 """
46 Task = None
48 def setUp(self):
49 schema = lsst.afw.table.SourceTable.makeMinimalSchema()
50 schema.addField("flux", float, "Flux value")
51 schema.addField("flux_flag", "Flag", "Bad flux?")
52 schema.addField("other_flux", float, "Flux value 2")
53 schema.addField("other_flux_flag", "Flag", "Bad flux 2?")
54 schema.addField("other_fluxErr", float, "Flux error 2")
55 schema.addField("goodFlag", "Flag", "Flagged if good")
56 schema.addField("badFlag", "Flag", "Flagged if bad")
57 schema.addField("starGalaxy", float, "0=star, 1=galaxy")
58 schema.addField("nChild", np.int32, "Number of children")
59 schema.addField("detect_isPrimary", "Flag", "Is primary detection?")
60 schema.addField("sky_source", "Flag", "Empty sky region.")
62 self.xCol = "centroid_x"
63 self.yCol = "centroid_y"
64 schema.addField(self.xCol, float, "Centroid x value.")
65 schema.addField(self.yCol, float, "Centroid y value.")
67 self.catalog = lsst.afw.table.SourceCatalog(schema)
68 self.catalog.reserve(10)
69 self.config = self.Task.ConfigClass()
70 self.exposure = None
72 def tearDown(self):
73 del self.catalog
75 def check(self, expected):
76 task = self.Task(config=self.config)
77 results = task.run(self.catalog, exposure=self.exposure)
78 self.assertListEqual(results.selected.tolist(), expected)
79 self.assertListEqual([src.getId() for src in results.sourceCat],
80 [src.getId() for src, ok in zip(self.catalog, expected) if ok])
82 # Check with pandas.DataFrame version of catalog
83 results = task.run(self.catalog.asAstropy().to_pandas(), exposure=self.exposure)
84 self.assertListEqual(results.selected.tolist(), expected)
85 self.assertListEqual(list(results.sourceCat['id']),
86 [src.getId() for src, ok in zip(self.catalog, expected) if ok])
88 # Check with astropy.table.Table version of catalog
89 results = task.run(self.catalog.asAstropy(), exposure=self.exposure)
90 self.assertListEqual(results.selected.tolist(), expected)
91 self.assertListEqual(list(results.sourceCat['id']),
92 [src.getId() for src, ok in zip(self.catalog, expected) if ok])
94 def testFlags(self):
95 bad1 = self.catalog.addNew()
96 bad1.set("goodFlag", False)
97 bad1.set("badFlag", False)
98 bad2 = self.catalog.addNew()
99 bad2.set("goodFlag", True)
100 bad2.set("badFlag", True)
101 bad3 = self.catalog.addNew()
102 bad3.set("goodFlag", False)
103 bad3.set("badFlag", True)
104 good = self.catalog.addNew()
105 good.set("goodFlag", True)
106 good.set("badFlag", False)
107 self.catalog["flux"] = 1.0
108 self.catalog["other_flux"] = 1.0
109 self.config.flags.good = ["goodFlag"]
110 self.config.flags.bad = ["badFlag"]
111 self.check([False, False, False, True])
113 def testSignalToNoise(self):
114 low = self.catalog.addNew()
115 low.set("other_flux", 1.0)
116 low.set("other_fluxErr", 1.0)
117 good = self.catalog.addNew()
118 good.set("other_flux", 1.0)
119 good.set("other_fluxErr", 0.1)
120 high = self.catalog.addNew()
121 high.set("other_flux", 1.0)
122 high.set("other_fluxErr", 0.001)
123 self.config.doSignalToNoise = True
124 self.config.signalToNoise.fluxField = "other_flux"
125 self.config.signalToNoise.errField = "other_fluxErr"
126 self.config.signalToNoise.minimum = 5.0
127 self.config.signalToNoise.maximum = 100.0
128 self.check([False, True, False])
130 def testSignalToNoiseNoWarn(self):
131 low = self.catalog.addNew()
132 low.set("other_flux", np.nan)
133 low.set("other_fluxErr", np.nan)
134 self.config.doSignalToNoise = True
135 self.config.signalToNoise.fluxField = "other_flux"
136 self.config.signalToNoise.errField = "other_fluxErr"
137 # Ensure no warnings are raised.
138 with warnings.catch_warnings():
139 warnings.simplefilter("error")
140 self.check([False])
143class ScienceSourceSelectorTaskTest(SourceSelectorTester, lsst.utils.tests.TestCase):
144 Task = lsst.meas.algorithms.ScienceSourceSelectorTask
146 def setUp(self):
147 SourceSelectorTester.setUp(self)
148 self.config.fluxLimit.fluxField = "flux"
149 self.config.flags.bad = []
150 self.config.doFluxLimit = True
151 self.config.doFlags = True
152 self.config.doUnresolved = False
153 self.config.doIsolated = False
155 def testFluxLimit(self):
156 tooBright = self.catalog.addNew()
157 tooBright.set("flux", 1.0e10)
158 tooBright.set("flux_flag", False)
159 good = self.catalog.addNew()
160 good.set("flux", 1000.0)
161 good.set("flux_flag", False)
162 bad = self.catalog.addNew()
163 bad.set("flux", good.get("flux"))
164 bad.set("flux_flag", True)
165 tooFaint = self.catalog.addNew()
166 tooFaint.set("flux", 1.0)
167 tooFaint.set("flux_flag", False)
168 self.config.fluxLimit.minimum = 10.0
169 self.config.fluxLimit.maximum = 1.0e6
170 self.config.fluxLimit.fluxField = "flux"
171 self.check([False, True, False, False])
173 # Works with no maximum set?
174 self.config.fluxLimit.maximum = None
175 self.check([True, True, False, False])
177 # Works with no minimum set?
178 self.config.fluxLimit.minimum = None
179 self.check([True, True, False, True])
181 def testUnresolved(self):
182 num = 5
183 for _ in range(num):
184 self.catalog.addNew()
185 self.catalog["flux"] = 1.0
186 starGalaxy = np.linspace(0.0, 1.0, num, False)
187 self.catalog["starGalaxy"] = starGalaxy
188 self.config.doUnresolved = True
189 self.config.unresolved.name = "starGalaxy"
190 minimum, maximum = 0.3, 0.7
191 self.config.unresolved.minimum = minimum
192 self.config.unresolved.maximum = maximum
193 self.check(((starGalaxy > minimum) & (starGalaxy < maximum)).tolist())
195 # Works with no minimum set?
196 self.config.unresolved.minimum = None
197 self.check((starGalaxy < maximum).tolist())
199 # Works with no maximum set?
200 self.config.unresolved.minimum = minimum
201 self.config.unresolved.maximum = None
202 self.check((starGalaxy > minimum).tolist())
204 def testIsolated(self):
205 num = 5
206 for _ in range(num):
207 self.catalog.addNew()
208 self.catalog["flux"] = 1.0
209 parent = np.array([0, 0, 10, 0, 0], dtype=int)
210 nChild = np.array([2, 0, 0, 0, 0], dtype=int)
211 self.catalog["parent"] = parent
212 self.catalog["nChild"] = nChild
213 self.config.doIsolated = True
214 self.config.isolated.parentName = "parent"
215 self.config.isolated.nChildName = "nChild"
216 self.check(((parent == 0) & (nChild == 0)).tolist())
218 def testRequireFiniteRaDec(self):
219 num = 5
220 for _ in range(num):
221 self.catalog.addNew()
222 ra = np.array([np.nan, np.nan, 0, 0, 0], dtype=float)
223 dec = np.array([2, np.nan, 0, 0, np.nan], dtype=float)
224 self.catalog["coord_ra"] = ra
225 self.catalog["coord_dec"] = dec
226 self.config.doRequireFiniteRaDec = True
227 self.config.requireFiniteRaDec.raColName = "coord_ra"
228 self.config.requireFiniteRaDec.decColName = "coord_dec"
229 self.check((np.isfinite(ra) & np.isfinite(dec)).tolist())
231 def testRequirePrimary(self):
232 num = 5
233 for _ in range(num):
234 self.catalog.addNew()
235 primary = np.array([True, True, False, True, False], dtype=bool)
236 self.catalog["detect_isPrimary"] = primary
237 self.config.doRequirePrimary = True
238 self.config.requirePrimary.primaryColName = "detect_isPrimary"
239 self.check(primary.tolist())
241 def testSkySource(self):
242 num = 5
243 for _ in range(num):
244 self.catalog.addNew()
245 sky = np.array([True, True, False, True, False], dtype=bool)
246 self.catalog["sky_source"] = sky
247 # This is a union, not an intersection, so include another selection
248 # that would otherwise reject everything.
249 self.config.doRequirePrimary = True
250 self.config.doSkySources = True
251 self.check(sky.tolist())
254class ReferenceSourceSelectorTaskTest(SourceSelectorTester, lsst.utils.tests.TestCase):
255 Task = lsst.meas.algorithms.ReferenceSourceSelectorTask
257 def setUp(self):
258 SourceSelectorTester.setUp(self)
259 self.config.magLimit.fluxField = "flux"
260 self.config.doMagLimit = True
261 self.config.doFlags = True
262 self.config.doUnresolved = False
263 self.config.doRequireFiniteRaDec = False
265 def testMagnitudeLimit(self):
266 tooBright = self.catalog.addNew()
267 tooBright.set("flux", 1.0e10)
268 tooBright.set("flux_flag", False)
269 good = self.catalog.addNew()
270 good.set("flux", 1000.0)
271 good.set("flux_flag", False)
272 bad = self.catalog.addNew()
273 bad.set("flux", good.get("flux"))
274 bad.set("flux_flag", True)
275 tooFaint = self.catalog.addNew()
276 tooFaint.set("flux", 1.0)
277 tooFaint.set("flux_flag", False)
278 # Note: magnitudes are backwards, so the minimum flux is the maximum magnitude
279 self.config.magLimit.minimum = (1.0e6*u.nJy).to_value(u.ABmag)
280 self.config.magLimit.maximum = (10.0*u.nJy).to_value(u.ABmag)
281 self.config.magLimit.fluxField = "flux"
282 self.check([False, True, False, False])
284 # Works with no minimum set?
285 self.config.magLimit.minimum = None
286 self.check([True, True, False, False])
288 # Works with no maximum set?
289 self.config.magLimit.maximum = None
290 self.check([True, True, False, True])
292 def testMagErrorLimit(self):
293 # Using an arbitrary field as if it was a magnitude error to save adding a new field
294 field = "other_fluxErr"
295 tooFaint = self.catalog.addNew()
296 tooFaint.set(field, 0.5)
297 tooBright = self.catalog.addNew()
298 tooBright.set(field, 0.00001)
299 good = self.catalog.addNew()
300 good.set(field, 0.2)
302 self.config.doMagError = True
303 self.config.magError.minimum = 0.01
304 self.config.magError.maximum = 0.3
305 self.config.magError.magErrField = field
306 self.check([False, False, True])
308 def testColorLimits(self):
309 num = 10
310 for _ in range(num):
311 self.catalog.addNew()
312 color = np.linspace(-0.5, 0.5, num, True)
313 flux = 1000.0*u.nJy
314 # Definition: color = mag(flux) - mag(otherFlux)
315 otherFlux = (flux.to(u.ABmag) - color*u.mag).to_value(u.nJy)
316 self.catalog["flux"] = flux.value
317 self.catalog["other_flux"] = otherFlux
318 minimum, maximum = -0.1, 0.2
319 self.config.colorLimits = {"test": ColorLimit(primary="flux", secondary="other_flux",
320 minimum=minimum, maximum=maximum)}
321 self.check(((color > minimum) & (color < maximum)).tolist())
323 # Works with no minimum set?
324 self.config.colorLimits["test"].minimum = None
325 self.check((color < maximum).tolist())
327 # Works with no maximum set?
328 self.config.colorLimits["test"].maximum = None
329 self.config.colorLimits["test"].minimum = minimum
330 self.check((color > minimum).tolist())
332 # Multiple limits
333 self.config.colorLimits = {"test": ColorLimit(primary="flux", secondary="other_flux",
334 minimum=minimum),
335 "other": ColorLimit(primary="flux", secondary="other_flux",
336 maximum=maximum)}
337 assert maximum > minimum # To be non-mutually-exclusive
338 self.check(((color > minimum) & (color < maximum)).tolist())
340 # Multiple mutually-exclusive limits
341 self.config.colorLimits["test"] = ColorLimit(primary="flux", secondary="other_flux", maximum=-0.1)
342 self.config.colorLimits["other"] = ColorLimit(primary="flux", secondary="other_flux", minimum=0.1)
343 self.check([False]*num)
345 def testUnresolved(self):
346 num = 5
347 for _ in range(num):
348 self.catalog.addNew()
349 self.catalog["flux"] = 1.0
350 starGalaxy = np.linspace(0.0, 1.0, num, False)
351 self.catalog["starGalaxy"] = starGalaxy
352 self.config.doUnresolved = True
353 self.config.unresolved.name = "starGalaxy"
354 minimum, maximum = 0.3, 0.7
355 self.config.unresolved.minimum = minimum
356 self.config.unresolved.maximum = maximum
357 self.check(((starGalaxy > minimum) & (starGalaxy < maximum)).tolist())
359 # Works with no minimum set?
360 self.config.unresolved.minimum = None
361 self.check((starGalaxy < maximum).tolist())
363 # Works with no maximum set?
364 self.config.unresolved.minimum = minimum
365 self.config.unresolved.maximum = None
366 self.check((starGalaxy > minimum).tolist())
368 def testFiniteRaDec(self):
369 "Test that non-finite RA and Dec values are caught."
370 num = 5
371 for _ in range(num):
372 self.catalog.addNew()
373 self.catalog["coord_ra"][:] = 1.0
374 self.catalog["coord_dec"][:] = 1.0
375 self.catalog["coord_ra"][0] = np.nan
376 self.catalog["coord_dec"][1] = np.inf
377 self.config.doRequireFiniteRaDec = True
379 self.check([False, False, True, True, True])
381 def testCullFromMaskedRegion(self):
382 # Test that objects whose centroids land on specified mask(s) are
383 # culled.
384 maskNames = ["NO_DATA", "BLAH"]
385 num = 5
386 for _ in range(num):
387 self.catalog.addNew()
389 for x0, y0 in [[0, 0], [3, 8]]:
390 self.exposure = lsst.afw.image.ExposureF(5, 5)
391 self.exposure.setXY0(lsst.geom.Point2I(x0, y0))
392 mask = self.exposure.mask
393 for maskName in maskNames:
394 if maskName not in mask.getMaskPlaneDict():
395 mask.addMaskPlane(maskName)
396 self.catalog[self.xCol][:] = x0 + 5.0
397 self.catalog[self.yCol][:] = y0 + 5.0
398 noDataPoints = [[0 + x0, 0 + y0], [3 + x0, 2 + y0]]
399 # Set first two entries in catalog to land in maskNames region.
400 for i, noDataPoint in enumerate(noDataPoints):
401 # Flip x & y for numpy array convention.
402 mask.array[noDataPoint[1] - y0][noDataPoint[0] - x0] = mask.getPlaneBitMask(
403 maskNames[min(i, len(maskNames) - 1)]
404 )
405 self.catalog[self.xCol][i] = noDataPoint[0]
406 self.catalog[self.yCol][i] = noDataPoint[1]
407 self.config.doCullFromMaskedRegion = True
408 self.config.cullFromMaskedRegion.xColName = self.xCol
409 self.config.cullFromMaskedRegion.yColName = self.yCol
410 self.config.cullFromMaskedRegion.badMaskNames = maskNames
411 self.check([False, False, True, True, True])
413 # Reset config back to False and None for other tests.
414 self.config.doCullFromMaskedRegion = False
415 self.exposure = None
418class TestBaseSourceSelector(lsst.utils.tests.TestCase):
419 """Test the API of the Abstract Base Class with a trivial example."""
420 def setUp(self):
421 schema = lsst.meas.base.tests.TestDataset.makeMinimalSchema()
422 self.selectedKeyName = "is_selected"
423 schema.addField(self.selectedKeyName, type="Flag")
424 self.catalog = lsst.afw.table.SourceCatalog(schema)
425 for i in range(4):
426 self.catalog.addNew()
428 self.sourceSelector = lsst.meas.algorithms.NullSourceSelectorTask()
430 def testRun(self):
431 """Test that run() returns a catalog and boolean selected array."""
432 result = self.sourceSelector.run(self.catalog)
433 for i, x in enumerate(self.catalog['id']):
434 self.assertIn(x, result.sourceCat['id'])
435 self.assertTrue(result.selected[i])
437 def testRunSourceSelectedField(self):
438 """Test that the selected flag is set in the original catalog."""
439 self.sourceSelector.run(self.catalog, sourceSelectedField=self.selectedKeyName)
440 np.testing.assert_array_equal(self.catalog[self.selectedKeyName], True)
442 def testRunNonContiguousRaises(self):
443 """Cannot do source selection on non-contiguous catalogs."""
444 del self.catalog[1] # take one out of the middle to make it non-contiguous.
445 self.assertFalse(self.catalog.isContiguous(), "Catalog is contiguous: the test won't work.")
447 with self.assertRaises(RuntimeError):
448 self.sourceSelector.run(self.catalog)
451class TestMemory(lsst.utils.tests.MemoryTestCase):
452 pass
455def setup_module(module):
456 lsst.utils.tests.init()
459if __name__ == "__main__": 459 ↛ 460line 459 didn't jump to line 460 because the condition on line 459 was never true
460 import sys
461 setup_module(sys.modules[__name__])
462 unittest.main()