Coverage for examples/utils.py: 0%
301 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-09 03:39 -0700
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-09 03:39 -0700
1import math
2import os
3import pylab as plt
4import numpy as np
5from matplotlib.patches import Ellipse
7import lsst.afw.math as afwMath
8import lsst.afw.image as afwImage
9import lsst.afw.geom as afwGeom
10import lsst.afw.detection as afwDet
11import lsst.afw.table as afwTable
14# To use multiprocessing, we need the plot elements to be picklable. Swig objects are not
15# picklable, so in preprocessing we pull out the items we need for plotting, putting them in
16# a _MockSource object.
18class _MockSource:
20 def __init__(self, src, mi, psfkey, fluxkey, xkey, ykey, flagKeys, ellipses=True,
21 maskbit=None):
22 # flagKeys: list of (key, string) tuples
23 self.sid = src.getId()
24 aa = {}
25 if maskbit is not None:
26 aa.update(mask=True)
27 self.im = footprintToImage(src.getFootprint(), mi, **aa).getArray()
28 if maskbit is not None:
29 self.im = ((self.im & maskbit) > 0)
31 self.x0 = mi.getX0()
32 self.y0 = mi.getY0()
33 self.ext = getExtent(src.getFootprint().getBBox())
34 self.ispsf = src.get(psfkey)
35 self.psfflux = src.get(fluxkey)
36 self.flags = [nm for key, nm in flagKeys if src.get(key)]
37 # self.cxy = (src.get(xkey), src.get(ykey))
38 self.cx = src.get(xkey)
39 self.cy = src.get(ykey)
40 pks = src.getFootprint().getPeaks()
41 self.pix = [pk.getIx() for pk in pks]
42 self.piy = [pk.getIy() for pk in pks]
43 self.pfx = [pk.getFx() for pk in pks]
44 self.pfy = [pk.getFy() for pk in pks]
45 if ellipses:
46 self.ell = (src.getX(), src.getY(), src.getIxx(), src.getIyy(), src.getIxy())
47 # for getEllipses()
49 def getX(self):
50 return self.ell[0] + 0.5
52 def getY(self):
53 return self.ell[1] + 0.5
55 def getIxx(self):
56 return self.ell[2]
58 def getIyy(self):
59 return self.ell[3]
61 def getIxy(self):
62 return self.ell[4]
65def plotDeblendFamily(*args, **kwargs):
66 X = plotDeblendFamilyPre(*args, **kwargs)
67 plotDeblendFamilyReal(*X, **kwargs)
69# Preprocessing: returns _MockSources for the parent and kids
72def plotDeblendFamilyPre(mi, parent, kids, dkids, srcs, sigma1, ellipses=True, maskbit=None, **kwargs):
73 schema = srcs.getSchema()
74 psfkey = schema.find("deblend_deblendedAsPsf").key
75 fluxkey = schema.find('deblend_psfFlux').key
76 xkey = schema.find('base_NaiveCentroid_x').key
77 ykey = schema.find('base_Naivecentroid_y').key
78 flagKeys = [(schema.find(keynm).key, nm)
79 for nm, keynm in [('EDGE', 'base_PixelFlags_flag_edge'),
80 ('INTERP', 'base_PixelFlags_flag_interpolated'),
81 ('INT-C', 'base_PixelFlags_flag_interpolatedCenter'),
82 ('SAT', 'base_PixelFlags_flag_saturated'),
83 ('SAT-C', 'base_PixelFlags_flag_saturatedCenter'),
84 ]]
85 p = _MockSource(parent, mi, psfkey, fluxkey, xkey, ykey, flagKeys, ellipses=ellipses, maskbit=maskbit)
86 ch = [_MockSource(kid, mi, psfkey, fluxkey, xkey, ykey, flagKeys,
87 ellipses=ellipses, maskbit=maskbit) for kid in kids]
88 dch = [_MockSource(kid, mi, psfkey, fluxkey, xkey, ykey, flagKeys,
89 ellipses=ellipses, maskbit=maskbit) for kid in dkids]
90 return (p, ch, dch, sigma1)
92# Real thing: make plots given the _MockSources
95def plotDeblendFamilyReal(parent, kids, dkids, sigma1, plotb=False, idmask=None, ellipses=True,
96 arcsinh=True, maskbit=None):
97 if idmask is None:
98 idmask = ~0
99 pim = parent.im
100 pext = parent.ext
102 N = 1 + len(kids)
103 S = math.ceil(math.sqrt(N))
104 C = S
105 R = math.ceil(float(N) / C)
107 def nlmap(X):
108 return np.arcsinh(X / (3.*sigma1))
110 def myimshow(im, **kwargs):
111 arcsinh = kwargs.pop('arcsinh', True)
112 if arcsinh:
113 kwargs = kwargs.copy()
114 mn = kwargs.get('vmin', -5*sigma1)
115 kwargs['vmin'] = nlmap(mn)
116 mx = kwargs.get('vmax', 100*sigma1)
117 kwargs['vmax'] = nlmap(mx)
118 plt.imshow(nlmap(im), **kwargs)
119 else:
120 plt.imshow(im, **kwargs)
122 imargs = dict(interpolation='nearest', origin='lower',
123 vmax=pim.max(), arcsinh=arcsinh)
124 if maskbit:
125 imargs.update(vmin=0)
127 plt.figure(figsize=(8, 8))
128 plt.clf()
129 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.9,
130 wspace=0.05, hspace=0.1)
131 plt.subplot(R, C, 1)
132 myimshow(pim, extent=pext, **imargs)
133 plt.gray()
134 plt.xticks([])
135 plt.yticks([])
136 m = 0.25
137 pax = [pext[0]-m, pext[1]+m, pext[2]-m, pext[3]+m]
138 x, y = parent.pix[0], parent.piy[0]
139 tt = 'parent %i @ (%i,%i)' % (parent.sid & idmask,
140 x - parent.x0, y - parent.y0)
141 if len(parent.flags):
142 tt += ', ' + ', '.join(parent.flags)
143 plt.title(tt)
144 Rx, Ry = [], []
145 tts = []
146 stys = []
147 xys = []
148 for i, kid in enumerate(kids):
149 ext = kid.ext
150 plt.subplot(R, C, i+2)
151 if plotb:
152 ima = imargs.copy()
153 ima.update(vmax=max(3.*sigma1, kid.im.max()))
154 else:
155 ima = imargs
157 myimshow(kid.im, extent=ext, **ima)
158 plt.gray()
159 plt.xticks([])
160 plt.yticks([])
161 tt = 'child %i' % (kid.sid & idmask)
162 if kid.ispsf:
163 sty1 = dict(color='g')
164 sty2 = dict(color=(0.1, 0.5, 0.1), lw=2, alpha=0.5)
165 tt += ' (psf: flux %.1f)' % kid.psfflux
166 else:
167 sty1 = dict(color='r')
168 sty2 = dict(color=(0.8, 0.1, 0.1), lw=2, alpha=0.5)
170 if len(kid.flags):
171 tt += ', ' + ', '.join(kid.flags)
173 tts.append(tt)
174 stys.append(sty1)
175 plt.title(tt)
176 # bounding box
177 xx = [ext[0], ext[1], ext[1], ext[0], ext[0]]
178 yy = [ext[2], ext[2], ext[3], ext[3], ext[2]]
179 plt.plot(xx, yy, '-', **sty1)
180 Rx.append(xx)
181 Ry.append(yy)
182 # peak(s)
183 plt.plot(kid.pfx, kid.pfy, 'x', **sty2)
184 xys.append((kid.pfx, kid.pfy, sty2))
185 # centroid
186 plt.plot([kid.cx], [kid.cy], 'x', **sty1)
187 xys.append(([kid.cx], [kid.cy], sty1))
188 # ellipse
189 if ellipses and not kid.ispsf:
190 drawEllipses(kid, ec=sty1['color'], fc='none', alpha=0.7)
191 if plotb:
192 plt.axis(ext)
193 else:
194 plt.axis(pax)
196 # Go back to the parent plot and add child bboxes
197 plt.subplot(R, C, 1)
198 for rx, ry, sty in zip(Rx, Ry, stys):
199 plt.plot(rx, ry, '-', **sty)
200 # add child centers and ellipses...
201 for x, y, sty in xys:
202 plt.plot(x, y, 'x', **sty)
203 if ellipses:
204 for kid, sty in zip(kids, stys):
205 if kid.ispsf:
206 continue
207 drawEllipses(kid, ec=sty['color'], fc='none', alpha=0.7)
208 plt.plot([parent.cx], [parent.cy], 'x', color='b')
209 if ellipses:
210 drawEllipses(parent, ec='b', fc='none', alpha=0.7)
212 # Plot dropped kids
213 for kid in dkids:
214 ext = kid.ext
215 # bounding box
216 xx = [ext[0], ext[1], ext[1], ext[0], ext[0]]
217 yy = [ext[2], ext[2], ext[3], ext[3], ext[2]]
218 plt.plot(xx, yy, 'y-')
219 # peak(s)
220 plt.plot(kid.pfx, kid.pfy, 'yx')
221 plt.axis(pax)
224def footprintToImage(fp, mi=None, mask=False):
225 if not fp.isHeavy():
226 fp = afwDet.makeHeavyFootprint(fp, mi)
227 bb = fp.getBBox()
228 if mask:
229 im = afwImage.MaskedImageF(bb.getWidth(), bb.getHeight())
230 else:
231 im = afwImage.ImageF(bb.getWidth(), bb.getHeight())
232 im.setXY0(bb.getMinX(), bb.getMinY())
233 fp.insert(im)
234 if mask:
235 im = im.getMask()
236 return im
239def getFamilies(cat):
240 '''
241 Returns [ (parent0, children0), (parent1, children1), ...]
242 '''
243 # parent -> [children] map.
244 children = {}
245 for src in cat:
246 pid = src.getParent()
247 if not pid:
248 continue
249 if pid in children:
250 children[pid].append(src)
251 else:
252 children[pid] = [src]
253 keys = sorted(children.keys())
254 return [(cat.find(pid), children[pid]) for pid in keys]
257def getExtent(bb, addHigh=1):
258 # so verbose...
259 return (bb.getMinX(), bb.getMaxX()+addHigh, bb.getMinY(), bb.getMaxY()+addHigh)
262def cutCatalog(cat, ndeblends, keepids=None, keepxys=None):
263 fams = getFamilies(cat)
264 if keepids:
265 # print 'Keeping ids:', keepids
266 # print 'parent ids:', [p.getId() for p,kids in fams]
267 fams = [(p, kids) for (p, kids) in fams if p.getId() in keepids]
268 if keepxys:
269 keep = []
270 pts = [afwGeom.Point2I(x, y) for x, y in keepxys]
271 for p, kids in fams:
272 for pt in pts:
273 if p.getFootprint().contains(pt):
274 keep.append((p, kids))
275 break
276 fams = keep
278 if ndeblends:
279 # We want to select the first "ndeblends" parents and all their children.
280 fams = fams[:ndeblends]
282 keepcat = afwTable.SourceCatalog(cat.getTable())
283 for p, kids in fams:
284 keepcat.append(p)
285 for k in kids:
286 keepcat.append(k)
287 keepcat.sort()
288 return keepcat
291def readCatalog(sourcefn, heavypat, ndeblends=0, dataref=None,
292 keepids=None, keepxys=None,
293 patargs=dict()):
294 if sourcefn is None:
295 cat = dataref.get('src')
296 try:
297 if not cat:
298 return None
299 except Exception:
300 return None
301 else:
302 if not os.path.exists(sourcefn):
303 print('No source catalog:', sourcefn)
304 return None
305 print('Reading catalog:', sourcefn)
306 cat = afwTable.SourceCatalog.readFits(sourcefn)
307 print(len(cat), 'sources')
308 cat.sort()
309 cat.defineCentroid('base_SdssCentroid')
311 if ndeblends or keepids or keepxys:
312 cat = cutCatalog(cat, ndeblends, keepids, keepxys)
313 print('Cut to', len(cat), 'sources')
315 if heavypat is not None:
316 print('Reading heavyFootprints...')
317 for src in cat:
318 if not src.getParent():
319 continue
320 dd = patargs.copy()
321 dd.update(id=src.getId())
322 heavyfn = heavypat % dd
323 if not os.path.exists(heavyfn):
324 print('No heavy footprint:', heavyfn)
325 return None
326 mim = afwImage.MaskedImageF(heavyfn)
327 heavy = afwDet.makeHeavyFootprint(src.getFootprint(), mim)
328 src.setFootprint(heavy)
329 return cat
332def datarefToMapper(dr):
333 return dr.butlerSubset.butler.mapper
336def datarefToButler(dr):
337 return dr.butlerSubset.butler
340class WrapperMapper:
342 def __init__(self, real):
343 self.real = real
344 for x in dir(real):
345 if not x.startswith('bypass_'):
346 continue
348 class RelayBypass:
350 def __init__(self, real, attr):
351 self.func = getattr(real, attr)
352 self.attr = attr
354 def __call__(self, *args):
355 # print('relaying', self.attr)
356 # print('to', self.func)
357 return self.func(*args)
358 setattr(self, x, RelayBypass(self.real, x))
359 # print('Wrapping', x)
361 def map(self, *args, **kwargs):
362 print('Mapping', args, kwargs)
363 R = self.real.map(*args, **kwargs)
364 print('->', R)
365 return R
366 # relay
368 def isAggregate(self, *args):
369 return self.real.isAggregate(*args)
371 def getKeys(self, *args):
372 return self.real.getKeys(*args)
374 def getDatasetTypes(self):
375 return self.real.getDatasetTypes()
377 def queryMetadata(self, *args):
378 return self.real.queryMetadata(*args)
380 def canStandardize(self, *args):
381 return self.real.canStandardize(*args)
383 def standardize(self, *args):
384 return self.real.standardize(*args)
386 def validate(self, *args):
387 return self.real.validate(*args)
389 def getDefaultLevel(self, *args):
390 return self.real.getDefaultLevel(*args)
393def getEllipses(src, nsigs=[1.], **kwargs):
394 xc = src.getX()
395 yc = src.getY()
396 x2 = src.getIxx()
397 y2 = src.getIyy()
398 xy = src.getIxy()
399 # SExtractor manual v2.5, pg 29.
400 a2 = (x2 + y2)/2. + np.sqrt(((x2 - y2)/2.)**2 + xy**2)
401 b2 = (x2 + y2)/2. - np.sqrt(((x2 - y2)/2.)**2 + xy**2)
402 theta = np.rad2deg(np.arctan2(2.*xy, (x2 - y2)) / 2.)
403 a = np.sqrt(a2)
404 b = np.sqrt(b2)
405 ells = []
406 for nsig in nsigs:
407 ells.append(Ellipse([xc, yc], 2.*a*nsig, 2.*b*nsig, angle=theta, **kwargs))
408 return ells
411def drawEllipses(src, **kwargs):
412 els = getEllipses(src, **kwargs)
413 for el in els:
414 plt.gca().add_artist(el)
415 return els
418def get_sigma1(mi):
419 stats = afwMath.makeStatistics(mi.getVariance(), mi.getMask(), afwMath.MEDIAN)
420 sigma1 = math.sqrt(stats.getValue(afwMath.MEDIAN))
421 return sigma1