Coverage for python/lsst/ip/isr/deferredCharge.py: 16%

278 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-23 11:36 +0000

1# This file is part of ip_isr. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ('DeferredChargeConfig', 'DeferredChargeTask', 'SerialTrap', 'DeferredChargeCalib') 

23 

24import numpy as np 

25from astropy.table import Table 

26 

27from lsst.afw.cameraGeom import ReadoutCorner 

28from lsst.pex.config import Config, Field 

29from lsst.pipe.base import Task 

30from .isrFunctions import gainContext 

31from .calibType import IsrCalib 

32 

33import scipy.interpolate as interp 

34 

35 

36class SerialTrap(): 

37 """Represents a serial register trap. 

38 

39 Parameters 

40 ---------- 

41 size : `float` 

42 Size of the charge trap, in electrons. 

43 emission_time : `float` 

44 Trap emission time constant, in inverse transfers. 

45 pixel : `int` 

46 Serial pixel location of the trap, including the prescan. 

47 trap_type : `str` 

48 Type of trap capture to use. Should be one of ``linear``, 

49 ``logistic``, or ``spline``. 

50 coeffs : `list` [`float`] 

51 Coefficients for the capture process. Linear traps need one 

52 coefficient, logistic traps need two, and spline based traps 

53 need to have an even number of coefficients that can be split 

54 into their spline locations and values. 

55 

56 Raises 

57 ------ 

58 ValueError 

59 Raised if the specified parameters are out of expected range. 

60 """ 

61 

62 def __init__(self, size, emission_time, pixel, trap_type, coeffs): 

63 if size < 0.0: 

64 raise ValueError('Trap size must be greater than or equal to 0.') 

65 self.size = size 

66 

67 if emission_time <= 0.0: 

68 raise ValueError('Emission time must be greater than 0.') 

69 if np.isnan(emission_time): 

70 raise ValueError('Emission time must be real-valued, not NaN') 

71 self.emission_time = emission_time 

72 

73 if int(pixel) != pixel: 

74 raise ValueError('Fraction value for pixel not allowed.') 

75 self.pixel = int(pixel) 

76 

77 self.trap_type = trap_type 

78 self.coeffs = coeffs 

79 

80 if self.trap_type not in ('linear', 'logistic', 'spline'): 

81 raise ValueError('Unknown trap type: %s', self.trap_type) 

82 

83 if self.trap_type == 'spline': 

84 # Note that ``spline`` is actually a piecewise linear interpolation 

85 # in the model and the application, and not a true spline. 

86 centers, values = np.split(np.array(self.coeffs, dtype=np.float64), 2) 

87 # Ensure all NaN values are stripped out 

88 values = values[~np.isnan(centers)] 

89 centers = centers[~np.isnan(centers)] 

90 centers = centers[~np.isnan(values)] 

91 values = values[~np.isnan(values)] 

92 self.interp = interp.interp1d( 

93 centers, 

94 values, 

95 bounds_error=False, 

96 fill_value=(values[0], values[-1]), 

97 ) 

98 

99 self._trap_array = None 

100 self._trapped_charge = None 

101 

102 def __eq__(self, other): 

103 # A trap is equal to another trap if all of the initialization 

104 # parameters are equal. All other properties are only filled 

105 # during use, and are not persisted into the calibration. 

106 if self.size != other.size: 

107 return False 

108 if self.emission_time != other.emission_time: 

109 return False 

110 if self.pixel != other.pixel: 

111 return False 

112 if self.trap_type != other.trap_type: 

113 return False 

114 if self.coeffs != other.coeffs: 

115 return False 

116 return True 

117 

118 @property 

119 def trap_array(self): 

120 return self._trap_array 

121 

122 @property 

123 def trapped_charge(self): 

124 return self._trapped_charge 

125 

126 def initialize(self, ny, nx, prescan_width): 

