Coverage for examples/plotDeblendFamilies.py: 0%
311 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-07-30 10:48 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2022-07-30 10:48 +0000
1import numpy as np
2import os
3import pylab as plt
5import matplotlib
6matplotlib.use('Agg')
8import lsst.daf.persistence as dafPersist # noqa E402
9import lsst.afw.detection as afwDet # noqa E402
10import lsst.afw.image as afwImage # noqa E402
11import lsst.afw.table as afwTable # noqa E402
12from lsst.meas.deblender.baseline import deblend # noqa E402
14from astrometry.util.plotutils import PlotSequence # noqa E402
16import lsstDebug # noqa E402
17lsstDebug.Info('lsst.meas.deblender.baseline').psf = True
20def foot_to_img(foot, img=None):
21 fimg = afwImage.ImageF(foot.getBBox())
22 fimg.getArray()[:, :] = np.nan
23 if foot.isHeavy():
24 foot.insert(fimg)
25 heavy = True
26 else:
27 if img is None:
28 return None, False
29 afwDet.copyWithinFootprintImage(foot, img, fimg)
30 # ia = img.getArray()
31 # fa = fimg.getArray()
32 # fbb = fimg.getBBox()
33 # fx0,fy0 = fbb.getMinX(), fbb.getMinY()
34 # ibb = img.getBBox()
35 # ix0,iy0 = ibb.getMinX(), ibb.getMinY()
36 # for span in foot.getSpans():
37 # y,x0,x1 = span.getY(), span.getX0(), span.getX1()
38 # # print 'Span', y, x0, x1
39 # # print 'img', ix0, iy0
40 # # print 'shape', ia[y - iy0, x0 - ix0: x1+1 - ix0].shape
41 # # print 'fimg', fx0, fy0,
42 # # print 'shape', fa[y - fy0, x0 - fx0: x1+1 - fx0].shape
43 # fa[y - fy0, x0 - fx0: x1+1 - fx0] = ia[y - iy0, x0 - ix0: x1+1 - ix0]
44 heavy = False
45 return fimg, heavy
48def img_to_rgb(im, mn, mx):
49 rgbim = np.clip((im-mn)/(mx-mn), 0., 1.)[:, :, np.newaxis].repeat(3, axis=2)
50 imNans = np.isnan(im)
51 for i in range(3):
52 rgbim[:, :, i][imNans] = (0.8, 0.8, 0.3)[i]
53 imZeros = (im == 0)
54 for i in range(3):
55 rgbim[:, :, i][imZeros] = (0.5, 0.5, 0.8)[i]
56 return rgbim
59def bb_to_ext(bb):
60 y0, y1, x0, x1 = bb.getMinY(), bb.getMaxY(), bb.getMinX(), bb.getMaxX()
61 return [x0-0.5, x1+0.5, y0-0.5, y1+0.5]
64def bb_to_xy(bb, margin=0):
65 y0, y1, x0, x1 = bb.getMinY(), bb.getMaxY(), bb.getMinX(), bb.getMaxX()
66 x0, x1, y0, y1 = x0-margin, x1+margin, y0-margin, y1+margin
67 return [x0, x0, x1, x1, x0], [y0, y1, y1, y0, y0]
70def makeplots(butler, dataId, ps, sources=None, pids=None, minsize=0,
71 maxpeaks=10):
72 calexp = butler.get("calexp", **dataId)
73 if sources is None:
74 ss = butler.get('src', **dataId)
75 else:
76 ss = sources
78 # print('Sources', ss)
79 # print('Calexp', calexp)
80 # print(dir(ss))
82 srcs = {}
83 families = {}
84 for src in ss:
85 sid = src.getId()
86 srcs[sid] = src
87 parent = src.getParent()
88 if parent == 0:
89 continue
90 if parent not in families:
91 families[parent] = []
92 families[parent].append(src)
93 # print 'Source', src
94 # print ' ', dir(src)
95 # print ' parent', src.getParent()
96 # print ' footprint', src.getFootprint()
98 print()
99 lsstimg = calexp.getMaskedImage().getImage()
100 img = lsstimg.getArray()
101 schema = ss.getSchema()
102 psfkey = schema.find("deblend_deblendedAsPsf").key
103 nchildkey = schema.find("deblend_nChild").key
104 toomanykey = schema.find("deblend_tooManyPeaks").key
105 failedkey = schema.find("deblend_failed").key
107 def getFlagString(src):
108 ss = ['Nchild: %i' % src.get(nchildkey)]
109 for key, s in [(psfkey, 'PSF'),
110 (toomanykey, 'TooMany'),
111 (failedkey, 'Failed')]:
112 if src.get(key):
113 ss.append(s)
114 return ', '.join(ss)
116 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.9,
117 hspace=0.2, wspace=0.3)
119 sig1 = np.sqrt(np.median(calexp.getMaskedImage().getVariance().getArray().ravel()))
120 pp = (img / np.sqrt(calexp.getMaskedImage().getVariance().getArray())).ravel()
121 plt.clf()
122 lo, hi = -4, 4
123 n, b, p = plt.hist(img.ravel() / sig1, 100, range=(lo, hi), histtype='step', color='b')
124 plt.hist(pp, 100, range=(lo, hi), histtype='step', color='g')
125 xx = np.linspace(lo, hi, 200)
126 yy = 1./(np.sqrt(2.*np.pi)) * np.exp(-0.5 * xx**2)
127 yy *= sum(n) * (b[1]-b[0])
128 plt.plot(xx, yy, 'k-', alpha=0.5)
129 plt.xlim(lo, hi)
130 plt.title('image-wide sig1: %.1f' % sig1)
131 ps.savefig()
133 for ifam, (p, kids) in enumerate(families.items()):
135 parent = srcs[p]
136 pid = parent.getId() & 0xffff
137 if len(pids) and pid not in pids:
138 # print('Skipping pid', pid)
139 continue
141 if len(kids) < minsize:
142 print('Skipping parent', pid, ': n kids', len(kids))
143 continue
145 # if len(kids) < 5:
146 # print 'Skipping family with', len(kids)
147 # continue
148 # print 'ifam', ifam
149 # if ifam != 18:
150 # print 'skipping'
151 # continue
153 print('Parent', parent)
154 print('Kids', kids)
156 print('Parent', parent.getId())
157 print('Kids', [k.getId() for k in kids])
159 pfoot = parent.getFootprint()
160 bb = pfoot.getBBox()
162 y0, y1, x0, x1 = bb.getMinY(), bb.getMaxY(), bb.getMinX(), bb.getMaxX()
163 slc = slice(y0, y1+1), slice(x0, x1+1)
165 ima = dict(interpolation='nearest', origin='lower', cmap='gray',
166 vmin=-10, vmax=40)
167 mn, mx = ima['vmin'], ima['vmax']
169 if False:
170 plt.clf()
171 plt.imshow(img[slc], extent=bb_to_ext(bb), **ima)
172 plt.title('Parent %i, %s' % (parent.getId(), getFlagString(parent)))
173 ax = plt.axis()
174 x, y = bb_to_xy(bb)
175 plt.plot(x, y, 'r-', lw=2)
176 for i, kid in enumerate(kids):
177 kfoot = kid.getFootprint()
178 kbb = kfoot.getBBox()
179 kx, ky = bb_to_xy(kbb, margin=0.4)
180 plt.plot(kx, ky, 'm-')
181 for pk in pfoot.getPeaks():
182 plt.plot(pk.getIx(), pk.getIy(), 'r+', ms=10, mew=3)
183 plt.axis(ax)
184 ps.savefig()
186 print('parent footprint:', pfoot)
187 print('heavy?', pfoot.isHeavy())
188 plt.clf()
189 pimg, h = foot_to_img(pfoot, lsstimg)
191 plt.imshow(img_to_rgb(pimg.getArray(), mn, mx), extent=bb_to_ext(bb), **ima)
192 tt = 'Parent %i' % parent.getId()
193 if not h:
194 tt += ', no HFoot'
195 tt += ', ' + getFlagString(parent)
196 plt.title(tt)
197 ax = plt.axis()
198 plt.plot([x0, x0, x1, x1, x0], [y0, y1, y1, y0, y0], 'r-', lw=2)
199 for i, kid in enumerate(kids):
200 kfoot = kid.getFootprint()
201 kbb = kfoot.getBBox()
202 kx, ky = bb_to_xy(kbb, margin=-0.1)
203 plt.plot(kx, ky, 'm-', lw=1.5)
204 for pk in pfoot.getPeaks():
205 plt.plot(pk.getIx(), pk.getIy(), 'r+', ms=10, mew=3)
206 plt.axis(ax)
207 ps.savefig()
209 cols = int(np.ceil(np.sqrt(len(kids))))
210 rows = int(np.ceil(len(kids) / float(cols)))
212 if False:
213 plt.clf()
214 for i, kid in enumerate(kids):
215 plt.subplot(rows, cols, 1+i)
216 kfoot = kid.getFootprint()
217 print('kfoot:', kfoot)
218 print('heavy?', kfoot.isHeavy())
219 # print(dir(kid))
220 kbb = kfoot.getBBox()
221 ky0, ky1, kx0, kx1 = kbb.getMinY(), kbb.getMaxY(), kbb.getMinX(), kbb.getMaxX()
222 kslc = slice(ky0, ky1+1), slice(kx0, kx1+1)
223 plt.imshow(img[kslc], extent=bb_to_ext(kbb), **ima)
224 plt.title('Child %i' % kid.getId())
225 plt.axis(ax)
226 ps.savefig()
228 plt.clf()
229 for i, kid in enumerate(kids):
230 plt.subplot(rows, cols, 1+i)
231 kfoot = kid.getFootprint()
232 kbb = kfoot.getBBox()
233 kimg, h = foot_to_img(kfoot, lsstimg)
234 tt = getFlagString(kid)
235 if not h:
236 tt += ', no HFoot'
237 plt.title('%s' % tt)
238 if kimg is None:
239 plt.axis(ax)
240 continue
241 plt.imshow(img_to_rgb(kimg.getArray(), mn, mx), extent=bb_to_ext(kbb), **ima)
242 for pk in kfoot.getPeaks():
243 plt.plot(pk.getIx(), pk.getIy(), 'g+', ms=10, mew=3)
244 plt.axis(ax)
245 plt.suptitle('Child HeavyFootprints')
246 ps.savefig()
248 print()
249 print('Re-running deblender...')
250 psf = calexp.getPsf()
251 psf_fwhm = psf.computeShape().getDeterminantRadius() * 2.35
252 deb = deblend(pfoot, calexp.getMaskedImage(), psf, psf_fwhm, verbose=True,
253 maxNumberOfPeaks=maxpeaks,
254 rampFluxAtEdge=True,
255 clipStrayFluxFraction=0.01,
256 )
257 print('Got', deb)
259 def getDebFlagString(kid):
260 ss = []
261 for k in ['skip', 'outOfBounds', 'tinyFootprint', 'noValidPixels',
262 ('deblendedAsPsf', 'PSF'), 'psfFitFailed', 'psfFitBadDof',
263 'psfFitBigDecenter', 'psfFitWithDecenter',
264 'failedSymmetricTemplate', 'hasRampedTemplate', 'patched']:
265 if len(k) == 2:
266 k, s = k
267 else:
268 s = k
269 if getattr(kid, k):
270 ss.append(s)
271 return ', '.join(ss)
273 N = len(deb.peaks)
274 cols = int(np.ceil(np.sqrt(N)))
275 rows = int(np.ceil(N / float(cols)))
277 for plotnum in range(4):
278 plt.clf()
279 for i, kid in enumerate(deb.peaks):
280 # print 'child', kid
281 # print ' flags:', getDebFlagString(kid)
283 kfoot = None
284 if plotnum == 0:
285 kfoot = kid.getFluxPortion(strayFlux=False)
286 supt = 'flux portion'
287 elif plotnum == 1:
288 kfoot = kid.getFluxPortion(strayFlux=True)
289 supt = 'flux portion + stray'
290 elif plotnum == 2:
291 kfoot = afwDet.makeHeavyFootprint(kid.templateFootprint,
292 kid.templateImage)
293 supt = 'template'
294 elif plotnum == 3:
295 if kid.deblendedAsPsf:
296 kfoot = afwDet.makeHeavyFootprint(kid.psfFootprint,
297 kid.psfTemplate)
298 kfoot.normalize()
299 kfoot.clipToNonzero(kid.psfTemplate.getImage())
300 # print 'kfoot BB:', kfoot.getBBox()
301 # print 'Img bb:', kid.psfTemplate.getImage().getBBox()
302 # for sp in kfoot.getSpans():
303 # print ' span', sp
304 else:
305 kfoot = afwDet.makeHeavyFootprint(kid.templateFootprint,
306 kid.templateImage)
307 supt = 'psf template'
309 kimg, h = foot_to_img(kfoot, None)
310 tt = 'kid %i: %s' % (i, getDebFlagString(kid))
311 if not h:
312 tt += ', no HFoot'
313 plt.subplot(rows, cols, 1+i)
314 plt.title('%s' % tt, fontsize=8)
315 if kimg is None:
316 plt.axis(ax)
317 continue
318 kbb = kfoot.getBBox()
320 plt.imshow(img_to_rgb(kimg.getArray(), mn, mx), extent=bb_to_ext(kbb), **ima)
322 # plt.imshow(kimg.getArray(), extent=bb_to_ext(kbb), **ima)
324 plt.axis(ax)
326 plt.suptitle(supt)
327 ps.savefig()
329 for i, kid in enumerate(deb.peaks):
330 if not kid.deblendedAsPsf:
331 continue
332 plt.clf()
334 ima = dict(interpolation='nearest', origin='lower', cmap='gray')
335 # vmin=0, vmax=kid.psfFitFlux)
337 plt.subplot(2, 4, 1)
338 # plt.title('fit psf 0')
339 # plt.imshow(kid.psfFitDebugPsf0Img.getArray(), **ima)
340 # plt.colorbar()
341 # plt.title('valid pixels')
342 # plt.imshow(kid.psfFitDebugValidPix, vmin=0, vmax=1, **ima)
343 plt.title('weights')
344 plt.imshow(kid.psfFitDebugWeight, vmin=0, **ima)
345 plt.xticks([])
346 plt.yticks([])
347 plt.colorbar()
349 plt.subplot(2, 4, 7)
350 plt.title('valid pixels')
351 plt.imshow(kid.psfFitDebugValidPix, vmin=0, vmax=1, **ima)
352 plt.xticks([])
353 plt.yticks([])
354 plt.colorbar()
356 plt.subplot(2, 4, 2)
357 # plt.title('ramp weights')
358 # plt.imshow(kid.psfFitDebugRampWeight, vmin=0, vmax=1, **ima)
359 # plt.colorbar()
360 sig = np.sqrt(kid.psfFitDebugVar.getArray())
361 data = kid.psfFitDebugStamp.getArray()
362 model = kid.psfFitDebugPsfModel.getArray()
363 chi = ((data - model) / sig)
364 valid = kid.psfFitDebugValidPix
366 plt.hist(np.clip((data/sig)[valid], -5, 5), 20, range=(-5, 5),
367 histtype='step', color='m')
368 plt.hist(np.clip((model/sig)[valid], -5, 5), 20, range=(-5, 5),
369 histtype='step', color='r')
370 plt.hist(np.clip(chi.ravel(), -5, 5), 20, range=(-5, 5),
371 histtype='step', color='g')
372 n, b, p = plt.hist(np.clip(chi[valid], -5, 5), 20, range=(-5, 5),
373 histtype='step', color='b')
375 xx = np.linspace(-5, 5, 200)
376 yy = 1./(np.sqrt(2.*np.pi)) * np.exp(-0.5 * xx**2)
377 yy *= sum(n) * (b[1]-b[0])
378 plt.plot(xx, yy, 'k-', alpha=0.5)
380 plt.xlim(-5, 5)
382 print('Sum of ramp weights:', np.sum(kid.psfFitDebugRampWeight))
383 print('Quadrature sum of ramp weights:', np.sqrt(np.sum(kid.psfFitDebugRampWeight**2)))
384 print('Number of valid pix:', np.sum(kid.psfFitDebugValidPix))
385 rw = kid.psfFitDebugRampWeight
386 valid = kid.psfFitDebugValidPix
387 # print 'valid values:', np.unique(valid)
388 print('rw[valid]', np.sum(rw[valid]))
389 print('rw range', rw.min(), rw.max())
390 # print 'rw', rw.shape, rw.dtype
391 # print 'valid', valid.shape, valid.dtype
392 # print 'rw[valid]:', rw[valid]
394 myresid = np.sum(kid.psfFitDebugValidPix
395 * kid.psfFitDebugRampWeight
396 * ((kid.psfFitDebugStamp.getArray()
397 - kid.psfFitDebugPsfModel.getArray())
398 / np.sqrt(kid.psfFitDebugVar.getArray()))**2)
399 print('myresid:', myresid)
401 plt.subplot(2, 4, 8)
402 N = 20000
403 rwv = rw[valid]
404 print('rwv', rwv)
405 x = np.random.normal(size=(N, len(rwv)))
406 ss = np.sum(rwv * x**2, axis=1)
407 plt.hist(ss, 25)
408 chi, dof = kid.psfFitBest
409 plt.axvline(chi, color='r')
411 mx = kid.psfFitDebugPsfModel.getArray().max()
413 plt.subplot(2, 4, 3)
414 # plt.title('fit psf')
415 # plt.imshow(kid.psfFitDebugPsfImg.getArray(), **ima)
416 # plt.colorbar()
417 # plt.title('variance')
418 # plt.imshow(kid.psfFitDebugVar.getArray(), vmin=0, **ima)
419 # plt.colorbar()
420 plt.title('model+noise')
421 plt.imshow((kid.psfFitDebugPsfModel.getArray()
422 + sig * np.random.normal(size=sig.shape))*valid,
423 vmin=0, vmax=mx, **ima)
424 plt.xticks([])
425 plt.yticks([])
426 plt.colorbar()
428 plt.subplot(2, 4, 4)
429 plt.title('fit psf model')
430 plt.imshow(kid.psfFitDebugPsfModel.getArray(), vmin=0, vmax=mx, **ima)
431 plt.xticks([])
432 plt.yticks([])
433 plt.colorbar()
435 plt.subplot(2, 4, 5)
436 plt.title('fit psf image')
437 plt.imshow(kid.psfFitDebugStamp.getArray(), vmin=0, vmax=mx, **ima)
438 plt.xticks([])
439 plt.yticks([])
440 plt.colorbar()
442 chi = (kid.psfFitDebugValidPix
443 * (kid.psfFitDebugStamp.getArray()
444 - kid.psfFitDebugPsfModel.getArray())
445 / np.sqrt(kid.psfFitDebugVar.getArray()))
447 plt.subplot(2, 4, 6)
448 plt.title('fit psf chi')
449 plt.imshow(-chi, vmin=-3, vmax=3, interpolation='nearest', origin='lower', cmap='RdBu')
450 plt.xticks([])
451 plt.yticks([])
452 plt.colorbar()
454 params = kid.psfFitParams
455 (flux, sky, skyx, skyy) = params[:4]
457 print('Model sum:', model.sum())
458 print('- sky', model.sum() - np.sum(valid)*sky)
460 sig1 = np.median(sig)
462 chi, dof = kid.psfFitBest
463 plt.suptitle('PSF kid %i: flux %.1f, sky %.1f, sig1 %.1f' %
464 (i, flux, sky, sig1)) # : chisq %g, dof %i' % (i, chi, dof))
466 ps.savefig()
468 # if ifam == 5:
469 # break
472if __name__ == '__main__':
473 import optparse
474 import sys
476 parser = optparse.OptionParser()
477 parser.add_option('--data', help='Data dir, default $SUPRIME_DATA_DIR/rerun/RERUN')
478 parser.add_option('--rerun', help='Rerun name, default %default', default='dstn/deb')
479 parser.add_option('--visit', help='Visit number, default %default', default=905516, type=int)
480 parser.add_option('--ccd', help='CCD number, default %default', default=22, type=int)
481 parser.add_option('--sources', help='Read sources file', type=str)
482 parser.add_option('--hdu', help='With --sources, HDU to read; default %default',
483 type=int, default=2)
484 parser.add_option('--pid', '-p', action='append', default=[], type=int,
485 help='Deblend a specific parent ID')
486 parser.add_option('--big', dest='minsize', default=0,
487 help='Only show results for deblend families larger than this', type=int)
488 parser.add_option('--maxpeaks', default=10, help='maxNumberOfPeaks', type=int)
490 opt, args = parser.parse_args()
492 if len(args):
493 parser.print_help()
494 sys.exit(-1)
496 if not opt.data:
497 opt.data = os.path.join(os.environ['SUPRIME_DATA_DIR'],
498 'rerun', opt.rerun)
500 print('Data directory:', opt.data)
501 butler = dafPersist.Butler(opt.data)
502 dataId = dict(visit=opt.visit, ccd=opt.ccd)
503 ps = PlotSequence('deb')
505 sources = None
506 if opt.sources:
507 flags = 0
508 sources = afwTable.SourceCatalog.readFits(opt.sources, opt.hdu, flags)
509 print('Read sources from', opt.sources, ':', sources)
511 makeplots(butler, dataId, ps, sources=sources, pids=opt.pid, minsize=opt.minsize,
512 maxpeaks=opt.maxpeaks)