Coverage for examples/utils.py: 0%

301 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-22 10:45 +0000

1import math 

2import os 

3import pylab as plt 

4import numpy as np 

5from matplotlib.patches import Ellipse 

6 

7import lsst.afw.math as afwMath 

8import lsst.afw.image as afwImage 

9import lsst.afw.geom as afwGeom 

10import lsst.afw.detection as afwDet 

11import lsst.afw.table as afwTable 

12 

13 

14# To use multiprocessing, we need the plot elements to be picklable. Swig objects are not 

15# picklable, so in preprocessing we pull out the items we need for plotting, putting them in 

16# a _MockSource object. 

17 

18class _MockSource: 

19 

20 def __init__(self, src, mi, psfkey, fluxkey, xkey, ykey, flagKeys, ellipses=True, 

21 maskbit=None): 

22 # flagKeys: list of (key, string) tuples 

23 self.sid = src.getId() 

24 aa = {} 

25 if maskbit is not None: 

26 aa.update(mask=True) 

27 self.im = footprintToImage(src.getFootprint(), mi, **aa).getArray() 

28 if maskbit is not None: 

29 self.im = ((self.im & maskbit) > 0) 

30 

31 self.x0 = mi.getX0() 

32 self.y0 = mi.getY0() 

33 self.ext = getExtent(src.getFootprint().getBBox()) 

34 self.ispsf = src.get(psfkey) 

35 self.psfflux = src.get(fluxkey) 

36 self.flags = [nm for key, nm in flagKeys if src.get(key)] 

37 # self.cxy = (src.get(xkey), src.get(ykey)) 

38 self.cx = src.get(xkey) 

39 self.cy = src.get(ykey) 

40 pks = src.getFootprint().getPeaks() 

41 self.pix = [pk.getIx() for pk in pks] 

42 self.piy = [pk.getIy() for pk in pks] 

43 self.pfx = [pk.getFx() for pk in pks] 

44 self.pfy = [pk.getFy() for pk in pks] 

45 if ellipses: 

46 self.ell = (src.getX(), src.getY(), src.getIxx(), src.getIyy(), src.getIxy()) 

47 # for getEllipses() 

48 

49 def getX(self): 

50 return self.ell[0] + 0.5 

51 

52 def getY(self): 

53 return self.ell[1] + 0.5 

54 

55 def getIxx(self): 

56 return self.ell[2] 

57 

58 def getIyy(self): 

59 return self.ell[3] 

60 

61 def getIxy(self): 

62 return self.ell[4] 

63 

64 

65def plotDeblendFamily(*args, **kwargs): 

66 X = plotDeblendFamilyPre(*args, **kwargs) 

67 plotDeblendFamilyReal(*X, **kwargs) 

68 

69# Preprocessing: returns _MockSources for the parent and kids 

70 

71 

72def plotDeblendFamilyPre(mi, parent, kids, dkids, srcs, sigma1, ellipses=True, maskbit=None, **kwargs): 

73 schema = srcs.getSchema() 

74 psfkey = schema.find("deblend_deblendedAsPsf").key 

75 fluxkey = schema.find('deblend_psfFlux').key 

76 xkey = schema.find('base_NaiveCentroid_x').key 

77 ykey = schema.find('base_Naivecentroid_y').key 

78 flagKeys = [(schema.find(keynm).key, nm) 

79 for nm, keynm in [('EDGE', 'base_PixelFlags_flag_edge'), 

80 ('INTERP', 'base_PixelFlags_flag_interpolated'), 

81 ('INT-C', 'base_PixelFlags_flag_interpolatedCenter'), 

82 ('SAT', 'base_PixelFlags_flag_saturated'), 

83 ('SAT-C', 'base_PixelFlags_flag_saturatedCenter'), 

84 ]] 

85 p = _MockSource(parent, mi, psfkey, fluxkey, xkey, ykey, flagKeys, ellipses=ellipses, maskbit=maskbit) 

86 ch = [_MockSource(kid, mi, psfkey, fluxkey, xkey, ykey, flagKeys, 

87 ellipses=ellipses, maskbit=maskbit) for kid in kids] 

88 dch = [_MockSource(kid, mi, psfkey, fluxkey, xkey, ykey, flagKeys, 

89 ellipses=ellipses, maskbit=maskbit) for kid in dkids] 