127 """Initialize trapping arrays for simulated readout. 

128 

129 Parameters 

130 ---------- 

131 ny : `int` 

132 Number of rows to simulate. 

133 nx : `int` 

134 Number of columns to simulate. 

135 prescan_width : `int` 

136 Additional transfers due to prescan. 

137 

138 Raises 

139 ------ 

140 ValueError 

141 Raised if the trap falls outside of the image. 

142 """ 

143 if self.pixel > nx+prescan_width: 

144 raise ValueError('Trap location {0} must be less than {1}'.format(self.pixel, 

145 nx+prescan_width)) 

146 

147 self._trap_array = np.zeros((ny, nx+prescan_width)) 

148 self._trap_array[:, self.pixel] = self.size 

149 self._trapped_charge = np.zeros((ny, nx+prescan_width)) 

150 

151 def release_charge(self): 

152 """Release charge through exponential decay. 

153 

154 Returns 

155 ------- 

156 released_charge : `float` 

157 Charge released. 

158 """ 

159 released_charge = self._trapped_charge*(1-np.exp(-1./self.emission_time)) 

160 self._trapped_charge -= released_charge 

161 

162 return released_charge 

163 

164 def trap_charge(self, free_charge): 

165 """Perform charge capture using a logistic function. 

166 

167 Parameters 

168 ---------- 

169 free_charge : `float` 

170 Charge available to be trapped. 

171 

172 Returns 

173 ------- 

174 captured_charge : `float` 

175 Amount of charge actually trapped. 

176 """ 

177 captured_charge = (np.clip(self.capture(free_charge), self.trapped_charge, self._trap_array) 

178 - self.trapped_charge) 

179 self._trapped_charge += captured_charge 

180 

181 return captured_charge 

182 

183 def capture(self, pixel_signals): 

184 """Trap capture function. 

185 

186 Parameters 

187 ---------- 

188 pixel_signals : `list` [`float`] 

189 Input pixel values. 

190 

191 Returns 

192 ------- 

193 captured_charge : `list` [`float`] 

194 Amount of charge captured from each pixel. 

195 

196 Raises 

197 ------ 

198 RuntimeError 

199 Raised if the trap type is invalid. 

200 """ 

201 if self.trap_type == 'linear': 

202 scaling = self.coeffs[0] 

203 return np.minimum(self.size, pixel_signals*scaling) 

204 elif self.trap_type == 'logistic': 

205 f0, k = (self.coeffs[0], self.coeffs[1]) 

206 return self.size/(1.+np.exp(-k*(pixel_signals-f0))) 

207 elif self.trap_type == 'spline': 

208 return self.interp(pixel_signals) 

209 else: 

210 raise RuntimeError(f"Invalid trap capture type: {self.trap_type}.") 

211 

212 

213class DeferredChargeCalib(IsrCalib): 

214 r"""Calibration containing deferred charge/CTI parameters. 

215 

216 Parameters 

217 ---------- 

218 **kwargs : 

219 Additional parameters to pass to parent constructor. 

220 

221 Notes 

222 ----- 

223 The charge transfer inefficiency attributes stored are: 

224 

225 driftScale : `dict` [`str`, `float`] 

226 A dictionary, keyed by amplifier name, of the local electronic 

227 offset drift scale parameter, A_L in Snyder+2021. 

228 decayTime : `dict` [`str`, `float`] 

229 A dictionary, keyed by amplifier name, of the local electronic 

230 offset decay time, \tau_L in Snyder+2021. 

231 globalCti : `dict` [`str`, `float`] 

232 A dictionary, keyed by amplifier name, of the mean global CTI 

233 paramter, b in Snyder+2021. 

234 serialTraps : `dict` [`str`, `lsst.ip.isr.SerialTrap`] 

235 A dictionary, keyed by amplifier name, containing a single 

236 serial trap for each amplifier. 

237 """ 

238 _OBSTYPE = 'CTI' 

239 _SCHEMA = 'Deferred Charge' 

240 _VERSION = 1.0 

241 

242 def __init__(self, useGains=True, **kwargs): 

243 self.driftScale = {} 

