Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of meas_algorithms. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Support utilities for Measuring sources""" 

23 

24import math 

25import numpy 

26 

27import lsst.log 

28import lsst.pex.exceptions as pexExcept 

29import lsst.daf.base as dafBase 

30import lsst.geom 

31import lsst.afw.geom as afwGeom 

32import lsst.afw.detection as afwDet 

33import lsst.afw.image as afwImage 

34import lsst.afw.math as afwMath 

35import lsst.afw.table as afwTable 

36import lsst.afw.display as afwDisplay 

37import lsst.afw.display.utils as displayUtils 

38import lsst.meas.base as measBase 

39from . import subtractPsf, fitKernelParamsToImage 

40 

41keptPlots = False # Have we arranged to keep spatial plots open? 

42 

43afwDisplay.setDefaultMaskTransparency(75) 

44 

45 

46def splitId(oid, asDict=True): 

47 

48 objId = int((oid & 0xffff) - 1) # Should be the value set by apps code 

49 

50 if asDict: 

51 return dict(objId=objId) 

52 else: 

53 return [objId] 

54 

55 

56def showSourceSet(sSet, xy0=(0, 0), display=None, ctype=afwDisplay.GREEN, symb="+", size=2): 

57 """Draw the (XAstrom, YAstrom) positions of a set of Sources. Image has the given XY0""" 

58 

59 if not display: 

60 display = afwDisplay.Display() 

61 with display.Buffering(): 

62 for s in sSet: 

63 xc, yc = s.getXAstrom() - xy0[0], s.getYAstrom() - xy0[1] 

64 

65 if symb == "id": 

66 display.dot(str(splitId(s.getId(), True)["objId"]), xc, yc, ctype=ctype, size=size) 

67 else: 

68 display.dot(symb, xc, yc, ctype=ctype, size=size) 

69 

70# 

71# PSF display utilities 

72# 

73 

74 

75def showPsfSpatialCells(exposure, psfCellSet, nMaxPerCell=-1, showChi2=False, showMoments=False, 

76 symb=None, ctype=None, ctypeUnused=None, ctypeBad=None, size=2, display=None): 

77 """Show the SpatialCells. 

78 

79 If symb is something that afwDisplay.Display.dot() understands (e.g. "o"), 

80 the top nMaxPerCell candidates will be indicated with that symbol, using 

81 ctype and size. 

82 """ 

83 

84 if not display: 

85 display = afwDisplay.Display() 

86 with display.Buffering(): 

87 origin = [-exposure.getMaskedImage().getX0(), -exposure.getMaskedImage().getY0()] 

88 for cell in psfCellSet.getCellList(): 

89 displayUtils.drawBBox(cell.getBBox(), origin=origin, display=display) 

90 

91 if nMaxPerCell < 0: 

92 nMaxPerCell = 0 

93 

94 i = 0 

95 goodies = ctypeBad is None 

96 for cand in cell.begin(goodies): 

97 if nMaxPerCell > 0: 

98 i += 1 

99 

100 xc, yc = cand.getXCenter() + origin[0], cand.getYCenter() + origin[1] 

101 

102 if i > nMaxPerCell: 

103 if not ctypeUnused: 

104 continue 

105 

106 color = ctypeBad if cand.isBad() else ctype 

107 

108 if symb: 

109 if i > nMaxPerCell: 

110 ct = ctypeUnused 

111 else: 

112 ct = ctype 

113 

114 display.dot(symb, xc, yc, ctype=ct, size=size) 

115 

116 source = cand.getSource() 

117 

118 if showChi2: 

119 rchi2 = cand.getChi2() 

120 if rchi2 > 1e100: 

121 rchi2 = numpy.nan 

122 display.dot("%d %.1f" % (splitId(source.getId(), True)["objId"], rchi2), 

123 xc - size, yc - size - 4, ctype=color, size=2) 

124 

125 if showMoments: 

126 display.dot("%.2f %.2f %.2f" % (source.getIxx(), source.getIxy(), source.getIyy()), 

127 xc-size, yc + size + 4, ctype=color, size=size) 

128 return display 

129 

130 

131def showPsfCandidates(exposure, psfCellSet, psf=None, display=None, normalize=True, showBadCandidates=True, 

132 fitBasisComponents=False, variance=None, chi=None): 

133 """Display the PSF candidates. 

