Coverage for tests/test_sourceSelector.py: 14%

272 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-14 02:24 -0700

1# 

2# LSST Data Management System 

3# 

4# Copyright 2008-2017 AURA/LSST. 

5# 

6# This product includes software developed by the 

7# LSST Project (http://www.lsst.org/). 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the LSST License Statement and 

20# the GNU General Public License along with this program. If not, 

21# see <https://www.lsstcorp.org/LegalNotices/>. 

22# 

23 

24import unittest 

25import numpy as np 

26import astropy.units as u 

27 

28import lsst.afw.table 

29import lsst.meas.algorithms 

30import lsst.meas.base.tests 

31import lsst.pipe.base 

32import lsst.utils.tests 

33 

34from lsst.meas.algorithms import ColorLimit 

35 

36 

37class SourceSelectorTester: 

38 """Mixin for testing 

39 

40 This provides a base class for doing tests common to the 

41 ScienceSourceSelectorTask and ReferenceSourceSelectorTask. 

42 """ 

43 Task = None 

44 

45 def setUp(self): 

46 schema = lsst.afw.table.SourceTable.makeMinimalSchema() 

47 schema.addField("flux", float, "Flux value") 

48 schema.addField("flux_flag", "Flag", "Bad flux?") 

49 schema.addField("other_flux", float, "Flux value 2") 

50 schema.addField("other_flux_flag", "Flag", "Bad flux 2?") 

51 schema.addField("other_fluxErr", float, "Flux error 2") 

52 schema.addField("goodFlag", "Flag", "Flagged if good") 

53 schema.addField("badFlag", "Flag", "Flagged if bad") 

54 schema.addField("starGalaxy", float, "0=star, 1=galaxy") 

55 schema.addField("nChild", np.int32, "Number of children") 

56 schema.addField("detect_isPrimary", "Flag", "Is primary detection?") 

57 self.catalog = lsst.afw.table.SourceCatalog(schema) 

58 self.catalog.reserve(10) 

59 self.config = self.Task.ConfigClass() 

60 

61 def tearDown(self): 

62 del self.catalog 

63 

64 def check(self, expected): 

65 task = self.Task(config=self.config) 

66 results = task.run(self.catalog) 

67 self.assertListEqual(results.selected.tolist(), expected) 

68 self.assertListEqual([src.getId() for src in results.sourceCat], 

69 [src.getId() for src, ok in zip(self.catalog, expected) if ok]) 

70 

71 # Check with pandas.DataFrame version of catalog 

72 results = task.run(self.catalog.asAstropy().to_pandas()) 

73 self.assertListEqual(results.selected.tolist(), expected) 

74 self.assertListEqual(list(results.sourceCat['id']), 

75 [src.getId() for src, ok in zip(self.catalog, expected) if ok]) 

76 

77 # Check with astropy.table.Table version of catalog 

78 results = task.run(self.catalog.asAstropy()) 

79 self.assertListEqual(results.selected.tolist(), expected) 

80 self.assertListEqual(list(results.sourceCat['id']), 

81 [src.getId() for src, ok in zip(self.catalog, expected) if ok]) 

82 

83 def testFlags(self): 

84 bad1 = self.catalog.addNew() 

85 bad1.set("goodFlag", False) 

86 bad1.set("badFlag", False) 

87 bad2 = self.catalog.addNew() 

88 bad2.set("goodFlag", True) 

89 bad2.set("badFlag", True) 

90 bad3 = self.catalog.addNew() 

91 bad3.set("goodFlag", False) 

92 bad3.set("badFlag", True) 

93 good = self.catalog.addNew() 

94 good.set("goodFlag", True) 

95 good.set("badFlag", False) 

96 self.catalog["flux"] = 1.0 

97 self.catalog["other_flux"] = 1.0 

98 self.config.flags.good = ["goodFlag"] 

99 self.config.flags.bad = ["badFlag"] 

100 self.check([False, False, False, True]) 

101 

102 def testSignalToNoise(self): 

103 low = self.catalog.addNew() 

104 low.set("other_flux", 1.0) 

105 low.set("other_fluxErr", 1.0) 

106 good = self.catalog.addNew() 

107 good.set("other_flux", 1.0) 