244 self.decayTime = {} 

245 self.globalCti = {} 

246 self.serialTraps = {} 

247 

248 super().__init__(**kwargs) 

249 

250 units = 'electrons' if useGains else 'ADU' 

251 self.updateMetadata(USEGAINS=useGains, UNITS=units) 

252 

253 self.requiredAttributes.update(['driftScale', 'decayTime', 'globalCti', 'serialTraps']) 

254 

255 def fromDetector(self, detector): 

256 """Read metadata parameters from a detector. 

257 

258 Parameters 

259 ---------- 

260 detector : `lsst.afw.cameraGeom.detector` 

261 Input detector with parameters to use. 

262 

263 Returns 

264 ------- 

265 calib : `lsst.ip.isr.Linearizer` 

266 The calibration constructed from the detector. 

267 """ 

268 

269 pass 

270 

271 @classmethod 

272 def fromDict(cls, dictionary): 

273 """Construct a calibration from a dictionary of properties. 

274 

275 Parameters 

276 ---------- 

277 dictionary : `dict` 

278 Dictionary of properties. 

279 

280 Returns 

281 ------- 

282 calib : `lsst.ip.isr.CalibType` 

283 Constructed calibration. 

284 

285 Raises 

286 ------ 

287 RuntimeError 

288 Raised if the supplied dictionary is for a different 

289 calibration. 

290 """ 

291 calib = cls() 

292 

293 if calib._OBSTYPE != dictionary['metadata']['OBSTYPE']: 

294 raise RuntimeError(f"Incorrect CTI supplied. Expected {calib._OBSTYPE}, " 

295 f"found {dictionary['metadata']['OBSTYPE']}") 

296 

297 calib.setMetadata(dictionary['metadata']) 

298 

299 calib.driftScale = dictionary['driftScale'] 

300 calib.decayTime = dictionary['decayTime'] 

301 calib.globalCti = dictionary['globalCti'] 

302 

303 for ampName in dictionary['serialTraps']: 

304 ampTraps = dictionary['serialTraps'][ampName] 

305 calib.serialTraps[ampName] = SerialTrap(ampTraps['size'], ampTraps['emissionTime'], 

306 ampTraps['pixel'], ampTraps['trap_type'], 

307 ampTraps['coeffs']) 

308 calib.updateMetadata() 

309 return calib 

310 

311 def toDict(self): 

312 """Return a dictionary containing the calibration properties. 

313 The dictionary should be able to be round-tripped through 

314 ``fromDict``. 

315 

316 Returns 

317 ------- 

318 dictionary : `dict` 

319 Dictionary of properties. 

320 """ 

321 self.updateMetadata() 

322 outDict = {} 

323 outDict['metadata'] = self.getMetadata() 

324 

325 outDict['driftScale'] = self.driftScale 

326 outDict['decayTime'] = self.decayTime 

327 outDict['globalCti'] = self.globalCti 

328 

329 outDict['serialTraps'] = {} 

330 for ampName in self.serialTraps: 

331 ampTrap = {'size': self.serialTraps[ampName].size, 

332 'emissionTime': self.serialTraps[ampName].emission_time, 

333 'pixel': self.serialTraps[ampName].pixel, 

334 'trap_type': self.serialTraps[ampName].trap_type, 

335 'coeffs': self.serialTraps[ampName].coeffs} 

336 outDict['serialTraps'][ampName] = ampTrap 

337 

338 return outDict 

339 

340 @classmethod 

341 def fromTable(cls, tableList): 

342 """Construct calibration from a list of tables. 

343 

344 This method uses the ``fromDict`` method to create the 

345 calibration, after constructing an appropriate dictionary from 

346 the input tables. 

347 

348 Parameters 

349 ---------- 

350 tableList : `list` [`lsst.afw.table.Table`] 

351 List of tables to use to construct the crosstalk 

352 calibration. Two tables are expected in this list, the 

353 first containing the per-amplifier CTI parameters, and the 

354 second containing the parameters for serial traps. 

355 

356 Returns 

357 ------- 

358 calib : `lsst.ip.isr.DeferredChargeCalib` 

359 The calibration defined in the tables. 

360 

361 Raises 

362 ------ 

363 ValueError 

364 Raised if the trap type or trap coefficients are not 

365 defined properly. 

366 """ 

