Coverage for python / lsst / obs / subaru / crosstalk.py: 0%

336 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-22 08:59 +0000

1# 

2# LSST Data Management System 

3# Copyright 2008-2016 AURA/LSST. 

4# 

5# This product includes software developed by the 

6# LSST Project (http://www.lsst.org/). 

7# 

8# This program is free software: you can redistribute it and/or modify 

9# it under the terms of the GNU General Public License as published by 

10# the Free Software Foundation, either version 3 of the License, or 

11# (at your option) any later version. 

12# 

13# This program is distributed in the hope that it will be useful, 

14# but WITHOUT ANY WARRANTY; without even the implied warranty of 

15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

16# GNU General Public License for more details. 

17# 

18# You should have received a copy of the LSST License Statement and 

19# the GNU General Public License along with this program. If not, 

20# see <https://www.lsstcorp.org/LegalNotices/>. 

21# 

22""" 

23Determine and apply crosstalk corrections 

24 

25N.b. This code was written and tested for the 4-amplifier Hamamatsu chips used 

26in (Hyper)?SuprimeCam, and will need to be generalised to handle other 

27amplifier layouts. I don't want to do this until we have an example. 

28 

29N.b. To estimate crosstalk from the SuprimeCam data, the commands are e.g.: 

30 

31.. code-block:: python 

32 

33 import crosstalk 

34 coeffs, coeffsErr = crosstalk.estimateCoeffs(butler, range(131634, 131642), 

35 range(10), threshold=1e5, 

36 plot=True, title="CCD0..9", 

37 fig=1) 

38 crosstalk.fixCcd(butler, 131634, 0, coeffs) 

39""" 

40import sys 

41import math 

42import time 

43 

44import numpy as np 

45 

46import lsst.afw.detection as afwDetect 

47import lsst.afw.image as afwImage 

48import lsst.afw.math as afwMath 

49import lsst.geom as geom 

50import lsst.pex.config as pexConfig 

51import lsst.afw.display as afwDisplay 

52from functools import reduce 

53from lsst.ip.isr.crosstalk import CrosstalkTask 

54 

55 

56class CrosstalkCoeffsConfig(pexConfig.Config): 

57 """Specify crosstalk coefficients for a CCD""" 

58 

59 values = pexConfig.ListField( 

60 dtype=float, 

61 doc="Crosstalk coefficients", 

62 default=[0, 0, 0, 0, 

63 0, 0, 0, 0, 

64 0, 0, 0, 0, 

65 0, 0, 0, 0], 

66 ) 

67 shape = pexConfig.ListField( 

68 dtype=int, 

69 doc="Shape of coeffs array", 

70 default=[4, 4], 

71 minLength=1, # really 2, but there's a bug in pex_config 

72 maxLength=2, 

73 ) 

74 

75 def getCoeffs(self): 

76 """Return a 2-D numpy array of crosstalk coefficients of the proper 

77 shape""" 

78 return np.array(self.values).reshape(self.shape) 

79 

80 

81class SubaruCrosstalkConfig(CrosstalkTask.ConfigClass): 

82 minPixelToMask = pexConfig.Field(dtype=float, default=45000, 

83 doc="Set crosstalk mask plane for pixels over this value") 

84 crosstalkMaskPlane = pexConfig.Field(dtype=str, default="CROSSTALK", doc="Name for crosstalk mask plane") 

85 coeffs = pexConfig.ConfigField(dtype=CrosstalkCoeffsConfig, doc="Crosstalk coefficients") 

86 

87 

88class SubaruCrosstalkTask(CrosstalkTask): 

89 ConfigClass = SubaruCrosstalkConfig 

90 

91 def run(self, exp, **kwargs): 

92 self.log.info("Applying crosstalk correction/Subaru") 

93 subtractXTalk(exp.getMaskedImage(), self.config.coeffs.getCoeffs(), self.config.minPixelToMask, 

94 self.config.crosstalkMaskPlane) 