134 

135 If psf is provided include PSF model and residuals; if normalize is true normalize the PSFs 

136 (and residuals) 

137 

138 If chi is True, generate a plot of residuals/sqrt(variance), i.e. chi 

139 

140 If fitBasisComponents is true, also find the best linear combination of the PSF's components 

141 (if they exist) 

142 """ 

143 if not display: 

144 display = afwDisplay.Display() 

145 

146 if chi is None: 

147 if variance is not None: # old name for chi 

148 chi = variance 

149 # 

150 # Show us the ccandidates 

151 # 

152 mos = displayUtils.Mosaic() 

153 # 

154 candidateCenters = [] 

155 candidateCentersBad = [] 

156 candidateIndex = 0 

157 

158 for cell in psfCellSet.getCellList(): 

159 for cand in cell.begin(False): # include bad candidates 

160 rchi2 = cand.getChi2() 

161 if rchi2 > 1e100: 

162 rchi2 = numpy.nan 

163 

164 if not showBadCandidates and cand.isBad(): 

165 continue 

166 

167 if psf: 

168 im_resid = displayUtils.Mosaic(gutter=0, background=-5, mode="x") 

169 

170 try: 

171 im = cand.getMaskedImage() # copy of this object's image 

172 xc, yc = cand.getXCenter(), cand.getYCenter() 

173 

174 margin = 0 if True else 5 

175 w, h = im.getDimensions() 

176 bbox = lsst.geom.BoxI(lsst.geom.PointI(margin, margin), im.getDimensions()) 

177 

178 if margin > 0: 

179 bim = im.Factory(w + 2*margin, h + 2*margin) 

180 

181 stdev = numpy.sqrt(afwMath.makeStatistics(im.getVariance(), afwMath.MEAN).getValue()) 

182 afwMath.randomGaussianImage(bim.getImage(), afwMath.Random()) 

183 bim.getVariance().set(stdev**2) 

184 

185 bim.assign(im, bbox) 

186 im = bim 

187 xc += margin 

188 yc += margin 

189 

190 im = im.Factory(im, True) 

191 im.setXY0(cand.getMaskedImage().getXY0()) 

192 except Exception: 

193 continue 

194 

195 if not variance: 

196 im_resid.append(im.Factory(im, True)) 

197 

198 if True: # tweak up centroids 

199 mi = im 

200 psfIm = mi.getImage() 

201 config = measBase.SingleFrameMeasurementTask.ConfigClass() 

202 config.slots.centroid = "base_SdssCentroid" 

203 

204 schema = afwTable.SourceTable.makeMinimalSchema() 

205 measureSources = measBase.SingleFrameMeasurementTask(schema, config=config) 

206 catalog = afwTable.SourceCatalog(schema) 

207 

208 extra = 10 # enough margin to run the sdss centroider 

209 miBig = mi.Factory(im.getWidth() + 2*extra, im.getHeight() + 2*extra) 

210 miBig[extra:-extra, extra:-extra, afwImage.LOCAL] = mi 

211 miBig.setXY0(mi.getX0() - extra, mi.getY0() - extra) 

212 mi = miBig 

213 del miBig 

214 

215 exp = afwImage.makeExposure(mi) 

216 exp.setPsf(psf) 

217 

218 footprintSet = afwDet.FootprintSet(mi, 

219 afwDet.Threshold(0.5*numpy.max(psfIm.getArray())), 

220 "DETECTED") 

221 footprintSet.makeSources(catalog) 

222 

223 if len(catalog) == 0: 

224 raise RuntimeError("Failed to detect any objects") 

225 

226 measureSources.run(catalog, exp) 

227 if len(catalog) == 1: 

228 source = catalog[0] 

229 else: # more than one source; find the once closest to (xc, yc) 

230 dmin = None # an invalid value to catch logic errors 

231 for i, s in enumerate(catalog): 

232 d = numpy.hypot(xc - s.getX(), yc - s.getY()) 

233 if i == 0 or d < dmin: 

234 source, dmin = s, d 

235 xc, yc = source.getCentroid() 

236 

237 # residuals using spatial model 

238 try: 

239 subtractPsf(psf, im, xc, yc) 

240 except Exception: 

241 continue 

242 

243 resid = im 

244 if variance: 

245 resid = resid.getImage() 

246 var = im.getVariance() 

247 var = var.Factory(var, True) 

248 numpy.sqrt(var.getArray(), var.getArray()) # inplace sqrt 

249 resid /= var 

250 

251 im_resid.append(resid) 

252 

253 # Fit the PSF components directly to the data (i.e. ignoring the spatial model) 

254 if fitBasisComponents: 

255 im = cand.getMaskedImage() 

256 

257 im = im.Factory(im, True) 

258 im.setXY0(cand.getMaskedImage().getXY0()) 

259 

260 try: 

261 noSpatialKernel = psf.getKernel() 

262 except Exception: 

263 noSpatialKernel = None 

264 

265 if noSpatialKernel: 

266 candCenter = lsst.geom.PointD(cand.getXCenter(), cand.getYCenter()) 

267 fit = fitKernelParamsToImage(noSpatialKernel, im, candCenter) 

268 params = fit[0] 

269 kernels = afwMath.KernelList(fit[1]) 

270 outputKernel = afwMath.LinearCombinationKernel(kernels, params) 

271 

272 outImage = afwImage.ImageD(outputKernel.getDimensions()) 

273 outputKernel.computeImage(outImage, False) 

274 

275 im -= outImage.convertF() 

276 resid = im 

277 

278 if margin > 0: 

279 bim = im.Factory(w + 2*margin, h + 2*margin) 

280 afwMath.randomGaussianImage(bim.getImage(), afwMath.Random()) 

281 bim *= stdev 

282 

283 bim.assign(resid, bbox) 

284 resid = bim 

285 

286 if variance: 

287 resid = resid.getImage() 

288 resid /= var 

289 

290 im_resid.append(resid) 

291 

292 im = im_resid.makeMosaic() 

293 else: 

294 im = cand.getMaskedImage() 

295 

296 if normalize: 

297 im /= afwMath.makeStatistics(im, afwMath.MAX).getValue() 

298 

299 objId = splitId(cand.getSource().getId(), True)["objId"] 

300 if psf: 

301 lab = "%d chi^2 %.1f" % (objId, rchi2) 

302 ctype = afwDisplay.RED if cand.isBad() else afwDisplay.GREEN 

303 else: 

304 lab = "%d flux %8.3g" % (objId, cand.getSource().getPsfInstFlux()) 

305 ctype = afwDisplay.GREEN 

306 

307 mos.append(im, lab, ctype) 

308 

309 if False and numpy.isnan(rchi2): 

310 display.mtv(cand.getMaskedImage().getImage(), title="showPsfCandidates: candidate") 

311 print("amp", cand.getAmplitude()) 

312 

313 im = cand.getMaskedImage() 

314 center = (candidateIndex, xc - im.getX0(), yc - im.getY0()) 

315 candidateIndex += 1 

316 if cand.isBad(): 

317 candidateCentersBad.append(center) 

318 else: 

319 candidateCenters.append(center) 

320 

321 if variance: 

322 title = "chi(Psf fit)" 

323 else: 

324 title = "Stars & residuals" 

325 mosaicImage = mos.makeMosaic(display=display, title=title) 

326 

327 with display.Buffering(): 

328 for centers, color in ((candidateCenters, afwDisplay.GREEN), (candidateCentersBad, afwDisplay.RED)): 

329 for cen in centers: 

330 bbox = mos.getBBox(cen[0]) 

331 display.dot("+", cen[1] + bbox.getMinX(), cen[2] + bbox.getMinY(), ctype=color) 

332 

333 return mosaicImage 

334 

335 

336def makeSubplots(fig, nx=2, ny=2, Nx=1, Ny=1, plottingArea=(0.1, 0.1, 0.85, 0.80), 

337 pxgutter=0.05, pygutter=0.05, xgutter=0.04, ygutter=0.04, 

338 headroom=0.0, panelBorderWeight=0, panelColor='black'): 

339 """Return a generator of a set of subplots, a set of Nx*Ny panels of nx*ny plots. Each panel is fully 

