Coverage for tests / test_filterDiaSourceCatalog.py: 12%

201 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 08:31 +0000

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import unittest 

23import numpy as np 

24 

25from lsst.ap.association.filterDiaSourceCatalog import (FilterDiaSourceCatalogConfig, 

26 FilterDiaSourceCatalogTask, 

27 FilterDiaSourceReliabilityConfig, 

28 FilterDiaSourceReliabilityTask) 

29import lsst.geom as geom 

30import lsst.meas.base.tests as measTests 

31import lsst.utils.tests 

32import lsst.afw.image as afwImage 

33import lsst.afw.table as afwTable 

34import lsst.daf.base as dafBase 

35 

36 

37class TestFilterDiaSourceCatalogTask(unittest.TestCase): 

38 

39 def setUp(self): 

40 self.config = FilterDiaSourceCatalogConfig() 

41 

42 self.nSkySources = 5 

43 self.nCrCenterSources = 6 

44 self.nFakeFlagSources = 4 

45 self.nNegativeSources = 7 

46 self.nTrailedSources = 10 

47 self.pixelScale = 0.2 # arcseconds/pixel 

48 self.nSources = (self.nSkySources + self.nCrCenterSources + self.nFakeFlagSources 

49 + self.nNegativeSources + self.nTrailedSources) 

50 self.yLoc = 100 

51 self.expId = 4321 

52 self.bbox = geom.Box2I(geom.Point2I(0, 0), 

53 geom.Extent2I(1024, 1153)) 

54 dataset = measTests.TestDataset(self.bbox) 

55 for srcIdx in range(self.nSources): 

56 dataset.addSource(10000.0, geom.Point2D(srcIdx, self.yLoc)) 

57 schema = dataset.makeMinimalSchema() 

58 schema.addField("sky_source", type="Flag", doc="Sky objects.") 

59 schema.addField("slot_Centroid_flag", type="Flag", doc="The centroid calculation " 

60 "failed. The source position is incorrect.") 

61 schema.addField("base_PixelFlags_flag_crCenter", type="Flag", doc="A cosmic ray was detected " 

62 "and interpolated in this object's center.") 

63 schema.addField("base_PixelFlags_flag_high_varianceCenterAll", type="Flag", doc="The object was " 

64 "detected in a region with exceptionally high template variance.") 

65 schema.addField("fakeBadFlag", type="Flag", doc="A fake flag to test a badFlagList longer " 

66 "than one item.") 

67 schema.addField("ip_diffim_forced_PsfFlux_instFlux", type="F", 

68 doc="Forced photometry flux for a point source model measured on the visit image " 

69 "centered at DiaSource position.") 

70 schema.addField("ip_diffim_forced_PsfFlux_instFluxErr", type="F", 

71 doc="Estimated uncertainty of ip_diffim_forced_PsfFlux_instFlux.") 

72 schema.addField('ext_trailedSources_Naive_flag_off_image', type="Flag", 

73 doc="Trail extends off image") 

74 schema.addField('ext_trailedSources_Naive_flag_suspect_long_trail', 

75 type="Flag", doc="Trail length is greater than three times the psf radius") 

76 schema.addField('ext_trailedSources_Naive_flag_edge', type="Flag", 

77 doc="Trail contains edge pixels") 

78 schema.addField('ext_trailedSources_Naive_flag_nan', type="Flag", 

79 doc="One or more trail coordinates are missing") 

80 schema.addField('ext_trailedSources_Naive_length', type="F", 

81 doc="Length of the source trail") 

82 schema.addField("reliability", type="F", 

83 doc="Reliability of the source") 

84 _, self.diaSourceCat = dataset.realize(10.0, schema, randomSeed=1234) 

85 

86 # set the sky_source flag for the first set 

87 self.diaSourceCat[0:self.nSkySources]["sky_source"] = True 

88 

89 # set the pixelFlags_crCenter flag 

90 crCenter_offset = self.nSkySources 

91 self.diaSourceCat[crCenter_offset:crCenter_offset+self.nCrCenterSources][ 

92 "base_PixelFlags_flag_crCenter" 

93 ] = True 

94 

95 # set the fakeBadFlag flag 

96 fakeFlag_offset = crCenter_offset + self.nCrCenterSources 

97 self.diaSourceCat[fakeFlag_offset:fakeFlag_offset+self.nFakeFlagSources][ 

98 "fakeBadFlag" 

99 ] = True 

100 

101 # create increasingly negative ip_diffim_forced_PsfFlux_instFlux/ip_diffim_forced_PsfFlux_instFluxErr 

102 self.nRemovedNegativeSources = 0 

103 negativeSources_offset = fakeFlag_offset + self.nFakeFlagSources 

104 for i, srcIdx in enumerate(range(negativeSources_offset, 

105 negativeSources_offset+self.nNegativeSources)): 

