Coverage for examples/plotDeblendFamilies.py: 0%

311 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-22 10:13 +0000

1import numpy as np 

2import os 

3import pylab as plt 

4 

5import matplotlib 

6matplotlib.use('Agg') 

7 

8import lsst.daf.persistence as dafPersist # noqa E402 

9import lsst.afw.detection as afwDet # noqa E402 

10import lsst.afw.image as afwImage # noqa E402 

11import lsst.afw.table as afwTable # noqa E402 

12from lsst.meas.deblender.baseline import deblend # noqa E402 

13 

14from astrometry.util.plotutils import PlotSequence # noqa E402 

15 

16import lsstDebug # noqa E402 

17lsstDebug.Info('lsst.meas.deblender.baseline').psf = True 

18 

19 

20def foot_to_img(foot, img=None): 

21 fimg = afwImage.ImageF(foot.getBBox()) 

22 fimg.getArray()[:, :] = np.nan 

23 if foot.isHeavy(): 

24 foot.insert(fimg) 

25 heavy = True 

26 else: 

27 if img is None: 

28 return None, False 

29 afwDet.copyWithinFootprintImage(foot, img, fimg) 

30 # ia = img.getArray() 

31 # fa = fimg.getArray() 

32 # fbb = fimg.getBBox() 

33 # fx0,fy0 = fbb.getMinX(), fbb.getMinY() 

34 # ibb = img.getBBox() 

35 # ix0,iy0 = ibb.getMinX(), ibb.getMinY() 

36 # for span in foot.getSpans(): 

37 # y,x0,x1 = span.getY(), span.getX0(), span.getX1() 

38 # # print 'Span', y, x0, x1 

39 # # print 'img', ix0, iy0 

40 # # print 'shape', ia[y - iy0, x0 - ix0: x1+1 - ix0].shape 

41 # # print 'fimg', fx0, fy0, 

42 # # print 'shape', fa[y - fy0, x0 - fx0: x1+1 - fx0].shape 

43 # fa[y - fy0, x0 - fx0: x1+1 - fx0] = ia[y - iy0, x0 - ix0: x1+1 - ix0] 

44 heavy = False 

45 return fimg, heavy 

46 

47 

48def img_to_rgb(im, mn, mx): 

49 rgbim = np.clip((im-mn)/(mx-mn), 0., 1.)[:, :, np.newaxis].repeat(3, axis=2) 

50 imNans = np.isnan(im) 

51 for i in range(3): 

52 rgbim[:, :, i][imNans] = (0.8, 0.8, 0.3)[i] 

53 imZeros = (im == 0) 

54 for i in range(3): 

55 rgbim[:, :, i][imZeros] = (0.5, 0.5, 0.8)[i] 

56 return rgbim 

57 

58 

59def bb_to_ext(bb): 

60 y0, y1, x0, x1 = bb.getMinY(), bb.getMaxY(), bb.getMinX(), bb.getMaxX() 

61 return [x0-0.5, x1+0.5, y0-0.5, y1+0.5] 

62 

63 

64def bb_to_xy(bb, margin=0): 

65 y0, y1, x0, x1 = bb.getMinY(), bb.getMaxY(), bb.getMinX(), bb.getMaxX() 

66 x0, x1, y0, y1 = x0-margin, x1+margin, y0-margin, y1+margin 

67 return [x0, x0, x1, x1, x0], [y0, y1, y1, y0, y0] 

68 

69 

70def makeplots(butler, dataId, ps, sources=None, pids=None, minsize=0, 

71 maxpeaks=10): 

72 calexp = butler.get("calexp", **dataId) 

73 if sources is None: 

74 ss = butler.get('src', **dataId) 

75 else: 

76 ss = sources 

77 

78 # print('Sources', ss) 

79 # print('Calexp', calexp) 

80 # print(dir(ss)) 

81 

82 srcs = {} 

83 families = {} 

84 for src in ss: 

85 sid = src.getId() 

86 srcs[sid] = src 