108 good.set("other_fluxErr", 0.1) 

109 high = self.catalog.addNew() 

110 high.set("other_flux", 1.0) 

111 high.set("other_fluxErr", 0.001) 

112 self.config.doSignalToNoise = True 

113 self.config.signalToNoise.fluxField = "other_flux" 

114 self.config.signalToNoise.errField = "other_fluxErr" 

115 self.config.signalToNoise.minimum = 5.0 

116 self.config.signalToNoise.maximum = 100.0 

117 self.check([False, True, False]) 

118 

119 

120class ScienceSourceSelectorTaskTest(SourceSelectorTester, lsst.utils.tests.TestCase): 

121 Task = lsst.meas.algorithms.ScienceSourceSelectorTask 

122 

123 def setUp(self): 

124 SourceSelectorTester.setUp(self) 

125 self.config.fluxLimit.fluxField = "flux" 

126 self.config.flags.bad = [] 

127 self.config.doFluxLimit = True 

128 self.config.doFlags = True 

129 self.config.doUnresolved = False 

130 self.config.doIsolated = False 

131 

132 def testFluxLimit(self): 

133 tooBright = self.catalog.addNew() 

134 tooBright.set("flux", 1.0e10) 

135 tooBright.set("flux_flag", False) 

136 good = self.catalog.addNew() 

137 good.set("flux", 1000.0) 

138 good.set("flux_flag", False) 

139 bad = self.catalog.addNew() 

140 bad.set("flux", good.get("flux")) 

141 bad.set("flux_flag", True) 

142 tooFaint = self.catalog.addNew() 

143 tooFaint.set("flux", 1.0) 

144 tooFaint.set("flux_flag", False) 

145 self.config.fluxLimit.minimum = 10.0 

146 self.config.fluxLimit.maximum = 1.0e6 

147 self.config.fluxLimit.fluxField = "flux" 

148 self.check([False, True, False, False]) 

149 

150 # Works with no maximum set? 

151 self.config.fluxLimit.maximum = None 

152 self.check([True, True, False, False]) 

153 

154 # Works with no minimum set? 

155 self.config.fluxLimit.minimum = None 

156 self.check([True, True, False, True]) 

157 

158 def testUnresolved(self): 

159 num = 5 

160 for _ in range(num): 

161 self.catalog.addNew() 

162 self.catalog["flux"] = 1.0 

163 starGalaxy = np.linspace(0.0, 1.0, num, False) 

164 self.catalog["starGalaxy"] = starGalaxy 

165 self.config.doUnresolved = True 

166 self.config.unresolved.name = "starGalaxy" 

167 minimum, maximum = 0.3, 0.7 

168 self.config.unresolved.minimum = minimum 

169 self.config.unresolved.maximum = maximum 

170 self.check(((starGalaxy > minimum) & (starGalaxy < maximum)).tolist()) 

171 

172 # Works with no minimum set? 

173 self.config.unresolved.minimum = None 

174 self.check((starGalaxy < maximum).tolist()) 

175 

176 # Works with no maximum set? 

177 self.config.unresolved.minimum = minimum 

178 self.config.unresolved.maximum = None 

179 self.check((starGalaxy > minimum).tolist()) 

180 

181 def testIsolated(self): 

182 num = 5 

183 for _ in range(num): 

184 self.catalog.addNew() 

185 self.catalog["flux"] = 1.0 

186 parent = np.array([0, 0, 10, 0, 0], dtype=int) 

187 nChild = np.array([2, 0, 0, 0, 0], dtype=int) 

188 self.catalog["parent"] = parent 

189 self.catalog["nChild"] = nChild 

190 self.config.doIsolated = True 

191 self.config.isolated.parentName = "parent" 

192 self.config.isolated.nChildName = "nChild" 

193 self.check(((parent == 0) & (nChild == 0)).tolist()) 

194 

195 def testRequireFiniteRaDec(self): 

196 num = 5 

197 for _ in range(num): 

198 self.catalog.addNew() 

199 ra = np.array([np.nan, np.nan, 0, 0, 0], dtype=float) 

200 dec = np.array([2, np.nan, 0, 0, np.nan], dtype=float) 

201 self.catalog["coord_ra"] = ra 

202 self.catalog["coord_dec"] = dec 