367 ampTable = tableList[0] 

368 

369 inDict = {} 

370 inDict['metadata'] = ampTable.meta 

371 

372 amps = ampTable['AMPLIFIER'] 

373 driftScale = ampTable['DRIFT_SCALE'] 

374 decayTime = ampTable['DECAY_TIME'] 

375 globalCti = ampTable['GLOBAL_CTI'] 

376 

377 inDict['driftScale'] = {amp: value for amp, value in zip(amps, driftScale)} 

378 inDict['decayTime'] = {amp: value for amp, value in zip(amps, decayTime)} 

379 inDict['globalCti'] = {amp: value for amp, value in zip(amps, globalCti)} 

380 

381 inDict['serialTraps'] = {} 

382 trapTable = tableList[1] 

383 

384 amps = trapTable['AMPLIFIER'] 

385 sizes = trapTable['SIZE'] 

386 emissionTimes = trapTable['EMISSION_TIME'] 

387 pixels = trapTable['PIXEL'] 

388 trap_type = trapTable['TYPE'] 

389 coeffs = trapTable['COEFFS'] 

390 

391 for index, amp in enumerate(amps): 

392 ampTrap = {} 

393 ampTrap['size'] = sizes[index] 

394 ampTrap['emissionTime'] = emissionTimes[index] 

395 ampTrap['pixel'] = pixels[index] 

396 ampTrap['trap_type'] = trap_type[index] 

397 

398 # Unpad any trailing NaN values: find the continuous array 

399 # of NaNs at the end of the coefficients, and remove them. 

400 inCoeffs = coeffs[index] 

401 breakIndex = 1 

402 nanValues = np.where(np.isnan(inCoeffs))[0] 

403 if nanValues is not None: 

404 coeffLength = len(inCoeffs) 

405 while breakIndex < coeffLength: 

406 if coeffLength - breakIndex in nanValues: 

407 breakIndex += 1 

408 else: 

409 break 

410 breakIndex -= 1 # Remove the fixed offset. 

411 if breakIndex != 0: 

412 outCoeffs = inCoeffs[0: coeffLength - breakIndex] 

413 else: 

414 outCoeffs = inCoeffs 

415 ampTrap['coeffs'] = outCoeffs.tolist() 

416 

417 if ampTrap['trap_type'] == 'linear': 

418 if len(ampTrap['coeffs']) < 1: 

419 raise ValueError("CTI Amplifier %s coefficients for trap has illegal length %d.", 

420 amp, len(ampTrap['coeffs'])) 

421 elif ampTrap['trap_type'] == 'logistic': 

422 if len(ampTrap['coeffs']) < 2: 

423 raise ValueError("CTI Amplifier %s coefficients for trap has illegal length %d.", 

424 amp, len(ampTrap['coeffs'])) 

425 elif ampTrap['trap_type'] == 'spline': 

426 if len(ampTrap['coeffs']) % 2 != 0: 

427 raise ValueError("CTI Amplifier %s coefficients for trap has illegal length %d.", 

428 amp, len(ampTrap['coeffs'])) 

429 else: 

430 raise ValueError('Unknown trap type: %s', ampTrap['trap_type']) 

431 

432 inDict['serialTraps'][amp] = ampTrap 

433 

434 return cls.fromDict(inDict) 

435 

436 def toTable(self): 

437 """Construct a list of tables containing the information in this 

438 calibration. 

439 

440 The list of tables should create an identical calibration 

441 after being passed to this class's fromTable method. 

442 

443 Returns 

444 ------- 

445 tableList : `list` [`lsst.afw.table.Table`] 

446 List of tables containing the crosstalk calibration 

447 information. Two tables are generated for this list, the 

448 first containing the per-amplifier CTI parameters, and the 

449 second containing the parameters for serial traps. 

450 """ 