87 parent = src.getParent() 

88 if parent == 0: 

89 continue 

90 if parent not in families: 

91 families[parent] = [] 

92 families[parent].append(src) 

93 # print 'Source', src 

94 # print ' ', dir(src) 

95 # print ' parent', src.getParent() 

96 # print ' footprint', src.getFootprint() 

97 

98 print() 

99 lsstimg = calexp.getMaskedImage().getImage() 

100 img = lsstimg.getArray() 

101 schema = ss.getSchema() 

102 psfkey = schema.find("deblend_deblendedAsPsf").key 

103 nchildkey = schema.find("deblend_nChild").key 

104 toomanykey = schema.find("deblend_tooManyPeaks").key 

105 failedkey = schema.find("deblend_failed").key 

106 

107 def getFlagString(src): 

108 ss = ['Nchild: %i' % src.get(nchildkey)] 

109 for key, s in [(psfkey, 'PSF'), 

110 (toomanykey, 'TooMany'), 

111 (failedkey, 'Failed')]: 

112 if src.get(key): 

113 ss.append(s) 

114 return ', '.join(ss) 

115 

116 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.9, 

117 hspace=0.2, wspace=0.3) 

118 

119 sig1 = np.sqrt(np.median(calexp.getMaskedImage().getVariance().getArray().ravel())) 

120 pp = (img / np.sqrt(calexp.getMaskedImage().getVariance().getArray())).ravel() 

121 plt.clf() 

122 lo, hi = -4, 4 

123 n, b, p = plt.hist(img.ravel() / sig1, 100, range=(lo, hi), histtype='step', color='b') 

124 plt.hist(pp, 100, range=(lo, hi), histtype='step', color='g') 

125 xx = np.linspace(lo, hi, 200) 

126 yy = 1./(np.sqrt(2.*np.pi)) * np.exp(-0.5 * xx**2) 

127 yy *= sum(n) * (b[1]-b[0]) 

128 plt.plot(xx, yy, 'k-', alpha=0.5) 

129 plt.xlim(lo, hi) 

130 plt.title('image-wide sig1: %.1f' % sig1) 

131 ps.savefig() 

132 

133 for ifam, (p, kids) in enumerate(families.items()): 

134 

135 parent = srcs[p] 

136 pid = parent.getId() & 0xffff 

137 if len(pids) and pid not in pids: 

138 # print('Skipping pid', pid) 

139 continue 

140 

141 if len(kids) < minsize: 

142 print('Skipping parent', pid, ': n kids', len(kids)) 

143 continue 

144 

145 # if len(kids) < 5: 

146 # print 'Skipping family with', len(kids) 

147 # continue 

148 # print 'ifam', ifam 

149 # if ifam != 18: 

150 # print 'skipping' 

151 # continue 

152 

153 print('Parent', parent) 

154 print('Kids', kids) 

155 

156 print('Parent', parent.getId()) 

157 print('Kids', [k.getId() for k in kids]) 

158 

159 pfoot = parent.getFootprint() 

160 bb = pfoot.getBBox() 

161 

162 y0, y1, x0, x1 = bb.getMinY(), bb.getMaxY(), bb.getMinX(), bb.getMaxX() 

163 slc = slice(y0, y1+1), slice(x0, x1+1) 

164 

165 ima = dict(interpolation='nearest', origin='lower', cmap='gray', 

166 vmin=-10, vmax=40) 

167 mn, mx = ima['vmin'], ima['vmax'] 

168 

169 if False: 

170 plt.clf() 

171 plt.imshow(img[slc], extent=bb_to_ext(bb), **ima) 

172 plt.title('Parent %i, %s' % (parent.getId(), getFlagString(parent))) 

173 ax = plt.axis() 

174 x, y = bb_to_xy(bb) 

175 plt.plot(x, y, 'r-', lw=2) 

176 for i, kid in enumerate(kids): 

177 kfoot = kid.getFootprint() 

178 kbb = kfoot.getBBox() 

