Coverage for tests/test_strayFlux.py: 9%

242 statements  

« prev     ^ index     » next       coverage.py v7.2.4, created at 2023-04-29 11:19 +0000

1# 

2# LSST Data Management System 

3# 

4# Copyright 2008-2016 AURA/LSST. 

5# 

6# This product includes software developed by the 

7# LSST Project (http://www.lsst.org/). 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the LSST License Statement and 

20# the GNU General Public License along with this program. If not, 

21# see <https://www.lsstcorp.org/LegalNotices/>. 

22# 

23import unittest 

24import numpy as np 

25from functools import reduce 

26 

27import lsst.utils.tests 

28import lsst.afw.detection as afwDet 

29import lsst.geom as geom 

30import lsst.afw.image as afwImage 

31from lsst.log import Log 

32from lsst.meas.deblender.baseline import deblend 

33import lsst.meas.algorithms as measAlg 

34 

35doPlot = False 

36if doPlot: 36 ↛ 37line 36 didn't jump to line 37, because the condition on line 36 was never true

37 import matplotlib 

38 matplotlib.use('Agg') 

39 import pylab as plt 

40 import os.path 

41 plotpat = os.path.join(os.path.dirname(__file__), 'stray%i.png') 

42 print('Writing plots to', plotpat) 

43else: 

44 print('"doPlot" not set -- not making plots. To enable plots, edit', __file__) 

45 

46# Lower the level to Log.DEBUG to see debug messages 

47Log.getLogger('lsst.meas.deblender.symmetrizeFootprint').setLevel(Log.INFO) 

48 

49 

50def imExt(img): 

51 bbox = img.getBBox() 

52 return [bbox.getMinX(), bbox.getMaxX(), 

53 bbox.getMinY(), bbox.getMaxY()] 

54 

55 

56def doubleGaussianPsf(W, H, fwhm1, fwhm2, a2): 

57 return measAlg.DoubleGaussianPsf(W, H, fwhm1, fwhm2, a2) 

58 

59 

60def gaussianPsf(W, H, fwhm): 

61 return measAlg.DoubleGaussianPsf(W, H, fwhm) 

62 

63 

64class StrayFluxTestCase(lsst.utils.tests.TestCase): 

65 

66 def test1(self): 

67 """A simple example: three overlapping blobs (detected as 1 

68 footprint with three peaks). We artificially omit one of the 

69 peaks, meaning that its flux is "stray". Assert that the 

70 stray flux assigned to the other two peaks accounts for all 

71 the flux in the parent. 

72 """ 

73 H, W = 100, 100 

74 

75 fpbb = geom.Box2I(geom.Point2I(0, 0), 

76 geom.Point2I(W-1, H-1)) 

77 

78 afwimg = afwImage.MaskedImageF(fpbb) 

79 imgbb = afwimg.getBBox() 

80 img = afwimg.getImage().getArray() 

81 

82 var = afwimg.getVariance().getArray() 

83 var[:, :] = 1. 

84 

85 blob_fwhm = 10. 

86 blob_psf = doubleGaussianPsf(99, 99, blob_fwhm, 3.*blob_fwhm, 0.03) 

87 

88 fakepsf_fwhm = 3. 

89 fakepsf = gaussianPsf(11, 11, fakepsf_fwhm) 

90 

91 blobimgs = [] 

92 x = 75. 

93 XY = [(x, 35.), (x, 65.), (50., 50.)] 

94 flux = 1e6 

95 for x, y in XY: 

96 bim = blob_psf.computeImage(geom.Point2D(x, y)) 

97 bbb = bim.getBBox() 

98 bbb.clip(imgbb) 

99 

100 bim = bim.Factory(bim, bbb) 

101 bim2 = bim.getArray() 

102 

103 blobimg = np.zeros_like(img) 

104 blobimg[bbb.getMinY():bbb.getMaxY()+1, 

105 bbb.getMinX():bbb.getMaxX()+1] += flux * bim2 

106 blobimgs.append(blobimg) 

107 

