Coverage for tests / test_sourceSelector.py: 13%

334 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-24 08:25 +0000

1# 

2# LSST Data Management System 

3# 

4# Copyright 2008-2017 AURA/LSST. 

5# 

6# This product includes software developed by the 

7# LSST Project (http://www.lsst.org/). 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the LSST License Statement and 

20# the GNU General Public License along with this program. If not, 

21# see <https://www.lsstcorp.org/LegalNotices/>. 

22# 

23 

24import unittest 

25import numpy as np 

26import astropy.units as u 

27import warnings 

28 

29import lsst.afw.image 

30import lsst.afw.table 

31import lsst.geom 

32import lsst.meas.algorithms 

33import lsst.meas.base.tests 

34import lsst.pipe.base 

35import lsst.utils.tests 

36 

37from lsst.meas.algorithms import ColorLimit 

38 

39 

40class SourceSelectorTester: 

41 """Mixin for testing 

42 

43 This provides a base class for doing tests common to the 

44 ScienceSourceSelectorTask and ReferenceSourceSelectorTask. 

45 """ 

46 Task = None 

47 

48 def setUp(self): 

49 schema = lsst.afw.table.SourceTable.makeMinimalSchema() 

50 schema.addField("flux", float, "Flux value") 

51 schema.addField("flux_flag", "Flag", "Bad flux?") 

52 schema.addField("other_flux", float, "Flux value 2") 

53 schema.addField("other_flux_flag", "Flag", "Bad flux 2?") 

54 schema.addField("other_fluxErr", float, "Flux error 2") 

55 schema.addField("goodFlag", "Flag", "Flagged if good") 

56 schema.addField("badFlag", "Flag", "Flagged if bad") 

57 schema.addField("starGalaxy", float, "0=star, 1=galaxy") 

58 schema.addField("nChild", np.int32, "Number of children") 

59 schema.addField("detect_isPrimary", "Flag", "Is primary detection?") 

60 schema.addField("sky_source", "Flag", "Empty sky region.") 

61 

62 self.xCol = "centroid_x" 

63 self.yCol = "centroid_y" 

64 schema.addField(self.xCol, float, "Centroid x value.") 

65 schema.addField(self.yCol, float, "Centroid y value.") 

66 

67 self.catalog = lsst.afw.table.SourceCatalog(schema) 

68 self.catalog.reserve(10) 

69 self.config = self.Task.ConfigClass() 

70 self.exposure = None 

71 

72 def tearDown(self): 

73 del self.catalog 

74 

75 def check(self, expected): 

76 task = self.Task(config=self.config) 

77 results = task.run(self.catalog, exposure=self.exposure) 

78 self.assertListEqual(results.selected.tolist(), expected) 

79 self.assertListEqual([src.getId() for src in results.sourceCat], 

80 [src.getId() for src, ok in zip(self.catalog, expected) if ok]) 

81 

82 # Check with pandas.DataFrame version of catalog 

83 results = task.run(self.catalog.asAstropy().to_pandas(), exposure=self.exposure) 

84 self.assertListEqual(results.selected.tolist(), expected) 

85 self.assertListEqual(list(results.sourceCat['id']), 

86 [src.getId() for src, ok in zip(self.catalog, expected) if ok]) 

87 

88 # Check with astropy.table.Table version of catalog 

89 results = task.run(self.catalog.asAstropy(), exposure=self.exposure) 

90 self.assertListEqual(results.selected.tolist(), expected) 

91 self.assertListEqual(list(results.sourceCat['id']), 

92 [src.getId() for src, ok in zip(self.catalog, expected) if ok]) 

93 

94 def testFlags(self): 

95 bad1 = self.catalog.addNew() 

96 bad1.set("goodFlag", False) 

97 bad1.set("badFlag", False) 

98 bad2 = self.catalog.addNew() 

99 bad2.set("goodFlag", True) 

100 bad2.set("badFlag", True) 

101 bad3 = self.catalog.addNew() 

102 bad3.set("goodFlag", False) 

103 bad3.set("badFlag", True) 

104 good = self.catalog.addNew() 

105 good.set("goodFlag", True) 