179 kx, ky = bb_to_xy(kbb, margin=0.4) 

180 plt.plot(kx, ky, 'm-') 

181 for pk in pfoot.getPeaks(): 

182 plt.plot(pk.getIx(), pk.getIy(), 'r+', ms=10, mew=3) 

183 plt.axis(ax) 

184 ps.savefig() 

185 

186 print('parent footprint:', pfoot) 

187 print('heavy?', pfoot.isHeavy()) 

188 plt.clf() 

189 pimg, h = foot_to_img(pfoot, lsstimg) 

190 

191 plt.imshow(img_to_rgb(pimg.getArray(), mn, mx), extent=bb_to_ext(bb), **ima) 

192 tt = 'Parent %i' % parent.getId() 

193 if not h: 

194 tt += ', no HFoot' 

195 tt += ', ' + getFlagString(parent) 

196 plt.title(tt) 

197 ax = plt.axis() 

198 plt.plot([x0, x0, x1, x1, x0], [y0, y1, y1, y0, y0], 'r-', lw=2) 

199 for i, kid in enumerate(kids): 

200 kfoot = kid.getFootprint() 

201 kbb = kfoot.getBBox() 

202 kx, ky = bb_to_xy(kbb, margin=-0.1) 

203 plt.plot(kx, ky, 'm-', lw=1.5) 

204 for pk in pfoot.getPeaks(): 

205 plt.plot(pk.getIx(), pk.getIy(), 'r+', ms=10, mew=3) 

206 plt.axis(ax) 

207 ps.savefig() 

208 

209 cols = int(np.ceil(np.sqrt(len(kids)))) 

210 rows = int(np.ceil(len(kids) / float(cols))) 

211 

212 if False: 

213 plt.clf() 

214 for i, kid in enumerate(kids): 

215 plt.subplot(rows, cols, 1+i) 

216 kfoot = kid.getFootprint() 

217 print('kfoot:', kfoot) 

218 print('heavy?', kfoot.isHeavy()) 

219 # print(dir(kid)) 

220 kbb = kfoot.getBBox() 

221 ky0, ky1, kx0, kx1 = kbb.getMinY(), kbb.getMaxY(), kbb.getMinX(), kbb.getMaxX() 

222 kslc = slice(ky0, ky1+1), slice(kx0, kx1+1) 

223 plt.imshow(img[kslc], extent=bb_to_ext(kbb), **ima) 

224 plt.title('Child %i' % kid.getId()) 

225 plt.axis(ax) 

226 ps.savefig() 

227 

228 plt.clf() 

229 for i, kid in enumerate(kids): 

230 plt.subplot(rows, cols, 1+i) 

231 kfoot = kid.getFootprint() 

232 kbb = kfoot.getBBox() 

233 kimg, h = foot_to_img(kfoot, lsstimg) 

234 tt = getFlagString(kid) 

235 if not h: 

236 tt += ', no HFoot' 

237 plt.title('%s' % tt) 

238 if kimg is None: 

239 plt.axis(ax) 

240 continue 

241 plt.imshow(img_to_rgb(kimg.getArray(), mn, mx), extent=bb_to_ext(kbb), **ima) 

242 for pk in kfoot.getPeaks(): 

243 plt.plot(pk.getIx(), pk.getIy(), 'g+', ms=10, mew=3) 

244 plt.axis(ax) 

245 plt.suptitle('Child HeavyFootprints') 

246 ps.savefig() 

247 

248 print() 

249 print('Re-running deblender...') 

250 psf = calexp.getPsf() 

251 psf_fwhm = psf.computeShape(psf.getAveragePosition()).getDeterminantRadius() * 2.35 

252 deb = deblend(pfoot, calexp.getMaskedImage(), psf, psf_fwhm, verbose=True, 

253 maxNumberOfPeaks=maxpeaks, 

254 rampFluxAtEdge=True, 

255 clipStrayFluxFraction=0.01, 

256 ) 

257 print('Got', deb) 

258 