90 return (p, ch, dch, sigma1) 

91 

92# Real thing: make plots given the _MockSources 

93 

94 

95def plotDeblendFamilyReal(parent, kids, dkids, sigma1, plotb=False, idmask=None, ellipses=True, 

96 arcsinh=True, maskbit=None): 

97 if idmask is None: 

98 idmask = ~0 

99 pim = parent.im 

100 pext = parent.ext 

101 

102 N = 1 + len(kids) 

103 S = math.ceil(math.sqrt(N)) 

104 C = S 

105 R = math.ceil(float(N) / C) 

106 

107 def nlmap(X): 

108 return np.arcsinh(X / (3.*sigma1)) 

109 

110 def myimshow(im, **kwargs): 

111 arcsinh = kwargs.pop('arcsinh', True) 

112 if arcsinh: 

113 kwargs = kwargs.copy() 

114 mn = kwargs.get('vmin', -5*sigma1) 

115 kwargs['vmin'] = nlmap(mn) 

116 mx = kwargs.get('vmax', 100*sigma1) 

117 kwargs['vmax'] = nlmap(mx) 

118 plt.imshow(nlmap(im), **kwargs) 

119 else: 

120 plt.imshow(im, **kwargs) 

121 

122 imargs = dict(interpolation='nearest', origin='lower', 

123 vmax=pim.max(), arcsinh=arcsinh) 

124 if maskbit: 

125 imargs.update(vmin=0) 

126 

127 plt.figure(figsize=(8, 8)) 

128 plt.clf() 

129 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.9, 

130 wspace=0.05, hspace=0.1) 

131 plt.subplot(R, C, 1) 

132 myimshow(pim, extent=pext, **imargs) 

133 plt.gray() 

134 plt.xticks([]) 

135 plt.yticks([]) 

136 m = 0.25 

137 pax = [pext[0]-m, pext[1]+m, pext[2]-m, pext[3]+m] 

138 x, y = parent.pix[0], parent.piy[0] 

139 tt = 'parent %i @ (%i,%i)' % (parent.sid & idmask, 

140 x - parent.x0, y - parent.y0) 

141 if len(parent.flags): 

142 tt += ', ' + ', '.join(parent.flags) 

143 plt.title(tt) 

144 Rx, Ry = [], [] 

145 tts = [] 

146 stys = [] 

147 xys = [] 

148 for i, kid in enumerate(kids): 

149 ext = kid.ext 

150 plt.subplot(R, C, i+2) 

151 if plotb: 

152 ima = imargs.copy() 

153 ima.update(vmax=max(3.*sigma1, kid.im.max())) 

154 else: 

155 ima = imargs 

156 

157 myimshow(kid.im, extent=ext, **ima) 

158 plt.gray() 

159 plt.xticks([]) 

160 plt.yticks([]) 

161 tt = 'child %i' % (kid.sid & idmask) 

162 if kid.ispsf: 

163 sty1 = dict(color='g') 

164 sty2 = dict(color=(0.1, 0.5, 0.1), lw=2, alpha=0.5) 

165 tt += ' (psf: flux %.1f)' % kid.psfflux 

166 else: 

167 sty1 = dict(color='r') 

168 sty2 = dict(color=(0.8, 0.1, 0.1), lw=2, alpha=0.5) 

169 

170 if len(kid.flags): 

171 tt += ', ' + ', '.join(kid.flags) 

172 

173 tts.append(tt) 

174 stys.append(sty1) 

175 plt.title(tt) 

176 # bounding box 

177 xx = [ext[0], ext[1], ext[1], ext[0], ext[0]] 

178 yy = [ext[2], ext[2], ext[3], ext[3], ext[2]] 

179 plt.plot(xx, yy, '-', **sty1) 

180 Rx.append(xx) 

181 Ry.append(yy) 

182 # peak(s) 

183 plt.plot(kid.pfx, kid.pfy, 'x', **sty2) 

184 xys.append((kid.pfx, kid.pfy, sty2)) 

185 # centroid 

186 plt.plot([kid.cx], [kid.cy], 'x', **sty1) 

187 xys.append(([kid.cx], [kid.cy], sty1)) 

188 # ellipse 

189 if ellipses and not kid.ispsf: 

