Coverage for tests/test_dipole.py: 13%

271 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-01 21:04 +0000

1# This file is part of ip_diffim. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import unittest 

23 

24import numpy as np 

25 

26import lsst.utils.tests 

27import lsst.daf.base as dafBase 

28import lsst.afw.image as afwImage 

29import lsst.afw.table as afwTable 

30import lsst.afw.math as afwMath 

31import lsst.geom as geom 

32import lsst.meas.algorithms as measAlg 

33import lsst.ip.diffim as ipDiffim 

34 

35display = False 

36try: 

37 display 

38except NameError: 

39 display = False 

40else: 

41 import lsst.afw.display as afwDisplay 

42 afwDisplay.setDefaultMaskTransparency(75) 

43 

44sigma2fwhm = 2.*np.sqrt(2.*np.log(2.)) 

45 

46 

47def makePluginAndCat(alg, name, control, metadata=False, centroid=None): 

48 schema = afwTable.SourceTable.makeMinimalSchema() 

49 if centroid: 

50 schema.addField(centroid + "_x", type=float) 

51 schema.addField(centroid + "_y", type=float) 

52 schema.addField(centroid + "_flag", type='Flag') 

53 schema.getAliasMap().set("slot_Centroid", centroid) 

54 if metadata: 

55 plugin = alg(control, name, schema, dafBase.PropertySet()) 

56 else: 

57 plugin = alg(control, name, schema) 

58 cat = afwTable.SourceCatalog(schema) 

59 return plugin, cat 

60 

61 

62def createDipole(w, h, xc, yc, scaling=100.0, fracOffset=1.2): 

63 # Make random noise image: set image plane to normal distribution 

64 image = afwImage.MaskedImageF(w, h) 

65 image.set(0) 

66 array = image.getImage().getArray() 

67 array[:, :] = np.random.randn(w, h) 

68 # Set variance to 1.0 

69 var = image.getVariance() 

70 var.set(1.0) 

71 

72 if display: 

73 afwDisplay.Display(frame=1).mtv(image, title="Original image") 

74 afwDisplay.Display(frame=2).mtv(image.getVariance(), title="Original variance") 

75 

76 # Create Psf for dipole creation and measurement 

77 psfSize = 17 

78 psf = measAlg.DoubleGaussianPsf(psfSize, psfSize, 2.0, 3.5, 0.1) 

79 psfFwhmPix = sigma2fwhm*psf.computeShape().getDeterminantRadius() 

80 psfim = psf.computeImage().convertF() 

81 psfim *= scaling/psf.computePeak() 

82 psfw, psfh = psfim.getDimensions() 

83 psfSum = np.sum(psfim.getArray()) 

84 

85 # Create the dipole, offset by fracOffset of the Psf FWHM (pixels) 

86 offset = fracOffset*psfFwhmPix//2 

87 array = image.getImage().getArray() 