259 def getDebFlagString(kid): 

260 ss = [] 

261 for k in ['skip', 'outOfBounds', 'tinyFootprint', 'noValidPixels', 

262 ('deblendedAsPsf', 'PSF'), 'psfFitFailed', 'psfFitBadDof', 

263 'psfFitBigDecenter', 'psfFitWithDecenter', 

264 'failedSymmetricTemplate', 'hasRampedTemplate', 'patched']: 

265 if len(k) == 2: 

266 k, s = k 

267 else: 

268 s = k 

269 if getattr(kid, k): 

270 ss.append(s) 

271 return ', '.join(ss) 

272 

273 N = len(deb.peaks) 

274 cols = int(np.ceil(np.sqrt(N))) 

275 rows = int(np.ceil(N / float(cols))) 

276 

277 for plotnum in range(4): 

278 plt.clf() 

279 for i, kid in enumerate(deb.peaks): 

280 # print 'child', kid 

281 # print ' flags:', getDebFlagString(kid) 

282 

283 kfoot = None 

284 if plotnum == 0: 

285 kfoot = kid.getFluxPortion(strayFlux=False) 

286 supt = 'flux portion' 

287 elif plotnum == 1: 

288 kfoot = kid.getFluxPortion(strayFlux=True) 

289 supt = 'flux portion + stray' 

290 elif plotnum == 2: 

291 kfoot = afwDet.makeHeavyFootprint(kid.templateFootprint, 

292 kid.templateImage) 

293 supt = 'template' 

294 elif plotnum == 3: 

295 if kid.deblendedAsPsf: 

296 kfoot = afwDet.makeHeavyFootprint(kid.psfFootprint, 

297 kid.psfTemplate) 

298 kfoot.normalize() 

299 kfoot.clipToNonzero(kid.psfTemplate.getImage()) 

300 # print 'kfoot BB:', kfoot.getBBox() 

301 # print 'Img bb:', kid.psfTemplate.getImage().getBBox() 

302 # for sp in kfoot.getSpans(): 

303 # print ' span', sp 

304 else: 

305 kfoot = afwDet.makeHeavyFootprint(kid.templateFootprint, 

306 kid.templateImage) 

307 supt = 'psf template' 

308 

309 kimg, h = foot_to_img(kfoot, None) 

310 tt = 'kid %i: %s' % (i, getDebFlagString(kid)) 

311 if not h: 

312 tt += ', no HFoot' 

313 plt.subplot(rows, cols, 1+i) 

314 plt.title('%s' % tt, fontsize=8) 

315 if kimg is None: 

316 plt.axis(ax) 

317 continue 

318 kbb = kfoot.getBBox() 

319 

320 plt.imshow(img_to_rgb(kimg.getArray(), mn, mx), extent=bb_to_ext(kbb), **ima) 

321 

322 # plt.imshow(kimg.getArray(), extent=bb_to_ext(kbb), **ima) 

323 

324 plt.axis(ax) 

325 

326 plt.suptitle(supt) 

327 ps.savefig() 

328 

329 for i, kid in enumerate(deb.peaks): 

330 if not kid.deblendedAsPsf: 

331 continue 

332 plt.clf() 

333 

334 ima = dict(interpolation='nearest', origin='lower', cmap='gray') 

335 # vmin=0, vmax=kid.psfFitFlux) 

336 

337 plt.subplot(2, 4, 1) 

338 # plt.title('fit psf 0') 

339 # plt.imshow(kid.psfFitDebugPsf0Img.getArray(), **ima) 

340 # plt.colorbar() 

341 # plt.title('valid pixels') 

342 # plt.imshow(kid.psfFitDebugValidPix, vmin=0, vmax=1, **ima) 

343 plt.title('weights') 

344 plt.imshow(kid.psfFitDebugWeight, vmin=0, **ima) 

345 plt.xticks([]) 

346 plt.yticks([]) 

347 plt.colorbar() 

348 

349 plt.subplot(2, 4, 7) 