106 good.set("badFlag", False) 

107 self.catalog["flux"] = 1.0 

108 self.catalog["other_flux"] = 1.0 

109 self.config.flags.good = ["goodFlag"] 

110 self.config.flags.bad = ["badFlag"] 

111 self.check([False, False, False, True]) 

112 

113 def testSignalToNoise(self): 

114 low = self.catalog.addNew() 

115 low.set("other_flux", 1.0) 

116 low.set("other_fluxErr", 1.0) 

117 good = self.catalog.addNew() 

118 good.set("other_flux", 1.0) 

119 good.set("other_fluxErr", 0.1) 

120 high = self.catalog.addNew() 

121 high.set("other_flux", 1.0) 

122 high.set("other_fluxErr", 0.001) 

123 self.config.doSignalToNoise = True 

124 self.config.signalToNoise.fluxField = "other_flux" 

125 self.config.signalToNoise.errField = "other_fluxErr" 

126 self.config.signalToNoise.minimum = 5.0 

127 self.config.signalToNoise.maximum = 100.0 

128 self.check([False, True, False]) 

129 

130 def testSignalToNoiseNoWarn(self): 

131 low = self.catalog.addNew() 

132 low.set("other_flux", np.nan) 

133 low.set("other_fluxErr", np.nan) 

134 self.config.doSignalToNoise = True 

135 self.config.signalToNoise.fluxField = "other_flux" 

136 self.config.signalToNoise.errField = "other_fluxErr" 

137 # Ensure no warnings are raised. 

138 with warnings.catch_warnings(): 

139 warnings.simplefilter("error") 

140 self.check([False]) 

141 

142 

143class ScienceSourceSelectorTaskTest(SourceSelectorTester, lsst.utils.tests.TestCase): 

144 Task = lsst.meas.algorithms.ScienceSourceSelectorTask 

145 

146 def setUp(self): 

147 SourceSelectorTester.setUp(self) 

148 self.config.fluxLimit.fluxField = "flux" 

149 self.config.flags.bad = [] 

150 self.config.doFluxLimit = True 

151 self.config.doFlags = True 

152 self.config.doUnresolved = False 

153 self.config.doIsolated = False 

154 

155 def testFluxLimit(self): 

156 tooBright = self.catalog.addNew() 

157 tooBright.set("flux", 1.0e10) 

158 tooBright.set("flux_flag", False) 

159 good = self.catalog.addNew() 

160 good.set("flux", 1000.0) 

161 good.set("flux_flag", False) 

162 bad = self.catalog.addNew() 

163 bad.set("flux", good.get("flux")) 

164 bad.set("flux_flag", True) 

165 tooFaint = self.catalog.addNew() 

166 tooFaint.set("flux", 1.0) 

167 tooFaint.set("flux_flag", False) 

168 self.config.fluxLimit.minimum = 10.0 

169 self.config.fluxLimit.maximum = 1.0e6 

170 self.config.fluxLimit.fluxField = "flux" 

171 self.check([False, True, False, False]) 

172 

173 # Works with no maximum set? 

174 self.config.fluxLimit.maximum = None 

175 self.check([True, True, False, False]) 

176 

177 # Works with no minimum set? 

178 self.config.fluxLimit.minimum = None 

179 self.check([True, True, False, True]) 

180 

181 def testUnresolved(self): 

182 num = 5 

183 for _ in range(num): 

184 self.catalog.addNew() 

185 self.catalog["flux"] = 1.0 

186 starGalaxy = np.linspace(0.0, 1.0, num, False) 

187 self.catalog["starGalaxy"] = starGalaxy 

188 self.config.doUnresolved = True 

189 self.config.unresolved.name = "starGalaxy" 

190 minimum, maximum = 0.3, 0.7 

191 self.config.unresolved.minimum = minimum 

192 self.config.unresolved.maximum = maximum 

193 self.check(((starGalaxy > minimum) & (starGalaxy < maximum)).tolist()) 

194 

195 # Works with no minimum set? 

196 self.config.unresolved.minimum = None 

197 self.check((starGalaxy < maximum).tolist()) 

198 

