Coverage for tests/test_sourceSelector.py: 15%

250 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-05 18:13 -0800

1# 

2# LSST Data Management System 

3# 

4# Copyright 2008-2017 AURA/LSST. 

5# 

6# This product includes software developed by the 

7# LSST Project (http://www.lsst.org/). 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the LSST License Statement and 

20# the GNU General Public License along with this program. If not, 

21# see <https://www.lsstcorp.org/LegalNotices/>. 

22# 

23 

24import unittest 

25import numpy as np 

26import astropy.units as u 

27 

28import lsst.afw.table 

29import lsst.meas.algorithms 

30import lsst.meas.base.tests 

31import lsst.pipe.base 

32import lsst.utils.tests 

33 

34from lsst.meas.algorithms import ColorLimit 

35 

36 

37class SourceSelectorTester: 

38 """Mixin for testing 

39 

40 This provides a base class for doing tests common to the 

41 ScienceSourceSelectorTask and ReferenceSourceSelectorTask. 

42 """ 

43 Task = None 

44 

45 def setUp(self): 

46 schema = lsst.afw.table.SourceTable.makeMinimalSchema() 

47 schema.addField("flux", float, "Flux value") 

48 schema.addField("flux_flag", "Flag", "Bad flux?") 

49 schema.addField("other_flux", float, "Flux value 2") 

50 schema.addField("other_flux_flag", "Flag", "Bad flux 2?") 

51 schema.addField("other_fluxErr", float, "Flux error 2") 

52 schema.addField("goodFlag", "Flag", "Flagged if good") 

53 schema.addField("badFlag", "Flag", "Flagged if bad") 

54 schema.addField("starGalaxy", float, "0=star, 1=galaxy") 

55 schema.addField("nChild", np.int32, "Number of children") 

56 self.catalog = lsst.afw.table.SourceCatalog(schema) 

57 self.catalog.reserve(10) 

58 self.config = self.Task.ConfigClass() 

59 

60 def tearDown(self): 

61 del self.catalog 

62 

63 def check(self, expected): 

64 task = self.Task(config=self.config) 

65 results = task.run(self.catalog) 

66 self.assertListEqual(results.selected.tolist(), expected) 

67 self.assertListEqual([src.getId() for src in results.sourceCat], 

68 [src.getId() for src, ok in zip(self.catalog, expected) if ok]) 

69 

70 # Check with pandas.DataFrame version of catalog 

71 results = task.run(self.catalog.asAstropy().to_pandas()) 

72 self.assertListEqual(results.selected.tolist(), expected) 

73 self.assertListEqual(list(results.sourceCat['id']), 

74 [src.getId() for src, ok in zip(self.catalog, expected) if ok]) 

75 

76 # Check with astropy.table.Table version of catalog 

77 results = task.run(self.catalog.asAstropy()) 

78 self.assertListEqual(results.selected.tolist(), expected) 

79 self.assertListEqual(list(results.sourceCat['id']), 

80 [src.getId() for src, ok in zip(self.catalog, expected) if ok]) 

81 

82 def testFlags(self): 

83 bad1 = self.catalog.addNew() 

84 bad1.set("goodFlag", False) 

85 bad1.set("badFlag", False) 

86 bad2 = self.catalog.addNew() 

87 bad2.set("goodFlag", True) 

88 bad2.set("badFlag", True) 

89 bad3 = self.catalog.addNew() 

90 bad3.set("goodFlag", False) 

91 bad3.set("badFlag", True) 

92 good = self.catalog.addNew() 

93 good.set("goodFlag", True) 

94 good.set("badFlag", False) 

95 self.catalog["flux"] = 1.0 

96 self.catalog["other_flux"] = 1.0 

97 self.config.flags.good = ["goodFlag"] 

98 self.config.flags.bad = ["badFlag"] 

99 self.check([False, False, False, True]) 

100 

101 def testSignalToNoise(self): 

102 low = self.catalog.addNew() 