350 plt.title('valid pixels') 

351 plt.imshow(kid.psfFitDebugValidPix, vmin=0, vmax=1, **ima) 

352 plt.xticks([]) 

353 plt.yticks([]) 

354 plt.colorbar() 

355 

356 plt.subplot(2, 4, 2) 

357 # plt.title('ramp weights') 

358 # plt.imshow(kid.psfFitDebugRampWeight, vmin=0, vmax=1, **ima) 

359 # plt.colorbar() 

360 sig = np.sqrt(kid.psfFitDebugVar.getArray()) 

361 data = kid.psfFitDebugStamp.getArray() 

362 model = kid.psfFitDebugPsfModel.getArray() 

363 chi = ((data - model) / sig) 

364 valid = kid.psfFitDebugValidPix 

365 

366 plt.hist(np.clip((data/sig)[valid], -5, 5), 20, range=(-5, 5), 

367 histtype='step', color='m') 

368 plt.hist(np.clip((model/sig)[valid], -5, 5), 20, range=(-5, 5), 

369 histtype='step', color='r') 

370 plt.hist(np.clip(chi.ravel(), -5, 5), 20, range=(-5, 5), 

371 histtype='step', color='g') 

372 n, b, p = plt.hist(np.clip(chi[valid], -5, 5), 20, range=(-5, 5), 

373 histtype='step', color='b') 

374 

375 xx = np.linspace(-5, 5, 200) 

376 yy = 1./(np.sqrt(2.*np.pi)) * np.exp(-0.5 * xx**2) 

377 yy *= sum(n) * (b[1]-b[0]) 

378 plt.plot(xx, yy, 'k-', alpha=0.5) 

379 

380 plt.xlim(-5, 5) 

381 

382 print('Sum of ramp weights:', np.sum(kid.psfFitDebugRampWeight)) 

383 print('Quadrature sum of ramp weights:', np.sqrt(np.sum(kid.psfFitDebugRampWeight**2))) 

384 print('Number of valid pix:', np.sum(kid.psfFitDebugValidPix)) 

385 rw = kid.psfFitDebugRampWeight 

386 valid = kid.psfFitDebugValidPix 

387 # print 'valid values:', np.unique(valid) 

388 print('rw[valid]', np.sum(rw[valid])) 

389 print('rw range', rw.min(), rw.max()) 

390 # print 'rw', rw.shape, rw.dtype 

391 # print 'valid', valid.shape, valid.dtype 

392 # print 'rw[valid]:', rw[valid] 

393 

394 myresid = np.sum(kid.psfFitDebugValidPix 

395 * kid.psfFitDebugRampWeight 

396 * ((kid.psfFitDebugStamp.getArray() 

397 - kid.psfFitDebugPsfModel.getArray()) 

398 / np.sqrt(kid.psfFitDebugVar.getArray()))**2) 

399 print('myresid:', myresid) 

400 

401 plt.subplot(2, 4, 8) 

402 N = 20000 

403 rwv = rw[valid] 

404 print('rwv', rwv) 

405 x = np.random.normal(size=(N, len(rwv))) 

406 ss = np.sum(rwv * x**2, axis=1) 

407 plt.hist(ss, 25) 

408 chi, dof = kid.psfFitBest 

409 plt.axvline(chi, color='r') 

410 

411 mx = kid.psfFitDebugPsfModel.getArray().max() 

412 

413 plt.subplot(2, 4, 3) 

414 # plt.title('fit psf') 

415 # plt.imshow(kid.psfFitDebugPsfImg.getArray(), **ima) 

416 # plt.colorbar() 

417 # plt.title('variance') 

418 # plt.imshow(kid.psfFitDebugVar.getArray(), vmin=0, **ima) 

419 # plt.colorbar() 

420 plt.title('model+noise') 

421 plt.imshow((kid.psfFitDebugPsfModel.getArray() 

422 + sig * np.random.normal(size=sig.shape))*valid, 

423 vmin=0, vmax=mx, **ima) 