340 filled by row (starting in the bottom left) before the next panel is started. If panelBorderWidth is 

341 greater than zero a border is drawn around each panel, adjusted to enclose the axis labels. 

342 

343 E.g. 

344 subplots = makeSubplots(fig, 2, 2, Nx=1, Ny=1, panelColor='k') 

345 ax = subplots.next(); ax.text(0.3, 0.5, '[0, 0] (0,0)') 

346 ax = subplots.next(); ax.text(0.3, 0.5, '[0, 0] (1,0)') 

347 ax = subplots.next(); ax.text(0.3, 0.5, '[0, 0] (0,1)') 

348 ax = subplots.next(); ax.text(0.3, 0.5, '[0, 0] (1,1)') 

349 fig.show() 

350 

351 @param fig The matplotlib figure to draw 

352 @param nx The number of plots in each row of each panel 

353 @param ny The number of plots in each column of each panel 

354 @param Nx The number of panels in each row of the figure 

355 @param Ny The number of panels in each column of the figure 

356 @param plottingArea (x0, y0, x1, y1) for the part of the figure containing all the panels 

357 @param pxgutter Spacing between columns of panels in units of (x1 - x0) 

358 @param pygutter Spacing between rows of panels in units of (y1 - y0) 

359 @param xgutter Spacing between columns of plots within a panel in units of (x1 - x0) 