108 img[bbb.getMinY():bbb.getMaxY()+1, 

109 bbb.getMinX():bbb.getMaxX()+1] += flux * bim2 

110 

111 # Run the detection code to get a ~ realistic footprint 

112 thresh = afwDet.createThreshold(5., 'value', True) 

113 fpSet = afwDet.FootprintSet(afwimg, thresh, 'DETECTED', 1) 

114 fps = fpSet.getFootprints() 

115 print('found', len(fps), 'footprints') 

116 pks2 = [] 

117 for fp in fps: 

118 print('peaks:', len(fp.getPeaks())) 

119 for pk in fp.getPeaks(): 

120 print(' ', pk.getIx(), pk.getIy()) 

121 pks2.append((pk.getIx(), pk.getIy())) 

122 

123 # The first peak in this list is the one we want to omit. 

124 fp0 = fps[0] 

125 fakefp = afwDet.Footprint(fp0.getSpans(), fp0.getBBox()) 

126 for pk in fp0.getPeaks()[1:]: 

127 fakefp.getPeaks().append(pk) 

128 

129 ima = dict(interpolation='nearest', origin='lower', cmap='gray', 

130 vmin=0, vmax=1e3) 

131 

132 if doPlot: 

133 plt.figure(figsize=(12, 6)) 

134 

135 plt.clf() 

136 plt.suptitle('strayFlux.py: test1 input') 

137 plt.subplot(2, 2, 1) 

138 plt.title('Image') 

139 plt.imshow(img, **ima) 

140 ax = plt.axis() 

141 plt.plot([x for x, y in XY], [y for x, y in XY], 'r.') 

142 plt.axis(ax) 

143 for i, (b, (x, y)) in enumerate(zip(blobimgs, XY)): 

144 plt.subplot(2, 2, 2+i) 

145 plt.title('Blob %i' % i) 

146 plt.imshow(b, **ima) 

147 ax = plt.axis() 

148 plt.plot(x, y, 'r.') 

149 plt.axis(ax) 

150 plt.savefig(plotpat % 1) 

151 

152 # Change verbose to False to quiet down the meas_deblender.baseline logger 

153 deb = deblend(fakefp, afwimg, fakepsf, fakepsf_fwhm, verbose=True) 

154 parent_img = afwImage.ImageF(fpbb) 

155 fakefp.spans.copyImage(afwimg.getImage(), parent_img) 

156 

157 if doPlot: 

158 def myimshow(*args, **kwargs): 

159 plt.imshow(*args, **kwargs) 

160 plt.xticks([]) 

161 plt.yticks([]) 

162 plt.axis(imExt(afwimg)) 

163 

164 plt.clf() 

165 plt.suptitle('strayFlux.py: test1 results') 

166 # R,C = 3,5 

167 R, C = 3, 4 

168 plt.subplot(R, C, (2*C) + 1) 

169 plt.title('Image') 

170 myimshow(img, **ima) 

171 ax = plt.axis() 

172 plt.plot([x for x, y in XY], [y for x, y in XY], 'r.') 

173 plt.axis(ax) 

174 

175 plt.subplot(R, C, (2*C) + 2) 

176 plt.title('Parent footprint') 

177 myimshow(parent_img.getArray(), **ima) 

178 ax = plt.axis() 

179 plt.plot([pk.getIx() for pk in fakefp.getPeaks()], 

180 [pk.getIy() for pk in fakefp.getPeaks()], 'r.') 

181 plt.axis(ax) 

182 

183 sumimg = None 

184 for i, dpk in enumerate(deb.peaks): 

185 plt.subplot(R, C, i*C + 1) 

186 plt.title('ch%i symm' % i) 

187 symm = dpk.templateImage 

188 myimshow(symm.getArray(), extent=imExt(symm), **ima) 

189 

190 plt.subplot(R, C, i*C + 2) 

191 plt.title('ch%i portion' % i) 

192 port = dpk.fluxPortion.getImage() 

193 myimshow(port.getArray(), extent=imExt(port), **ima) 

194 

195 himg = afwImage.ImageF(fpbb) 