199 # Works with no maximum set? 

200 self.config.unresolved.minimum = minimum 

201 self.config.unresolved.maximum = None 

202 self.check((starGalaxy > minimum).tolist()) 

203 

204 def testIsolated(self): 

205 num = 5 

206 for _ in range(num): 

207 self.catalog.addNew() 

208 self.catalog["flux"] = 1.0 

209 parent = np.array([0, 0, 10, 0, 0], dtype=int) 

210 nChild = np.array([2, 0, 0, 0, 0], dtype=int) 

211 self.catalog["parent"] = parent 

212 self.catalog["nChild"] = nChild 

213 self.config.doIsolated = True 

214 self.config.isolated.parentName = "parent" 

215 self.config.isolated.nChildName = "nChild" 

216 self.check(((parent == 0) & (nChild == 0)).tolist()) 

217 

218 def testRequireFiniteRaDec(self): 

219 num = 5 

220 for _ in range(num): 

221 self.catalog.addNew() 

222 ra = np.array([np.nan, np.nan, 0, 0, 0], dtype=float) 

223 dec = np.array([2, np.nan, 0, 0, np.nan], dtype=float) 

224 self.catalog["coord_ra"] = ra 

225 self.catalog["coord_dec"] = dec 

226 self.config.doRequireFiniteRaDec = True 

227 self.config.requireFiniteRaDec.raColName = "coord_ra" 

228 self.config.requireFiniteRaDec.decColName = "coord_dec" 

229 self.check((np.isfinite(ra) & np.isfinite(dec)).tolist()) 

230 

231 def testRequirePrimary(self): 

232 num = 5 

233 for _ in range(num): 

234 self.catalog.addNew() 

235 primary = np.array([True, True, False, True, False], dtype=bool) 

236 self.catalog["detect_isPrimary"] = primary 

237 self.config.doRequirePrimary = True 

238 self.config.requirePrimary.primaryColName = "detect_isPrimary" 

239 self.check(primary.tolist()) 

240 

241 def testSkySource(self): 

242 num = 5 

243 for _ in range(num): 

244 self.catalog.addNew() 

245 sky = np.array([True, True, False, True, False], dtype=bool) 

246 self.catalog["sky_source"] = sky 

247 # This is a union, not an intersection, so include another selection 

248 # that would otherwise reject everything. 

249 self.config.doRequirePrimary = True 

250 self.config.doSkySources = True 

251 self.check(sky.tolist()) 

252 

253 

254class ReferenceSourceSelectorTaskTest(SourceSelectorTester, lsst.utils.tests.TestCase): 

255 Task = lsst.meas.algorithms.ReferenceSourceSelectorTask 

256 

257 def setUp(self): 

258 SourceSelectorTester.setUp(self) 

259 self.config.magLimit.fluxField = "flux" 

260 self.config.doMagLimit = True 

261 self.config.doFlags = True 

262 self.config.doUnresolved = False 

263 self.config.doRequireFiniteRaDec = False 

264 

265 def testMagnitudeLimit(self): 

266 tooBright = self.catalog.addNew() 

267 tooBright.set("flux", 1.0e10) 

268 tooBright.set("flux_flag", False) 

269 good = self.catalog.addNew() 

270 good.set("flux", 1000.0) 

271 good.set("flux_flag", False) 

272 bad = self.catalog.addNew() 

273 bad.set("flux", good.get("flux")) 

274 bad.set("flux_flag", True) 

275 tooFaint = self.catalog.addNew() 

276 tooFaint.set("flux", 1.0) 

277 tooFaint.set("flux_flag", False) 

278 # Note: magnitudes are backwards, so the minimum flux is the maximum magnitude 

279 self.config.magLimit.minimum = (1.0e6*u.nJy).to_value(u.ABmag) 

280 self.config.magLimit.maximum = (10.0*u.nJy).to_value(u.ABmag) 

281 self.config.magLimit.fluxField = "flux" 

282 self.check([False, True, False, False]) 

283 

284 # Works with no minimum set? 

285 self.config.magLimit.minimum = None 

286 self.check([True, True, False, False]) 

287 

288 # Works with no maximum set? 