203 self.config.doRequireFiniteRaDec = True 

204 self.config.requireFiniteRaDec.raColName = "coord_ra" 

205 self.config.requireFiniteRaDec.decColName = "coord_dec" 

206 self.check((np.isfinite(ra) & np.isfinite(dec)).tolist()) 

207 

208 def testRequirePrimary(self): 

209 num = 5 

210 for _ in range(num): 

211 self.catalog.addNew() 

212 primary = np.array([True, True, False, True, False], dtype=bool) 

213 self.catalog["detect_isPrimary"] = primary 

214 self.config.doRequirePrimary = True 

215 self.config.requirePrimary.primaryColName = "detect_isPrimary" 

216 self.check(primary.tolist()) 

217 

218 

219class ReferenceSourceSelectorTaskTest(SourceSelectorTester, lsst.utils.tests.TestCase): 

220 Task = lsst.meas.algorithms.ReferenceSourceSelectorTask 

221 

222 def setUp(self): 

223 SourceSelectorTester.setUp(self) 

224 self.config.magLimit.fluxField = "flux" 

225 self.config.doMagLimit = True 

226 self.config.doFlags = True 

227 self.config.doUnresolved = False 

228 

229 def testMagnitudeLimit(self): 

230 tooBright = self.catalog.addNew() 

231 tooBright.set("flux", 1.0e10) 

232 tooBright.set("flux_flag", False) 

233 good = self.catalog.addNew() 

234 good.set("flux", 1000.0) 

235 good.set("flux_flag", False) 

236 bad = self.catalog.addNew() 

237 bad.set("flux", good.get("flux")) 

238 bad.set("flux_flag", True) 

239 tooFaint = self.catalog.addNew() 

240 tooFaint.set("flux", 1.0) 

241 tooFaint.set("flux_flag", False) 

242 # Note: magnitudes are backwards, so the minimum flux is the maximum magnitude 

243 self.config.magLimit.minimum = (1.0e6*u.nJy).to_value(u.ABmag) 

244 self.config.magLimit.maximum = (10.0*u.nJy).to_value(u.ABmag) 

245 self.config.magLimit.fluxField = "flux" 

246 self.check([False, True, False, False]) 

247 

248 # Works with no minimum set? 

249 self.config.magLimit.minimum = None 

250 self.check([True, True, False, False]) 

251 

252 # Works with no maximum set? 

253 self.config.magLimit.maximum = None 

254 self.check([True, True, False, True]) 

255 

256 def testMagErrorLimit(self): 

257 # Using an arbitrary field as if it was a magnitude error to save adding a new field 

258 field = "other_fluxErr" 

259 tooFaint = self.catalog.addNew() 

260 tooFaint.set(field, 0.5) 

261 tooBright = self.catalog.addNew() 

262 tooBright.set(field, 0.00001) 

263 good = self.catalog.addNew() 

264 good.set(field, 0.2) 

265 

266 self.config.doMagError = True 

267 self.config.magError.minimum = 0.01 

268 self.config.magError.maximum = 0.3 

269 self.config.magError.magErrField = field 

270 self.check([False, False, True]) 

271 

272 def testColorLimits(self): 

273 num = 10 

274 for _ in range(num): 

275 self.catalog.addNew() 

276 color = np.linspace(-0.5, 0.5, num, True) 

277 flux = 1000.0*u.nJy 

278 # Definition: color = mag(flux) - mag(otherFlux) 

279 otherFlux = (flux.to(u.ABmag) - color*u.mag).to_value(u.nJy) 

280 self.catalog["flux"] = flux.value 

281 self.catalog["other_flux"] = otherFlux 

282 minimum, maximum = -0.1, 0.2 

283 self.config.colorLimits = {"test": ColorLimit(primary="flux", secondary="other_flux", 

284 minimum=minimum, maximum=maximum)} 

285 self.check(((color > minimum) & (color < maximum)).tolist()) 

286 

287 # Works with no minimum set? 

288 self.config.colorLimits["test"].minimum = None 

289 self.check((color < maximum).tolist()) 

290 

291 # Works with no maximum set? 

292 self.config.colorLimits["test"].maximum = None 

293 self.config.colorLimits["test"].minimum = minimum 