196 heavy = dpk.getFluxPortion(strayFlux=False) 

197 heavy.insert(himg) 

198 

199 # plt.subplot(R, C, i*C + 3) 

200 # plt.title('ch%i heavy' % i) 

201 # myimshow(himg.getArray(), **ima) 

202 # ax = plt.axis() 

203 # plt.plot([x for x,y in XY], [y for x,y in XY], 'r.') 

204 # plt.axis(ax) 

205 

206 simg = afwImage.ImageF(fpbb) 

207 dpk.strayFlux.insert(simg) 

208 

209 plt.subplot(R, C, i*C + 3) 

210 plt.title('ch%i stray' % i) 

211 myimshow(simg.getArray(), **ima) 

212 ax = plt.axis() 

213 plt.plot([x for x, y in XY], [y for x, y in XY], 'r.') 

214 plt.axis(ax) 

215 

216 himg2 = afwImage.ImageF(fpbb) 

217 heavy = dpk.getFluxPortion(strayFlux=True) 

218 heavy.insert(himg2) 

219 

220 if sumimg is None: 

221 sumimg = himg2.getArray().copy() 

222 else: 

223 sumimg += himg2.getArray() 

224 

225 plt.subplot(R, C, i*C + 4) 

226 myimshow(himg2.getArray(), **ima) 

227 plt.title('ch%i total' % i) 

228 ax = plt.axis() 

229 plt.plot([x for x, y in XY], [y for x, y in XY], 'r.') 

230 plt.axis(ax) 

231 

232 plt.subplot(R, C, (2*C) + C) 

233 myimshow(sumimg, **ima) 

234 ax = plt.axis() 

235 plt.plot([x for x, y in XY], [y for x, y in XY], 'r.') 

236 plt.axis(ax) 

237 plt.title('Sum of deblends') 

238 

239 plt.savefig(plotpat % 2) 

240 

241 # Compute the sum-of-children image 

242 sumimg = None 

243 for i, dpk in enumerate(deb.deblendedParents[0].peaks): 

244 himg2 = afwImage.ImageF(fpbb) 

245 dpk.getFluxPortion().insert(himg2) 

246 if sumimg is None: 

247 sumimg = himg2.getArray().copy() 

248 else: 

249 sumimg += himg2.getArray() 

250 

251 # Sum of children ~= Original image inside footprint (parent_img) 

252 

253 absdiff = np.max(np.abs(sumimg - parent_img.getArray())) 

254 print('Max abs diff:', absdiff) 

255 imgmax = parent_img.getArray().max() 

256 print('Img max:', imgmax) 

257 self.assertLess(absdiff, imgmax*1e-6) 

258 

259 def test2(self): 

260 """A 1-d example, to test the stray-flux assignment. 

261 """ 

262 H, W = 1, 100 

263 

264 fpbb = geom.Box2I(geom.Point2I(0, 0), 

265 geom.Point2I(W-1, H-1)) 

266 afwimg = afwImage.MaskedImageF(fpbb) 

267 img = afwimg.getImage().getArray() 

268 

269 var = afwimg.getVariance().getArray() 

270 var[:, :] = 1. 

271 

272 y = 0 

273 img[y, 1:-1] = 10. 

274 

275 img[0, 1] = 20. 

276 img[0, -2] = 20. 

277 

278 fakepsf_fwhm = 1. 

279 fakepsf = gaussianPsf(1, 1, fakepsf_fwhm) 

280 

281 # Run the detection code to get a ~ realistic footprint 

282 thresh = afwDet.createThreshold(5., 'value', True) 

283 fpSet = afwDet.FootprintSet(afwimg, thresh, 'DETECTED', 1) 

284 fps = fpSet.getFootprints() 

285 self.assertEqual(len(fps), 1) 

286 fp = fps[0] 

287 

288 # WORKAROUND: the detection alg produces ONE peak, at (1,0), 

289 # rather than two. 

290 self.assertEqual(len(fp.getPeaks()), 1) 

291 fp.addPeak(W-2, y, float("NaN")) 

