Coverage for tests / test_dipole.py: 14%

243 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-21 10:37 +0000

1# This file is part of ip_diffim. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import unittest 

23 

24import numpy as np 

25 

26import lsst.utils.tests 

27import lsst.daf.base as dafBase 

28import lsst.afw.image as afwImage 

29import lsst.afw.table as afwTable 

30import lsst.afw.math as afwMath 

31import lsst.geom as geom 

32import lsst.meas.algorithms as measAlg 

33import lsst.ip.diffim as ipDiffim 

34 

35display = False 

36try: 

37 display 

38except NameError: 

39 display = False 

40else: 

41 import lsst.afw.display as afwDisplay 

42 afwDisplay.setDefaultMaskTransparency(75) 

43 

44sigma2fwhm = 2.*np.sqrt(2.*np.log(2.)) 

45 

46 

47def makePluginAndCat(alg, name, control, metadata=False, centroid=None): 

48 schema = afwTable.SourceTable.makeMinimalSchema() 

49 if centroid: 

50 schema.addField(centroid + "_x", type=float) 

51 schema.addField(centroid + "_y", type=float) 

52 schema.addField(centroid + "_flag", type='Flag') 

53 schema.getAliasMap().set("slot_Centroid", centroid) 

54 if metadata: 

55 plugin = alg(control, name, schema, dafBase.PropertySet()) 

56 else: 

57 plugin = alg(control, name, schema) 

58 cat = afwTable.SourceCatalog(schema) 

59 return plugin, cat 

60 

61 

62def createDipole(w, h, xc, yc, scaling=100.0, fracOffset=1.2): 

63 # Make random noise image: set image plane to normal distribution 

64 image = afwImage.MaskedImageF(w, h) 

65 image.set(0) 

66 array = image.image.array 

67 array[:, :] = np.random.randn(w, h) 

68 # Set variance to 1.0 

69 var = image.getVariance() 

70 var.set(1.0) 

71 

72 if display: 

73 afwDisplay.Display(frame=1).mtv(image, title="Original image") 

74 afwDisplay.Display(frame=2).mtv(image.getVariance(), title="Original variance") 

75 

76 # Create Psf for dipole creation and measurement 

77 psfSize = 17 

78 psf = measAlg.DoubleGaussianPsf(psfSize, psfSize, 2.0, 3.5, 0.1) 

79 pos = psf.getAveragePosition() 

80 psfFwhmPix = sigma2fwhm*psf.computeShape(pos).getDeterminantRadius() 

81 psfim = psf.computeImage(pos).convertF() 

82 psfim *= scaling/psf.computePeak(pos) 

83 psfw, psfh = psfim.getDimensions() 

84 psfSum = np.sum(psfim.array) 

85 

86 # Create the dipole, offset by fracOffset of the Psf FWHM (pixels) 

87 offset = fracOffset*psfFwhmPix//2 

88 array = image.image.array 