294 self.check((color > minimum).tolist()) 

295 

296 # Multiple limits 

297 self.config.colorLimits = {"test": ColorLimit(primary="flux", secondary="other_flux", 

298 minimum=minimum), 

299 "other": ColorLimit(primary="flux", secondary="other_flux", 

300 maximum=maximum)} 

301 assert maximum > minimum # To be non-mutually-exclusive 

302 self.check(((color > minimum) & (color < maximum)).tolist()) 

303 

304 # Multiple mutually-exclusive limits 

305 self.config.colorLimits["test"] = ColorLimit(primary="flux", secondary="other_flux", maximum=-0.1) 

306 self.config.colorLimits["other"] = ColorLimit(primary="flux", secondary="other_flux", minimum=0.1) 

307 self.check([False]*num) 

308 

309 def testUnresolved(self): 

310 num = 5 

311 for _ in range(num): 

312 self.catalog.addNew() 

313 self.catalog["flux"] = 1.0 

314 starGalaxy = np.linspace(0.0, 1.0, num, False) 

315 self.catalog["starGalaxy"] = starGalaxy 

316 self.config.doUnresolved = True 

317 self.config.unresolved.name = "starGalaxy" 

318 minimum, maximum = 0.3, 0.7 

319 self.config.unresolved.minimum = minimum 

320 self.config.unresolved.maximum = maximum 

321 self.check(((starGalaxy > minimum) & (starGalaxy < maximum)).tolist()) 

322 

323 # Works with no minimum set? 

324 self.config.unresolved.minimum = None 

325 self.check((starGalaxy < maximum).tolist()) 

326 

327 # Works with no maximum set? 

328 self.config.unresolved.minimum = minimum 

329 self.config.unresolved.maximum = None 

330 self.check((starGalaxy > minimum).tolist()) 

331 

332 

333class TrivialSourceSelector(lsst.meas.algorithms.BaseSourceSelectorTask): 

334 """Return true for every source. Purely for testing.""" 

335 def selectSources(self, sourceCat, matches=None, exposure=None): 

336 return lsst.pipe.base.Struct(selected=np.ones(len(sourceCat), dtype=bool)) 

337 

338 

339class TestBaseSourceSelector(lsst.utils.tests.TestCase): 

340 """Test the API of the Abstract Base Class with a trivial example.""" 

341 def setUp(self): 

342 schema = lsst.meas.base.tests.TestDataset.makeMinimalSchema() 

343 self.selectedKeyName = "is_selected" 

344 schema.addField(self.selectedKeyName, type="Flag") 

345 self.catalog = lsst.afw.table.SourceCatalog(schema) 

346 for i in range(4): 

347 self.catalog.addNew() 

348 

349 self.sourceSelector = TrivialSourceSelector() 

350 

351 def testRun(self): 

352 """Test that run() returns a catalog and boolean selected array.""" 

353 result = self.sourceSelector.run(self.catalog) 

354 for i, x in enumerate(self.catalog['id']): 

355 self.assertIn(x, result.sourceCat['id']) 

356 self.assertTrue(result.selected[i]) 

357 

358 def testRunSourceSelectedField(self): 

359 """Test that the selected flag is set in the original catalog.""" 

360 self.sourceSelector.run(self.catalog, sourceSelectedField=self.selectedKeyName) 

361 np.testing.assert_array_equal(self.catalog[self.selectedKeyName], True) 

362 

363 def testRunNonContiguousRaises(self): 

364 """Cannot do source selection on non-contiguous catalogs.""" 

365 del self.catalog[1] # take one out of the middle to make it non-contiguous. 

366 self.assertFalse(self.catalog.isContiguous(), "Catalog is contiguous: the test won't work.") 

367 

368 with self.assertRaises(RuntimeError): 

369 self.sourceSelector.run(self.catalog) 

370 

371 

372class TestMemory(lsst.utils.tests.MemoryTestCase): 

373 pass 

374 

375 

376def setup_module(module): 

377 lsst.utils.tests.init() 

378 

379 

380if __name__ == "__main__": 380 ↛ 381line 380 didn't jump to line 381, because the condition on line 380 was never true

381 import sys 

382 setup_module(sys.modules[__name__]) 

383 unittest.main()