360 @param ygutter Spacing between rows of plots within a panel in units of (y1 - y0) 

361 @param headroom Extra spacing above each plot for e.g. a title 

362 @param panelBorderWeight Width of border drawn around panels 

363 @param panelColor Colour of border around panels 

364 """ 

365 

366 log = lsst.log.Log.getLogger("utils.makeSubplots") 

367 try: 

368 import matplotlib.pyplot as plt 

369 except ImportError as e: 

370 log.warn("Unable to import matplotlib: %s", e) 

371 return 

372 

373 # Make show() call canvas.draw() too so that we know how large the axis labels are. Sigh 

374 try: 

375 fig.__show 

376 except AttributeError: 

377 fig.__show = fig.show 

378 

379 def myShow(fig): 

380 fig.__show() 

381 fig.canvas.draw() 

382 

383 import types 

384 fig.show = types.MethodType(myShow, fig) 

385 # 

386 # We can't get the axis sizes until after draw()'s been called, so use a callback Sigh^2 

387 # 

388 axes = {} # all axes in all the panels we're drawing: axes[panel][0] etc. 

389 # 

390 

391 def on_draw(event): 

392 """ 

393 Callback to draw the panel borders when the plots are drawn to the canvas 

394 """ 

395 if panelBorderWeight <= 0: 

396 return False 

397 

398 for p in axes.keys(): 

399 bboxes = [] 

400 for ax in axes[p]: 

401 bboxes.append(ax.bbox.union([label.get_window_extent() for label in 

402 ax.get_xticklabels() + ax.get_yticklabels()])) 

403 

404 ax = axes[p][0] 

405 

406 # this is the bbox that bounds all the bboxes, again in relative 

407 # figure coords 

408 

409 bbox = ax.bbox.union(bboxes) 

410 

411 xy0, xy1 = ax.transData.inverted().transform(bbox) 

412 x0, y0 = xy0 

413 x1, y1 = xy1 

414 w, h = x1 - x0, y1 - y0 

415 # allow a little space around BBox 

416 x0 -= 0.02*w 

417 w += 0.04*w 

418 y0 -= 0.02*h 

419 h += 0.04*h 

420 h += h*headroom 

421 # draw BBox 

422 ax.patches = [] # remove old ones 

423 rec = ax.add_patch(plt.Rectangle((x0, y0), w, h, fill=False, 

424 lw=panelBorderWeight, edgecolor=panelColor)) 

425 rec.set_clip_on(False) 

426 

427 return False 

428 

429 fig.canvas.mpl_connect('draw_event', on_draw) 

430 # 

431 # Choose the plotting areas for each subplot 

432 # 

433 x0, y0 = plottingArea[0:2] 

434 W, H = plottingArea[2:4] 

435 w = (W - (Nx - 1)*pxgutter - (nx*Nx - 1)*xgutter)/float(nx*Nx) 

436 h = (H - (Ny - 1)*pygutter - (ny*Ny - 1)*ygutter)/float(ny*Ny) 

437 # 

438 # OK! Time to create the subplots 

439 # 

440 for panel in range(Nx*Ny): 

441 axes[panel] = [] 

442 px = panel%Nx 

443 py = panel//Nx 

444 for window in range(nx*ny): 

445 x = nx*px + window%nx 

446 y = ny*py + window//nx 

447 ax = fig.add_axes((x0 + xgutter + pxgutter + x*w + (px - 1)*pxgutter + (x - 1)*xgutter, 

448 y0 + ygutter + pygutter + y*h + (py - 1)*pygutter + (y - 1)*ygutter, 

449 w, h), frame_on=True, facecolor='w') 

450 axes[panel].append(ax) 

451 yield ax 

452 

453 

454def plotPsfSpatialModel(exposure, psf, psfCellSet, showBadCandidates=True, numSample=128, 

455 matchKernelAmplitudes=False, keepPlots=True): 

456 """Plot the PSF spatial model.""" 

457 

458 log = lsst.log.Log.getLogger("utils.plotPsfSpatialModel") 

459 try: 

460 import matplotlib.pyplot as plt 

461 import matplotlib as mpl 

462 except ImportError as e: 

463 log.warn("Unable to import matplotlib: %s", e) 

464 return 

465 

466 noSpatialKernel = psf.getKernel() 

467 candPos = list() 

468 candFits = list() 

469 badPos = list() 

470 badFits = list() 

471 candAmps = list() 

472 badAmps = list() 

473 for cell in psfCellSet.getCellList(): 

474 for cand in cell.begin(False): 

475 if not showBadCandidates and cand.isBad(): 

476 continue 

477 candCenter = lsst.geom.PointD(cand.getXCenter(), cand.getYCenter()) 

478 try: 

479 im = cand.getMaskedImage() 

480 except Exception: 

481 continue 

482 

483 fit = fitKernelParamsToImage(noSpatialKernel, im, candCenter) 

484 params = fit[0] 

485 kernels = fit[1] 

486 amp = 0.0 

487 for p, k in zip(params, kernels): 

488 amp += p * k.getSum() 

489 

490 targetFits = badFits if cand.isBad() else candFits 

491 targetPos = badPos if cand.isBad() else candPos 

492 targetAmps = badAmps if cand.isBad() else candAmps 

493 

494 targetFits.append([x / amp for x in params]) 

495 targetPos.append(candCenter) 

496 targetAmps.append(amp) 

497 

498 xGood = numpy.array([pos.getX() for pos in candPos]) - exposure.getX0() 

499 yGood = numpy.array([pos.getY() for pos in candPos]) - exposure.getY0() 

500 zGood = numpy.array(candFits) 

501 

502 xBad = numpy.array([pos.getX() for pos in badPos]) - exposure.getX0() 

503 yBad = numpy.array([pos.getY() for pos in badPos]) - exposure.getY0() 

504 zBad = numpy.array(badFits) 

505 numBad = len(badPos) 

506 

507 xRange = numpy.linspace(0, exposure.getWidth(), num=numSample) 

508 yRange = numpy.linspace(0, exposure.getHeight(), num=numSample) 

509 

510 kernel = psf.getKernel() 

511 nKernelComponents = kernel.getNKernelParameters() 

512 # 

513 # Figure out how many panels we'll need 

514 # 

515 nPanelX = int(math.sqrt(nKernelComponents)) 

516 nPanelY = nKernelComponents//nPanelX 

517 while nPanelY*nPanelX < nKernelComponents: 

518 nPanelX += 1 

519 

520 fig = plt.figure(1) 

521 fig.clf() 

522 try: 

523 fig.canvas._tkcanvas._root().lift() # == Tk's raise, but raise is a python reserved word 

524 except Exception: # protect against API changes 

525 pass 

526 # 

527 # Generator for axes arranged in panels 

528 # 

529 mpl.rcParams["figure.titlesize"] = "x-small" 

530 subplots = makeSubplots(fig, 2, 2, Nx=nPanelX, Ny=nPanelY, xgutter=0.06, ygutter=0.06, pygutter=0.04) 

531 

532 for k in range(nKernelComponents): 

533 func = kernel.getSpatialFunction(k) 

534 dfGood = zGood[:, k] - numpy.array([func(pos.getX(), pos.getY()) for pos in candPos]) 

535 yMin = dfGood.min() 

536 yMax = dfGood.max() 

537 if numBad > 0: 

538 dfBad = zBad[:, k] - numpy.array([func(pos.getX(), pos.getY()) for pos in badPos]) 

539 yMin = min([yMin, dfBad.min()]) 

540 yMax = max([yMax, dfBad.max()]) 

541 yMin -= 0.05 * (yMax - yMin) 

542 yMax += 0.05 * (yMax - yMin) 

543 

544 yMin = -0.01 

545 yMax = 0.01 

546 

547 fRange = numpy.ndarray((len(xRange), len(yRange))) 

548 for j, yVal in enumerate(yRange): 

549 for i, xVal in enumerate(xRange): 

550 fRange[j][i] = func(xVal, yVal) 

551 

552 ax = next(subplots) 

553 

554 ax.set_autoscale_on(False) 

555 ax.set_xbound(lower=0, upper=exposure.getHeight()) 

556 ax.set_ybound(lower=yMin, upper=yMax) 

557 ax.plot(yGood, dfGood, 'b+') 

558 if numBad > 0: 

559 ax.plot(yBad, dfBad, 'r+') 

560 ax.axhline(0.0) 

561 ax.set_title('Residuals(y)') 

562 

563 ax = next(subplots) 

564 

565 if matchKernelAmplitudes and k == 0: 

566 vmin = 0.0 

567 vmax = 1.1 

568 else: 

569 vmin = fRange.min() 

570 vmax = fRange.max() 

571 

572 norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) 

573 im = ax.imshow(fRange, aspect='auto', origin="lower", norm=norm, 

574 extent=[0, exposure.getWidth()-1, 0, exposure.getHeight()-1]) 

575 ax.set_title('Spatial poly') 

576 plt.colorbar(im, orientation='horizontal', ticks=[vmin, vmax]) 

577 

578 ax = next(subplots) 

579 ax.set_autoscale_on(False) 

580 ax.set_xbound(lower=0, upper=exposure.getWidth()) 

581 ax.set_ybound(lower=yMin, upper=yMax) 

582 ax.plot(xGood, dfGood, 'b+') 

583 if numBad > 0: 

584 ax.plot(xBad, dfBad, 'r+') 

585 ax.axhline(0.0) 

586 ax.set_title('K%d Residuals(x)' % k) 

587 

588 ax = next(subplots) 

589 

590 photoCalib = exposure.getPhotoCalib() 

591 # If there is no calibration factor, use 1.0. 

592 if photoCalib.getCalibrationMean() <= 0: 

593 photoCalib = afwImage.PhotoCalib(1.0) 

594 

595 ampMag = [photoCalib.instFluxToMagnitude(candAmp) for candAmp in candAmps] 

596 ax.plot(ampMag, zGood[:, k], 'b+') 

597 if numBad > 0: 

598 badAmpMag = [photoCalib.instFluxToMagnitude(badAmp) for badAmp in badAmps] 

599 ax.plot(badAmpMag, zBad[:, k], 'r+') 

600 

601 ax.set_title('Flux variation') 

602 

603 fig.show() 

604 

605 global keptPlots 

606 if keepPlots and not keptPlots: 

607 # Keep plots open when done 

608 def show(): 

609 print("%s: Please close plots when done." % __name__) 

610 try: 

611 plt.show() 

612 except Exception: 

613 pass 

614 print("Plots closed, exiting...") 

615 import atexit 

616 atexit.register(show) 

617 keptPlots = True 

618 

619 

620def showPsf(psf, eigenValues=None, XY=None, normalize=True, display=None): 

621 """Display a PSF's eigen images 