88 xp = int(xc - psfw//2 + offset) 

89 yp = int(yc - psfh//2 + offset) 

90 array[yp:yp + psfh, xp:xp + psfw] += psfim.getArray() 

91 

92 xn = int(xc - psfw//2 - offset) 

93 yn = int(yc - psfh//2 - offset) 

94 array[yn:yn + psfh, xn:xn + psfw] -= psfim.getArray() 

95 

96 if display: 

97 afwDisplay.Display(frame=3).mtv(image, title="With dipole") 

98 

99 # Create an exposure, detect positive and negative peaks separately 

100 exp = afwImage.makeExposure(image) 

101 exp.setPsf(psf) 

102 config = measAlg.SourceDetectionConfig() 

103 config.thresholdPolarity = "both" 

104 config.reEstimateBackground = False 

105 schema = afwTable.SourceTable.makeMinimalSchema() 

106 task = measAlg.SourceDetectionTask(schema, config=config) 

107 table = afwTable.SourceTable.make(schema) 

108 results = task.run(table, exp) 

109 if display: 

110 afwDisplay.Display(frame=4).mtv(image, title="Detection plane") 

111 

112 # Merge them together 

113 assert(len(results.sources) == 2) 

114 fpSet = results.fpSets.positive 

115 fpSet.merge(results.fpSets.negative, 0, 0, False) 

116 sources = afwTable.SourceCatalog(table) 

117 fpSet.makeSources(sources) 

118 assert(len(sources) == 1) 

119 s = sources[0] 

120 assert(len(s.getFootprint().getPeaks()) == 2) 

121 

122 return psf, psfSum, exp, s 

123 

124 

125class DipoleAlgorithmTest(lsst.utils.tests.TestCase): 

126 """ A test case for dipole algorithms""" 

127 

128 def setUp(self): 

129 np.random.seed(666) 

130 self.w, self.h = 100, 100 # size of image 

131 self.xc, self.yc = 50, 50 # location of center of dipole 

132 

133 def testNaiveDipoleCentroid(self): 

134 control = ipDiffim.DipoleCentroidControl() 

135 psf, psfSum, exposure, s = createDipole(self.w, self.h, self.xc, self.yc) 

136 plugin, cat = makePluginAndCat(ipDiffim.NaiveDipoleCentroid, "test", control, centroid="centroid") 

137 source = cat.addNew() 

138 source.set("centroid_x", 50) 

139 source.set("centroid_y", 50) 

140 source.setFootprint(s.getFootprint()) 

141 plugin.measure(source, exposure) 

142 for key in ("_pos_x", "_pos_y", "_pos_xErr", "_pos_yErr", "_pos_flag", 

143 "_neg_x", "_neg_y", "_neg_xErr", "_neg_yErr", "_neg_flag"): 

144 try: 

145 source.get("test" + key) 

146 except Exception: 

147 self.fail() 

148 

149 def testNaiveDipoleFluxControl(self): 

150 psf, psfSum, exposure, s = createDipole(self.w, self.h, self.xc, self.yc) 

151 control = ipDiffim.DipoleFluxControl() 

152 psf, psfSum, exposure, s = createDipole(self.w, self.h, self.xc, self.yc) 

153 plugin, cat = makePluginAndCat(ipDiffim.NaiveDipoleFlux, "test", control, centroid="centroid") 

154 source = cat.addNew() 

155 source.set("centroid_x", 50) 

156 source.set("centroid_y", 50) 

157 source.setFootprint(s.getFootprint()) 

158 plugin.measure(source, exposure) 

159 for key in ("_pos_instFlux", "_pos_instFluxErr", "_pos_flag", "_npos", 

160 "_neg_instFlux", "_neg_instFluxErr", "_neg_flag", "_nneg"): 

161 try: 

162 source.get("test" + key) 

163 except Exception: 

164 self.fail() 

165 

166 def testPsfDipoleFluxControl(self): 

167 psf, psfSum, exposure, s = createDipole(self.w, self.h, self.xc, self.yc) 

168 psf, psfSum, exposure, s = createDipole(self.w, self.h, self.xc, self.yc) 

169 control = ipDiffim.PsfDipoleFluxControl() 

170 psf, psfSum, exposure, s = createDipole(self.w, self.h, self.xc, self.yc) 

171 plugin, cat = makePluginAndCat(ipDiffim.PsfDipoleFlux, "test", control, centroid="centroid") 

172 source = cat.addNew() 

173 source.set("centroid_x", 50) 

174 source.set("centroid_y", 50) 

175 source.setFootprint(s.getFootprint()) 

176 plugin.measure(source, exposure) 

177 for key in ("_pos_instFlux", "_pos_instFluxErr", "_pos_flag", 

178 "_neg_instFlux", "_neg_instFluxErr", "_neg_flag"): 

179 try: 

180 source.get("test" + key) 

181 except Exception: 

182 self.fail() 

183 

184 def testAll(self): 

185 psf, psfSum, exposure, s = createDipole(self.w, self.h, self.xc, self.yc) 

186 self.measureDipole(s, exposure) 

187 

188 def _makeModel(self, exposure, psf, fp, negCenter, posCenter): 

189 

190 negPsf = psf.computeImage(negCenter).convertF() 

191 posPsf = psf.computeImage(posCenter).convertF() 

192 negPeak = psf.computePeak(negCenter) 

193 posPeak = psf.computePeak(posCenter) 

194 negPsf /= negPeak 

195 posPsf /= posPeak 

196 

197 model = afwImage.ImageF(fp.getBBox()) 

198 negModel = afwImage.ImageF(fp.getBBox()) 

199 posModel = afwImage.ImageF(fp.getBBox()) 

200 

201 # The center of the Psf should be at negCenter, posCenter 

202 negPsfBBox = negPsf.getBBox() 

203 posPsfBBox = posPsf.getBBox() 

204 modelBBox = model.getBBox() 

205 

206 # Portion of the negative Psf that overlaps the montage 

207 negOverlapBBox = geom.Box2I(negPsfBBox) 

208 negOverlapBBox.clip(modelBBox) 

209 self.assertFalse(negOverlapBBox.isEmpty()) 

210 

211 # Portion of the positivePsf that overlaps the montage 

212 posOverlapBBox = geom.Box2I(posPsfBBox) 

213 posOverlapBBox.clip(modelBBox) 

214 self.assertFalse(posOverlapBBox.isEmpty()) 

215 

216 negPsfSubim = type(negPsf)(negPsf, negOverlapBBox) 

217 modelSubim = type(model)(model, negOverlapBBox) 

218 negModelSubim = type(negModel)(negModel, negOverlapBBox) 

219 modelSubim += negPsfSubim # just for debugging 

220 negModelSubim += negPsfSubim # for fitting 

221 

222 posPsfSubim = type(posPsf)(posPsf, posOverlapBBox) 

223 modelSubim = type(model)(model, posOverlapBBox) 

224 posModelSubim = type(posModel)(posModel, posOverlapBBox) 

225 modelSubim += posPsfSubim 

226 posModelSubim += posPsfSubim 

227 

228 data = afwImage.ImageF(exposure.getMaskedImage().getImage(), fp.getBBox()) 

229 var = afwImage.ImageF(exposure.getMaskedImage().getVariance(), fp.getBBox()) 

230 matrixNorm = 1./np.sqrt(np.median(var.getArray())) 

231 

232 if display: 

233 afwDisplay.Display(frame=5).mtv(model, title="Unfitted model") 

234 afwDisplay.Display(frame=6).mtv(data, title="Data") 

235 

236 posPsfSum = np.sum(posPsf.getArray()) 

237 negPsfSum = np.sum(negPsf.getArray()) 

238 

239 M = np.array((np.ravel(negModel.getArray()), np.ravel(posModel.getArray()))).T.astype(np.float64) 

240 B = np.array((np.ravel(data.getArray()))).astype(np.float64) 

241 M *= matrixNorm 

242 B *= matrixNorm 

243 

244 # Numpy solution 

245 fneg0, fpos0 = np.linalg.lstsq(M, B, rcond=-1)[0] 

246 

247 # Afw solution 

248 lsq = afwMath.LeastSquares.fromDesignMatrix(M, B, afwMath.LeastSquares.DIRECT_SVD) 

249 fneg, fpos = lsq.getSolution() 

250 

251 # Should be exaxtly the same as each other 

252 self.assertAlmostEqual(1e-2*fneg0, 1e-2*fneg) 

253 self.assertAlmostEqual(1e-2*fpos0, 1e-2*fpos) 

254 

255 # Recreate model 

256 fitted = afwImage.ImageF(fp.getBBox()) 

257 negFit = type(negPsf)(negPsf, negOverlapBBox, afwImage.PARENT, True) 

258 negFit *= float(fneg) 

259 posFit = type(posPsf)(posPsf, posOverlapBBox, afwImage.PARENT, True) 

260 posFit *= float(fpos) 

261 

262 fitSubim = type(fitted)(fitted, negOverlapBBox) 

263 fitSubim += negFit 

264 fitSubim = type(fitted)(fitted, posOverlapBBox) 

265 fitSubim += posFit 

266 if display: 

267 afwDisplay.Display(frame=7).mtv(fitted, title="Fitted model") 

268 

269 fitted -= data 

270 

271 if display: 

272 afwDisplay.Display(frame=8).mtv(fitted, title="Residuals") 

273 

274 fitted *= fitted 

275 fitted /= var 

276 

277 if display: 

278 afwDisplay.Display(frame=9).mtv(fitted, title="Chi2") 

279 

280 return fneg, negPsfSum, fpos, posPsfSum, fitted 

281 

282 def testPsfDipoleFit(self, scaling=100.): 

283 psf, psfSum, exposure, s = createDipole(self.w, self.h, self.xc, self.yc, scaling=scaling) 

284 source = self.measureDipole(s, exposure) 

285 # Recreate the simultaneous joint Psf fit in python 

286 fp = source.getFootprint() 

287 peaks = fp.getPeaks() 

288 speaks = [(p.getPeakValue(), p) for p in peaks] 

289 speaks.sort() 

290 dpeaks = [speaks[0][1], speaks[-1][1]] 

291 

292 negCenter = geom.Point2D(dpeaks[0].getFx(), dpeaks[0].getFy()) 

293 posCenter = geom.Point2D(dpeaks[1].getFx(), dpeaks[1].getFy()) 

294 

295 fneg, negPsfSum, fpos, posPsfSum, residIm = self._makeModel(exposure, psf, fp, negCenter, posCenter) 

296 

297 # Should be close to the same as the inputs; as fracOffset 

298 # gets smaller this will be worse. This works for scaling = 

299 # 100. 

300 self.assertAlmostEqual(1e-2*scaling, -1e-2*fneg, 2) 

301 self.assertAlmostEqual(1e-2*scaling, 1e-2*fpos, 2) 

302 

303 # Now compare the LeastSquares results fitted here to the C++ 

304 # implementation: Since total flux is returned, and this is of 

305 # order 1e4 for this default test, scale back down so that 

306 # assertAlmostEqual behaves reasonably (the comparison to 2 

307 # places means to 0.01). Also note that PsfDipoleFlux returns 

308 # the total flux, while here we are just fitting for the 

309 # scaling of the Psf. Therefore the comparison is 

310 # fneg*negPsfSum to flux.dipole.psf.neg. 

311 self.assertAlmostEqual(1e-4*fneg*negPsfSum, 

312 1e-4*source.get("ip_diffim_PsfDipoleFlux_neg_instFlux"), 

313 2) 

314 self.assertAlmostEqual(1e-4*fpos*posPsfSum, 

315 1e-4*source.get("ip_diffim_PsfDipoleFlux_pos_instFlux"), 

316 2) 

317 

318 self.assertGreater(source.get("ip_diffim_PsfDipoleFlux_pos_instFluxErr"), 0.0) 

319 self.assertGreater(source.get("ip_diffim_PsfDipoleFlux_neg_instFluxErr"), 0.0) 

320 self.assertFalse(source.get("ip_diffim_PsfDipoleFlux_neg_flag")) 

321 self.assertFalse(source.get("ip_diffim_PsfDipoleFlux_pos_flag")) 

322 

323 self.assertAlmostEqual(source.get("ip_diffim_PsfDipoleFlux_centroid_x"), 50.0, 1) 

324 self.assertAlmostEqual(source.get("ip_diffim_PsfDipoleFlux_centroid_y"), 50.0, 1) 

325 self.assertAlmostEqual(source.get("ip_diffim_PsfDipoleFlux_neg_centroid_x"), negCenter[0], 1) 

326 self.assertAlmostEqual(source.get("ip_diffim_PsfDipoleFlux_neg_centroid_y"), negCenter[1], 1) 

327 self.assertAlmostEqual(source.get("ip_diffim_PsfDipoleFlux_pos_centroid_x"), posCenter[0], 1) 

328 self.assertAlmostEqual(source.get("ip_diffim_PsfDipoleFlux_pos_centroid_y"), posCenter[1], 1) 

329 self.assertFalse(source.get("ip_diffim_PsfDipoleFlux_neg_flag")) 

330 self.assertFalse(source.get("ip_diffim_PsfDipoleFlux_pos_flag")) 

331 

332 self.assertGreater(source.get("ip_diffim_PsfDipoleFlux_chi2dof"), 0.0) 

333 

334 def measureDipole(self, s, exp): 

335 msConfig = ipDiffim.DipoleMeasurementConfig() 

336 schema = afwTable.SourceTable.makeMinimalSchema() 

337 schema.addField("centroid_x", type=float) 

338 schema.addField("centroid_y", type=float) 

339 schema.addField("centroid_flag", type='Flag') 

340 task = ipDiffim.DipoleMeasurementTask(schema, config=msConfig) 

341 measCat = afwTable.SourceCatalog(schema) 

342 measCat.defineCentroid("centroid") 

343 source = measCat.addNew() 

344 source.set("centroid_x", self.xc) 

345 source.set("centroid_y", self.yc) 

346 source.setFootprint(s.getFootprint()) 

347 # Then run the default SFM task. Results not checked 

348 task.run(measCat, exp) 

349 return measCat[0] 

350 

351 def testDipoleAnalysis(self): 

352 psf, psfSum, exposure, s = createDipole(self.w, self.h, self.xc, self.yc) 

353 source = self.measureDipole(s, exposure) 

354 dpAnalysis = ipDiffim.DipoleAnalysis() 

355 dpAnalysis(source) 

356 

357 def testDipoleDeblender(self): 

358 psf, psfSum, exposure, s = createDipole(self.w, self.h, self.xc, self.yc) 

359 source = self.measureDipole(s, exposure) 

360 dpDeblender = ipDiffim.DipoleDeblender() 

361 dpDeblender(source, exposure) 

362 

363 

364class DipoleMeasurementTaskTest(lsst.utils.tests.TestCase): 

365 """A test case for the DipoleMeasurementTask. Essentially just 

366 test the classification flag since the invididual algorithms are 

367 tested above""" 

368 

369 def setUp(self): 

370 np.random.seed(666) 

371 self.config = ipDiffim.DipoleMeasurementConfig() 

372 

373 def tearDown(self): 

374 del self.config 

375 

376 def testMeasure(self): 

377 schema = afwTable.SourceTable.makeMinimalSchema() 

378 task = ipDiffim.DipoleMeasurementTask(schema, config=self.config) 

379 table = afwTable.SourceTable.make(schema) 

380 sources = afwTable.SourceCatalog(table) 

381 source = sources.addNew() 

382 # make fake image 

383 psf, psfSum, exposure, s = createDipole(100, 100, 50, 50) 

384 

385 # set it in source with the appropriate schema 

386 source.setFootprint(s.getFootprint()) 

387 task.run(sources, exposure) 

388 self.assertEqual(source.get("ip_diffim_ClassificationDipole_value"), 1.0) 

389 

390 

391class TestMemory(lsst.utils.tests.MemoryTestCase): 

392 pass 

393 

394 

395def setup_module(module): 

396 lsst.utils.tests.init() 

397 

398 

399if __name__ == "__main__": 399 ↛ 400line 399 didn't jump to line 400, because the condition on line 399 was never true

400 lsst.utils.tests.init() 

401 unittest.main()