Coverage for tests/test_dipole.py: 13%

272 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-26 03:40 -0700

1# This file is part of ip_diffim. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import unittest 

23 

24import numpy as np 

25 

26import lsst.utils.tests 

27import lsst.daf.base as dafBase 

28import lsst.afw.image as afwImage 

29import lsst.afw.table as afwTable 

30import lsst.afw.math as afwMath 

31import lsst.geom as geom 

32import lsst.meas.algorithms as measAlg 

33import lsst.ip.diffim as ipDiffim 

34 

35display = False 

36try: 

37 display 

38except NameError: 

39 display = False 

40else: 

41 import lsst.afw.display as afwDisplay 

42 afwDisplay.setDefaultMaskTransparency(75) 

43 

44sigma2fwhm = 2.*np.sqrt(2.*np.log(2.)) 

45 

46 

47def makePluginAndCat(alg, name, control, metadata=False, centroid=None): 

48 schema = afwTable.SourceTable.makeMinimalSchema() 

49 if centroid: 

50 schema.addField(centroid + "_x", type=float) 

51 schema.addField(centroid + "_y", type=float) 

52 schema.addField(centroid + "_flag", type='Flag') 

53 schema.getAliasMap().set("slot_Centroid", centroid) 

54 if metadata: 

55 plugin = alg(control, name, schema, dafBase.PropertySet()) 

56 else: 

57 plugin = alg(control, name, schema) 

58 cat = afwTable.SourceCatalog(schema) 

59 return plugin, cat 

60 

61 

62def createDipole(w, h, xc, yc, scaling=100.0, fracOffset=1.2): 

63 # Make random noise image: set image plane to normal distribution 

64 image = afwImage.MaskedImageF(w, h) 

65 image.set(0) 

66 array = image.image.array 

67 array[:, :] = np.random.randn(w, h) 

68 # Set variance to 1.0 

69 var = image.getVariance() 

70 var.set(1.0) 

71 

72 if display: 

73 afwDisplay.Display(frame=1).mtv(image, title="Original image") 

74 afwDisplay.Display(frame=2).mtv(image.getVariance(), title="Original variance") 

75 

76 # Create Psf for dipole creation and measurement 

77 psfSize = 17 

78 psf = measAlg.DoubleGaussianPsf(psfSize, psfSize, 2.0, 3.5, 0.1) 

79 pos = psf.getAveragePosition() 

80 psfFwhmPix = sigma2fwhm*psf.computeShape(pos).getDeterminantRadius() 

81 psfim = psf.computeImage(pos).convertF() 

82 psfim *= scaling/psf.computePeak(pos) 

83 psfw, psfh = psfim.getDimensions() 

84 psfSum = np.sum(psfim.array) 

85 

86 # Create the dipole, offset by fracOffset of the Psf FWHM (pixels) 

87 offset = fracOffset*psfFwhmPix//2 

88 array = image.image.array 