622 

623 If normalize is True, set the largest absolute value of each eigenimage to 1.0 (n.b. sum == 0.0 for i > 0) 

624 """ 

625 

626 if eigenValues: 

627 coeffs = eigenValues 

628 elif XY is not None: 

629 coeffs = psf.getLocalKernel(lsst.geom.PointD(XY[0], XY[1])).getKernelParameters() 

630 else: 

631 coeffs = None 

632 

633 mos = displayUtils.Mosaic(gutter=2, background=-0.1) 

634 for i, k in enumerate(psf.getKernel().getKernelList()): 

635 im = afwImage.ImageD(k.getDimensions()) 

636 k.computeImage(im, False) 

637 if normalize: 

638 im /= numpy.max(numpy.abs(im.getArray())) 

639 

640 if coeffs: 

641 mos.append(im, "%g" % (coeffs[i]/coeffs[0])) 

642 else: 

643 mos.append(im) 

644 

645 if not display: 

646 display = afwDisplay.Display() 

647 mos.makeMosaic(display=display, title="Kernel Basis Functions") 

648 

649 return mos 

650 

651 

652def showPsfMosaic(exposure, psf=None, nx=7, ny=None, showCenter=True, showEllipticity=False, 

653 showFwhm=False, stampSize=0, display=None, title=None): 

654 """Show a mosaic of Psf images. exposure may be an Exposure (optionally with PSF), 