103 low.set("other_flux", 1.0) 

104 low.set("other_fluxErr", 1.0) 

105 good = self.catalog.addNew() 

106 good.set("other_flux", 1.0) 

107 good.set("other_fluxErr", 0.1) 

108 high = self.catalog.addNew() 

109 high.set("other_flux", 1.0) 

110 high.set("other_fluxErr", 0.001) 

111 self.config.doSignalToNoise = True 

112 self.config.signalToNoise.fluxField = "other_flux" 

113 self.config.signalToNoise.errField = "other_fluxErr" 

114 self.config.signalToNoise.minimum = 5.0 

115 self.config.signalToNoise.maximum = 100.0 

116 self.check([False, True, False]) 

117 

118 

119class ScienceSourceSelectorTaskTest(SourceSelectorTester, lsst.utils.tests.TestCase): 

120 Task = lsst.meas.algorithms.ScienceSourceSelectorTask 

121 

122 def setUp(self): 

123 SourceSelectorTester.setUp(self) 

124 self.config.fluxLimit.fluxField = "flux" 

125 self.config.flags.bad = [] 

126 self.config.doFluxLimit = True 

127 self.config.doFlags = True 

128 self.config.doUnresolved = False 

129 self.config.doIsolated = False 

130 

131 def testFluxLimit(self): 

132 tooBright = self.catalog.addNew() 

133 tooBright.set("flux", 1.0e10) 

134 tooBright.set("flux_flag", False) 

135 good = self.catalog.addNew() 

136 good.set("flux", 1000.0) 

137 good.set("flux_flag", False) 

138 bad = self.catalog.addNew() 

139 bad.set("flux", good.get("flux")) 

140 bad.set("flux_flag", True) 

141 tooFaint = self.catalog.addNew() 

142 tooFaint.set("flux", 1.0) 

143 tooFaint.set("flux_flag", False) 

144 self.config.fluxLimit.minimum = 10.0 

145 self.config.fluxLimit.maximum = 1.0e6 

146 self.config.fluxLimit.fluxField = "flux" 

147 self.check([False, True, False, False]) 

148 

149 # Works with no maximum set? 

150 self.config.fluxLimit.maximum = None 

151 self.check([True, True, False, False]) 

152 

153 # Works with no minimum set? 

154 self.config.fluxLimit.minimum = None 

155 self.check([True, True, False, True]) 

156 

157 def testUnresolved(self): 

158 num = 5 

159 for _ in range(num): 

160 self.catalog.addNew() 

161 self.catalog["flux"] = 1.0 

162 starGalaxy = np.linspace(0.0, 1.0, num, False) 

163 self.catalog["starGalaxy"] = starGalaxy 

164 self.config.doUnresolved = True 

165 self.config.unresolved.name = "starGalaxy" 

166 minimum, maximum = 0.3, 0.7 

167 self.config.unresolved.minimum = minimum 

168 self.config.unresolved.maximum = maximum 

169 self.check(((starGalaxy > minimum) & (starGalaxy < maximum)).tolist()) 

170 

171 # Works with no minimum set? 

172 self.config.unresolved.minimum = None 

173 self.check((starGalaxy < maximum).tolist()) 

174 

175 # Works with no maximum set? 

176 self.config.unresolved.minimum = minimum 

177 self.config.unresolved.maximum = None 

178 self.check((starGalaxy > minimum).tolist()) 

179 

180 def testIsolated(self): 

181 num = 5 

182 for _ in range(num): 

183 self.catalog.addNew() 

184 self.catalog["flux"] = 1.0 

185 parent = np.array([0, 0, 10, 0, 0], dtype=int) 

186 nChild = np.array([2, 0, 0, 0, 0], dtype=int) 

187 self.catalog["parent"] = parent 

188 self.catalog["nChild"] = nChild 

189 self.config.doIsolated = True 

190 self.config.isolated.parentName = "parent" 