95 

96 

97nAmp = 4 

98 

99 

100def getXPos(width, hwidth, x): 

101 """Return the amp that x is in, and the positions of its image in each 

102 amplifier""" 

103 amp = x//(hwidth//2) # which amp am I in? Assumes nAmp == 4 

104 assert nAmp == 4 

105 assert amp in range(nAmp) 

106 

107 if amp == 0: 

108 xa = x # distance to amp 

109 xs = hwidth - x - 1 # symmetrical position within this half of the chip 

110 xx = (x, xs, hwidth + xa, hwidth + xs) 

111 elif amp == 1: 

112 xa = hwidth - x - 1 # distance to amp 

113 xs = hwidth - x # symmetrical position within this half of the chip 

114 xx = (xs - 1, x, hwidth + xs - 1, hwidth + x) 

115 elif amp == 2: 

116 xa = x - hwidth # distance to amp 

117 xs = width - x # symmetrical position within this half of the chip 

118 xx = (xa, width - x - 1, x, width - xa - 1) 

119 elif amp == 3: 

120 xa = x - hwidth # distance to amp 

121 xs = width - x # symmetrical position within this half of the chip 

122 xx = (width - x - 1, xa, width - xa - 1, x) 

123 

124 return amp, xx 

125 

126 

127def getAmplitudeRatios(mi, threshold=45000, bkgd=None, rats=None): 

128 img = mi.getImage() 

129 msk = mi.getMask() 

130 width = mi.getWidth() 

131 hwidth = width//2 

132 

133 if rats is None: 

134 rats = [] 

135 for i in range(nAmp): 

136 rats.append([]) 

137 for j in range(nAmp): 

138 rats[i].append([]) 

139 rats[i][i].append(0) 

140 

141 fs = afwDetect.FootprintSet(mi, afwDetect.Threshold(threshold), "DETECTED") 

142 

143 if bkgd is None: 

144 sctrl = afwMath.StatisticsControl() 

145 sctrl.setAndMask(msk.getPlaneBitMask("DETECTED")) 

146 

147 bkgd = afwMath.makeStatistics(mi, afwMath.MEDIAN, sctrl).getValue() 

148 

149 badMask = msk.getPlaneBitMask(["BAD", "EDGE", "SAT", "INTRP"]) 

150 

151 for foot in fs.getFootprints(): 

152 for s in foot.getSpans(): 

153 y, x0, x1 = s.getY(), s.getX0(), s.getX1() 

154 for x in range(x0, x1): 

155 val = img.get(x, y) 

156 amp, xx = getXPos(width, hwidth, x) 

157 for a, _x in enumerate(xx): 

158 if a != amp: 

159 if msk.get(_x, y) & badMask: 

160 continue 

161 if False: 

162 foo = (img.get(_x, y) - bkgd)/val 

163 if np.abs(foo) < 1e-5: 

164 print(img.get(_x, y) - bkgd, val) 

165 mi.getMask().set(_x, y, 0x100) 

166 

167 rats[amp][a].append((img.get(_x, y) - bkgd)/val) 

168 

169 return rats 

170 

171 

172def calculateCoeffs(rats, nsigma, plot=False, fig=None, title=None): 

173 """Calculate cross-talk coefficients""" 

174 coeffs = np.empty((nAmp, nAmp)) 

175 coeffsErr = np.empty_like(coeffs) 

176 

177 if plot: 

178 if fig is None: 

179 fig = int(title[-1]) + 1 if title else 1 

180 

181 fig = getMpFigure(fig, clear=True) 

182 subplots = makeSubplots(fig, nAmp, nAmp) 

183 

184 rMin = 2e-3 

185 bins = np.arange(-rMin, rMin, 0.05*rMin) 

186 

187 xMajorLocator = ticker.MaxNLocator(nbins=3) # steps=(-rMin/2, 0, rMin/2)) 

188 

189 for ain in range(nAmp): 