292 # print 'Added peak; peaks:', len(fp.getPeaks()) 

293 # for pk in fp.getPeaks(): 

294 # print ' ', pk.getFx(), pk.getFy() 

295 

296 # Change verbose to False to quiet down the meas_deblender.baseline logger 

297 deb = deblend(fp, afwimg, fakepsf, fakepsf_fwhm, verbose=True, 

298 fitPsfs=False, ) 

299 

300 if doPlot: 

301 XX = np.arange(W+1).repeat(2)[1:-1] 

302 

303 plt.clf() 

304 p1 = plt.plot(XX, img[y, :].repeat(2), 'g-', lw=3, alpha=0.3) 

305 

306 for i, dpk in enumerate(deb.peaks): 

307 print(dpk) 

308 port = dpk.fluxPortion.getImage() 

309 bb = port.getBBox() 

310 YY = np.zeros(XX.shape) 

311 YY[bb.getMinX()*2: (bb.getMaxX()+1)*2] = port.getArray()[0, :].repeat(2) 

312 p2 = plt.plot(XX, YY, 'r-') 

313 

314 simg = afwImage.ImageF(fpbb) 

315 dpk.strayFlux.insert(simg) 

316 p3 = plt.plot(XX, simg.getArray()[y, :].repeat(2), 'b-') 

317 

318 plt.legend((p1[0], p2[0], p3[0]), 

319 ('Parent Flux', 'Child portion', 'Child stray flux')) 

320 plt.ylim(-2, 22) 

321 plt.savefig(plotpat % 3) 

322 

323 strays = [] 

324 for i, dpk in enumerate(deb.deblendedParents[0].peaks): 

325 simg = afwImage.ImageF(fpbb) 

326 dpk.strayFlux.insert(simg) 

327 strays.append(simg.getArray()) 

328 

329 ssum = reduce(np.add, strays) 

330 

331 starget = np.zeros(W) 

332 starget[2:-2] = 10. 

333 

334 self.assertFloatsEqual(ssum, starget) 

335 

336 X = np.arange(W) 

337 dx1 = X - 1. 

338 dx2 = X - (W-2) 

339 f1 = (1. / (1. + dx1**2)) 

340 f2 = (1. / (1. + dx2**2)) 

341 strayclip = 0.001 

342 fsum = f1 + f2 

343 f1[f1 < strayclip * fsum] = 0. 

344 f2[f2 < strayclip * fsum] = 0. 

345 

346 s1 = f1 / (f1+f2) * 10. 

347 s2 = f2 / (f1+f2) * 10. 

348 

349 s1[:2] = 0. 

350 s2[-2:] = 0. 

351 

352 if doPlot: 

353 p4 = plt.plot(XX, s1.repeat(2), 'm-') 

354 plt.plot(XX, s2.repeat(2), 'm-') 

355 

356 plt.legend((p1[0], p2[0], p3[0], p4[0]), 

357 ('Parent Flux', 'Child portion', 'Child stray flux', 

358 'Expected stray flux')) 

359 plt.ylim(-2, 22) 

360 plt.savefig(plotpat % 4) 

361 

362 # test abs diff 

363 d = np.max(np.abs(s1 - strays[0])) 

364 self.assertLess(d, 1e-6) 

365 d = np.max(np.abs(s2 - strays[1])) 

366 self.assertLess(d, 1e-6) 

367 

368 # test relative diff 

369 self.assertLess(np.max(np.abs(s1 - strays[0])/np.maximum(1e-3, s1)), 1e-6) 

370 self.assertLess(np.max(np.abs(s2 - strays[1])/np.maximum(1e-3, s2)), 1e-6) 

371 

372 

373class TestMemory(lsst.utils.tests.MemoryTestCase): 

374 pass 

375 

376 

377def setup_module(module): 

378 lsst.utils.tests.init() 

379 

380 

381if __name__ == "__main__": 381 ↛ 382line 381 didn't jump to line 382, because the condition on line 381 was never true

382 lsst.utils.tests.init() 

383 unittest.main()