191 self.config.isolated.nChildName = "nChild" 

192 self.check(((parent == 0) & (nChild == 0)).tolist()) 

193 

194 

195class ReferenceSourceSelectorTaskTest(SourceSelectorTester, lsst.utils.tests.TestCase): 

196 Task = lsst.meas.algorithms.ReferenceSourceSelectorTask 

197 

198 def setUp(self): 

199 SourceSelectorTester.setUp(self) 

200 self.config.magLimit.fluxField = "flux" 

201 self.config.doMagLimit = True 

202 self.config.doFlags = True 

203 self.config.doUnresolved = False 

204 

205 def testMagnitudeLimit(self): 

206 tooBright = self.catalog.addNew() 

207 tooBright.set("flux", 1.0e10) 

208 tooBright.set("flux_flag", False) 

209 good = self.catalog.addNew() 

210 good.set("flux", 1000.0) 

211 good.set("flux_flag", False) 

212 bad = self.catalog.addNew() 

213 bad.set("flux", good.get("flux")) 

214 bad.set("flux_flag", True) 

215 tooFaint = self.catalog.addNew() 

216 tooFaint.set("flux", 1.0) 

217 tooFaint.set("flux_flag", False) 

218 # Note: magnitudes are backwards, so the minimum flux is the maximum magnitude 

219 self.config.magLimit.minimum = (1.0e6*u.nJy).to_value(u.ABmag) 

220 self.config.magLimit.maximum = (10.0*u.nJy).to_value(u.ABmag) 

221 self.config.magLimit.fluxField = "flux" 

222 self.check([False, True, False, False]) 

223 

224 # Works with no minimum set? 

225 self.config.magLimit.minimum = None 

226 self.check([True, True, False, False]) 

227 

228 # Works with no maximum set? 

229 self.config.magLimit.maximum = None 

230 self.check([True, True, False, True]) 

231 

232 def testMagErrorLimit(self): 

233 # Using an arbitrary field as if it was a magnitude error to save adding a new field 

234 field = "other_fluxErr" 

235 tooFaint = self.catalog.addNew() 

236 tooFaint.set(field, 0.5) 

237 tooBright = self.catalog.addNew() 

238 tooBright.set(field, 0.00001) 

239 good = self.catalog.addNew() 

240 good.set(field, 0.2) 

241 

242 self.config.doMagError = True 

243 self.config.magError.minimum = 0.01 

244 self.config.magError.maximum = 0.3 

245 self.config.magError.magErrField = field 

246 self.check([False, False, True]) 

247 

248 def testColorLimits(self): 

249 num = 10 

250 for _ in range(num): 

251 self.catalog.addNew() 

252 color = np.linspace(-0.5, 0.5, num, True) 

253 flux = 1000.0*u.nJy 

254 # Definition: color = mag(flux) - mag(otherFlux) 

255 otherFlux = (flux.to(u.ABmag) - color*u.mag).to_value(u.nJy) 

256 self.catalog["flux"] = flux.value 

257 self.catalog["other_flux"] = otherFlux 

258 minimum, maximum = -0.1, 0.2 

259 self.config.colorLimits = {"test": ColorLimit(primary="flux", secondary="other_flux", 

260 minimum=minimum, maximum=maximum)} 

261 self.check(((color > minimum) & (color < maximum)).tolist()) 

262 

263 # Works with no minimum set? 

264 self.config.colorLimits["test"].minimum = None 

265 self.check((color < maximum).tolist()) 

266 

267 # Works with no maximum set? 

268 self.config.colorLimits["test"].maximum = None 

269 self.config.colorLimits["test"].minimum = minimum 

270 self.check((color > minimum).tolist()) 

271 

272 # Multiple limits 

273 self.config.colorLimits = {"test": ColorLimit(primary="flux", secondary="other_flux", 

274 minimum=minimum), 

275 "other": ColorLimit(primary="flux", secondary="other_flux", 

276 maximum=maximum)} 

