Coverage for tests / test_filterDiaSourceCatalog.py: 12%
201 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-23 08:46 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-23 08:46 +0000
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import unittest
23import numpy as np
25from lsst.ap.association.filterDiaSourceCatalog import (FilterDiaSourceCatalogConfig,
26 FilterDiaSourceCatalogTask,
27 FilterDiaSourceReliabilityConfig,
28 FilterDiaSourceReliabilityTask)
29import lsst.geom as geom
30import lsst.meas.base.tests as measTests
31import lsst.utils.tests
32import lsst.afw.image as afwImage
33import lsst.afw.table as afwTable
34import lsst.daf.base as dafBase
37class TestFilterDiaSourceCatalogTask(unittest.TestCase):
39 def setUp(self):
40 self.config = FilterDiaSourceCatalogConfig()
42 self.nSkySources = 5
43 self.nCrCenterSources = 6
44 self.nFakeFlagSources = 4
45 self.nNegativeSources = 7
46 self.nTrailedSources = 10
47 self.pixelScale = 0.2 # arcseconds/pixel
48 self.nSources = (self.nSkySources + self.nCrCenterSources + self.nFakeFlagSources
49 + self.nNegativeSources + self.nTrailedSources)
50 self.yLoc = 100
51 self.expId = 4321
52 self.bbox = geom.Box2I(geom.Point2I(0, 0),
53 geom.Extent2I(1024, 1153))
54 dataset = measTests.TestDataset(self.bbox)
55 for srcIdx in range(self.nSources):
56 dataset.addSource(10000.0, geom.Point2D(srcIdx, self.yLoc))
57 schema = dataset.makeMinimalSchema()
58 schema.addField("sky_source", type="Flag", doc="Sky objects.")
59 schema.addField("slot_Centroid_flag", type="Flag", doc="The centroid calculation "
60 "failed. The source position is incorrect.")
61 schema.addField("base_PixelFlags_flag_crCenter", type="Flag", doc="A cosmic ray was detected "
62 "and interpolated in this object's center.")
63 schema.addField("base_PixelFlags_flag_high_varianceCenterAll", type="Flag", doc="The object was "
64 "detected in a region with exceptionally high template variance.")
65 schema.addField("fakeBadFlag", type="Flag", doc="A fake flag to test a badFlagList longer "
66 "than one item.")
67 schema.addField("ip_diffim_forced_PsfFlux_instFlux", type="F",
68 doc="Forced photometry flux for a point source model measured on the visit image "
69 "centered at DiaSource position.")
70 schema.addField("ip_diffim_forced_PsfFlux_instFluxErr", type="F",
71 doc="Estimated uncertainty of ip_diffim_forced_PsfFlux_instFlux.")
72 schema.addField('ext_trailedSources_Naive_flag_off_image', type="Flag",
73 doc="Trail extends off image")
74 schema.addField('ext_trailedSources_Naive_flag_suspect_long_trail',
75 type="Flag", doc="Trail length is greater than three times the psf radius")
76 schema.addField('ext_trailedSources_Naive_flag_edge', type="Flag",
77 doc="Trail contains edge pixels")
78 schema.addField('ext_trailedSources_Naive_flag_nan', type="Flag",
79 doc="One or more trail coordinates are missing")
80 schema.addField('ext_trailedSources_Naive_length', type="F",
81 doc="Length of the source trail")
82 schema.addField("reliability", type="F",
83 doc="Reliability of the source")
84 _, self.diaSourceCat = dataset.realize(10.0, schema, randomSeed=1234)
86 # set the sky_source flag for the first set
87 self.diaSourceCat[0:self.nSkySources]["sky_source"] = True
89 # set the pixelFlags_crCenter flag
90 crCenter_offset = self.nSkySources
91 self.diaSourceCat[crCenter_offset:crCenter_offset+self.nCrCenterSources][
92 "base_PixelFlags_flag_crCenter"
93 ] = True
95 # set the fakeBadFlag flag
96 fakeFlag_offset = crCenter_offset + self.nCrCenterSources
97 self.diaSourceCat[fakeFlag_offset:fakeFlag_offset+self.nFakeFlagSources][
98 "fakeBadFlag"
99 ] = True
101 # create increasingly negative ip_diffim_forced_PsfFlux_instFlux/ip_diffim_forced_PsfFlux_instFluxErr
102 self.nRemovedNegativeSources = 0
103 negativeSources_offset = fakeFlag_offset + self.nFakeFlagSources
104 for i, srcIdx in enumerate(range(negativeSources_offset,
105 negativeSources_offset+self.nNegativeSources)):
106 self.diaSourceCat[srcIdx]["ip_diffim_forced_PsfFlux_instFlux"] = -0.5 * i
107 self.diaSourceCat[srcIdx]["ip_diffim_forced_PsfFlux_instFluxErr"] = 1.01
108 if (-0.5 * i)/1.01 < self.config.minAllowedDirectSnr:
109 self.nRemovedNegativeSources += 1
111 # The last 10 sources will all contained trail length measurements,
112 # increasing in size by 1.5 arcseconds. Only the last three will have
113 # lengths which are too long and will be filtered out.
114 self.nFilteredTrailedSources = 0
115 trail_offset = negativeSources_offset + self.nNegativeSources
116 for i, srcIdx in enumerate(range(trail_offset, trail_offset+self.nTrailedSources)):
117 self.diaSourceCat[srcIdx]["ext_trailedSources_Naive_length"] = 1.5*(i+1)/self.pixelScale
118 if 1.5*(i+1) > 36000/3600.0/24.0 * 30.0:
119 self.nFilteredTrailedSources += 1
120 # Setting a combination of flags for filtering in tests
121 self.diaSourceCat[trail_offset+1]["ext_trailedSources_Naive_flag_off_image"] = True
122 self.diaSourceCat[trail_offset+2]["ext_trailedSources_Naive_flag_suspect_long_trail"] = True
123 self.diaSourceCat[trail_offset+2]["ext_trailedSources_Naive_flag_edge"] = True
124 # As only two of these flags are set, the total number of filtered
125 # sources will be self.nFilteredTrailedSources + 2
126 self.nFilteredTrailedSources += 2
128 mjd = 57071.0
129 self.utc_jd = mjd + 2_400_000.5 - 35.0 / (24.0 * 60.0 * 60.0)
131 self.visitInfo = afwImage.VisitInfo(
132 # This incomplete visitInfo is sufficient for testing because the
133 # Python constructor sets all other required values to some
134 # default.
135 exposureTime=30.0,
136 darkTime=3.0,
137 date=dafBase.DateTime(mjd, system=dafBase.DateTime.MJD),
138 boresightRaDec=geom.SpherePoint(0.0, 0.0, geom.degrees),
139 )
141 def test_run_without_filter(self):
142 """Test that when all filters are turned off all sources in the catalog
143 are returned.
144 """
145 self.config.doRemoveSkySources = False
146 self.config.badFlagList = []
147 self.config.doRemoveNegativeDirectImageSources = False
148 self.config.doWriteRejectedSkySources = False
149 self.config.doTrailedSourceFilter = False
150 filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config)
151 result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo)
152 self.assertEqual(len(result.filteredDiaSourceCat), len(self.diaSourceCat))
153 self.assertEqual(len(result.rejectedDiaSources), 0)
154 self.assertEqual(len(self.diaSourceCat), self.nSources)
156 def test_run_with_filter_sky_only(self):
157 """Test that when only the sky filter is turned on the first five
158 sources which are flagged as sky objects are filtered out of the
159 catalog and the rest are returned.
160 """
161 self.config.doRemoveSkySources = True
162 self.config.badFlagList = []
163 self.config.doRemoveNegativeDirectImageSources = False
164 self.config.doWriteRejectedSkySources = True
165 self.config.doTrailedSourceFilter = False
166 filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config)
167 result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo)
168 nExpectedFilteredSources = self.nSources - self.nSkySources
169 self.assertEqual(len(result.filteredDiaSourceCat),
170 len(self.diaSourceCat[~self.diaSourceCat['sky_source']]))
171 self.assertEqual(len(result.filteredDiaSourceCat), nExpectedFilteredSources)
172 self.assertEqual(len(result.rejectedDiaSources), self.nSkySources)
173 self.assertEqual(len(self.diaSourceCat), self.nSources)
175 def test_run_with_filter_defaultBadFlagList_only(self):
176 """Test that when only the CR center filter is turned on the six sources which are flagged
177 as base_PixelFlags_flag_crCenter are filtered out of the catalog and the rest are returned.
178 """
179 self.config.doRemoveSkySources = False
180 self.config.doRemoveNegativeDirectImageSources = False
181 self.config.doWriteRejectedSkySources = False
182 self.config.doTrailedSourceFilter = False
183 filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config)
184 result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo)
185 nExpectedFilteredSources = self.nSources - self.nCrCenterSources
186 self.assertEqual(len(result.filteredDiaSourceCat),
187 len(self.diaSourceCat[~self.diaSourceCat['base_PixelFlags_flag_crCenter']]))
188 self.assertEqual(len(result.filteredDiaSourceCat), nExpectedFilteredSources)
189 self.assertEqual(len(result.rejectedDiaSources), self.nCrCenterSources)
190 self.assertEqual(len(self.diaSourceCat), self.nSources)
192 def test_run_with_filter_nonDefaultBadFlagList_only(self):
193 """Test that the badFlagList filters appropriately when it is not when the default configuration.
194 The six sources flagged base_PixelFlags_flag_crCenter and the four sources flagged fakeBadFlag
195 should be filtered out of the catalog and the rest are returned.
196 """
197 self.config.doRemoveSkySources = False
198 self.config.badFlagList = ["base_PixelFlags_flag_crCenter", "fakeBadFlag"]
199 self.config.doRemoveNegativeDirectImageSources = False
200 self.config.doWriteRejectedSkySources = False
201 self.config.doTrailedSourceFilter = False
202 filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config)
203 result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo)
204 nExpectedFilteredSources = self.nSources - self.nCrCenterSources - self.nFakeFlagSources
205 self.assertEqual(len(result.filteredDiaSourceCat),
206 len(self.diaSourceCat[~self.diaSourceCat['base_PixelFlags_flag_crCenter']
207 & ~self.diaSourceCat['fakeBadFlag']]))
208 self.assertEqual(len(result.filteredDiaSourceCat), nExpectedFilteredSources)
209 self.assertEqual(len(result.rejectedDiaSources), self.nCrCenterSources + self.nFakeFlagSources)
210 self.assertEqual(len(self.diaSourceCat), self.nSources)
212 def test_run_with_filter_negative_only(self):
213 """Test that when only the negative filter is turned on then
214 sources which below the negtive snr cut are filtered out of the
215 catalog and the rest are returned.
216 """
217 self.config.doRemoveSkySources = False
218 self.config.badFlagList = []
219 self.config.doRemoveNegativeDirectImageSources = True
220 self.config.doWriteRejectedSkySources = True
221 self.config.doTrailedSourceFilter = False
222 filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config)
223 result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo)
224 nExpectedFilteredSources = self.nSources - self.nRemovedNegativeSources
225 self.assertEqual(len(result.filteredDiaSourceCat), nExpectedFilteredSources)
226 self.assertEqual(len(result.rejectedDiaSources), self.nRemovedNegativeSources)
227 self.assertEqual(len(self.diaSourceCat), self.nSources)
228 self.assertEqual(np.sum(result.filteredDiaSourceCat['ip_diffim_forced_PsfFlux_instFlux']
229 / result.filteredDiaSourceCat['ip_diffim_forced_PsfFlux_instFluxErr']
230 < self.config.minAllowedDirectSnr), 0)
231 self.assertEqual(np.sum(result.rejectedDiaSources['ip_diffim_forced_PsfFlux_instFlux']
232 / result.rejectedDiaSources['ip_diffim_forced_PsfFlux_instFluxErr']
233 < self.config.minAllowedDirectSnr),
234 self.nRemovedNegativeSources)
236 def test_run_with_filter_negative_and_sky(self):
237 """Test concatenating rejects when both sky and negative filtering
238 are on.
239 """
240 self.config.doRemoveSkySources = True
241 self.config.badFlagList = []
242 self.config.doRemoveNegativeDirectImageSources = True
243 self.config.doWriteRejectedSkySources = True
244 self.config.doTrailedSourceFilter = False
245 filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config)
246 result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo)
247 nExpectedFilteredSources = self.nSources - self.nSkySources - self.nRemovedNegativeSources
248 self.assertEqual(len(result.filteredDiaSourceCat), nExpectedFilteredSources)
249 self.assertEqual(len(result.rejectedDiaSources), self.nSkySources + self.nRemovedNegativeSources)
250 self.assertEqual(len(self.diaSourceCat), self.nSources)
252 def test_run_with_filter_reliability_only(self):
253 """Test that when only the reliability filter is turned on,
254 sources below the reliability threshold are filtered out."""
256 reliability_threshold = 0.7
257 config = FilterDiaSourceReliabilityConfig()
258 config.minReliability = reliability_threshold
260 schema = afwTable.SourceTable.makeMinimalSchema()
261 schema.addField("score", type="F",
262 doc="Reliability of the source")
263 reliabilityCat = afwTable.SourceCatalog(schema)
264 reliabilityCat.reserve(self.nSources)
266 # Set reliability: first half below threshold, second half above
267 for srcIdx in range(self.nSources):
268 reliabilityCat.addNew()
269 if srcIdx < self.nSources // 2:
270 reliabilityCat[srcIdx]["score"] = 0.25
271 else:
272 reliabilityCat[srcIdx]["score"] = 0.95
273 reliabilityCat['id'] = self.diaSourceCat['id']
274 nLowReliability = np.sum(reliabilityCat["score"] < 0.5)
275 filterDiaSourceCatalogTask = FilterDiaSourceReliabilityTask(config=config)
276 result = filterDiaSourceCatalogTask.run(self.diaSourceCat, reliabilityCat)
277 self.assertEqual(len(result.filteredDiaSources), self.nSources - nLowReliability)
278 self.assertEqual(len(result.rejectedDiaSources), nLowReliability)
279 self.assertTrue(np.all(result.filteredDiaSources["reliability"] >= reliability_threshold))
280 self.assertTrue(np.all(result.rejectedDiaSources["reliability"] < reliability_threshold))
282 def test_run_with_filter_trailed_sources_only(self):
283 """Test that when only the trail filter is turned on the correct number
284 of sources are filtered out. The filtered sources should be the last
285 three sources which have long trails, one source where both the suspect
286 trail and edge trail flag are set, and one source where off_image is
287 set. All sky objects should remain in the catalog.
288 """
289 self.config.doRemoveSkySources = False
290 self.config.badFlagList = []
291 self.config.doRemoveNegativeDirectImageSources = False
292 self.config.doWriteRejectedSkySources = False
293 self.config.doTrailedSourceFilter = True
294 filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config)
295 result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo)
296 nExpectedFilteredSources = self.nSources - self.nFilteredTrailedSources
297 self.assertEqual(len(result.filteredDiaSourceCat), nExpectedFilteredSources)
298 self.assertEqual(len(self.diaSourceCat), self.nSources)
300 def test_run_with_all_filters(self):
301 """Test that all sources are filtered out correctly. Only 15 sources
302 should remain in the catalog after filtering.
303 """
304 self.config.doRemoveSkySources = True
305 self.config.doRemoveNegativeDirectImageSources = True
306 self.config.doWriteRejectedSkySources = True
307 self.config.doTrailedSourceFilter = True
308 filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config)
309 result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo)
310 nExpectedFilteredSources = (self.nSources - self.nSkySources
311 - self.nCrCenterSources
312 - self.nFilteredTrailedSources
313 - self.nRemovedNegativeSources)
314 nExpectedRejectedSources = (self.nSkySources
315 + self.nCrCenterSources
316 + self.nRemovedNegativeSources)
317 # 32 total sources
318 # 5 filtered out sky sources
319 # 6 filtered out sources with cosmic ray detections
320 # 2 filtered out negative sources
321 # 4 filtered out trailed sources, 2 with long trails 2 with flags
322 # 15 sources left
323 self.assertEqual(len(result.filteredDiaSourceCat), nExpectedFilteredSources)
324 # 17 sources rejected, 4 trailed sources not included, 13 rejected sources in catalog.
325 self.assertEqual(len(result.rejectedDiaSources), nExpectedRejectedSources)
326 self.assertEqual(len(self.diaSourceCat), self.nSources)
328 def test_pixelScale_calculation(self):
329 """Check the calculation of the pixel scale from the input catalog.
330 """
331 self.config.doTrailedSourceFilter = True
332 filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config)
333 scale = filterDiaSourceCatalogTask._estimate_pixel_scale(self.diaSourceCat)
334 # Should be almost but not actually equal
335 self.assertNotEqual(self.config.estimatedPixelScale, scale)
336 self.assertAlmostEqual(self.config.estimatedPixelScale, scale, places=6)
338 # If the estimatedPixelScale is very different, that value should be
339 # used exactly and it should not raise an error.
340 self.config.estimatedPixelScale = 1.2
341 filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config)
342 scale = filterDiaSourceCatalogTask._estimate_pixel_scale(self.diaSourceCat)
343 self.assertEqual(self.config.estimatedPixelScale, scale)
346class MemoryTester(lsst.utils.tests.MemoryTestCase):
347 pass
350def setup_module(module):
351 lsst.utils.tests.init()
354if __name__ == "__main__": 354 ↛ 355line 354 didn't jump to line 355 because the condition on line 354 was never true
355 lsst.utils.tests.init()
356 unittest.main()