Coverage for tests/test_strayFlux.py: 9%
242 statements
« prev ^ index » next coverage.py v7.2.6, created at 2023-05-27 09:50 +0000
« prev ^ index » next coverage.py v7.2.6, created at 2023-05-27 09:50 +0000
1#
2# LSST Data Management System
3#
4# Copyright 2008-2016 AURA/LSST.
5#
6# This product includes software developed by the
7# LSST Project (http://www.lsst.org/).
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the LSST License Statement and
20# the GNU General Public License along with this program. If not,
21# see <https://www.lsstcorp.org/LegalNotices/>.
22#
23import unittest
24import numpy as np
25from functools import reduce
27import lsst.utils.tests
28import lsst.afw.detection as afwDet
29import lsst.geom as geom
30import lsst.afw.image as afwImage
31from lsst.log import Log
32from lsst.meas.deblender.baseline import deblend
33import lsst.meas.algorithms as measAlg
35doPlot = False
36if doPlot: 36 ↛ 37line 36 didn't jump to line 37, because the condition on line 36 was never true
37 import matplotlib
38 matplotlib.use('Agg')
39 import pylab as plt
40 import os.path
41 plotpat = os.path.join(os.path.dirname(__file__), 'stray%i.png')
42 print('Writing plots to', plotpat)
43else:
44 print('"doPlot" not set -- not making plots. To enable plots, edit', __file__)
46# Lower the level to Log.DEBUG to see debug messages
47Log.getLogger('lsst.meas.deblender.symmetrizeFootprint').setLevel(Log.INFO)
50def imExt(img):
51 bbox = img.getBBox()
52 return [bbox.getMinX(), bbox.getMaxX(),
53 bbox.getMinY(), bbox.getMaxY()]
56def doubleGaussianPsf(W, H, fwhm1, fwhm2, a2):
57 return measAlg.DoubleGaussianPsf(W, H, fwhm1, fwhm2, a2)
60def gaussianPsf(W, H, fwhm):
61 return measAlg.DoubleGaussianPsf(W, H, fwhm)
64class StrayFluxTestCase(lsst.utils.tests.TestCase):
66 def test1(self):
67 """A simple example: three overlapping blobs (detected as 1
68 footprint with three peaks). We artificially omit one of the
69 peaks, meaning that its flux is "stray". Assert that the
70 stray flux assigned to the other two peaks accounts for all
71 the flux in the parent.
72 """
73 H, W = 100, 100
75 fpbb = geom.Box2I(geom.Point2I(0, 0),
76 geom.Point2I(W-1, H-1))
78 afwimg = afwImage.MaskedImageF(fpbb)
79 imgbb = afwimg.getBBox()
80 img = afwimg.getImage().getArray()
82 var = afwimg.getVariance().getArray()
83 var[:, :] = 1.
85 blob_fwhm = 10.
86 blob_psf = doubleGaussianPsf(99, 99, blob_fwhm, 3.*blob_fwhm, 0.03)
88 fakepsf_fwhm = 3.
89 fakepsf = gaussianPsf(11, 11, fakepsf_fwhm)
91 blobimgs = []
92 x = 75.
93 XY = [(x, 35.), (x, 65.), (50., 50.)]
94 flux = 1e6
95 for x, y in XY:
96 bim = blob_psf.computeImage(geom.Point2D(x, y))
97 bbb = bim.getBBox()
98 bbb.clip(imgbb)
100 bim = bim.Factory(bim, bbb)
101 bim2 = bim.getArray()
103 blobimg = np.zeros_like(img)
104 blobimg[bbb.getMinY():bbb.getMaxY()+1,
105 bbb.getMinX():bbb.getMaxX()+1] += flux * bim2
106 blobimgs.append(blobimg)
108 img[bbb.getMinY():bbb.getMaxY()+1,
109 bbb.getMinX():bbb.getMaxX()+1] += flux * bim2
111 # Run the detection code to get a ~ realistic footprint
112 thresh = afwDet.createThreshold(5., 'value', True)
113 fpSet = afwDet.FootprintSet(afwimg, thresh, 'DETECTED', 1)
114 fps = fpSet.getFootprints()
115 print('found', len(fps), 'footprints')
116 pks2 = []
117 for fp in fps:
118 print('peaks:', len(fp.getPeaks()))
119 for pk in fp.getPeaks():
120 print(' ', pk.getIx(), pk.getIy())
121 pks2.append((pk.getIx(), pk.getIy()))
123 # The first peak in this list is the one we want to omit.
124 fp0 = fps[0]
125 fakefp = afwDet.Footprint(fp0.getSpans(), fp0.getBBox())
126 for pk in fp0.getPeaks()[1:]:
127 fakefp.getPeaks().append(pk)
129 ima = dict(interpolation='nearest', origin='lower', cmap='gray',
130 vmin=0, vmax=1e3)
132 if doPlot:
133 plt.figure(figsize=(12, 6))
135 plt.clf()
136 plt.suptitle('strayFlux.py: test1 input')
137 plt.subplot(2, 2, 1)
138 plt.title('Image')
139 plt.imshow(img, **ima)
140 ax = plt.axis()
141 plt.plot([x for x, y in XY], [y for x, y in XY], 'r.')
142 plt.axis(ax)
143 for i, (b, (x, y)) in enumerate(zip(blobimgs, XY)):
144 plt.subplot(2, 2, 2+i)
145 plt.title('Blob %i' % i)
146 plt.imshow(b, **ima)
147 ax = plt.axis()
148 plt.plot(x, y, 'r.')
149 plt.axis(ax)
150 plt.savefig(plotpat % 1)
152 # Change verbose to False to quiet down the meas_deblender.baseline logger
153 deb = deblend(fakefp, afwimg, fakepsf, fakepsf_fwhm, verbose=True)
154 parent_img = afwImage.ImageF(fpbb)
155 fakefp.spans.copyImage(afwimg.getImage(), parent_img)
157 if doPlot:
158 def myimshow(*args, **kwargs):
159 plt.imshow(*args, **kwargs)
160 plt.xticks([])
161 plt.yticks([])
162 plt.axis(imExt(afwimg))
164 plt.clf()
165 plt.suptitle('strayFlux.py: test1 results')
166 # R,C = 3,5
167 R, C = 3, 4
168 plt.subplot(R, C, (2*C) + 1)
169 plt.title('Image')
170 myimshow(img, **ima)
171 ax = plt.axis()
172 plt.plot([x for x, y in XY], [y for x, y in XY], 'r.')
173 plt.axis(ax)
175 plt.subplot(R, C, (2*C) + 2)
176 plt.title('Parent footprint')
177 myimshow(parent_img.getArray(), **ima)
178 ax = plt.axis()
179 plt.plot([pk.getIx() for pk in fakefp.getPeaks()],
180 [pk.getIy() for pk in fakefp.getPeaks()], 'r.')
181 plt.axis(ax)
183 sumimg = None
184 for i, dpk in enumerate(deb.peaks):
185 plt.subplot(R, C, i*C + 1)
186 plt.title('ch%i symm' % i)
187 symm = dpk.templateImage
188 myimshow(symm.getArray(), extent=imExt(symm), **ima)
190 plt.subplot(R, C, i*C + 2)
191 plt.title('ch%i portion' % i)
192 port = dpk.fluxPortion.getImage()
193 myimshow(port.getArray(), extent=imExt(port), **ima)
195 himg = afwImage.ImageF(fpbb)
196 heavy = dpk.getFluxPortion(strayFlux=False)
197 heavy.insert(himg)
199 # plt.subplot(R, C, i*C + 3)
200 # plt.title('ch%i heavy' % i)
201 # myimshow(himg.getArray(), **ima)
202 # ax = plt.axis()
203 # plt.plot([x for x,y in XY], [y for x,y in XY], 'r.')
204 # plt.axis(ax)
206 simg = afwImage.ImageF(fpbb)
207 dpk.strayFlux.insert(simg)
209 plt.subplot(R, C, i*C + 3)
210 plt.title('ch%i stray' % i)
211 myimshow(simg.getArray(), **ima)
212 ax = plt.axis()
213 plt.plot([x for x, y in XY], [y for x, y in XY], 'r.')
214 plt.axis(ax)
216 himg2 = afwImage.ImageF(fpbb)
217 heavy = dpk.getFluxPortion(strayFlux=True)
218 heavy.insert(himg2)
220 if sumimg is None:
221 sumimg = himg2.getArray().copy()
222 else:
223 sumimg += himg2.getArray()
225 plt.subplot(R, C, i*C + 4)
226 myimshow(himg2.getArray(), **ima)
227 plt.title('ch%i total' % i)
228 ax = plt.axis()
229 plt.plot([x for x, y in XY], [y for x, y in XY], 'r.')
230 plt.axis(ax)
232 plt.subplot(R, C, (2*C) + C)
233 myimshow(sumimg, **ima)
234 ax = plt.axis()
235 plt.plot([x for x, y in XY], [y for x, y in XY], 'r.')
236 plt.axis(ax)
237 plt.title('Sum of deblends')
239 plt.savefig(plotpat % 2)
241 # Compute the sum-of-children image
242 sumimg = None
243 for i, dpk in enumerate(deb.deblendedParents[0].peaks):
244 himg2 = afwImage.ImageF(fpbb)
245 dpk.getFluxPortion().insert(himg2)
246 if sumimg is None:
247 sumimg = himg2.getArray().copy()
248 else:
249 sumimg += himg2.getArray()
251 # Sum of children ~= Original image inside footprint (parent_img)
253 absdiff = np.max(np.abs(sumimg - parent_img.getArray()))
254 print('Max abs diff:', absdiff)
255 imgmax = parent_img.getArray().max()
256 print('Img max:', imgmax)
257 self.assertLess(absdiff, imgmax*1e-6)
259 def test2(self):
260 """A 1-d example, to test the stray-flux assignment.
261 """
262 H, W = 1, 100
264 fpbb = geom.Box2I(geom.Point2I(0, 0),
265 geom.Point2I(W-1, H-1))
266 afwimg = afwImage.MaskedImageF(fpbb)
267 img = afwimg.getImage().getArray()
269 var = afwimg.getVariance().getArray()
270 var[:, :] = 1.
272 y = 0
273 img[y, 1:-1] = 10.
275 img[0, 1] = 20.
276 img[0, -2] = 20.
278 fakepsf_fwhm = 1.
279 fakepsf = gaussianPsf(1, 1, fakepsf_fwhm)
281 # Run the detection code to get a ~ realistic footprint
282 thresh = afwDet.createThreshold(5., 'value', True)
283 fpSet = afwDet.FootprintSet(afwimg, thresh, 'DETECTED', 1)
284 fps = fpSet.getFootprints()
285 self.assertEqual(len(fps), 1)
286 fp = fps[0]
288 # WORKAROUND: the detection alg produces ONE peak, at (1,0),
289 # rather than two.
290 self.assertEqual(len(fp.getPeaks()), 1)
291 fp.addPeak(W-2, y, float("NaN"))
292 # print 'Added peak; peaks:', len(fp.getPeaks())
293 # for pk in fp.getPeaks():
294 # print ' ', pk.getFx(), pk.getFy()
296 # Change verbose to False to quiet down the meas_deblender.baseline logger
297 deb = deblend(fp, afwimg, fakepsf, fakepsf_fwhm, verbose=True,
298 fitPsfs=False, )
300 if doPlot:
301 XX = np.arange(W+1).repeat(2)[1:-1]
303 plt.clf()
304 p1 = plt.plot(XX, img[y, :].repeat(2), 'g-', lw=3, alpha=0.3)
306 for i, dpk in enumerate(deb.peaks):
307 print(dpk)
308 port = dpk.fluxPortion.getImage()
309 bb = port.getBBox()
310 YY = np.zeros(XX.shape)
311 YY[bb.getMinX()*2: (bb.getMaxX()+1)*2] = port.getArray()[0, :].repeat(2)
312 p2 = plt.plot(XX, YY, 'r-')
314 simg = afwImage.ImageF(fpbb)
315 dpk.strayFlux.insert(simg)
316 p3 = plt.plot(XX, simg.getArray()[y, :].repeat(2), 'b-')
318 plt.legend((p1[0], p2[0], p3[0]),
319 ('Parent Flux', 'Child portion', 'Child stray flux'))
320 plt.ylim(-2, 22)
321 plt.savefig(plotpat % 3)
323 strays = []
324 for i, dpk in enumerate(deb.deblendedParents[0].peaks):
325 simg = afwImage.ImageF(fpbb)
326 dpk.strayFlux.insert(simg)
327 strays.append(simg.getArray())
329 ssum = reduce(np.add, strays)
331 starget = np.zeros(W)
332 starget[2:-2] = 10.
334 self.assertFloatsEqual(ssum, starget)
336 X = np.arange(W)
337 dx1 = X - 1.
338 dx2 = X - (W-2)
339 f1 = (1. / (1. + dx1**2))
340 f2 = (1. / (1. + dx2**2))
341 strayclip = 0.001
342 fsum = f1 + f2
343 f1[f1 < strayclip * fsum] = 0.
344 f2[f2 < strayclip * fsum] = 0.
346 s1 = f1 / (f1+f2) * 10.
347 s2 = f2 / (f1+f2) * 10.
349 s1[:2] = 0.
350 s2[-2:] = 0.
352 if doPlot:
353 p4 = plt.plot(XX, s1.repeat(2), 'm-')
354 plt.plot(XX, s2.repeat(2), 'm-')
356 plt.legend((p1[0], p2[0], p3[0], p4[0]),
357 ('Parent Flux', 'Child portion', 'Child stray flux',
358 'Expected stray flux'))
359 plt.ylim(-2, 22)
360 plt.savefig(plotpat % 4)
362 # test abs diff
363 d = np.max(np.abs(s1 - strays[0]))
364 self.assertLess(d, 1e-6)
365 d = np.max(np.abs(s2 - strays[1]))
366 self.assertLess(d, 1e-6)
368 # test relative diff
369 self.assertLess(np.max(np.abs(s1 - strays[0])/np.maximum(1e-3, s1)), 1e-6)
370 self.assertLess(np.max(np.abs(s2 - strays[1])/np.maximum(1e-3, s2)), 1e-6)
373class TestMemory(lsst.utils.tests.MemoryTestCase):
374 pass
377def setup_module(module):
378 lsst.utils.tests.init()
381if __name__ == "__main__": 381 ↛ 382line 381 didn't jump to line 382, because the condition on line 381 was never true
382 lsst.utils.tests.init()
383 unittest.main()