451 tableList = [] 

452 self.updateMetadata() 

453 

454 ampList = [] 

455 driftScale = [] 

456 decayTime = [] 

457 globalCti = [] 

458 

459 for amp in self.driftScale.keys(): 

460 ampList.append(amp) 

461 driftScale.append(self.driftScale[amp]) 

462 decayTime.append(self.decayTime[amp]) 

463 globalCti.append(self.globalCti[amp]) 

464 

465 ampTable = Table({'AMPLIFIER': ampList, 

466 'DRIFT_SCALE': driftScale, 

467 'DECAY_TIME': decayTime, 

468 'GLOBAL_CTI': globalCti, 

469 }) 

470 

471 ampTable.meta = self.getMetadata().toDict() 

472 tableList.append(ampTable) 

473 

474 ampList = [] 

475 sizeList = [] 

476 timeList = [] 

477 pixelList = [] 

478 typeList = [] 

479 coeffList = [] 

480 

481 # Get maximum coeff length 

482 maxCoeffLength = 0 

483 for trap in self.serialTraps.values(): 

484 maxCoeffLength = np.maximum(maxCoeffLength, len(trap.coeffs)) 

485 

486 # Pack and pad the end of the coefficients with NaN values. 

487 for amp, trap in self.serialTraps.items(): 

488 ampList.append(amp) 

489 sizeList.append(trap.size) 

490 timeList.append(trap.emission_time) 

491 pixelList.append(trap.pixel) 

492 typeList.append(trap.trap_type) 

493 

494 coeffs = trap.coeffs 

495 if len(coeffs) != maxCoeffLength: 

496 coeffs = np.pad(coeffs, (0, maxCoeffLength - len(coeffs)), 

497 constant_values=np.nan).tolist() 

498 coeffList.append(coeffs) 

499 

500 trapTable = Table({'AMPLIFIER': ampList, 

501 'SIZE': sizeList, 

502 'EMISSION_TIME': timeList, 

503 'PIXEL': pixelList, 

504 'TYPE': typeList, 

505 'COEFFS': coeffList}) 

506 

507 tableList.append(trapTable) 

508 

509 return tableList 

510 

511 

512class DeferredChargeConfig(Config): 

513 """Settings for deferred charge correction. 

514 """ 

515 nPixelOffsetCorrection = Field( 

516 dtype=int, 

517 doc="Number of prior pixels to use for local offset correction.", 

518 default=15, 

519 ) 

520 nPixelTrapCorrection = Field( 

521 dtype=int, 

522 doc="Number of prior pixels to use for trap correction.", 

523 default=6, 

524 ) 

525 useGains = Field( 

526 dtype=bool, 

527 doc="If true, scale by the gain.", 

528 default=False, 

529 ) 

530 zeroUnusedPixels = Field( 

531 dtype=bool, 

532 doc="If true, set serial prescan and parallel overscan to zero before correction.", 

533 default=False, 

534 ) 

535 

536 

537class DeferredChargeTask(Task): 

538 """Task to correct an exposure for charge transfer inefficiency. 

539 

540 This uses the methods described by Snyder et al. 2021, Journal of 

541 Astronimcal Telescopes, Instruments, and Systems, 7, 

542 048002. doi:10.1117/1.JATIS.7.4.048002 (Snyder+21). 

543 """ 

544 ConfigClass = DeferredChargeConfig 

545 _DefaultName = 'isrDeferredCharge' 

546 

547 def run(self, exposure, ctiCalib, gains=None): 

548 """Correct deferred charge/CTI issues. 

549 

550 Parameters 

551 ---------- 

552 exposure : `lsst.afw.image.Exposure` 

553 Exposure to correct the deferred charge on. 

554 ctiCalib : `lsst.ip.isr.DeferredChargeCalib` 

555 Calibration object containing the charge transfer 

556 inefficiency model. 

557 gains : `dict` [`str`, `float`] 

558 A dictionary, keyed by amplifier name, of the gains to 

559 use. If gains is None, the nominal gains in the amplifier 

560 object are used. 

561 

562 Returns 

563 ------- 

564 exposure : `lsst.afw.image.Exposure` 

565 The corrected exposure. 

566 """ 