89 xp = int(xc - psfw//2 + offset) 

90 yp = int(yc - psfh//2 + offset) 

91 array[yp:yp + psfh, xp:xp + psfw] += psfim.array 

92 

93 xn = int(xc - psfw//2 - offset) 

94 yn = int(yc - psfh//2 - offset) 

95 array[yn:yn + psfh, xn:xn + psfw] -= psfim.array 

96 

97 if display: 

98 afwDisplay.Display(frame=3).mtv(image, title="With dipole") 

99 

100 # Create an exposure, detect positive and negative peaks separately 

101 exp = afwImage.makeExposure(image) 

102 exp.setPsf(psf) 

103 config = measAlg.SourceDetectionConfig() 

104 config.thresholdPolarity = "both" 

105 config.reEstimateBackground = False 

106 schema = afwTable.SourceTable.makeMinimalSchema() 

107 task = measAlg.SourceDetectionTask(schema, config=config) 

108 table = afwTable.SourceTable.make(schema) 

109 results = task.run(table, exp) 

110 if display: 

111 afwDisplay.Display(frame=4).mtv(image, title="Detection plane") 

112 

113 # Merge them together 

114 assert len(results.sources) == 2 

115 fpSet = results.positive 

116 fpSet.merge(results.negative, 0, 0, False) 

117 sources = afwTable.SourceCatalog(table) 

118 fpSet.makeSources(sources) 

119 assert len(sources) == 1 

120 s = sources[0] 

121 assert len(s.getFootprint().getPeaks()) == 2 

122 

123 return psf, psfSum, exp, s 

124 

125 

126class DipoleAlgorithmTest(lsst.utils.tests.TestCase): 

127 """ A test case for dipole algorithms""" 

128 

129 def setUp(self): 

130 np.random.seed(666) 

131 self.w, self.h = 100, 100 # size of image 

132 self.xc, self.yc = 50, 50 # location of center of dipole 

133 

134 def testPsfDipoleFluxControl(self): 

135 psf, psfSum, exposure, s = createDipole(self.w, self.h, self.xc, self.yc) 

136 psf, psfSum, exposure, s = createDipole(self.w, self.h, self.xc, self.yc) 

137 control = ipDiffim.PsfDipoleFluxControl() 

138 psf, psfSum, exposure, s = createDipole(self.w, self.h, self.xc, self.yc) 

139 plugin, cat = makePluginAndCat(ipDiffim.PsfDipoleFlux, "test", control, centroid="centroid") 

140 source = cat.addNew() 

141 source.set("centroid_x", 50) 

142 source.set("centroid_y", 50) 

143 source.setFootprint(s.getFootprint()) 

144 plugin.measure(source, exposure) 

145 for key in ("_pos_instFlux", "_pos_instFluxErr", "_pos_flag", 

146 "_neg_instFlux", "_neg_instFluxErr", "_neg_flag"): 

147 try: 

148 source.get("test" + key) 

149 except Exception: 

150 self.fail() 

151 

152 def testAll(self): 

153 psf, psfSum, exposure, s = createDipole(self.w, self.h, self.xc, self.yc) 

154 self.measureDipole(s, exposure) 

155 

156 def _makeModel(self, exposure, psf, fp, negCenter, posCenter): 

157 

158 negPsf = psf.computeImage(negCenter).convertF() 

159 posPsf = psf.computeImage(posCenter).convertF() 

160 negPeak = psf.computePeak(negCenter) 

161 posPeak = psf.computePeak(posCenter) 

162 negPsf /= negPeak 

163 posPsf /= posPeak 

164 

165 model = afwImage.ImageF(fp.getBBox()) 

166 negModel = afwImage.ImageF(fp.getBBox()) 

167 posModel = afwImage.ImageF(fp.getBBox()) 

168 

169 # The center of the Psf should be at negCenter, posCenter 

170 negPsfBBox = negPsf.getBBox() 

171 posPsfBBox = posPsf.getBBox() 

172 modelBBox = model.getBBox() 

173 

174 # Portion of the negative Psf that overlaps the montage 

175 negOverlapBBox = geom.Box2I(negPsfBBox) 

176 negOverlapBBox.clip(modelBBox) 

177 self.assertFalse(negOverlapBBox.isEmpty()) 

178 

179 # Portion of the positivePsf that overlaps the montage 

180 posOverlapBBox = geom.Box2I(posPsfBBox) 

181 posOverlapBBox.clip(modelBBox) 

182 self.assertFalse(posOverlapBBox.isEmpty()) 

183 

184 negPsfSubim = type(negPsf)(negPsf, negOverlapBBox) 

185 modelSubim = type(model)(model, negOverlapBBox) 

186 negModelSubim = type(negModel)(negModel, negOverlapBBox) 

187 modelSubim += negPsfSubim # just for debugging 

188 negModelSubim += negPsfSubim # for fitting 

189 

190 posPsfSubim = type(posPsf)(posPsf, posOverlapBBox) 

191 modelSubim = type(model)(model, posOverlapBBox) 

192 posModelSubim = type(posModel)(posModel, posOverlapBBox) 

193 modelSubim += posPsfSubim 

194 posModelSubim += posPsfSubim 

195 

196 data = afwImage.ImageF(exposure.image, fp.getBBox()) 

197 var = afwImage.ImageF(exposure.variance, fp.getBBox()) 

198 matrixNorm = 1./np.sqrt(np.median(var.array)) 

199 

200 if display: 

201 afwDisplay.Display(frame=5).mtv(model, title="Unfitted model") 

202 afwDisplay.Display(frame=6).mtv(data, title="Data") 

203 

204 posPsfSum = np.sum(posPsf.array) 

205 negPsfSum = np.sum(negPsf.array) 

206 

207 M = np.array((np.ravel(negModel.array), np.ravel(posModel.array))).T.astype(np.float64) 

208 B = np.array((np.ravel(data.array))).astype(np.float64) 

209 M *= matrixNorm 

210 B *= matrixNorm 

211 

212 # Numpy solution 

213 fneg0, fpos0 = np.linalg.lstsq(M, B, rcond=-1)[0] 

214 

215 # Afw solution 

216 lsq = afwMath.LeastSquares.fromDesignMatrix(M, B, afwMath.LeastSquares.DIRECT_SVD) 

217 fneg, fpos = lsq.getSolution() 

218 

219 # Should be exaxtly the same as each other 

220 self.assertAlmostEqual(1e-2*fneg0, 1e-2*fneg) 

221 self.assertAlmostEqual(1e-2*fpos0, 1e-2*fpos) 

222 

223 # Recreate model 

224 fitted = afwImage.ImageF(fp.getBBox()) 

225 negFit = type(negPsf)(negPsf, negOverlapBBox, afwImage.PARENT, True) 

226 negFit *= float(fneg) 

227 posFit = type(posPsf)(posPsf, posOverlapBBox, afwImage.PARENT, True) 

228 posFit *= float(fpos) 

229 

230 fitSubim = type(fitted)(fitted, negOverlapBBox) 

231 fitSubim += negFit 

232 fitSubim = type(fitted)(fitted, posOverlapBBox) 

233 fitSubim += posFit 

234 if display: 

235 afwDisplay.Display(frame=7).mtv(fitted, title="Fitted model") 

236 

237 fitted -= data 

238 

239 if display: 

240 afwDisplay.Display(frame=8).mtv(fitted, title="Residuals") 

241 

242 fitted *= fitted 

243 fitted /= var 

244 

245 if display: 

246 afwDisplay.Display(frame=9).mtv(fitted, title="Chi2") 

247 

248 return fneg, negPsfSum, fpos, posPsfSum, fitted 

249 

250 def testPsfDipoleFit(self, scaling=100.): 

251 psf, psfSum, exposure, s = createDipole(self.w, self.h, self.xc, self.yc, scaling=scaling) 

252 source = self.measureDipole(s, exposure) 

253 # Recreate the simultaneous joint Psf fit in python 

254 fp = source.getFootprint() 

255 peaks = fp.getPeaks() 

256 speaks = [(p.getPeakValue(), p) for p in peaks] 

257 speaks.sort() 

258 dpeaks = [speaks[0][1], speaks[-1][1]] 

259 

260 negCenter = geom.Point2D(dpeaks[0].getFx(), dpeaks[0].getFy()) 

261 posCenter = geom.Point2D(dpeaks[1].getFx(), dpeaks[1].getFy()) 

262 

263 fneg, negPsfSum, fpos, posPsfSum, residIm = self._makeModel(exposure, psf, fp, negCenter, posCenter) 

264 

265 # Should be close to the same as the inputs; as fracOffset 

266 # gets smaller this will be worse. This works for scaling = 

267 # 100. 

268 self.assertAlmostEqual(1e-2*scaling, -1e-2*fneg, 2) 

269 self.assertAlmostEqual(1e-2*scaling, 1e-2*fpos, 2) 

270 

271 # Now compare the LeastSquares results fitted here to the C++ 

272 # implementation: Since total flux is returned, and this is of 

273 # order 1e4 for this default test, scale back down so that 

274 # assertAlmostEqual behaves reasonably (the comparison to 2 

275 # places means to 0.01). Also note that PsfDipoleFlux returns 

276 # the total flux, while here we are just fitting for the 

277 # scaling of the Psf. Therefore the comparison is 

278 # fneg*negPsfSum to flux.dipole.psf.neg. 

279 self.assertAlmostEqual(1e-4*fneg*negPsfSum, 

280 1e-4*source.get("ip_diffim_PsfDipoleFlux_neg_instFlux"), 

281 2) 

282 self.assertAlmostEqual(1e-4*fpos*posPsfSum, 

283 1e-4*source.get("ip_diffim_PsfDipoleFlux_pos_instFlux"), 

284 2) 

285 

286 self.assertGreater(source.get("ip_diffim_PsfDipoleFlux_pos_instFluxErr"), 0.0) 

287 self.assertGreater(source.get("ip_diffim_PsfDipoleFlux_neg_instFluxErr"), 0.0) 

288 self.assertFalse(source.get("ip_diffim_PsfDipoleFlux_neg_flag")) 

289 self.assertFalse(source.get("ip_diffim_PsfDipoleFlux_pos_flag")) 

290 

291 self.assertAlmostEqual(source.get("ip_diffim_PsfDipoleFlux_x"), 50.0, 1) 

292 self.assertAlmostEqual(source.get("ip_diffim_PsfDipoleFlux_y"), 50.0, 1) 

293 self.assertAlmostEqual(source.get("ip_diffim_PsfDipoleFlux_neg_centroid_x"), negCenter[0], 1) 

294 self.assertAlmostEqual(source.get("ip_diffim_PsfDipoleFlux_neg_centroid_y"), negCenter[1], 1) 

295 self.assertAlmostEqual(source.get("ip_diffim_PsfDipoleFlux_pos_centroid_x"), posCenter[0], 1) 

296 self.assertAlmostEqual(source.get("ip_diffim_PsfDipoleFlux_pos_centroid_y"), posCenter[1], 1) 

297 self.assertFalse(source.get("ip_diffim_PsfDipoleFlux_neg_flag")) 

298 self.assertFalse(source.get("ip_diffim_PsfDipoleFlux_pos_flag")) 

299 

300 self.assertGreater(source.get("ip_diffim_PsfDipoleFlux_chi2dof"), 0.0) 

301 

302 def measureDipole(self, s, exp): 

303 msConfig = ipDiffim.DipoleMeasurementConfig() 

304 schema = afwTable.SourceTable.makeMinimalSchema() 

305 schema.addField("centroid_x", type=float) 

306 schema.addField("centroid_y", type=float) 

307 schema.addField("centroid_flag", type='Flag') 

308 task = ipDiffim.DipoleMeasurementTask(schema, config=msConfig) 

309 measCat = afwTable.SourceCatalog(schema) 

310 measCat.defineCentroid("centroid") 

311 source = measCat.addNew() 

312 source.set("centroid_x", self.xc) 

313 source.set("centroid_y", self.yc) 

314 source.setFootprint(s.getFootprint()) 

315 # Then run the default SFM task. Results not checked 

316 task.run(measCat, exp) 

317 return measCat[0] 

318 

319 def testDipoleAnalysis(self): 

320 psf, psfSum, exposure, s = createDipole(self.w, self.h, self.xc, self.yc) 

321 source = self.measureDipole(s, exposure) 

322 dpAnalysis = ipDiffim.DipoleAnalysis() 

323 dpAnalysis(source) 

324 

325 def testDipoleDeblender(self): 

326 psf, psfSum, exposure, s = createDipole(self.w, self.h, self.xc, self.yc) 

327 source = self.measureDipole(s, exposure) 

328 dpDeblender = ipDiffim.DipoleDeblender() 

329 dpDeblender(source, exposure) 

330 

331 

332class DipoleMeasurementTaskTest(lsst.utils.tests.TestCase): 

333 """A test case for the DipoleMeasurementTask. Essentially just 

334 test the classification flag since the invididual algorithms are 

335 tested above""" 

336 

337 def setUp(self): 

338 np.random.seed(666) 

339 self.config = ipDiffim.DipoleMeasurementConfig() 

340 

341 def tearDown(self): 

342 del self.config 

343 

344 def testMeasure(self): 

345 schema = afwTable.SourceTable.makeMinimalSchema() 

346 task = ipDiffim.DipoleMeasurementTask(schema, config=self.config) 

347 table = afwTable.SourceTable.make(schema) 

348 sources = afwTable.SourceCatalog(table) 

349 source = sources.addNew() 

350 # make fake image 

351 psf, psfSum, exposure, s = createDipole(100, 100, 50, 50) 

352 

353 # set it in source with the appropriate schema 

354 source.setFootprint(s.getFootprint()) 

355 task.run(sources, exposure) 

356 self.assertEqual(source.get("ip_diffim_ClassificationDipole_value"), 1.0) 

357 

358 

359class TestMemory(lsst.utils.tests.MemoryTestCase): 

360 pass 

361 

362 

363def setup_module(module): 

364 lsst.utils.tests.init() 

365 

366 

367if __name__ == "__main__": 367 ↛ 368line 367 didn't jump to line 368 because the condition on line 367 was never true

368 lsst.utils.tests.init() 

369 unittest.main()