106 self.diaSourceCat[srcIdx]["ip_diffim_forced_PsfFlux_instFlux"] = -0.5 * i 

107 self.diaSourceCat[srcIdx]["ip_diffim_forced_PsfFlux_instFluxErr"] = 1.01 

108 if (-0.5 * i)/1.01 < self.config.minAllowedDirectSnr: 

109 self.nRemovedNegativeSources += 1 

110 

111 # The last 10 sources will all contained trail length measurements, 

112 # increasing in size by 1.5 arcseconds. Only the last three will have 

113 # lengths which are too long and will be filtered out. 

114 self.nFilteredTrailedSources = 0 

115 trail_offset = negativeSources_offset + self.nNegativeSources 

116 for i, srcIdx in enumerate(range(trail_offset, trail_offset+self.nTrailedSources)): 

117 self.diaSourceCat[srcIdx]["ext_trailedSources_Naive_length"] = 1.5*(i+1)/self.pixelScale 

118 if 1.5*(i+1) > 36000/3600.0/24.0 * 30.0: 

119 self.nFilteredTrailedSources += 1 

120 # Setting a combination of flags for filtering in tests 

121 self.diaSourceCat[trail_offset+1]["ext_trailedSources_Naive_flag_off_image"] = True 

122 self.diaSourceCat[trail_offset+2]["ext_trailedSources_Naive_flag_suspect_long_trail"] = True 

123 self.diaSourceCat[trail_offset+2]["ext_trailedSources_Naive_flag_edge"] = True 

124 # As only two of these flags are set, the total number of filtered 

125 # sources will be self.nFilteredTrailedSources + 2 

126 self.nFilteredTrailedSources += 2 

127 

128 mjd = 57071.0 

129 self.utc_jd = mjd + 2_400_000.5 - 35.0 / (24.0 * 60.0 * 60.0) 

130 

131 self.visitInfo = afwImage.VisitInfo( 

132 # This incomplete visitInfo is sufficient for testing because the 

133 # Python constructor sets all other required values to some 

134 # default. 

135 exposureTime=30.0, 

136 darkTime=3.0, 

137 date=dafBase.DateTime(mjd, system=dafBase.DateTime.MJD), 

138 boresightRaDec=geom.SpherePoint(0.0, 0.0, geom.degrees), 

139 ) 

140 

141 def test_run_without_filter(self): 

142 """Test that when all filters are turned off all sources in the catalog 

143 are returned. 

144 """ 

145 self.config.doRemoveSkySources = False 

146 self.config.badFlagList = [] 

147 self.config.doRemoveNegativeDirectImageSources = False 

148 self.config.doWriteRejectedSkySources = False 

149 self.config.doTrailedSourceFilter = False 

150 filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config) 

151 result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo) 

152 self.assertEqual(len(result.filteredDiaSourceCat), len(self.diaSourceCat)) 

153 self.assertEqual(len(result.rejectedDiaSources), 0) 

154 self.assertEqual(len(self.diaSourceCat), self.nSources) 

155 

156 def test_run_with_filter_sky_only(self): 

157 """Test that when only the sky filter is turned on the first five 

158 sources which are flagged as sky objects are filtered out of the 

159 catalog and the rest are returned. 

160 """ 

161 self.config.doRemoveSkySources = True 

162 self.config.badFlagList = [] 

163 self.config.doRemoveNegativeDirectImageSources = False 

164 self.config.doWriteRejectedSkySources = True 

165 self.config.doTrailedSourceFilter = False 

166 filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config) 

167 result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo) 

168 nExpectedFilteredSources = self.nSources - self.nSkySources 

169 self.assertEqual(len(result.filteredDiaSourceCat), 

170 len(self.diaSourceCat[~self.diaSourceCat['sky_source']])) 

171 self.assertEqual(len(result.filteredDiaSourceCat), nExpectedFilteredSources) 

172 self.assertEqual(len(result.rejectedDiaSources), self.nSkySources) 

173 self.assertEqual(len(self.diaSourceCat), self.nSources) 

174 

175 def test_run_with_filter_defaultBadFlagList_only(self): 

176 """Test that when only the CR center filter is turned on the six sources which are flagged 

177 as base_PixelFlags_flag_crCenter are filtered out of the catalog and the rest are returned. 

178 """ 

179 self.config.doRemoveSkySources = False 

180 self.config.doRemoveNegativeDirectImageSources = False 

181 self.config.doWriteRejectedSkySources = False 

182 self.config.doTrailedSourceFilter = False 

183 filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config) 

184 result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo) 

185 nExpectedFilteredSources = self.nSources - self.nCrCenterSources 

186 self.assertEqual(len(result.filteredDiaSourceCat), 

187 len(self.diaSourceCat[~self.diaSourceCat['base_PixelFlags_flag_crCenter']])) 

