Coverage for examples/plotDeblendFamilies.py : 0%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import matplotlib
2matplotlib.use('Agg') # noqa E402
3import pylab as plt
5import os
6import numpy as np
8import lsst.daf.persistence as dafPersist
9import lsst.afw.detection as afwDet
10import lsst.afw.image as afwImage
11import lsst.afw.table as afwTable
12from lsst.meas.deblender.baseline import deblend
14from astrometry.util.plotutils import PlotSequence
16import lsstDebug
17lsstDebug.Info('lsst.meas.deblender.baseline').psf = True
20def foot_to_img(foot, img=None):
21 fimg = afwImage.ImageF(foot.getBBox())
22 fimg.getArray()[:, :] = np.nan
23 if foot.isHeavy():
24 foot.insert(fimg)
25 heavy = True
26 else:
27 if img is None:
28 return None, False
29 afwDet.copyWithinFootprintImage(foot, img, fimg)
30 # ia = img.getArray()
31 # fa = fimg.getArray()
32 # fbb = fimg.getBBox()
33 # fx0,fy0 = fbb.getMinX(), fbb.getMinY()
34 # ibb = img.getBBox()
35 # ix0,iy0 = ibb.getMinX(), ibb.getMinY()
36 # for span in foot.getSpans():
37 # y,x0,x1 = span.getY(), span.getX0(), span.getX1()
38 # # print 'Span', y, x0, x1
39 # # print 'img', ix0, iy0
40 # # print 'shape', ia[y - iy0, x0 - ix0: x1+1 - ix0].shape
41 # # print 'fimg', fx0, fy0,
42 # # print 'shape', fa[y - fy0, x0 - fx0: x1+1 - fx0].shape
43 # fa[y - fy0, x0 - fx0: x1+1 - fx0] = ia[y - iy0, x0 - ix0: x1+1 - ix0]
44 heavy = False
45 return fimg, heavy
48def img_to_rgb(im, mn, mx):
49 rgbim = np.clip((im-mn)/(mx-mn), 0., 1.)[:, :, np.newaxis].repeat(3, axis=2)
50 imNans = np.isnan(im)
51 for i in range(3):
52 rgbim[:, :, i][imNans] = (0.8, 0.8, 0.3)[i]
53 imZeros = (im == 0)
54 for i in range(3):
55 rgbim[:, :, i][imZeros] = (0.5, 0.5, 0.8)[i]
56 return rgbim
59def bb_to_ext(bb):
60 y0, y1, x0, x1 = bb.getMinY(), bb.getMaxY(), bb.getMinX(), bb.getMaxX()
61 return [x0-0.5, x1+0.5, y0-0.5, y1+0.5]
64def bb_to_xy(bb, margin=0):
65 y0, y1, x0, x1 = bb.getMinY(), bb.getMaxY(), bb.getMinX(), bb.getMaxX()
66 x0, x1, y0, y1 = x0-margin, x1+margin, y0-margin, y1+margin
67 return [x0, x0, x1, x1, x0], [y0, y1, y1, y0, y0]
70def makeplots(butler, dataId, ps, sources=None, pids=None, minsize=0,
71 maxpeaks=10):
72 calexp = butler.get("calexp", **dataId)
73 if sources is None:
74 ss = butler.get('src', **dataId)
75 else:
76 ss = sources
78 # print('Sources', ss)
79 # print('Calexp', calexp)
80 # print(dir(ss))
82 srcs = {}
83 families = {}
84 for src in ss:
85 sid = src.getId()
86 srcs[sid] = src
87 parent = src.getParent()
88 if parent == 0:
89 continue
90 if parent not in families:
91 families[parent] = []
92 families[parent].append(src)
93 # print 'Source', src
94 # print ' ', dir(src)
95 # print ' parent', src.getParent()
96 # print ' footprint', src.getFootprint()
98 print()
99 lsstimg = calexp.getMaskedImage().getImage()
100 img = lsstimg.getArray()
101 schema = ss.getSchema()
102 psfkey = schema.find("deblend_deblendedAsPsf").key
103 nchildkey = schema.find("deblend_nChild").key
104 toomanykey = schema.find("deblend_tooManyPeaks").key
105 failedkey = schema.find("deblend_failed").key
107 def getFlagString(src):
108 ss = ['Nchild: %i' % src.get(nchildkey)]
109 for key, s in [(psfkey, 'PSF'),
110 (toomanykey, 'TooMany'),
111 (failedkey, 'Failed')]:
112 if src.get(key):
113 ss.append(s)
114 return ', '.join(ss)
116 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.9,
117 hspace=0.2, wspace=0.3)
119 sig1 = np.sqrt(np.median(calexp.getMaskedImage().getVariance().getArray().ravel()))
120 pp = (img / np.sqrt(calexp.getMaskedImage().getVariance().getArray())).ravel()
121 plt.clf()
122 lo, hi = -4, 4
123 n, b, p = plt.hist(img.ravel() / sig1, 100, range=(lo, hi), histtype='step', color='b')
124 plt.hist(pp, 100, range=(lo, hi), histtype='step', color='g')
125 xx = np.linspace(lo, hi, 200)
126 yy = 1./(np.sqrt(2.*np.pi)) * np.exp(-0.5 * xx**2)
127 yy *= sum(n) * (b[1]-b[0])
128 plt.plot(xx, yy, 'k-', alpha=0.5)
129 plt.xlim(lo, hi)
130 plt.title('image-wide sig1: %.1f' % sig1)
131 ps.savefig()
133 for ifam, (p, kids) in enumerate(families.items()):
135 parent = srcs[p]
136 pid = parent.getId() & 0xffff
137 if len(pids) and pid not in pids:
138 # print('Skipping pid', pid)
139 continue
141 if len(kids) < minsize:
142 print('Skipping parent', pid, ': n kids', len(kids))
143 continue
145 # if len(kids) < 5:
146 # print 'Skipping family with', len(kids)
147 # continue
148 # print 'ifam', ifam
149 # if ifam != 18:
150 # print 'skipping'
151 # continue
153 print('Parent', parent)
154 print('Kids', kids)
156 print('Parent', parent.getId())
157 print('Kids', [k.getId() for k in kids])
159 pfoot = parent.getFootprint()
160 bb = pfoot.getBBox()
162 y0, y1, x0, x1 = bb.getMinY(), bb.getMaxY(), bb.getMinX(), bb.getMaxX()
163 slc = slice(y0, y1+1), slice(x0, x1+1)
165 ima = dict(interpolation='nearest', origin='lower', cmap='gray',
166 vmin=-10, vmax=40)
167 mn, mx = ima['vmin'], ima['vmax']
169 if False:
170 plt.clf()
171 plt.imshow(img[slc], extent=bb_to_ext(bb), **ima)
172 plt.title('Parent %i, %s' % (parent.getId(), getFlagString(parent)))
173 ax = plt.axis()
174 x, y = bb_to_xy(bb)
175 plt.plot(x, y, 'r-', lw=2)
176 for i, kid in enumerate(kids):
177 kfoot = kid.getFootprint()
178 kbb = kfoot.getBBox()
179 kx, ky = bb_to_xy(kbb, margin=0.4)
180 plt.plot(kx, ky, 'm-')
181 for pk in pfoot.getPeaks():
182 plt.plot(pk.getIx(), pk.getIy(), 'r+', ms=10, mew=3)
183 plt.axis(ax)
184 ps.savefig()
186 print('parent footprint:', pfoot)
187 print('heavy?', pfoot.isHeavy())
188 plt.clf()
189 pimg, h = foot_to_img(pfoot, lsstimg)
191 plt.imshow(img_to_rgb(pimg.getArray(), mn, mx), extent=bb_to_ext(bb), **ima)
192 tt = 'Parent %i' % parent.getId()
193 if not h:
194 tt += ', no HFoot'
195 tt += ', ' + getFlagString(parent)
196 plt.title(tt)
197 ax = plt.axis()
198 plt.plot([x0, x0, x1, x1, x0], [y0, y1, y1, y0, y0], 'r-', lw=2)
199 for i, kid in enumerate(kids):
200 kfoot = kid.getFootprint()
201 kbb = kfoot.getBBox()
202 kx, ky = bb_to_xy(kbb, margin=-0.1)
203 plt.plot(kx, ky, 'm-', lw=1.5)
204 for pk in pfoot.getPeaks():
205 plt.plot(pk.getIx(), pk.getIy(), 'r+', ms=10, mew=3)
206 plt.axis(ax)
207 ps.savefig()
209 cols = int(np.ceil(np.sqrt(len(kids))))
210 rows = int(np.ceil(len(kids) / float(cols)))
212 if False:
213 plt.clf()
214 for i, kid in enumerate(kids):
215 plt.subplot(rows, cols, 1+i)
216 kfoot = kid.getFootprint()
217 print('kfoot:', kfoot)
218 print('heavy?', kfoot.isHeavy())
219 # print(dir(kid))
220 kbb = kfoot.getBBox()
221 ky0, ky1, kx0, kx1 = kbb.getMinY(), kbb.getMaxY(), kbb.getMinX(), kbb.getMaxX()
222 kslc = slice(ky0, ky1+1), slice(kx0, kx1+1)
223 plt.imshow(img[kslc], extent=bb_to_ext(kbb), **ima)
224 plt.title('Child %i' % kid.getId())
225 plt.axis(ax)
226 ps.savefig()
228 plt.clf()
229 for i, kid in enumerate(kids):
230 plt.subplot(rows, cols, 1+i)
231 kfoot = kid.getFootprint()
232 kbb = kfoot.getBBox()
233 kimg, h = foot_to_img(kfoot, lsstimg)
234 tt = getFlagString(kid)
235 if not h:
236 tt += ', no HFoot'
237 plt.title('%s' % tt)
238 if kimg is None:
239 plt.axis(ax)
240 continue
241 plt.imshow(img_to_rgb(kimg.getArray(), mn, mx), extent=bb_to_ext(kbb), **ima)
242 for pk in kfoot.getPeaks():
243 plt.plot(pk.getIx(), pk.getIy(), 'g+', ms=10, mew=3)
244 plt.axis(ax)
245 plt.suptitle('Child HeavyFootprints')
246 ps.savefig()
248 print()
249 print('Re-running deblender...')
250 psf = calexp.getPsf()
251 psf_fwhm = psf.computeShape().getDeterminantRadius() * 2.35
252 deb = deblend(pfoot, calexp.getMaskedImage(), psf, psf_fwhm, verbose=True,
253 maxNumberOfPeaks=maxpeaks,
254 rampFluxAtEdge=True,
255 clipStrayFluxFraction=0.01,
256 )
257 print('Got', deb)
259 def getDebFlagString(kid):
260 ss = []
261 for k in ['skip', 'outOfBounds', 'tinyFootprint', 'noValidPixels',
262 ('deblendedAsPsf', 'PSF'), 'psfFitFailed', 'psfFitBadDof',
263 'psfFitBigDecenter', 'psfFitWithDecenter',
264 'failedSymmetricTemplate', 'hasRampedTemplate', 'patched']:
265 if len(k) == 2:
266 k, s = k
267 else:
268 s = k
269 if getattr(kid, k):
270 ss.append(s)
271 return ', '.join(ss)
273 N = len(deb.peaks)
274 cols = int(np.ceil(np.sqrt(N)))
275 rows = int(np.ceil(N / float(cols)))
277 for plotnum in range(4):
278 plt.clf()
279 for i, kid in enumerate(deb.peaks):
280 # print 'child', kid
281 # print ' flags:', getDebFlagString(kid)
283 kfoot = None
284 if plotnum == 0:
285 kfoot = kid.getFluxPortion(strayFlux=False)
286 supt = 'flux portion'
287 elif plotnum == 1:
288 kfoot = kid.getFluxPortion(strayFlux=True)
289 supt = 'flux portion + stray'
290 elif plotnum == 2:
291 kfoot = afwDet.makeHeavyFootprint(kid.templateFootprint,
292 kid.templateImage)
293 supt = 'template'
294 elif plotnum == 3:
295 if kid.deblendedAsPsf:
296 kfoot = afwDet.makeHeavyFootprint(kid.psfFootprint,
297 kid.psfTemplate)
298 kfoot.normalize()
299 kfoot.clipToNonzero(kid.psfTemplate.getImage())
300 # print 'kfoot BB:', kfoot.getBBox()
301 # print 'Img bb:', kid.psfTemplate.getImage().getBBox()
302 # for sp in kfoot.getSpans():
303 # print ' span', sp
304 else:
305 kfoot = afwDet.makeHeavyFootprint(kid.templateFootprint,
306 kid.templateImage)
307 supt = 'psf template'
309 kimg, h = foot_to_img(kfoot, None)
310 tt = 'kid %i: %s' % (i, getDebFlagString(kid))
311 if not h:
312 tt += ', no HFoot'
313 plt.subplot(rows, cols, 1+i)
314 plt.title('%s' % tt, fontsize=8)
315 if kimg is None:
316 plt.axis(ax)
317 continue
318 kbb = kfoot.getBBox()
320 plt.imshow(img_to_rgb(kimg.getArray(), mn, mx), extent=bb_to_ext(kbb), **ima)
322 # plt.imshow(kimg.getArray(), extent=bb_to_ext(kbb), **ima)
324 plt.axis(ax)
326 plt.suptitle(supt)
327 ps.savefig()
329 for i, kid in enumerate(deb.peaks):
330 if not kid.deblendedAsPsf:
331 continue
332 plt.clf()
334 ima = dict(interpolation='nearest', origin='lower', cmap='gray')
335 # vmin=0, vmax=kid.psfFitFlux)
337 plt.subplot(2, 4, 1)
338 # plt.title('fit psf 0')
339 # plt.imshow(kid.psfFitDebugPsf0Img.getArray(), **ima)
340 # plt.colorbar()
341 # plt.title('valid pixels')
342 # plt.imshow(kid.psfFitDebugValidPix, vmin=0, vmax=1, **ima)
343 plt.title('weights')
344 plt.imshow(kid.psfFitDebugWeight, vmin=0, **ima)
345 plt.xticks([])
346 plt.yticks([])
347 plt.colorbar()
349 plt.subplot(2, 4, 7)
350 plt.title('valid pixels')
351 plt.imshow(kid.psfFitDebugValidPix, vmin=0, vmax=1, **ima)
352 plt.xticks([])
353 plt.yticks([])
354 plt.colorbar()
356 plt.subplot(2, 4, 2)
357 # plt.title('ramp weights')
358 # plt.imshow(kid.psfFitDebugRampWeight, vmin=0, vmax=1, **ima)
359 # plt.colorbar()
360 sig = np.sqrt(kid.psfFitDebugVar.getArray())
361 data = kid.psfFitDebugStamp.getArray()
362 model = kid.psfFitDebugPsfModel.getArray()
363 chi = ((data - model) / sig)
364 valid = kid.psfFitDebugValidPix
366 plt.hist(np.clip((data/sig)[valid], -5, 5), 20, range=(-5, 5),
367 histtype='step', color='m')
368 plt.hist(np.clip((model/sig)[valid], -5, 5), 20, range=(-5, 5),
369 histtype='step', color='r')
370 plt.hist(np.clip(chi.ravel(), -5, 5), 20, range=(-5, 5),
371 histtype='step', color='g')
372 n, b, p = plt.hist(np.clip(chi[valid], -5, 5), 20, range=(-5, 5),
373 histtype='step', color='b')
375 xx = np.linspace(-5, 5, 200)
376 yy = 1./(np.sqrt(2.*np.pi)) * np.exp(-0.5 * xx**2)
377 yy *= sum(n) * (b[1]-b[0])
378 plt.plot(xx, yy, 'k-', alpha=0.5)
380 plt.xlim(-5, 5)
382 print('Sum of ramp weights:', np.sum(kid.psfFitDebugRampWeight))
383 print('Quadrature sum of ramp weights:', np.sqrt(np.sum(kid.psfFitDebugRampWeight**2)))
384 print('Number of valid pix:', np.sum(kid.psfFitDebugValidPix))
385 rw = kid.psfFitDebugRampWeight
386 valid = kid.psfFitDebugValidPix
387 # print 'valid values:', np.unique(valid)
388 print('rw[valid]', np.sum(rw[valid]))
389 print('rw range', rw.min(), rw.max())
390 # print 'rw', rw.shape, rw.dtype
391 # print 'valid', valid.shape, valid.dtype
392 # print 'rw[valid]:', rw[valid]
394 myresid = np.sum(kid.psfFitDebugValidPix
395 * kid.psfFitDebugRampWeight
396 * ((kid.psfFitDebugStamp.getArray()
397 - kid.psfFitDebugPsfModel.getArray())
398 / np.sqrt(kid.psfFitDebugVar.getArray()))**2)
399 print('myresid:', myresid)
401 plt.subplot(2, 4, 8)
402 N = 20000
403 rwv = rw[valid]
404 print('rwv', rwv)
405 x = np.random.normal(size=(N, len(rwv)))
406 ss = np.sum(rwv * x**2, axis=1)
407 plt.hist(ss, 25)
408 chi, dof = kid.psfFitBest
409 plt.axvline(chi, color='r')
411 mx = kid.psfFitDebugPsfModel.getArray().max()
413 plt.subplot(2, 4, 3)
414 # plt.title('fit psf')
415 # plt.imshow(kid.psfFitDebugPsfImg.getArray(), **ima)
416 # plt.colorbar()
417 # plt.title('variance')
418 # plt.imshow(kid.psfFitDebugVar.getArray(), vmin=0, **ima)
419 # plt.colorbar()
420 plt.title('model+noise')
421 plt.imshow((kid.psfFitDebugPsfModel.getArray()
422 + sig * np.random.normal(size=sig.shape))*valid,
423 vmin=0, vmax=mx, **ima)
424 plt.xticks([])
425 plt.yticks([])
426 plt.colorbar()
428 plt.subplot(2, 4, 4)
429 plt.title('fit psf model')
430 plt.imshow(kid.psfFitDebugPsfModel.getArray(), vmin=0, vmax=mx, **ima)
431 plt.xticks([])
432 plt.yticks([])
433 plt.colorbar()
435 plt.subplot(2, 4, 5)
436 plt.title('fit psf image')
437 plt.imshow(kid.psfFitDebugStamp.getArray(), vmin=0, vmax=mx, **ima)
438 plt.xticks([])
439 plt.yticks([])
440 plt.colorbar()
442 chi = (kid.psfFitDebugValidPix
443 * (kid.psfFitDebugStamp.getArray()
444 - kid.psfFitDebugPsfModel.getArray())
445 / np.sqrt(kid.psfFitDebugVar.getArray()))
447 plt.subplot(2, 4, 6)
448 plt.title('fit psf chi')
449 plt.imshow(-chi, vmin=-3, vmax=3, interpolation='nearest', origin='lower', cmap='RdBu')
450 plt.xticks([])
451 plt.yticks([])
452 plt.colorbar()
454 params = kid.psfFitParams
455 (flux, sky, skyx, skyy) = params[:4]
457 print('Model sum:', model.sum())
458 print('- sky', model.sum() - np.sum(valid)*sky)
460 sig1 = np.median(sig)
462 chi, dof = kid.psfFitBest
463 plt.suptitle('PSF kid %i: flux %.1f, sky %.1f, sig1 %.1f' %
464 (i, flux, sky, sig1)) # : chisq %g, dof %i' % (i, chi, dof))
466 ps.savefig()
468 # if ifam == 5:
469 # break
472if __name__ == '__main__':
473 import optparse
474 import sys
476 parser = optparse.OptionParser()
477 parser.add_option('--data', help='Data dir, default $SUPRIME_DATA_DIR/rerun/RERUN')
478 parser.add_option('--rerun', help='Rerun name, default %default', default='dstn/deb')
479 parser.add_option('--visit', help='Visit number, default %default', default=905516, type=int)
480 parser.add_option('--ccd', help='CCD number, default %default', default=22, type=int)
481 parser.add_option('--sources', help='Read sources file', type=str)
482 parser.add_option('--hdu', help='With --sources, HDU to read; default %default',
483 type=int, default=2)
484 parser.add_option('--pid', '-p', action='append', default=[], type=int,
485 help='Deblend a specific parent ID')
486 parser.add_option('--big', dest='minsize', default=0,
487 help='Only show results for deblend families larger than this', type=int)
488 parser.add_option('--maxpeaks', default=10, help='maxNumberOfPeaks', type=int)
490 opt, args = parser.parse_args()
492 if len(args):
493 parser.print_help()
494 sys.exit(-1)
496 if not opt.data:
497 opt.data = os.path.join(os.environ['SUPRIME_DATA_DIR'],
498 'rerun', opt.rerun)
500 print('Data directory:', opt.data)
501 butler = dafPersist.Butler(opt.data)
502 dataId = dict(visit=opt.visit, ccd=opt.ccd)
503 ps = PlotSequence('deb')
505 sources = None
506 if opt.sources:
507 flags = 0
508 sources = afwTable.SourceCatalog.readFits(opt.sources, opt.hdu, flags)
509 print('Read sources from', opt.sources, ':', sources)
511 makeplots(butler, dataId, ps, sources=sources, pids=opt.pid, minsize=opt.minsize,
512 maxpeaks=opt.maxpeaks)