190 for aout in range(nAmp): 

191 tmp = np.array(rats[ain][aout]) 

192 tmp.sort() 

193 for i in range(3): 

194 n = len(tmp) 

195 med = tmp[int(0.5*n)] 

196 sigma = 0.741*(tmp[int(0.75*n)] - tmp[int(0.25*n)]) 

197 w = np.where(abs(tmp - med) < nsigma*sigma) 

198 if not np.any(w): 

199 break 

200 tmp = tmp[w] 

201 

202 coeffs[ain][aout] = tmp[len(tmp)//2] 

203 

204 err = 0.741*(tmp[3*len(tmp)//4] - tmp[len(tmp)//4]) # estimate s.d. from IQR 

205 err /= math.sqrt(len(tmp)) # standard error of mean 

206 err *= math.sqrt(math.pi/2) # standard error of median 

207 coeffsErr[ain][aout] = err 

208 

209 if plot: 

210 axes = next(subplots) 

211 axes.xaxis.set_major_locator(xMajorLocator) 

212 

213 if ain != aout: 

214 hist = np.histogram(rats[ain][aout], bins)[0] 

215 axes.bar(bins[0:-1], hist, width=bins[1]-bins[0], color="red", linewidth=0, alpha=0.8) 

216 

217 axes.axvline(0, linestyle="--", color="green") 

218 axes.axvline(coeffs[ain][aout], linestyle="-", color="blue") 

219 for i in (-1, 1): 

220 axes.axvline(coeffs[ain][aout] + i*coeffsErr[ain][aout], linestyle=":", color="cyan") 

221 axes.text(-0.9*rMin, 0.8*axes.get_ylim()[1], r"%.1e" % coeffs[ain][aout], fontsize="smaller") 

222 axes.set_xlim(-1.05*rMin, 1.05*rMin) 

223 

224 if plot: 

225 if title: 

226 fig.suptitle(title) 

227 fig.show() 

228 

229 return coeffs, coeffsErr 

230 

231 

232def subtractXTalk(mi, coeffs, minPixelToMask=45000, crosstalkStr="CROSSTALK"): 

233 """Subtract the crosstalk from MaskedImage mi given a set of coefficients 

234 

235 The pixels affected by signal over minPixelToMask have the crosstalkStr 

236 bit set 

237 """ 

238 sctrl = afwMath.StatisticsControl() 

239 sctrl.setAndMask(mi.getMask().getPlaneBitMask("BAD")) 

240 bkgd = afwMath.makeStatistics(mi, afwMath.MEDIAN, sctrl).getValue() 

241 # 

242 # These are the pixels that are bright enough to cause crosstalk (more 

243 # precisely, the ones that we label as causing crosstalk; in reality all 

244 # pixels cause crosstalk) 

245 # 

246 tempStr = "TEMP" # mask plane used to record the bright pixels that we need to mask 

247 msk = mi.getMask() 

248 msk.addMaskPlane(tempStr) 

249 try: 

250 fs = afwDetect.FootprintSet(mi, afwDetect.Threshold(minPixelToMask), tempStr) 

251 

252 mi.getMask().addMaskPlane(crosstalkStr) 

253 afwDisplay.getDisplay().setMaskPlaneColor(crosstalkStr, afwDisplay.MAGENTA) 

254 # the crosstalkStr bit will now be set whenever we subtract crosstalk 

255 fs.setMask(mi.getMask(), crosstalkStr) 

256 crosstalk = mi.getMask().getPlaneBitMask(crosstalkStr) 

257 

258 width, height = mi.getDimensions() 

259 for i in range(nAmp): 

260 bbox = geom.BoxI(geom.PointI(i*(width//nAmp), 0), geom.ExtentI(width//nAmp, height)) 

261 ampI = mi.Factory(mi, bbox) 

262 for j in range(nAmp): 

263 if i == j: 

264 continue 

265 

266 bbox = geom.BoxI(geom.PointI(j*(width//nAmp), 0), geom.ExtentI(width//nAmp, height)) 

267 if (i + j)%2 == 1: 

268 ampJ = afwMath.flipImage(mi.Factory(mi, bbox), True, False) # no need for a deep copy 

269 else: 

270 ampJ = mi.Factory(mi, bbox, afwImage.LOCAL, True) 

271 

272 msk = ampJ.getMask() 

273 if np.all(msk.getArray() & msk.getPlaneBitMask("BAD")): 

274 # Bad amplifier; ignore it completely --- its effect will 

275 # come out in the bias 

276 continue 

277 msk &= crosstalk 

278 

279 ampJ -= bkgd 

280 ampJ *= coeffs[j][i] 

281 

282 ampI -= ampJ 

283 # 

284 # Clear the crosstalkStr bit in the original bright pixels, where 

285 # tempStr is set 

286 # 

287 msk = mi.getMask() 

288 temp = msk.getPlaneBitMask(tempStr) 

289 xtalk_temp = crosstalk | temp 

290 np_msk = msk.getArray() 

291 mask_indicies = np.where(np.bitwise_and(np_msk, xtalk_temp) == xtalk_temp) 

292 np_msk[mask_indicies] &= getattr(np, np_msk.dtype.name)(~crosstalk) 

293 

294 finally: 

295 msk.removeAndClearMaskPlane(tempStr, True) # added in afw #1853 

296 

297 

298def printCoeffs(coeffs, coeffsErr=None, LaTeX=False, ppm=False): 

299 """Print cross-talk coefficients""" 

300 

301 if LaTeX: 

302 print(r"""\begin{tabular}{l|*{4}{l}} 

303ampIn & \multicolumn{4}{c}{ampOut} \\ 

304 & 0 & 1 & 2 & 3 \\ 

305\hline""") 

306 for ain in range(nAmp): 

307 msg = "%-4d " % ain 

308 for aout in range(nAmp): 

309 if ppm: 

310 msg += "& " if ain == aout else "& %5.0f " % (1e6*coeffs[ain][aout]) 

311 else: 

312 msg += "& %9.2e " % (coeffs[ain][aout]) 

313 if coeffsErr is not None: 

314 if ppm: 

315 if ain != aout: 

316 val = int(1e6*coeffsErr[ain][aout] + 0.5) 

317 msg += r"$\pm$ %s%d " % (r"$\phantom{0}$" if val < 10 else "", val) 

318 else: 

319 msg += r"$\pm$ %7.1e " % (coeffsErr[ain][aout]) 

320 print(msg + r" \\") 

321 print(r"\end{tabular}") 

322 

323 return 

324 

325 print("ampIn ", end=' ') 

326 if coeffsErr is not None: 

327 print(" ", end=' ') 

328 print("ampOut") 

329 

330 msg = "%-4s " % "" 

331 for aout in range(nAmp): 

332 msg += " %d " % aout 

333 if coeffsErr is not None: 

334 msg += "%11s" % "" 

335 print(msg) 

336 for ain in range(nAmp): 

337 msg = "%-4d " % ain 

338 for aout in range(nAmp): 

339 msg += " %9.2e" % coeffs[ain][aout] 

340 if coeffsErr is not None: 

341 msg += " +- %7.1e" % coeffsErr[ain][aout] 

342 

343 print(msg) 

344 

345 

346# 

347# Code to simulate crosstalk 

348# 

349xTalkAmplitudes = np.array([(0, -1.0e-4, -2.0e-4, -3.0e-4), # cross talk from amp0 to amp1, 2, 3 

350 (-1.5e-4, 0, -2.5e-4, -2.9e-4), 

351 (-2.2e-4, -3.1e-4, 0, -0.9e-4), 

352 (-2.7e-4, -3.3e-4, -1.9e-4, 0)]) # ... from amp 3 

353 

354 

355def addTrail(mi, val, x0, y0, pix, addCrosstalk=True): 

356 width = mi.getWidth() 

357 hwidth = width//2 

358 

359 if addCrosstalk: 

360 xtalk = mi.Factory(mi.getDimensions()) 

361 

362 SAT = reduce(lambda x, y: x | afwImage.Mask.getPlaneBitMask(y), ["SAT", "INTRP"], 0x0) if False else 0 

363 

364 for _y, _x12 in enumerate(pix): 

365 for _x in range(*_x12): 

366 x, y = x0 + _x, y0 + _y 

367 mi.set(x, y, (val, SAT,)) 

368 

369 if addCrosstalk: 

370 amp, xx = getXPos(width, hwidth, x) 

371 for i, x in enumerate(xx): 

372 xtalk.set(x, y, (xTalkAmplitudes[amp][i]*val, )) 

373 

374 mi += xtalk 

375 

376 

377def addSaturated(mi, addCrosstalk=True): 

378 trail1 = 6*[(0, 2)] + 4*[(-1, 3)] + 4*[(-2, 4)] + 3*[(-1, 3)] + 4*[(0, 2)] 

379 trail2 = 12*[(0, 2)] + 8*[(-1, 3)] + 4*[(-2, 4)] + 4*[(-3, 6)] + 3*[(-2, 5)] + 3*[(-1, 3)] + 10*[(0, 2)] 

380 

381 addTrail(mi, 48000, 300, 350, trail1, addCrosstalk) 

382 addTrail(mi, 50000, 100, 450, trail1, addCrosstalk) 

383 addTrail(mi, 60000, 50, 550, trail1, addCrosstalk) 

384 addTrail(mi, 52000, 450, 650, trail1, addCrosstalk) 

385 

386 addTrail(mi, 60000, 100, 300, trail2, addCrosstalk) 

387 addTrail(mi, 50000, 200, 400, trail2, addCrosstalk) 

388 addTrail(mi, 46000, 300, 500, trail2, addCrosstalk) 

389 addTrail(mi, 48000, 400, 600, trail2, addCrosstalk) 

390 

391 

392def makeImage(width=500, height=1000): 

393 mi = afwImage.MaskedImageF(width, height) 

394 var = 50 

395 mi.set(1000, 0x0, var) 

396 

397 addSaturated(mi, addCrosstalk=True) 

398 

399 ralg, rseed = "MT19937", int(time.time()) if True else 1234 

400 

401 noise = afwImage.ImageF(width, height) 

402 afwMath.randomGaussianImage(noise, afwMath.Random(ralg, rseed)) 

403 noise *= math.sqrt(var) 

404 mi += noise 

405 

406 return mi 

407 

408 

409def readImage(butler, **kwargs): 

410 try: 

411 return butler.get("calexp", **kwargs).getMaskedImage() 

412 except Exception as e: 

413 print(e) 

414 import pdb 

415 pdb.set_trace() 

416 

417 

418def makeList(x): 

419 try: 

420 x[0] 

421 return x 

422 except TypeError: 

423 return [x] 

424 

425 

426def estimateCoeffs(butler, visitList, ccdList, threshold=45000, nSample=1, plot=False, fig=None, title=None): 

427 rats = None 

428 for v in visitList: 

429 for ccd in ccdList: 

430 if ccd == "simulated": 

431 mi = makeImage() 

432 else: 

433 mi = readImage(butler, visit=v, ccd=ccd) 

434 

435 rats = getAmplitudeRatios(mi, threshold, rats=rats) 

436 

437 return calculateCoeffs(rats, nsigma=2, plot=plot, title=title, fig=fig) 

438 

439 

440def main(butler, visit=131634, ccd=None, threshold=45000, nSample=1, showCoeffs=True, fixXTalk=True, 

441 plot=False, title=None): 

442 if ccd is None: 

443 visitList = list(range(nSample)) 

444 ccdList = ["simulated", ] 

445 else: 

446 ccdList = makeList(ccd) 

447 visitList = makeList(visit) 

448 

449 coeffs, coeffsErr = estimateCoeffs(butler, visitList, ccdList, threshold=45000, plot=plot, title=title) 

450 

451 if showCoeffs: 

452 printCoeffs(coeffs, coeffsErr) 

453 

454 mi = readImage(butler, visit=visitList[0], ccd=ccdList[0]) 

455 if fixXTalk: 

456 subtractXTalk(mi, coeffs, threshold) 

457 

458 return mi, coeffs 

459 

460 

461try: 

462 import matplotlib.ticker as ticker 

463 import matplotlib.pyplot as pyplot 

464except ImportError: 

465 pyplot = None 

466try: 

467 mpFigures 

468except NameError: 

469 mpFigures = {0: None} # matplotlib (actually pyplot) figures 

470 

471 

472def makeSubplots(figure, nx=2, ny=2): 

473 """Return a generator of a set of subplots""" 

474 for window in range(nx*ny): 

475 yield figure.add_subplot(nx, ny, window + 1) # 1-indexed 

476 

477 

478def getMpFigure(fig=None, clear=True): 

479 """Return a pyplot figure() 

480 

481 If fig is supplied save it and make it the default fig may also be a bool 

482 (make a new figure) or an int (return or make a figure (1-indexed; 

483 python-list style -n supported) 

484 """ 

485 

486 if not pyplot: 

487 raise RuntimeError("I am unable to plot as I failed to import matplotlib") 

488 

489 if isinstance(fig, bool): # we want a new one 

490 fig = len(mpFigures) + 1 # matplotlib is 1-indexed 

491 

492 if isinstance(fig, int): 

493 i = fig 

494 if i == 0: 

495 raise RuntimeError("I'm sorry, but matplotlib uses 1-indexed figures") 

496 if i < 0: 

497 try: 

498 i = sorted(mpFigures.keys())[i] # simulate list's [-n] syntax 

499 except IndexError: 

500 if mpFigures: 

501 print("Illegal index: %d" % i, file=sys.stderr) 

502 i = 1 

503 

504 def lift(fig): 

505 fig.canvas._tkcanvas._root().lift() # == Tk's raise, but raise is a python reserved word 

506 

507 if i in mpFigures: 

508 try: 

509 lift(mpFigures[i]) 

510 except Exception: 

511 del mpFigures[i] 

512 

513 if i not in mpFigures: 

514 for j in range(1, i): 

515 getMpFigure(j) 

516 

517 mpFigures[i] = pyplot.figure() 

518 # 

519 # Modify pyplot.figure().show() to make it raise the plot too 

520 # 

521 

522 def show(self, _show=mpFigures[i].show): 

523 _show(self) 

524 try: 

525 lift(self) 

526 except Exception: 

527 pass 

528 # create a bound method 

529 import types 

530 mpFigures[i].show = types.MethodType(show, mpFigures[i], mpFigures[i].__class__) 

531 

532 fig = mpFigures[i] 

533 

534 if not fig: 

535 i = sorted(mpFigures.keys())[0] 

536 if i > 0: 

537 fig = mpFigures[i[-1]] 

538 else: 

539 fig = getMpFigure(1) 

540 

541 if clear: 

542 fig.clf() 

543 

544 pyplot.figure(fig.number) # make it active 

545 

546 return fig 

547 

548 

549def fixCcd(butler, visit, ccd, coeffs, display=True): 

550 """Apply cross-talk correction to a CCD, given the cross-talk coefficients 

551 """ 

552 mi = readImage(butler, visit=visit, ccd=ccd) 

553 if display: 

554 afwDisplay.getDisplay(frame=1).mtv(mi.getImage(), title="CCD %d" % ccd) 

555 

556 subtractXTalk(mi, coeffs) 

557 

558 if display: 

559 afwDisplay.getDisplay(frame=2).mtv(mi, title="corrected %d" % ccd) 

560 

561 

562if __name__ == "__main__": 

563 main()