289 self.config.magLimit.maximum = None 

290 self.check([True, True, False, True]) 

291 

292 def testMagErrorLimit(self): 

293 # Using an arbitrary field as if it was a magnitude error to save adding a new field 

294 field = "other_fluxErr" 

295 tooFaint = self.catalog.addNew() 

296 tooFaint.set(field, 0.5) 

297 tooBright = self.catalog.addNew() 

298 tooBright.set(field, 0.00001) 

299 good = self.catalog.addNew() 

300 good.set(field, 0.2) 

301 

302 self.config.doMagError = True 

303 self.config.magError.minimum = 0.01 

304 self.config.magError.maximum = 0.3 

305 self.config.magError.magErrField = field 

306 self.check([False, False, True]) 

307 

308 def testColorLimits(self): 

309 num = 10 

310 for _ in range(num): 

311 self.catalog.addNew() 

312 color = np.linspace(-0.5, 0.5, num, True) 

313 flux = 1000.0*u.nJy 

314 # Definition: color = mag(flux) - mag(otherFlux) 

315 otherFlux = (flux.to(u.ABmag) - color*u.mag).to_value(u.nJy) 

316 self.catalog["flux"] = flux.value 

317 self.catalog["other_flux"] = otherFlux 

318 minimum, maximum = -0.1, 0.2 

319 self.config.colorLimits = {"test": ColorLimit(primary="flux", secondary="other_flux", 

320 minimum=minimum, maximum=maximum)} 

321 self.check(((color > minimum) & (color < maximum)).tolist()) 

322 

323 # Works with no minimum set? 

324 self.config.colorLimits["test"].minimum = None 

325 self.check((color < maximum).tolist()) 

326 

327 # Works with no maximum set? 

328 self.config.colorLimits["test"].maximum = None 

329 self.config.colorLimits["test"].minimum = minimum 

330 self.check((color > minimum).tolist()) 

331 

332 # Multiple limits 

333 self.config.colorLimits = {"test": ColorLimit(primary="flux", secondary="other_flux", 

334 minimum=minimum), 

335 "other": ColorLimit(primary="flux", secondary="other_flux", 

336 maximum=maximum)} 

337 assert maximum > minimum # To be non-mutually-exclusive 

338 self.check(((color > minimum) & (color < maximum)).tolist()) 

339 

340 # Multiple mutually-exclusive limits 

341 self.config.colorLimits["test"] = ColorLimit(primary="flux", secondary="other_flux", maximum=-0.1) 

342 self.config.colorLimits["other"] = ColorLimit(primary="flux", secondary="other_flux", minimum=0.1) 

343 self.check([False]*num) 

344 

345 def testUnresolved(self): 

346 num = 5 

347 for _ in range(num): 

348 self.catalog.addNew() 

349 self.catalog["flux"] = 1.0 

350 starGalaxy = np.linspace(0.0, 1.0, num, False) 

351 self.catalog["starGalaxy"] = starGalaxy 

352 self.config.doUnresolved = True 

353 self.config.unresolved.name = "starGalaxy" 

354 minimum, maximum = 0.3, 0.7 

355 self.config.unresolved.minimum = minimum 

356 self.config.unresolved.maximum = maximum 

357 self.check(((starGalaxy > minimum) & (starGalaxy < maximum)).tolist()) 

358 

359 # Works with no minimum set? 

360 self.config.unresolved.minimum = None 

361 self.check((starGalaxy < maximum).tolist()) 

362 

363 # Works with no maximum set? 

364 self.config.unresolved.minimum = minimum 

365 self.config.unresolved.maximum = None 

366 self.check((starGalaxy > minimum).tolist()) 

367 

368 def testFiniteRaDec(self): 

369 "Test that non-finite RA and Dec values are caught." 

370 num = 5 

371 for _ in range(num): 

372 self.catalog.addNew() 

373 self.catalog["coord_ra"][:] = 1.0 

374 self.catalog["coord_dec"][:] = 1.0 

375 self.catalog["coord_ra"][0] = np.nan 

376 self.catalog["coord_dec"][1] = np.inf 