277 assert maximum > minimum # To be non-mutually-exclusive 

278 self.check(((color > minimum) & (color < maximum)).tolist()) 

279 

280 # Multiple mutually-exclusive limits 

281 self.config.colorLimits["test"] = ColorLimit(primary="flux", secondary="other_flux", maximum=-0.1) 

282 self.config.colorLimits["other"] = ColorLimit(primary="flux", secondary="other_flux", minimum=0.1) 

283 self.check([False]*num) 

284 

285 def testUnresolved(self): 

286 num = 5 

287 for _ in range(num): 

288 self.catalog.addNew() 

289 self.catalog["flux"] = 1.0 

290 starGalaxy = np.linspace(0.0, 1.0, num, False) 

291 self.catalog["starGalaxy"] = starGalaxy 

292 self.config.doUnresolved = True 

293 self.config.unresolved.name = "starGalaxy" 

294 minimum, maximum = 0.3, 0.7 

295 self.config.unresolved.minimum = minimum 

296 self.config.unresolved.maximum = maximum 

297 self.check(((starGalaxy > minimum) & (starGalaxy < maximum)).tolist()) 

298 

299 # Works with no minimum set? 

300 self.config.unresolved.minimum = None 

301 self.check((starGalaxy < maximum).tolist()) 

302 

303 # Works with no maximum set? 

304 self.config.unresolved.minimum = minimum 

305 self.config.unresolved.maximum = None 

306 self.check((starGalaxy > minimum).tolist()) 

307 

308 

309class TrivialSourceSelector(lsst.meas.algorithms.BaseSourceSelectorTask): 

310 """Return true for every source. Purely for testing.""" 

311 def selectSources(self, sourceCat, matches=None, exposure=None): 

312 return lsst.pipe.base.Struct(selected=np.ones(len(sourceCat), dtype=bool)) 

313 

314 

315class TestBaseSourceSelector(lsst.utils.tests.TestCase): 

316 """Test the API of the Abstract Base Class with a trivial example.""" 

317 def setUp(self): 

318 schema = lsst.meas.base.tests.TestDataset.makeMinimalSchema() 

319 self.selectedKeyName = "is_selected" 

320 schema.addField(self.selectedKeyName, type="Flag") 

321 self.catalog = lsst.afw.table.SourceCatalog(schema) 

322 for i in range(4): 

323 self.catalog.addNew() 

324 

325 self.sourceSelector = TrivialSourceSelector() 

326 

327 def testRun(self): 

328 """Test that run() returns a catalog and boolean selected array.""" 

329 result = self.sourceSelector.run(self.catalog) 

330 for i, x in enumerate(self.catalog['id']): 

331 self.assertIn(x, result.sourceCat['id']) 

332 self.assertTrue(result.selected[i]) 

333 

334 def testRunSourceSelectedField(self): 

335 """Test that the selected flag is set in the original catalog.""" 

336 self.sourceSelector.run(self.catalog, sourceSelectedField=self.selectedKeyName) 

337 np.testing.assert_array_equal(self.catalog[self.selectedKeyName], True) 

338 

339 def testRunNonContiguousRaises(self): 

340 """Cannot do source selection on non-contiguous catalogs.""" 

341 del self.catalog[1] # take one out of the middle to make it non-contiguous. 

342 self.assertFalse(self.catalog.isContiguous(), "Catalog is contiguous: the test won't work.") 

343 

344 with self.assertRaises(RuntimeError): 

345 self.sourceSelector.run(self.catalog) 

346 

347 

348class TestMemory(lsst.utils.tests.MemoryTestCase): 

349 pass 

350 

351 

352def setup_module(module): 

353 lsst.utils.tests.init() 

354 

355 

356if __name__ == "__main__": 356 ↛ 357line 356 didn't jump to line 357, because the condition on line 356 was never true

357 import sys 

358 setup_module(sys.modules[__name__]) 

359 unittest.main()