190 drawEllipses(kid, ec=sty1['color'], fc='none', alpha=0.7) 

191 if plotb: 

192 plt.axis(ext) 

193 else: 

194 plt.axis(pax) 

195 

196 # Go back to the parent plot and add child bboxes 

197 plt.subplot(R, C, 1) 

198 for rx, ry, sty in zip(Rx, Ry, stys): 

199 plt.plot(rx, ry, '-', **sty) 

200 # add child centers and ellipses... 

201 for x, y, sty in xys: 

202 plt.plot(x, y, 'x', **sty) 

203 if ellipses: 

204 for kid, sty in zip(kids, stys): 

205 if kid.ispsf: 

206 continue 

207 drawEllipses(kid, ec=sty['color'], fc='none', alpha=0.7) 

208 plt.plot([parent.cx], [parent.cy], 'x', color='b') 

209 if ellipses: 

210 drawEllipses(parent, ec='b', fc='none', alpha=0.7) 

211 

212 # Plot dropped kids 

213 for kid in dkids: 

214 ext = kid.ext 

215 # bounding box 

216 xx = [ext[0], ext[1], ext[1], ext[0], ext[0]] 

217 yy = [ext[2], ext[2], ext[3], ext[3], ext[2]] 

218 plt.plot(xx, yy, 'y-') 

219 # peak(s) 

220 plt.plot(kid.pfx, kid.pfy, 'yx') 

221 plt.axis(pax) 

222 

223 

224def footprintToImage(fp, mi=None, mask=False): 

225 if not fp.isHeavy(): 

226 fp = afwDet.makeHeavyFootprint(fp, mi) 

227 bb = fp.getBBox() 

228 if mask: 

229 im = afwImage.MaskedImageF(bb.getWidth(), bb.getHeight()) 

230 else: 

231 im = afwImage.ImageF(bb.getWidth(), bb.getHeight()) 

232 im.setXY0(bb.getMinX(), bb.getMinY()) 

233 fp.insert(im) 

234 if mask: 

235 im = im.getMask() 

236 return im 

237 

238 

239def getFamilies(cat): 

240 ''' 

241 Returns [ (parent0, children0), (parent1, children1), ...] 

242 ''' 

243 # parent -> [children] map. 

244 children = {} 

245 for src in cat: 

246 pid = src.getParent() 

247 if not pid: 

248 continue 

249 if pid in children: 

250 children[pid].append(src) 

251 else: 

252 children[pid] = [src] 

253 keys = sorted(children.keys()) 

254 return [(cat.find(pid), children[pid]) for pid in keys] 

255 

256 

257def getExtent(bb, addHigh=1): 

258 # so verbose... 

259 return (bb.getMinX(), bb.getMaxX()+addHigh, bb.getMinY(), bb.getMaxY()+addHigh) 

260 

261 

262def cutCatalog(cat, ndeblends, keepids=None, keepxys=None): 

263 fams = getFamilies(cat) 

264 if keepids: 

265 # print 'Keeping ids:', keepids 

266 # print 'parent ids:', [p.getId() for p,kids in fams] 

267 fams = [(p, kids) for (p, kids) in fams if p.getId() in keepids] 

268 if keepxys: 

269 keep = [] 

270 pts = [afwGeom.Point2I(x, y) for x, y in keepxys] 

271 for p, kids in fams: 

272 for pt in pts: 

273 if p.getFootprint().contains(pt): 

274 keep.append((p, kids)) 

275 break 

276 fams = keep 

277 

278 if ndeblends: 

279 # We want to select the first "ndeblends" parents and all their children. 

280 fams = fams[:ndeblends] 

281 

282 keepcat = afwTable.SourceCatalog(cat.getTable()) 

283 for p, kids in fams: 

284 keepcat.append(p) 

285 for k in kids: 

286 keepcat.append(k) 

287 keepcat.sort() 

288 return keepcat 

289 

290 

291def readCatalog(sourcefn, heavypat, ndeblends=0, dataref=None, 

292 keepids=None, keepxys=None, 

293 patargs=dict()): 

294 if sourcefn is None: 

295 cat = dataref.get('src') 

296 try: 

297 if not cat: 

298 return None 

299 except Exception: 

300 return None 

301 else: 

302 if not os.path.exists(sourcefn): 