188 self.assertEqual(len(result.filteredDiaSourceCat), nExpectedFilteredSources) 

189 self.assertEqual(len(result.rejectedDiaSources), self.nCrCenterSources) 

190 self.assertEqual(len(self.diaSourceCat), self.nSources) 

191 

192 def test_run_with_filter_nonDefaultBadFlagList_only(self): 

193 """Test that the badFlagList filters appropriately when it is not when the default configuration. 

194 The six sources flagged base_PixelFlags_flag_crCenter and the four sources flagged fakeBadFlag 

195 should be filtered out of the catalog and the rest are returned. 

196 """ 

197 self.config.doRemoveSkySources = False 

198 self.config.badFlagList = ["base_PixelFlags_flag_crCenter", "fakeBadFlag"] 

199 self.config.doRemoveNegativeDirectImageSources = False 

200 self.config.doWriteRejectedSkySources = False 

201 self.config.doTrailedSourceFilter = False 

202 filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config) 

203 result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo) 

204 nExpectedFilteredSources = self.nSources - self.nCrCenterSources - self.nFakeFlagSources 

205 self.assertEqual(len(result.filteredDiaSourceCat), 

206 len(self.diaSourceCat[~self.diaSourceCat['base_PixelFlags_flag_crCenter'] 

207 & ~self.diaSourceCat['fakeBadFlag']])) 

208 self.assertEqual(len(result.filteredDiaSourceCat), nExpectedFilteredSources) 

209 self.assertEqual(len(result.rejectedDiaSources), self.nCrCenterSources + self.nFakeFlagSources) 

210 self.assertEqual(len(self.diaSourceCat), self.nSources) 

211 

212 def test_run_with_filter_negative_only(self): 

213 """Test that when only the negative filter is turned on then 

214 sources which below the negtive snr cut are filtered out of the 

215 catalog and the rest are returned. 

216 """ 

217 self.config.doRemoveSkySources = False 

218 self.config.badFlagList = [] 

219 self.config.doRemoveNegativeDirectImageSources = True 

220 self.config.doWriteRejectedSkySources = True 

221 self.config.doTrailedSourceFilter = False 

222 filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config) 

223 result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo) 

224 nExpectedFilteredSources = self.nSources - self.nRemovedNegativeSources 

225 self.assertEqual(len(result.filteredDiaSourceCat), nExpectedFilteredSources) 

226 self.assertEqual(len(result.rejectedDiaSources), self.nRemovedNegativeSources) 

227 self.assertEqual(len(self.diaSourceCat), self.nSources) 

228 self.assertEqual(np.sum(result.filteredDiaSourceCat['ip_diffim_forced_PsfFlux_instFlux'] 

229 / result.filteredDiaSourceCat['ip_diffim_forced_PsfFlux_instFluxErr'] 

230 < self.config.minAllowedDirectSnr), 0) 

231 self.assertEqual(np.sum(result.rejectedDiaSources['ip_diffim_forced_PsfFlux_instFlux'] 

232 / result.rejectedDiaSources['ip_diffim_forced_PsfFlux_instFluxErr'] 

233 < self.config.minAllowedDirectSnr), 

234 self.nRemovedNegativeSources) 

235 

236 def test_run_with_filter_negative_and_sky(self): 

237 """Test concatenating rejects when both sky and negative filtering 

238 are on. 

239 """ 

240 self.config.doRemoveSkySources = True 

241 self.config.badFlagList = [] 

242 self.config.doRemoveNegativeDirectImageSources = True 

243 self.config.doWriteRejectedSkySources = True 

244 self.config.doTrailedSourceFilter = False 

245 filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config) 

246 result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo) 

247 nExpectedFilteredSources = self.nSources - self.nSkySources - self.nRemovedNegativeSources 

248 self.assertEqual(len(result.filteredDiaSourceCat), nExpectedFilteredSources) 

249 self.assertEqual(len(result.rejectedDiaSources), self.nSkySources + self.nRemovedNegativeSources) 

250 self.assertEqual(len(self.diaSourceCat), self.nSources) 

251 

252 def test_run_with_filter_reliability_only(self): 

253 """Test that when only the reliability filter is turned on, 

254 sources below the reliability threshold are filtered out.""" 

255 

256 reliability_threshold = 0.7 

257 config = FilterDiaSourceReliabilityConfig() 

258 config.minReliability = reliability_threshold 

259 

260 schema = afwTable.SourceTable.makeMinimalSchema() 

261 schema.addField("score", type="F", 

262 doc="Reliability of the source") 

263 reliabilityCat = afwTable.SourceCatalog(schema) 

264 reliabilityCat.reserve(self.nSources) 

265 

266 # Set reliability: first half below threshold, second half above 

267 for srcIdx in range(self.nSources): 