567 image = exposure.getMaskedImage().image 

568 detector = exposure.getDetector() 

569 

570 # If gains were supplied, they should be used. If useGains is 

571 # true, but no external gains were supplied, use the nominal 

572 # gains listed in the detector. Finally, if useGains is 

573 # false, fake a dictionary of unit gains for ``gainContext``. 

574 useGains = True 

575 if "USEGAINS" in ctiCalib.getMetadata().keys(): 

576 useGains = ctiCalib.getMetadata()["USEGAINS"] 

577 self.log.info(f"useGains = {useGains} from calibration metadata.") 

578 else: 

579 useGains = self.config.useGains 

580 self.log.info(f"USEGAINS not found in calibration metadata. Using {useGains} from config.") 

581 

582 if useGains: 

583 if gains is None: 

584 gains = {amp.getName(): amp.getGain() for amp in detector.getAmplifiers()} 

585 

586 with gainContext(exposure, image, useGains, gains): 

587 for amp in detector.getAmplifiers(): 

588 ampName = amp.getName() 

589 

590 ampImage = image[amp.getRawBBox()] 

591 if self.config.zeroUnusedPixels: 

592 # We don't apply overscan subtraction, so zero these 

593 # out for now. 

594 ampImage[amp.getRawParallelOverscanBBox()].array[:, :] = 0.0 

595 ampImage[amp.getRawSerialPrescanBBox()].array[:, :] = 0.0 

596 

597 # The algorithm expects that the readout corner is in 

598 # the lower left corner. Flip it to be so: 

599 

600 ampData = self.flipData(ampImage.array, amp) 

601 

602 if ctiCalib.driftScale[ampName] > 0.0: 

603 correctedAmpData = self.local_offset_inverse(ampData, 

604 ctiCalib.driftScale[ampName], 

605 ctiCalib.decayTime[ampName], 

606 self.config.nPixelOffsetCorrection) 

607 else: 

608 correctedAmpData = ampData.copy() 

609 

610 correctedAmpData = self.local_trap_inverse(correctedAmpData, 

611 ctiCalib.serialTraps[ampName], 

612 ctiCalib.globalCti[ampName], 

613 self.config.nPixelTrapCorrection) 

614 

615 # Undo flips here. The method is symmetric. 

616 correctedAmpData = self.flipData(correctedAmpData, amp) 

617 image[amp.getRawBBox()].array[:, :] = correctedAmpData[:, :] 

618 

619 return exposure 

620 

621 @staticmethod 

622 def flipData(ampData, amp): 

623 """Flip data array such that readout corner is at lower-left. 

624 

625 Parameters 

626 ---------- 

627 ampData : `numpy.ndarray`, (nx, ny) 

628 Image data to flip. 

629 amp : `lsst.afw.cameraGeom.Amplifier` 

630 Amplifier to get readout corner information. 

631 

632 Returns 

633 ------- 

634 ampData : `numpy.ndarray`, (nx, ny) 

635 Flipped image data. 

636 """ 

637 X_FLIP = {ReadoutCorner.LL: False, 

638 ReadoutCorner.LR: True, 

639 ReadoutCorner.UL: False, 

640 ReadoutCorner.UR: True} 

641 Y_FLIP = {ReadoutCorner.LL: False, 

642 ReadoutCorner.LR: False, 

643 ReadoutCorner.UL: True, 

644 ReadoutCorner.UR: True} 

645 

646 if X_FLIP[amp.getReadoutCorner()]: 

647 ampData = np.fliplr(ampData) 

648 if Y_FLIP[amp.getReadoutCorner()]: 

649 ampData = np.flipud(ampData) 

650 

651 return ampData 

