Coverage for python / lsst / obs / subaru / crosstalk.py: 0%
336 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 09:23 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 09:23 +0000
1#
2# LSST Data Management System
3# Copyright 2008-2016 AURA/LSST.
4#
5# This product includes software developed by the
6# LSST Project (http://www.lsst.org/).
7#
8# This program is free software: you can redistribute it and/or modify
9# it under the terms of the GNU General Public License as published by
10# the Free Software Foundation, either version 3 of the License, or
11# (at your option) any later version.
12#
13# This program is distributed in the hope that it will be useful,
14# but WITHOUT ANY WARRANTY; without even the implied warranty of
15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16# GNU General Public License for more details.
17#
18# You should have received a copy of the LSST License Statement and
19# the GNU General Public License along with this program. If not,
20# see <https://www.lsstcorp.org/LegalNotices/>.
21#
22"""
23Determine and apply crosstalk corrections
25N.b. This code was written and tested for the 4-amplifier Hamamatsu chips used
26in (Hyper)?SuprimeCam, and will need to be generalised to handle other
27amplifier layouts. I don't want to do this until we have an example.
29N.b. To estimate crosstalk from the SuprimeCam data, the commands are e.g.:
31.. code-block:: python
33 import crosstalk
34 coeffs, coeffsErr = crosstalk.estimateCoeffs(butler, range(131634, 131642),
35 range(10), threshold=1e5,
36 plot=True, title="CCD0..9",
37 fig=1)
38 crosstalk.fixCcd(butler, 131634, 0, coeffs)
39"""
40import sys
41import math
42import time
44import numpy as np
46import lsst.afw.detection as afwDetect
47import lsst.afw.image as afwImage
48import lsst.afw.math as afwMath
49import lsst.geom as geom
50import lsst.pex.config as pexConfig
51import lsst.afw.display as afwDisplay
52from functools import reduce
53from lsst.ip.isr.crosstalk import CrosstalkTask
56class CrosstalkCoeffsConfig(pexConfig.Config):
57 """Specify crosstalk coefficients for a CCD"""
59 values = pexConfig.ListField(
60 dtype=float,
61 doc="Crosstalk coefficients",
62 default=[0, 0, 0, 0,
63 0, 0, 0, 0,
64 0, 0, 0, 0,
65 0, 0, 0, 0],
66 )
67 shape = pexConfig.ListField(
68 dtype=int,
69 doc="Shape of coeffs array",
70 default=[4, 4],
71 minLength=1, # really 2, but there's a bug in pex_config
72 maxLength=2,
73 )
75 def getCoeffs(self):
76 """Return a 2-D numpy array of crosstalk coefficients of the proper
77 shape"""
78 return np.array(self.values).reshape(self.shape)
81class SubaruCrosstalkConfig(CrosstalkTask.ConfigClass):
82 minPixelToMask = pexConfig.Field(dtype=float, default=45000,
83 doc="Set crosstalk mask plane for pixels over this value")
84 crosstalkMaskPlane = pexConfig.Field(dtype=str, default="CROSSTALK", doc="Name for crosstalk mask plane")
85 coeffs = pexConfig.ConfigField(dtype=CrosstalkCoeffsConfig, doc="Crosstalk coefficients")
88class SubaruCrosstalkTask(CrosstalkTask):
89 ConfigClass = SubaruCrosstalkConfig
91 def run(self, exp, **kwargs):
92 self.log.info("Applying crosstalk correction/Subaru")
93 subtractXTalk(exp.getMaskedImage(), self.config.coeffs.getCoeffs(), self.config.minPixelToMask,
94 self.config.crosstalkMaskPlane)
97nAmp = 4
100def getXPos(width, hwidth, x):
101 """Return the amp that x is in, and the positions of its image in each
102 amplifier"""
103 amp = x//(hwidth//2) # which amp am I in? Assumes nAmp == 4
104 assert nAmp == 4
105 assert amp in range(nAmp)
107 if amp == 0:
108 xa = x # distance to amp
109 xs = hwidth - x - 1 # symmetrical position within this half of the chip
110 xx = (x, xs, hwidth + xa, hwidth + xs)
111 elif amp == 1:
112 xa = hwidth - x - 1 # distance to amp
113 xs = hwidth - x # symmetrical position within this half of the chip
114 xx = (xs - 1, x, hwidth + xs - 1, hwidth + x)
115 elif amp == 2:
116 xa = x - hwidth # distance to amp
117 xs = width - x # symmetrical position within this half of the chip
118 xx = (xa, width - x - 1, x, width - xa - 1)
119 elif amp == 3:
120 xa = x - hwidth # distance to amp
121 xs = width - x # symmetrical position within this half of the chip
122 xx = (width - x - 1, xa, width - xa - 1, x)
124 return amp, xx
127def getAmplitudeRatios(mi, threshold=45000, bkgd=None, rats=None):
128 img = mi.getImage()
129 msk = mi.getMask()
130 width = mi.getWidth()
131 hwidth = width//2
133 if rats is None:
134 rats = []
135 for i in range(nAmp):
136 rats.append([])
137 for j in range(nAmp):
138 rats[i].append([])
139 rats[i][i].append(0)
141 fs = afwDetect.FootprintSet(mi, afwDetect.Threshold(threshold), "DETECTED")
143 if bkgd is None:
144 sctrl = afwMath.StatisticsControl()
145 sctrl.setAndMask(msk.getPlaneBitMask("DETECTED"))
147 bkgd = afwMath.makeStatistics(mi, afwMath.MEDIAN, sctrl).getValue()
149 badMask = msk.getPlaneBitMask(["BAD", "EDGE", "SAT", "INTRP"])
151 for foot in fs.getFootprints():
152 for s in foot.getSpans():
153 y, x0, x1 = s.getY(), s.getX0(), s.getX1()
154 for x in range(x0, x1):
155 val = img.get(x, y)
156 amp, xx = getXPos(width, hwidth, x)
157 for a, _x in enumerate(xx):
158 if a != amp:
159 if msk.get(_x, y) & badMask:
160 continue
161 if False:
162 foo = (img.get(_x, y) - bkgd)/val
163 if np.abs(foo) < 1e-5:
164 print(img.get(_x, y) - bkgd, val)
165 mi.getMask().set(_x, y, 0x100)
167 rats[amp][a].append((img.get(_x, y) - bkgd)/val)
169 return rats
172def calculateCoeffs(rats, nsigma, plot=False, fig=None, title=None):
173 """Calculate cross-talk coefficients"""
174 coeffs = np.empty((nAmp, nAmp))
175 coeffsErr = np.empty_like(coeffs)
177 if plot:
178 if fig is None:
179 fig = int(title[-1]) + 1 if title else 1
181 fig = getMpFigure(fig, clear=True)
182 subplots = makeSubplots(fig, nAmp, nAmp)
184 rMin = 2e-3
185 bins = np.arange(-rMin, rMin, 0.05*rMin)
187 xMajorLocator = ticker.MaxNLocator(nbins=3) # steps=(-rMin/2, 0, rMin/2))
189 for ain in range(nAmp):
190 for aout in range(nAmp):
191 tmp = np.array(rats[ain][aout])
192 tmp.sort()
193 for i in range(3):
194 n = len(tmp)
195 med = tmp[int(0.5*n)]
196 sigma = 0.741*(tmp[int(0.75*n)] - tmp[int(0.25*n)])
197 w = np.where(abs(tmp - med) < nsigma*sigma)
198 if not np.any(w):
199 break
200 tmp = tmp[w]
202 coeffs[ain][aout] = tmp[len(tmp)//2]
204 err = 0.741*(tmp[3*len(tmp)//4] - tmp[len(tmp)//4]) # estimate s.d. from IQR
205 err /= math.sqrt(len(tmp)) # standard error of mean
206 err *= math.sqrt(math.pi/2) # standard error of median
207 coeffsErr[ain][aout] = err
209 if plot:
210 axes = next(subplots)
211 axes.xaxis.set_major_locator(xMajorLocator)
213 if ain != aout:
214 hist = np.histogram(rats[ain][aout], bins)[0]
215 axes.bar(bins[0:-1], hist, width=bins[1]-bins[0], color="red", linewidth=0, alpha=0.8)
217 axes.axvline(0, linestyle="--", color="green")
218 axes.axvline(coeffs[ain][aout], linestyle="-", color="blue")
219 for i in (-1, 1):
220 axes.axvline(coeffs[ain][aout] + i*coeffsErr[ain][aout], linestyle=":", color="cyan")
221 axes.text(-0.9*rMin, 0.8*axes.get_ylim()[1], r"%.1e" % coeffs[ain][aout], fontsize="smaller")
222 axes.set_xlim(-1.05*rMin, 1.05*rMin)
224 if plot:
225 if title:
226 fig.suptitle(title)
227 fig.show()
229 return coeffs, coeffsErr
232def subtractXTalk(mi, coeffs, minPixelToMask=45000, crosstalkStr="CROSSTALK"):
233 """Subtract the crosstalk from MaskedImage mi given a set of coefficients
235 The pixels affected by signal over minPixelToMask have the crosstalkStr
236 bit set
237 """
238 sctrl = afwMath.StatisticsControl()
239 sctrl.setAndMask(mi.getMask().getPlaneBitMask("BAD"))
240 bkgd = afwMath.makeStatistics(mi, afwMath.MEDIAN, sctrl).getValue()
241 #
242 # These are the pixels that are bright enough to cause crosstalk (more
243 # precisely, the ones that we label as causing crosstalk; in reality all
244 # pixels cause crosstalk)
245 #
246 tempStr = "TEMP" # mask plane used to record the bright pixels that we need to mask
247 msk = mi.getMask()
248 msk.addMaskPlane(tempStr)
249 try:
250 fs = afwDetect.FootprintSet(mi, afwDetect.Threshold(minPixelToMask), tempStr)
252 mi.getMask().addMaskPlane(crosstalkStr)
253 afwDisplay.getDisplay().setMaskPlaneColor(crosstalkStr, afwDisplay.MAGENTA)
254 # the crosstalkStr bit will now be set whenever we subtract crosstalk
255 fs.setMask(mi.getMask(), crosstalkStr)
256 crosstalk = mi.getMask().getPlaneBitMask(crosstalkStr)
258 width, height = mi.getDimensions()
259 for i in range(nAmp):
260 bbox = geom.BoxI(geom.PointI(i*(width//nAmp), 0), geom.ExtentI(width//nAmp, height))
261 ampI = mi.Factory(mi, bbox)
262 for j in range(nAmp):
263 if i == j:
264 continue
266 bbox = geom.BoxI(geom.PointI(j*(width//nAmp), 0), geom.ExtentI(width//nAmp, height))
267 if (i + j)%2 == 1:
268 ampJ = afwMath.flipImage(mi.Factory(mi, bbox), True, False) # no need for a deep copy
269 else:
270 ampJ = mi.Factory(mi, bbox, afwImage.LOCAL, True)
272 msk = ampJ.getMask()
273 if np.all(msk.getArray() & msk.getPlaneBitMask("BAD")):
274 # Bad amplifier; ignore it completely --- its effect will
275 # come out in the bias
276 continue
277 msk &= crosstalk
279 ampJ -= bkgd
280 ampJ *= coeffs[j][i]
282 ampI -= ampJ
283 #
284 # Clear the crosstalkStr bit in the original bright pixels, where
285 # tempStr is set
286 #
287 msk = mi.getMask()
288 temp = msk.getPlaneBitMask(tempStr)
289 xtalk_temp = crosstalk | temp
290 np_msk = msk.getArray()
291 mask_indicies = np.where(np.bitwise_and(np_msk, xtalk_temp) == xtalk_temp)
292 np_msk[mask_indicies] &= getattr(np, np_msk.dtype.name)(~crosstalk)
294 finally:
295 msk.removeAndClearMaskPlane(tempStr, True) # added in afw #1853
298def printCoeffs(coeffs, coeffsErr=None, LaTeX=False, ppm=False):
299 """Print cross-talk coefficients"""
301 if LaTeX:
302 print(r"""\begin{tabular}{l|*{4}{l}}
303ampIn & \multicolumn{4}{c}{ampOut} \\
304 & 0 & 1 & 2 & 3 \\
305\hline""")
306 for ain in range(nAmp):
307 msg = "%-4d " % ain
308 for aout in range(nAmp):
309 if ppm:
310 msg += "& " if ain == aout else "& %5.0f " % (1e6*coeffs[ain][aout])
311 else:
312 msg += "& %9.2e " % (coeffs[ain][aout])
313 if coeffsErr is not None:
314 if ppm:
315 if ain != aout:
316 val = int(1e6*coeffsErr[ain][aout] + 0.5)
317 msg += r"$\pm$ %s%d " % (r"$\phantom{0}$" if val < 10 else "", val)
318 else:
319 msg += r"$\pm$ %7.1e " % (coeffsErr[ain][aout])
320 print(msg + r" \\")
321 print(r"\end{tabular}")
323 return
325 print("ampIn ", end=' ')
326 if coeffsErr is not None:
327 print(" ", end=' ')
328 print("ampOut")
330 msg = "%-4s " % ""
331 for aout in range(nAmp):
332 msg += " %d " % aout
333 if coeffsErr is not None:
334 msg += "%11s" % ""
335 print(msg)
336 for ain in range(nAmp):
337 msg = "%-4d " % ain
338 for aout in range(nAmp):
339 msg += " %9.2e" % coeffs[ain][aout]
340 if coeffsErr is not None:
341 msg += " +- %7.1e" % coeffsErr[ain][aout]
343 print(msg)
346#
347# Code to simulate crosstalk
348#
349xTalkAmplitudes = np.array([(0, -1.0e-4, -2.0e-4, -3.0e-4), # cross talk from amp0 to amp1, 2, 3
350 (-1.5e-4, 0, -2.5e-4, -2.9e-4),
351 (-2.2e-4, -3.1e-4, 0, -0.9e-4),
352 (-2.7e-4, -3.3e-4, -1.9e-4, 0)]) # ... from amp 3
355def addTrail(mi, val, x0, y0, pix, addCrosstalk=True):
356 width = mi.getWidth()
357 hwidth = width//2
359 if addCrosstalk:
360 xtalk = mi.Factory(mi.getDimensions())
362 SAT = reduce(lambda x, y: x | afwImage.Mask.getPlaneBitMask(y), ["SAT", "INTRP"], 0x0) if False else 0
364 for _y, _x12 in enumerate(pix):
365 for _x in range(*_x12):
366 x, y = x0 + _x, y0 + _y
367 mi.set(x, y, (val, SAT,))
369 if addCrosstalk:
370 amp, xx = getXPos(width, hwidth, x)
371 for i, x in enumerate(xx):
372 xtalk.set(x, y, (xTalkAmplitudes[amp][i]*val, ))
374 mi += xtalk
377def addSaturated(mi, addCrosstalk=True):
378 trail1 = 6*[(0, 2)] + 4*[(-1, 3)] + 4*[(-2, 4)] + 3*[(-1, 3)] + 4*[(0, 2)]
379 trail2 = 12*[(0, 2)] + 8*[(-1, 3)] + 4*[(-2, 4)] + 4*[(-3, 6)] + 3*[(-2, 5)] + 3*[(-1, 3)] + 10*[(0, 2)]
381 addTrail(mi, 48000, 300, 350, trail1, addCrosstalk)
382 addTrail(mi, 50000, 100, 450, trail1, addCrosstalk)
383 addTrail(mi, 60000, 50, 550, trail1, addCrosstalk)
384 addTrail(mi, 52000, 450, 650, trail1, addCrosstalk)
386 addTrail(mi, 60000, 100, 300, trail2, addCrosstalk)
387 addTrail(mi, 50000, 200, 400, trail2, addCrosstalk)
388 addTrail(mi, 46000, 300, 500, trail2, addCrosstalk)
389 addTrail(mi, 48000, 400, 600, trail2, addCrosstalk)
392def makeImage(width=500, height=1000):
393 mi = afwImage.MaskedImageF(width, height)
394 var = 50
395 mi.set(1000, 0x0, var)
397 addSaturated(mi, addCrosstalk=True)
399 ralg, rseed = "MT19937", int(time.time()) if True else 1234
401 noise = afwImage.ImageF(width, height)
402 afwMath.randomGaussianImage(noise, afwMath.Random(ralg, rseed))
403 noise *= math.sqrt(var)
404 mi += noise
406 return mi
409def readImage(butler, **kwargs):
410 try:
411 return butler.get("calexp", **kwargs).getMaskedImage()
412 except Exception as e:
413 print(e)
414 import pdb
415 pdb.set_trace()
418def makeList(x):
419 try:
420 x[0]
421 return x
422 except TypeError:
423 return [x]
426def estimateCoeffs(butler, visitList, ccdList, threshold=45000, nSample=1, plot=False, fig=None, title=None):
427 rats = None
428 for v in visitList:
429 for ccd in ccdList:
430 if ccd == "simulated":
431 mi = makeImage()
432 else:
433 mi = readImage(butler, visit=v, ccd=ccd)
435 rats = getAmplitudeRatios(mi, threshold, rats=rats)
437 return calculateCoeffs(rats, nsigma=2, plot=plot, title=title, fig=fig)
440def main(butler, visit=131634, ccd=None, threshold=45000, nSample=1, showCoeffs=True, fixXTalk=True,
441 plot=False, title=None):
442 if ccd is None:
443 visitList = list(range(nSample))
444 ccdList = ["simulated", ]
445 else:
446 ccdList = makeList(ccd)
447 visitList = makeList(visit)
449 coeffs, coeffsErr = estimateCoeffs(butler, visitList, ccdList, threshold=45000, plot=plot, title=title)
451 if showCoeffs:
452 printCoeffs(coeffs, coeffsErr)
454 mi = readImage(butler, visit=visitList[0], ccd=ccdList[0])
455 if fixXTalk:
456 subtractXTalk(mi, coeffs, threshold)
458 return mi, coeffs
461try:
462 import matplotlib.ticker as ticker
463 import matplotlib.pyplot as pyplot
464except ImportError:
465 pyplot = None
466try:
467 mpFigures
468except NameError:
469 mpFigures = {0: None} # matplotlib (actually pyplot) figures
472def makeSubplots(figure, nx=2, ny=2):
473 """Return a generator of a set of subplots"""
474 for window in range(nx*ny):
475 yield figure.add_subplot(nx, ny, window + 1) # 1-indexed
478def getMpFigure(fig=None, clear=True):
479 """Return a pyplot figure()
481 If fig is supplied save it and make it the default fig may also be a bool
482 (make a new figure) or an int (return or make a figure (1-indexed;
483 python-list style -n supported)
484 """
486 if not pyplot:
487 raise RuntimeError("I am unable to plot as I failed to import matplotlib")
489 if isinstance(fig, bool): # we want a new one
490 fig = len(mpFigures) + 1 # matplotlib is 1-indexed
492 if isinstance(fig, int):
493 i = fig
494 if i == 0:
495 raise RuntimeError("I'm sorry, but matplotlib uses 1-indexed figures")
496 if i < 0:
497 try:
498 i = sorted(mpFigures.keys())[i] # simulate list's [-n] syntax
499 except IndexError:
500 if mpFigures:
501 print("Illegal index: %d" % i, file=sys.stderr)
502 i = 1
504 def lift(fig):
505 fig.canvas._tkcanvas._root().lift() # == Tk's raise, but raise is a python reserved word
507 if i in mpFigures:
508 try:
509 lift(mpFigures[i])
510 except Exception:
511 del mpFigures[i]
513 if i not in mpFigures:
514 for j in range(1, i):
515 getMpFigure(j)
517 mpFigures[i] = pyplot.figure()
518 #
519 # Modify pyplot.figure().show() to make it raise the plot too
520 #
522 def show(self, _show=mpFigures[i].show):
523 _show(self)
524 try:
525 lift(self)
526 except Exception:
527 pass
528 # create a bound method
529 import types
530 mpFigures[i].show = types.MethodType(show, mpFigures[i], mpFigures[i].__class__)
532 fig = mpFigures[i]
534 if not fig:
535 i = sorted(mpFigures.keys())[0]
536 if i > 0:
537 fig = mpFigures[i[-1]]
538 else:
539 fig = getMpFigure(1)
541 if clear:
542 fig.clf()
544 pyplot.figure(fig.number) # make it active
546 return fig
549def fixCcd(butler, visit, ccd, coeffs, display=True):
550 """Apply cross-talk correction to a CCD, given the cross-talk coefficients
551 """
552 mi = readImage(butler, visit=visit, ccd=ccd)
553 if display:
554 afwDisplay.getDisplay(frame=1).mtv(mi.getImage(), title="CCD %d" % ccd)
556 subtractXTalk(mi, coeffs)
558 if display:
559 afwDisplay.getDisplay(frame=2).mtv(mi, title="corrected %d" % ccd)
562if __name__ == "__main__":
563 main()