89 xp = int(xc - psfw//2 + offset) 

90 yp = int(yc - psfh//2 + offset) 

91 array[yp:yp + psfh, xp:xp + psfw] += psfim.array 

92 

93 xn = int(xc - psfw//2 - offset) 

94 yn = int(yc - psfh//2 - offset) 

95 array[yn:yn + psfh, xn:xn + psfw] -= psfim.array 

96 

97 if display: 

98 afwDisplay.Display(frame=3).mtv(image, title="With dipole") 

99 

100 # Create an exposure, detect positive and negative peaks separately 

101 exp = afwImage.makeExposure(image) 

102 exp.setPsf(psf) 

103 config = measAlg.SourceDetectionConfig() 

104 config.thresholdPolarity = "both" 

105 config.reEstimateBackground = False 

106 schema = afwTable.SourceTable.makeMinimalSchema() 

107 task = measAlg.SourceDetectionTask(schema, config=config) 

108 table = afwTable.SourceTable.make(schema) 

109 results = task.run(table, exp) 

110 if display: 

111 afwDisplay.Display(frame=4).mtv(image, title="Detection plane") 

112 

113 # Merge them together 

114 assert len(results.sources) == 2 

115 fpSet = results.positive 

116 fpSet.merge(results.negative, 0, 0, False) 

117 sources = afwTable.SourceCatalog(table) 

118 fpSet.makeSources(sources) 

119 assert len(sources) == 1 

120 s = sources[0] 

121 assert len(s.getFootprint().getPeaks()) == 2 

122 

123 return psf, psfSum, exp, s 

124 

125 

126class DipoleAlgorithmTest(lsst.utils.tests.TestCase): 

127 """ A test case for dipole algorithms""" 

128 

129 def setUp(self): 

130 np.random.seed(666) 

131 self.w, self.h = 100, 100 # size of image 

132 self.xc, self.yc = 50, 50 # location of center of dipole 

133 

134 # Remove this test on DM-44030 

135 def testNaiveDipoleCentroid(self): 

136 control = ipDiffim.DipoleCentroidControl() 

137 psf, psfSum, exposure, s = createDipole(self.w, self.h, self.xc, self.yc) 

138 plugin, cat = makePluginAndCat(ipDiffim.NaiveDipoleCentroid, "test", control, centroid="centroid") 

139 source = cat.addNew() 

140 source.set("centroid_x", 50) 

141 source.set("centroid_y", 50) 

142 source.setFootprint(s.getFootprint()) 

143 plugin.measure(source, exposure) 

144 for key in ("_pos_x", "_pos_y", "_pos_xErr", "_pos_yErr", "_pos_flag", 

145 "_neg_x", "_neg_y", "_neg_xErr", "_neg_yErr", "_neg_flag"): 

146 try: 

147 source.get("test" + key) 

148 except Exception: 

149 self.fail() 

150 

151 # Remove this test on DM-44030 

152 def testNaiveDipoleFluxControl(self): 

153 psf, psfSum, exposure, s = createDipole(self.w, self.h, self.xc, self.yc) 

154 control = ipDiffim.DipoleFluxControl() 

155 psf, psfSum, exposure, s = createDipole(self.w, self.h, self.xc, self.yc) 

156 plugin, cat = makePluginAndCat(ipDiffim.NaiveDipoleFlux, "test", control, centroid="centroid") 

157 source = cat.addNew() 

158 source.set("centroid_x", 50) 

159 source.set("centroid_y", 50) 

160 source.setFootprint(s.getFootprint()) 

161 plugin.measure(source, exposure) 

162 for key in ("_pos_instFlux", "_pos_instFluxErr", "_pos_flag", "_npos", 

163 "_neg_instFlux", "_neg_instFluxErr", "_neg_flag", "_nneg"): 

164 try: 

165 source.get("test" + key) 

166 except Exception: 

167 self.fail() 

168 

169 def testPsfDipoleFluxControl(self): 

170 psf, psfSum, exposure, s = createDipole(self.w, self.h, self.xc, self.yc) 

171 psf, psfSum, exposure, s = createDipole(self.w, self.h, self.xc, self.yc) 

172 control = ipDiffim.PsfDipoleFluxControl() 

173 psf, psfSum, exposure, s = createDipole(self.w, self.h, self.xc, self.yc) 

174 plugin, cat = makePluginAndCat(ipDiffim.PsfDipoleFlux, "test", control, centroid="centroid") 

175 source = cat.addNew() 

176 source.set("centroid_x", 50) 

177 source.set("centroid_y", 50) 

178 source.setFootprint(s.getFootprint()) 

179 plugin.measure(source, exposure) 

180 for key in ("_pos_instFlux", "_pos_instFluxErr", "_pos_flag", 

181 "_neg_instFlux", "_neg_instFluxErr", "_neg_flag"): 

182 try: 

183 source.get("test" + key) 

184 except Exception: 

185 self.fail() 

186 

187 def testAll(self): 

188 psf, psfSum, exposure, s = createDipole(self.w, self.h, self.xc, self.yc) 

189 self.measureDipole(s, exposure) 

190 

191 def _makeModel(self, exposure, psf, fp, negCenter, posCenter): 

192 

193 negPsf = psf.computeImage(negCenter).convertF() 

194 posPsf = psf.computeImage(posCenter).convertF() 

195 negPeak = psf.computePeak(negCenter) 

196 posPeak = psf.computePeak(posCenter) 

197 negPsf /= negPeak 

198 posPsf /= posPeak 

199 

200 model = afwImage.ImageF(fp.getBBox()) 

201 negModel = afwImage.ImageF(fp.getBBox()) 

202 posModel = afwImage.ImageF(fp.getBBox()) 

203 

204 # The center of the Psf should be at negCenter, posCenter 

205 negPsfBBox = negPsf.getBBox() 

206 posPsfBBox = posPsf.getBBox() 

207 modelBBox = model.getBBox() 

208 

209 # Portion of the negative Psf that overlaps the montage 

210 negOverlapBBox = geom.Box2I(negPsfBBox) 

211 negOverlapBBox.clip(modelBBox) 

212 self.assertFalse(negOverlapBBox.isEmpty()) 

213 

214 # Portion of the positivePsf that overlaps the montage 

215 posOverlapBBox = geom.Box2I(posPsfBBox) 

216 posOverlapBBox.clip(modelBBox) 

217 self.assertFalse(posOverlapBBox.isEmpty()) 

218 

219 negPsfSubim = type(negPsf)(negPsf, negOverlapBBox) 

220 modelSubim = type(model)(model, negOverlapBBox) 

221 negModelSubim = type(negModel)(negModel, negOverlapBBox) 

222 modelSubim += negPsfSubim # just for debugging 

223 negModelSubim += negPsfSubim # for fitting 

224 

225 posPsfSubim = type(posPsf)(posPsf, posOverlapBBox) 

226 modelSubim = type(model)(model, posOverlapBBox) 

227 posModelSubim = type(posModel)(posModel, posOverlapBBox) 

228 modelSubim += posPsfSubim 

229 posModelSubim += posPsfSubim 

230 

231 data = afwImage.ImageF(exposure.image, fp.getBBox()) 

232 var = afwImage.ImageF(exposure.variance, fp.getBBox()) 

233 matrixNorm = 1./np.sqrt(np.median(var.array)) 

234 

235 if display: 

236 afwDisplay.Display(frame=5).mtv(model, title="Unfitted model") 

237 afwDisplay.Display(frame=6).mtv(data, title="Data") 

238 

239 posPsfSum = np.sum(posPsf.array) 

240 negPsfSum = np.sum(negPsf.array) 

241 

242 M = np.array((np.ravel(negModel.array), np.ravel(posModel.array))).T.astype(np.float64) 

243 B = np.array((np.ravel(data.array))).astype(np.float64) 

244 M *= matrixNorm 

245 B *= matrixNorm 

246 

247 # Numpy solution 

248 fneg0, fpos0 = np.linalg.lstsq(M, B, rcond=-1)[0] 

249 

250 # Afw solution 

251 lsq = afwMath.LeastSquares.fromDesignMatrix(M, B, afwMath.LeastSquares.DIRECT_SVD) 

252 fneg, fpos = lsq.getSolution() 

253 

254 # Should be exaxtly the same as each other 

255 self.assertAlmostEqual(1e-2*fneg0, 1e-2*fneg) 

256 self.assertAlmostEqual(1e-2*fpos0, 1e-2*fpos) 

257 

258 # Recreate model 

259 fitted = afwImage.ImageF(fp.getBBox()) 

260 negFit = type(negPsf)(negPsf, negOverlapBBox, afwImage.PARENT, True) 

261 negFit *= float(fneg) 

262 posFit = type(posPsf)(posPsf, posOverlapBBox, afwImage.PARENT, True) 

263 posFit *= float(fpos) 

264 

265 fitSubim = type(fitted)(fitted, negOverlapBBox) 

266 fitSubim += negFit 

267 fitSubim = type(fitted)(fitted, posOverlapBBox) 

268 fitSubim += posFit 

269 if display: 

270 afwDisplay.Display(frame=7).mtv(fitted, title="Fitted model") 

271 

272 fitted -= data 

273 

274 if display: 

275 afwDisplay.Display(frame=8).mtv(fitted, title="Residuals") 

276 

277 fitted *= fitted 

278 fitted /= var 

279 

280 if display: 

281 afwDisplay.Display(frame=9).mtv(fitted, title="Chi2") 

282 

283 return fneg, negPsfSum, fpos, posPsfSum, fitted 

284 

285 def testPsfDipoleFit(self, scaling=100.): 

286 psf, psfSum, exposure, s = createDipole(self.w, self.h, self.xc, self.yc, scaling=scaling) 

287 source = self.measureDipole(s, exposure) 

288 # Recreate the simultaneous joint Psf fit in python 

289 fp = source.getFootprint() 

290 peaks = fp.getPeaks() 

291 speaks = [(p.getPeakValue(), p) for p in peaks] 

292 speaks.sort() 

293 dpeaks = [speaks[0][1], speaks[-1][1]] 

294 

295 negCenter = geom.Point2D(dpeaks[0].getFx(), dpeaks[0].getFy()) 

296 posCenter = geom.Point2D(dpeaks[1].getFx(), dpeaks[1].getFy()) 

297 

298 fneg, negPsfSum, fpos, posPsfSum, residIm = self._makeModel(exposure, psf, fp, negCenter, posCenter) 

299 

300 # Should be close to the same as the inputs; as fracOffset 

301 # gets smaller this will be worse. This works for scaling = 

302 # 100. 

303 self.assertAlmostEqual(1e-2*scaling, -1e-2*fneg, 2) 

304 self.assertAlmostEqual(1e-2*scaling, 1e-2*fpos, 2) 

305 

306 # Now compare the LeastSquares results fitted here to the C++ 

307 # implementation: Since total flux is returned, and this is of 

308 # order 1e4 for this default test, scale back down so that 

309 # assertAlmostEqual behaves reasonably (the comparison to 2 

310 # places means to 0.01). Also note that PsfDipoleFlux returns 

311 # the total flux, while here we are just fitting for the 

312 # scaling of the Psf. Therefore the comparison is 

313 # fneg*negPsfSum to flux.dipole.psf.neg. 

314 self.assertAlmostEqual(1e-4*fneg*negPsfSum, 

315 1e-4*source.get("ip_diffim_PsfDipoleFlux_neg_instFlux"), 

316 2) 

317 self.assertAlmostEqual(1e-4*fpos*posPsfSum, 

318 1e-4*source.get("ip_diffim_PsfDipoleFlux_pos_instFlux"), 

319 2) 

320 

321 self.assertGreater(source.get("ip_diffim_PsfDipoleFlux_pos_instFluxErr"), 0.0) 

322 self.assertGreater(source.get("ip_diffim_PsfDipoleFlux_neg_instFluxErr"), 0.0) 

323 self.assertFalse(source.get("ip_diffim_PsfDipoleFlux_neg_flag")) 

324 self.assertFalse(source.get("ip_diffim_PsfDipoleFlux_pos_flag")) 

325 

326 self.assertAlmostEqual(source.get("ip_diffim_PsfDipoleFlux_x"), 50.0, 1) 

327 self.assertAlmostEqual(source.get("ip_diffim_PsfDipoleFlux_y"), 50.0, 1) 

328 self.assertAlmostEqual(source.get("ip_diffim_PsfDipoleFlux_neg_centroid_x"), negCenter[0], 1) 

329 self.assertAlmostEqual(source.get("ip_diffim_PsfDipoleFlux_neg_centroid_y"), negCenter[1], 1) 

330 self.assertAlmostEqual(source.get("ip_diffim_PsfDipoleFlux_pos_centroid_x"), posCenter[0], 1) 

331 self.assertAlmostEqual(source.get("ip_diffim_PsfDipoleFlux_pos_centroid_y"), posCenter[1], 1) 

332 self.assertFalse(source.get("ip_diffim_PsfDipoleFlux_neg_flag")) 

333 self.assertFalse(source.get("ip_diffim_PsfDipoleFlux_pos_flag")) 

334 

335 self.assertGreater(source.get("ip_diffim_PsfDipoleFlux_chi2dof"), 0.0) 

336 

337 def measureDipole(self, s, exp): 

338 msConfig = ipDiffim.DipoleMeasurementConfig() 

339 schema = afwTable.SourceTable.makeMinimalSchema() 

340 schema.addField("centroid_x", type=float) 

341 schema.addField("centroid_y", type=float) 

342 schema.addField("centroid_flag", type='Flag') 

343 task = ipDiffim.DipoleMeasurementTask(schema, config=msConfig) 

344 measCat = afwTable.SourceCatalog(schema) 

345 measCat.defineCentroid("centroid") 

346 source = measCat.addNew() 

347 source.set("centroid_x", self.xc) 

348 source.set("centroid_y", self.yc) 

349 source.setFootprint(s.getFootprint()) 

350 # Then run the default SFM task. Results not checked 

351 task.run(measCat, exp) 

352 return measCat[0] 

353 

354 def testDipoleAnalysis(self): 

355 psf, psfSum, exposure, s = createDipole(self.w, self.h, self.xc, self.yc) 

356 source = self.measureDipole(s, exposure) 

357 dpAnalysis = ipDiffim.DipoleAnalysis() 

358 dpAnalysis(source) 

359 

360 def testDipoleDeblender(self): 

361 psf, psfSum, exposure, s = createDipole(self.w, self.h, self.xc, self.yc) 

362 source = self.measureDipole(s, exposure) 

363 dpDeblender = ipDiffim.DipoleDeblender() 

364 dpDeblender(source, exposure) 

365 

366 

367class DipoleMeasurementTaskTest(lsst.utils.tests.TestCase): 

368 """A test case for the DipoleMeasurementTask. Essentially just 

369 test the classification flag since the invididual algorithms are 

370 tested above""" 

371 

372 def setUp(self): 

373 np.random.seed(666) 

374 self.config = ipDiffim.DipoleMeasurementConfig() 

375 

376 def tearDown(self): 

377 del self.config 

378 

379 def testMeasure(self): 

380 schema = afwTable.SourceTable.makeMinimalSchema() 

381 task = ipDiffim.DipoleMeasurementTask(schema, config=self.config) 

382 table = afwTable.SourceTable.make(schema) 

383 sources = afwTable.SourceCatalog(table) 

384 source = sources.addNew() 

385 # make fake image 

386 psf, psfSum, exposure, s = createDipole(100, 100, 50, 50) 

387 

388 # set it in source with the appropriate schema 

389 source.setFootprint(s.getFootprint()) 

390 task.run(sources, exposure) 

391 self.assertEqual(source.get("ip_diffim_ClassificationDipole_value"), 1.0) 

392 

393 

394class TestMemory(lsst.utils.tests.MemoryTestCase): 

395 pass 

396 

397 

398def setup_module(module): 

399 lsst.utils.tests.init() 

400 

401 

402if __name__ == "__main__": 402 ↛ 403line 402 didn't jump to line 403, because the condition on line 402 was never true

403 lsst.utils.tests.init() 

404 unittest.main()