377 self.config.doRequireFiniteRaDec = True 

378 

379 self.check([False, False, True, True, True]) 

380 

381 def testCullFromMaskedRegion(self): 

382 # Test that objects whose centroids land on specified mask(s) are 

383 # culled. 

384 maskNames = ["NO_DATA", "BLAH"] 

385 num = 5 

386 for _ in range(num): 

387 self.catalog.addNew() 

388 

389 for x0, y0 in [[0, 0], [3, 8]]: 

390 self.exposure = lsst.afw.image.ExposureF(5, 5) 

391 self.exposure.setXY0(lsst.geom.Point2I(x0, y0)) 

392 mask = self.exposure.mask 

393 for maskName in maskNames: 

394 if maskName not in mask.getMaskPlaneDict(): 

395 mask.addMaskPlane(maskName) 

396 self.catalog[self.xCol][:] = x0 + 5.0 

397 self.catalog[self.yCol][:] = y0 + 5.0 

398 noDataPoints = [[0 + x0, 0 + y0], [3 + x0, 2 + y0]] 

399 # Set first two entries in catalog to land in maskNames region. 

400 for i, noDataPoint in enumerate(noDataPoints): 

401 # Flip x & y for numpy array convention. 

402 mask.array[noDataPoint[1] - y0][noDataPoint[0] - x0] = mask.getPlaneBitMask( 

403 maskNames[min(i, len(maskNames) - 1)] 

404 ) 

405 self.catalog[self.xCol][i] = noDataPoint[0] 

406 self.catalog[self.yCol][i] = noDataPoint[1] 

407 self.config.doCullFromMaskedRegion = True 

408 self.config.cullFromMaskedRegion.xColName = self.xCol 

409 self.config.cullFromMaskedRegion.yColName = self.yCol 

410 self.config.cullFromMaskedRegion.badMaskNames = maskNames 

411 self.check([False, False, True, True, True]) 

412 

413 # Reset config back to False and None for other tests. 

414 self.config.doCullFromMaskedRegion = False 

415 self.exposure = None 

416 

417 

418class TestBaseSourceSelector(lsst.utils.tests.TestCase): 

419 """Test the API of the Abstract Base Class with a trivial example.""" 

420 def setUp(self): 

421 schema = lsst.meas.base.tests.TestDataset.makeMinimalSchema() 

422 self.selectedKeyName = "is_selected" 

423 schema.addField(self.selectedKeyName, type="Flag") 

424 self.catalog = lsst.afw.table.SourceCatalog(schema) 

425 for i in range(4): 

426 self.catalog.addNew() 

427 

428 self.sourceSelector = lsst.meas.algorithms.NullSourceSelectorTask() 

429 

430 def testRun(self): 

431 """Test that run() returns a catalog and boolean selected array.""" 

432 result = self.sourceSelector.run(self.catalog) 

433 for i, x in enumerate(self.catalog['id']): 

434 self.assertIn(x, result.sourceCat['id']) 

435 self.assertTrue(result.selected[i]) 

436 

437 def testRunSourceSelectedField(self): 

438 """Test that the selected flag is set in the original catalog.""" 

439 self.sourceSelector.run(self.catalog, sourceSelectedField=self.selectedKeyName) 

440 np.testing.assert_array_equal(self.catalog[self.selectedKeyName], True) 

441 

442 def testRunNonContiguousRaises(self): 

443 """Cannot do source selection on non-contiguous catalogs.""" 

444 del self.catalog[1] # take one out of the middle to make it non-contiguous. 

445 self.assertFalse(self.catalog.isContiguous(), "Catalog is contiguous: the test won't work.") 

446 

447 with self.assertRaises(RuntimeError): 

448 self.sourceSelector.run(self.catalog) 

449 

450 

451class TestMemory(lsst.utils.tests.MemoryTestCase): 

452 pass 

453 

454 

455def setup_module(module): 

456 lsst.utils.tests.init() 

457 

458 

459if __name__ == "__main__": 459 ↛ 460line 459 didn't jump to line 460 because the condition on line 459 was never true

460 import sys 

461 setup_module(sys.modules[__name__]) 

462 unittest.main()