303 print('No source catalog:', sourcefn) 

304 return None 

305 print('Reading catalog:', sourcefn) 

306 cat = afwTable.SourceCatalog.readFits(sourcefn) 

307 print(len(cat), 'sources') 

308 cat.sort() 

309 cat.defineCentroid('base_SdssCentroid') 

310 

311 if ndeblends or keepids or keepxys: 

312 cat = cutCatalog(cat, ndeblends, keepids, keepxys) 

313 print('Cut to', len(cat), 'sources') 

314 

315 if heavypat is not None: 

316 print('Reading heavyFootprints...') 

317 for src in cat: 

318 if not src.getParent(): 

319 continue 

320 dd = patargs.copy() 

321 dd.update(id=src.getId()) 

322 heavyfn = heavypat % dd 

323 if not os.path.exists(heavyfn): 

324 print('No heavy footprint:', heavyfn) 

325 return None 

326 mim = afwImage.MaskedImageF(heavyfn) 

327 heavy = afwDet.makeHeavyFootprint(src.getFootprint(), mim) 

328 src.setFootprint(heavy) 

329 return cat 

330 

331 

332def datarefToMapper(dr): 

333 return dr.butlerSubset.butler.mapper 

334 

335 

336def datarefToButler(dr): 

337 return dr.butlerSubset.butler 

338 

339 

340class WrapperMapper: 

341 

342 def __init__(self, real): 

343 self.real = real 

344 for x in dir(real): 

345 if not x.startswith('bypass_'): 

346 continue 

347 

348 class RelayBypass: 

349 

350 def __init__(self, real, attr): 

351 self.func = getattr(real, attr) 

352 self.attr = attr 

353 

354 def __call__(self, *args): 

355 # print('relaying', self.attr) 

356 # print('to', self.func) 

357 return self.func(*args) 

358 setattr(self, x, RelayBypass(self.real, x)) 

359 # print('Wrapping', x) 

360 

361 def map(self, *args, **kwargs): 

362 print('Mapping', args, kwargs) 

363 R = self.real.map(*args, **kwargs) 

364 print('->', R) 

365 return R 

366 # relay 

367 

368 def isAggregate(self, *args): 

369 return self.real.isAggregate(*args) 

370 

371 def getKeys(self, *args): 

372 return self.real.getKeys(*args) 

373 

374 def getDatasetTypes(self): 

375 return self.real.getDatasetTypes() 

376 

377 def queryMetadata(self, *args): 

378 return self.real.queryMetadata(*args) 

379 

380 def canStandardize(self, *args): 

381 return self.real.canStandardize(*args) 

382 

383 def standardize(self, *args): 

384 return self.real.standardize(*args) 

385 

386 def validate(self, *args): 

387 return self.real.validate(*args) 

388 

389 def getDefaultLevel(self, *args): 

390 return self.real.getDefaultLevel(*args) 

391 

392 

393def getEllipses(src, nsigs=[1.], **kwargs): 

394 xc = src.getX() 

395 yc = src.getY() 

396 x2 = src.getIxx() 

397 y2 = src.getIyy() 

398 xy = src.getIxy() 

399 # SExtractor manual v2.5, pg 29. 

400 a2 = (x2 + y2)/2. + np.sqrt(((x2 - y2)/2.)**2 + xy**2) 

401 b2 = (x2 + y2)/2. - np.sqrt(((x2 - y2)/2.)**2 + xy**2) 

402 theta = np.rad2deg(np.arctan2(2.*xy, (x2 - y2)) / 2.) 

403 a = np.sqrt(a2) 

404 b = np.sqrt(b2) 

405 ells = [] 

406 for nsig in nsigs: 

407 ells.append(Ellipse([xc, yc], 2.*a*nsig, 2.*b*nsig, angle=theta, **kwargs)) 

408 return ells 

409 

410 

411def drawEllipses(src, **kwargs): 

412 els = getEllipses(src, **kwargs) 

413 for el in els: 

414 plt.gca().add_artist(el) 

415 return els 

416 

417 

418def get_sigma1(mi): 

419 stats = afwMath.makeStatistics(mi.getVariance(), mi.getMask(), afwMath.MEDIAN) 

420 sigma1 = math.sqrt(stats.getValue(afwMath.MEDIAN)) 

421 return sigma1