655 or a tuple (width, height) 

656 

657 If stampSize is > 0, the psf images will be trimmed to stampSize*stampSize 

658 """ 

659 

660 scale = 1.0 

661 if showFwhm: 

662 showEllipticity = True 

663 scale = 2*math.log(2) # convert sigma^2 to HWHM^2 for a Gaussian 

664 

665 mos = displayUtils.Mosaic() 

666 

667 try: # maybe it's a real Exposure 

668 width, height = exposure.getWidth(), exposure.getHeight() 

669 x0, y0 = exposure.getXY0() 

670 if not psf: 

671 psf = exposure.getPsf() 

672 except AttributeError: 

673 try: # OK, maybe a list [width, height] 

674 width, height = exposure[0], exposure[1] 

675 x0, y0 = 0, 0 

676 except TypeError: # I guess not 

677 raise RuntimeError("Unable to extract width/height from object of type %s" % type(exposure)) 

678 

679 if not ny: 

680 ny = int(nx*float(height)/width + 0.5) 

681 if not ny: 

682 ny = 1 

683 

684 centroidName = "SdssCentroid" 

685 shapeName = "base_SdssShape" 

686 

687 schema = afwTable.SourceTable.makeMinimalSchema() 

688 schema.getAliasMap().set("slot_Centroid", centroidName) 

689 schema.getAliasMap().set("slot_Centroid_flag", centroidName+"_flag") 

690 

691 control = measBase.SdssCentroidControl() 

692 centroider = measBase.SdssCentroidAlgorithm(control, centroidName, schema) 

693 

694 sdssShape = measBase.SdssShapeControl() 

695 shaper = measBase.SdssShapeAlgorithm(sdssShape, shapeName, schema) 

696 table = afwTable.SourceTable.make(schema) 

697 

698 table.defineCentroid(centroidName) 

699 table.defineShape(shapeName) 

700 

701 bbox = None 

702 if stampSize > 0: 

703 w, h = psf.computeImage(lsst.geom.PointD(0, 0)).getDimensions() 

704 if stampSize <= w and stampSize <= h: 

705 bbox = lsst.geom.BoxI(lsst.geom.PointI((w - stampSize)//2, (h - stampSize)//2), 

706 lsst.geom.ExtentI(stampSize, stampSize)) 

707 

708 centers = [] 

709 shapes = [] 

710 for iy in range(ny): 

711 for ix in range(nx): 

712 x = int(ix*(width-1)/(nx-1)) + x0 

713 y = int(iy*(height-1)/(ny-1)) + y0 

714 

715 im = psf.computeImage(lsst.geom.PointD(x, y)).convertF() 

716 imPeak = psf.computePeak(lsst.geom.PointD(x, y)) 

717 im /= imPeak 

718 if bbox: 

719 im = im.Factory(im, bbox) 

720 lab = "PSF(%d,%d)" % (x, y) if False else "" 

721 mos.append(im, lab) 

722 

723 exp = afwImage.makeExposure(afwImage.makeMaskedImage(im)) 

724 exp.setPsf(psf) 

725 w, h = im.getWidth(), im.getHeight() 

726 centerX = im.getX0() + w//2 

727 centerY = im.getY0() + h//2 

728 src = table.makeRecord() 

729 spans = afwGeom.SpanSet(exp.getBBox()) 

730 foot = afwDet.Footprint(spans) 

731 foot.addPeak(centerX, centerY, 1) 

732 src.setFootprint(foot) 

733 

734 try: 

735 centroider.measure(src, exp) 

736 centers.append((src.getX() - im.getX0(), src.getY() - im.getY0())) 

737 

738 shaper.measure(src, exp) 

739 shapes.append((src.getIxx(), src.getIxy(), src.getIyy())) 

740 except Exception: 

741 pass 

742 

743 if not display: 

744 display = afwDisplay.Display() 

745 mos.makeMosaic(display=display, title=title if title else "Model Psf", mode=nx) 

746 

747 if centers and display: 

748 with display.Buffering(): 

749 for i, (cen, shape) in enumerate(zip(centers, shapes)): 

750 bbox = mos.getBBox(i) 

751 xc, yc = cen[0] + bbox.getMinX(), cen[1] + bbox.getMinY() 

752 if showCenter: 

753 display.dot("+", xc, yc, ctype=afwDisplay.BLUE) 

754 

755 if showEllipticity: 

756 ixx, ixy, iyy = shape 

757 ixx *= scale 

758 ixy *= scale 

759 iyy *= scale 

760 display.dot("@:%g,%g,%g" % (ixx, ixy, iyy), xc, yc, ctype=afwDisplay.RED) 

761 

762 return mos 

763 

764 

765def showPsfResiduals(exposure, sourceSet, magType="psf", scale=10, display=None): 

766 mimIn = exposure.getMaskedImage() 

767 mimIn = mimIn.Factory(mimIn, True) # make a copy to subtract from 

768 

769 psf = exposure.getPsf() 

770 psfWidth, psfHeight = psf.getLocalKernel().getDimensions() 

771 # 

772 # Make the image that we'll paste our residuals into. N.b. they can overlap the edges 

773 # 

774 w, h = int(mimIn.getWidth()/scale), int(mimIn.getHeight()/scale) 

775 

776 im = mimIn.Factory(w + psfWidth, h + psfHeight) 

777 

778 cenPos = [] 

779 for s in sourceSet: 

780 x, y = s.getX(), s.getY() 

781 

782 sx, sy = int(x/scale + 0.5), int(y/scale + 0.5) 

783 

784 smim = im.Factory(im, lsst.geom.BoxI(lsst.geom.PointI(sx, sy), 

785 lsst.geom.ExtentI(psfWidth, psfHeight))) 

786 sim = smim.getImage() 

787 

788 try: 

789 if magType == "ap": 

790 flux = s.getApInstFlux() 

791 elif magType == "model": 

792 flux = s.getModelInstFlux() 

793 elif magType == "psf": 

794 flux = s.getPsfInstFlux() 

795 else: 

796 raise RuntimeError("Unknown flux type %s" % magType) 

797 

798 subtractPsf(psf, mimIn, x, y, flux) 

799 except Exception as e: 

800 print(e) 

801 

802 try: 

803 expIm = mimIn.getImage().Factory(mimIn.getImage(), 

804 lsst.geom.BoxI(lsst.geom.PointI(int(x) - psfWidth//2, 

805 int(y) - psfHeight//2), 

806 lsst.geom.ExtentI(psfWidth, psfHeight)), 

807 ) 

808 except pexExcept.Exception: 

809 continue 

810 

811 cenPos.append([x - expIm.getX0() + sx, y - expIm.getY0() + sy]) 

812 

813 sim += expIm 

814 

815 if display: 

816 display = afwDisplay.Display() 

817 display.mtv(im, title="showPsfResiduals: image") 

818 with display.Buffering(): 

819 for x, y in cenPos: 

820 display.dot("+", x, y) 

821 

822 return im 

823 

824 

825def saveSpatialCellSet(psfCellSet, fileName="foo.fits", display=None): 

826 """Write the contents of a SpatialCellSet to a many-MEF fits file""" 

827 

828 mode = "w" 

829 for cell in psfCellSet.getCellList(): 

830 for cand in cell.begin(False): # include bad candidates 

831 dx = afwImage.positionToIndex(cand.getXCenter(), True)[1] 

832 dy = afwImage.positionToIndex(cand.getYCenter(), True)[1] 

833 im = afwMath.offsetImage(cand.getMaskedImage(), -dx, -dy, "lanczos5") 

834 

835 md = dafBase.PropertySet() 

836 md.set("CELL", cell.getLabel()) 

837 md.set("ID", cand.getId()) 

838 md.set("XCENTER", cand.getXCenter()) 

839 md.set("YCENTER", cand.getYCenter()) 

840 md.set("BAD", cand.isBad()) 

841 md.set("AMPL", cand.getAmplitude()) 

842 md.set("FLUX", cand.getSource().getPsfInstFlux()) 

843 md.set("CHI2", cand.getSource().getChi2()) 

844 

845 im.writeFits(fileName, md, mode) 

846 mode = "a" 

847 

848 if display: 

849 display.mtv(im, title="saveSpatialCellSet: image")