268 reliabilityCat.addNew() 

269 if srcIdx < self.nSources // 2: 

270 reliabilityCat[srcIdx]["score"] = 0.25 

271 else: 

272 reliabilityCat[srcIdx]["score"] = 0.95 

273 reliabilityCat['id'] = self.diaSourceCat['id'] 

274 nLowReliability = np.sum(reliabilityCat["score"] < 0.5) 

275 filterDiaSourceCatalogTask = FilterDiaSourceReliabilityTask(config=config) 

276 result = filterDiaSourceCatalogTask.run(self.diaSourceCat, reliabilityCat) 

277 self.assertEqual(len(result.filteredDiaSources), self.nSources - nLowReliability) 

278 self.assertEqual(len(result.rejectedDiaSources), nLowReliability) 

279 self.assertTrue(np.all(result.filteredDiaSources["reliability"] >= reliability_threshold)) 

280 self.assertTrue(np.all(result.rejectedDiaSources["reliability"] < reliability_threshold)) 

281 

282 def test_run_with_filter_trailed_sources_only(self): 

283 """Test that when only the trail filter is turned on the correct number 

284 of sources are filtered out. The filtered sources should be the last 

285 three sources which have long trails, one source where both the suspect 

286 trail and edge trail flag are set, and one source where off_image is 

287 set. All sky objects should remain in the catalog. 

288 """ 

289 self.config.doRemoveSkySources = False 

290 self.config.badFlagList = [] 

291 self.config.doRemoveNegativeDirectImageSources = False 

292 self.config.doWriteRejectedSkySources = False 

293 self.config.doTrailedSourceFilter = True 

294 filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config) 

295 result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo) 

296 nExpectedFilteredSources = self.nSources - self.nFilteredTrailedSources 

297 self.assertEqual(len(result.filteredDiaSourceCat), nExpectedFilteredSources) 

298 self.assertEqual(len(self.diaSourceCat), self.nSources) 

299 

300 def test_run_with_all_filters(self): 

301 """Test that all sources are filtered out correctly. Only 15 sources 

302 should remain in the catalog after filtering. 

303 """ 

304 self.config.doRemoveSkySources = True 

305 self.config.doRemoveNegativeDirectImageSources = True 

306 self.config.doWriteRejectedSkySources = True 

307 self.config.doTrailedSourceFilter = True 

308 filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config) 

309 result = filterDiaSourceCatalogTask.run(self.diaSourceCat, self.visitInfo) 

310 nExpectedFilteredSources = (self.nSources - self.nSkySources 

311 - self.nCrCenterSources 

312 - self.nFilteredTrailedSources 

313 - self.nRemovedNegativeSources) 

314 nExpectedRejectedSources = (self.nSkySources 

315 + self.nCrCenterSources 

316 + self.nRemovedNegativeSources) 

317 # 32 total sources 

318 # 5 filtered out sky sources 

319 # 6 filtered out sources with cosmic ray detections 

320 # 2 filtered out negative sources 

321 # 4 filtered out trailed sources, 2 with long trails 2 with flags 

322 # 15 sources left 

323 self.assertEqual(len(result.filteredDiaSourceCat), nExpectedFilteredSources) 

324 # 17 sources rejected, 4 trailed sources not included, 13 rejected sources in catalog. 

325 self.assertEqual(len(result.rejectedDiaSources), nExpectedRejectedSources) 

326 self.assertEqual(len(self.diaSourceCat), self.nSources) 

327 

328 def test_pixelScale_calculation(self): 

329 """Check the calculation of the pixel scale from the input catalog. 

330 """ 

331 self.config.doTrailedSourceFilter = True 

332 filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config) 

333 scale = filterDiaSourceCatalogTask._estimate_pixel_scale(self.diaSourceCat) 

334 # Should be almost but not actually equal 

335 self.assertNotEqual(self.config.estimatedPixelScale, scale) 

336 self.assertAlmostEqual(self.config.estimatedPixelScale, scale, places=6) 

337 

338 # If the estimatedPixelScale is very different, that value should be 

339 # used exactly and it should not raise an error. 

340 self.config.estimatedPixelScale = 1.2 

341 filterDiaSourceCatalogTask = FilterDiaSourceCatalogTask(config=self.config) 

342 scale = filterDiaSourceCatalogTask._estimate_pixel_scale(self.diaSourceCat) 

343 self.assertEqual(self.config.estimatedPixelScale, scale) 

344 

345 

346class MemoryTester(lsst.utils.tests.MemoryTestCase): 

347 pass 

348 

349 

350def setup_module(module): 

351 lsst.utils.tests.init() 

352 

353 

354if __name__ == "__main__": 354 ↛ 355line 354 didn't jump to line 355 because the condition on line 354 was never true

355 lsst.utils.tests.init() 

356 unittest.main()