652 

653 @staticmethod 

654 def local_offset_inverse(inputArr, drift_scale, decay_time, num_previous_pixels=15): 

655 r"""Remove CTI effects from local offsets. 

656 

657 This implements equation 10 of Snyder+21. For an image with 

658 CTI, s'(m, n), the correction factor is equal to the maximum 

659 value of the set of: 

660 

661 .. code-block:: 

662 

663 {A_L s'(m, n - j) exp(-j t / \tau_L)}_j=0^jmax 

664 

665 Parameters 

666 ---------- 

667 inputArr : `numpy.ndarray`, (nx, ny) 

668 Input image data to correct. 

669 drift_scale : `float` 

670 Drift scale (Snyder+21 A_L value) to use in correction. 

671 decay_time : `float` 

672 Decay time (Snyder+21 \tau_L) of the correction. 

673 num_previous_pixels : `int`, optional 

674 Number of previous pixels to use for correction. As the 

675 CTI has an exponential decay, this essentially truncates 

676 the correction where that decay scales the input charge to 

677 near zero. 

678 

679 Returns 

680 ------- 

681 outputArr : `numpy.ndarray`, (nx, ny) 

682 Corrected image data. 

683 """ 

684 r = np.exp(-1/decay_time) 

685 Ny, Nx = inputArr.shape 

686 

687 # j = 0 term: 

688 offset = np.zeros((num_previous_pixels, Ny, Nx)) 

689 offset[0, :, :] = drift_scale*np.maximum(0, inputArr) 

690 

691 # j = 1..jmax terms: 

692 for n in range(1, num_previous_pixels): 

693 offset[n, :, n:] = drift_scale*np.maximum(0, inputArr[:, :-n])*(r**n) 

694 

695 Linv = np.amax(offset, axis=0) 

696 outputArr = inputArr - Linv 

697 

698 return outputArr 

699 

700 @staticmethod 

701 def local_trap_inverse(inputArr, trap, global_cti=0.0, num_previous_pixels=6): 

702 r"""Apply localized trapping inverse operator to pixel signals. 

703 

704 This implements equation 13 of Snyder+21. For an image with 

705 CTI, s'(m, n), the correction factor is equal to the maximum 

706 value of the set of: 

707 

708 .. code-block:: 

709 

710 {A_L s'(m, n - j) exp(-j t / \tau_L)}_j=0^jmax 

711 

712 Parameters 

713 ---------- 

714 inputArr : `numpy.ndarray`, (nx, ny) 

715 Input image data to correct. 

716 trap : `lsst.ip.isr.SerialTrap` 

717 Serial trap describing the capture and release of charge. 

718 global_cti: `float` 

719 Mean charge transfer inefficiency, b from Snyder+21. 

720 num_previous_pixels : `int`, optional 

721 Number of previous pixels to use for correction. 

722 

723 Returns 

724 ------- 

725 outputArr : `numpy.ndarray`, (nx, ny) 

726 Corrected image data. 

727 

728 """ 

729 Ny, Nx = inputArr.shape 

730 a = 1 - global_cti 

731 r = np.exp(-1/trap.emission_time) 

732 

733 # Estimate trap occupancies during readout 

734 trap_occupancy = np.zeros((num_previous_pixels, Ny, Nx)) 

735 for n in range(num_previous_pixels): 

736 trap_occupancy[n, :, n+1:] = trap.capture(np.maximum(0, inputArr))[:, :-(n+1)]*(r**n) 

737 trap_occupancy = np.amax(trap_occupancy, axis=0) 

738 

739 # Estimate captured charge 

740 C = trap.capture(np.maximum(0, inputArr)) - trap_occupancy*r 

741 C[C < 0] = 0. 

742 

743 # Estimate released charge 

744 R = np.zeros(inputArr.shape) 

745 R[:, 1:] = trap_occupancy[:, 1:]*(1-r) 

746 T = R - C 

747 

748 outputArr = inputArr - a*T 

749 

750 return outputArr