424 plt.xticks([]) 

425 plt.yticks([]) 

426 plt.colorbar() 

427 

428 plt.subplot(2, 4, 4) 

429 plt.title('fit psf model') 

430 plt.imshow(kid.psfFitDebugPsfModel.getArray(), vmin=0, vmax=mx, **ima) 

431 plt.xticks([]) 

432 plt.yticks([]) 

433 plt.colorbar() 

434 

435 plt.subplot(2, 4, 5) 

436 plt.title('fit psf image') 

437 plt.imshow(kid.psfFitDebugStamp.getArray(), vmin=0, vmax=mx, **ima) 

438 plt.xticks([]) 

439 plt.yticks([]) 

440 plt.colorbar() 

441 

442 chi = (kid.psfFitDebugValidPix 

443 * (kid.psfFitDebugStamp.getArray() 

444 - kid.psfFitDebugPsfModel.getArray()) 

445 / np.sqrt(kid.psfFitDebugVar.getArray())) 

446 

447 plt.subplot(2, 4, 6) 

448 plt.title('fit psf chi') 

449 plt.imshow(-chi, vmin=-3, vmax=3, interpolation='nearest', origin='lower', cmap='RdBu') 

450 plt.xticks([]) 

451 plt.yticks([]) 

452 plt.colorbar() 

453 

454 params = kid.psfFitParams 

455 (flux, sky, skyx, skyy) = params[:4] 

456 

457 print('Model sum:', model.sum()) 

458 print('- sky', model.sum() - np.sum(valid)*sky) 

459 

460 sig1 = np.median(sig) 

461 

462 chi, dof = kid.psfFitBest 

463 plt.suptitle('PSF kid %i: flux %.1f, sky %.1f, sig1 %.1f' % 

464 (i, flux, sky, sig1)) # : chisq %g, dof %i' % (i, chi, dof)) 

465 

466 ps.savefig() 

467 

468 # if ifam == 5: 

469 # break 

470 

471 

472if __name__ == '__main__': 

473 import optparse 

474 import sys 

475 

476 parser = optparse.OptionParser() 

477 parser.add_option('--data', help='Data dir, default $SUPRIME_DATA_DIR/rerun/RERUN') 

478 parser.add_option('--rerun', help='Rerun name, default %default', default='dstn/deb') 

479 parser.add_option('--visit', help='Visit number, default %default', default=905516, type=int) 

480 parser.add_option('--ccd', help='CCD number, default %default', default=22, type=int) 

481 parser.add_option('--sources', help='Read sources file', type=str) 

482 parser.add_option('--hdu', help='With --sources, HDU to read; default %default', 

483 type=int, default=2) 

484 parser.add_option('--pid', '-p', action='append', default=[], type=int, 

485 help='Deblend a specific parent ID') 

486 parser.add_option('--big', dest='minsize', default=0, 

487 help='Only show results for deblend families larger than this', type=int) 

488 parser.add_option('--maxpeaks', default=10, help='maxNumberOfPeaks', type=int) 

489 

490 opt, args = parser.parse_args() 

491 

492 if len(args): 

493 parser.print_help() 

494 sys.exit(-1) 

495 

496 if not opt.data: 

497 opt.data = os.path.join(os.environ['SUPRIME_DATA_DIR'], 

498 'rerun', opt.rerun) 

499 

500 print('Data directory:', opt.data) 

501 butler = dafPersist.Butler(opt.data) 

502 dataId = dict(visit=opt.visit, ccd=opt.ccd) 

503 ps = PlotSequence('deb') 

504 

505 sources = None 

506 if opt.sources: 

507 flags = 0 

508 sources = afwTable.SourceCatalog.readFits(opt.sources, opt.hdu, flags) 

509 print('Read sources from', opt.sources, ':', sources) 

510 

511 makeplots(butler, dataId, ps, sources=sources, pids=opt.pid, minsize=opt.minsize, 

512 maxpeaks=opt.maxpeaks)