Coverage for python/lsst/ip/isr/deferredCharge.py: 16%

270 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-30 12:54 +0000

1# This file is part of ip_isr. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ('DeferredChargeConfig', 'DeferredChargeTask', 'SerialTrap', 'DeferredChargeCalib') 

23 

24import numpy as np 

25from astropy.table import Table 

26 

27from lsst.afw.cameraGeom import ReadoutCorner 

28from lsst.pex.config import Config, Field 

29from lsst.pipe.base import Task 

30from .isrFunctions import gainContext 

31from .calibType import IsrCalib 

32 

33import scipy.interpolate as interp 

34 

35 

36class SerialTrap(): 

37 """Represents a serial register trap. 

38 

39 Parameters 

40 ---------- 

41 size : `float` 

42 Size of the charge trap, in electrons. 

43 emission_time : `float` 

44 Trap emission time constant, in inverse transfers. 

45 pixel : `int` 

46 Serial pixel location of the trap, including the prescan. 

47 trap_type : `str` 

48 Type of trap capture to use. Should be one of ``linear``, 

49 ``logistic``, or ``spline``. 

50 coeffs : `list` [`float`] 

51 Coefficients for the capture process. Linear traps need one 

52 coefficient, logistic traps need two, and spline based traps 

53 need to have an even number of coefficients that can be split 

54 into their spline locations and values. 

55 

56 Raises 

57 ------ 

58 ValueError 

59 Raised if the specified parameters are out of expected range. 

60 """ 

61 

62 def __init__(self, size, emission_time, pixel, trap_type, coeffs): 

63 if size < 0.0: 

64 raise ValueError('Trap size must be greater than or equal to 0.') 

65 self.size = size 

66 

67 if emission_time <= 0.0: 

68 raise ValueError('Emission time must be greater than 0.') 

69 if np.isnan(emission_time): 

70 raise ValueError('Emission time must be real-valued, not NaN') 

71 self.emission_time = emission_time 

72 

73 if int(pixel) != pixel: 

74 raise ValueError('Fraction value for pixel not allowed.') 

75 self.pixel = int(pixel) 

76 

77 self.trap_type = trap_type 

78 self.coeffs = coeffs 

79 

80 if self.trap_type not in ('linear', 'logistic', 'spline'): 

81 raise ValueError('Unknown trap type: %s', self.trap_type) 

82 

83 if self.trap_type == 'spline': 

84 # Note that ``spline`` is actually a piecewise linear interpolation 

85 # in the model and the application, and not a true spline. 

86 centers, values = np.split(np.array(self.coeffs, dtype=np.float64), 2) 

87 # Ensure all NaN values are stripped out 

88 values = values[~np.isnan(centers)] 

89 centers = centers[~np.isnan(centers)] 

90 centers = centers[~np.isnan(values)] 

91 values = values[~np.isnan(values)] 

92 self.interp = interp.interp1d( 

93 centers, 

94 values, 

95 bounds_error=False, 

96 fill_value=(values[0], values[-1]), 

97 ) 

98 

99 self._trap_array = None 

100 self._trapped_charge = None 

101 

102 def __eq__(self, other): 

103 # A trap is equal to another trap if all of the initialization 

104 # parameters are equal. All other properties are only filled 

105 # during use, and are not persisted into the calibration. 

106 if self.size != other.size: 

107 return False 

108 if self.emission_time != other.emission_time: 

109 return False 

110 if self.pixel != other.pixel: 

111 return False 

112 if self.trap_type != other.trap_type: 

113 return False 

114 if self.coeffs != other.coeffs: 

115 return False 

116 return True 

117 

118 @property 

119 def trap_array(self): 

120 return self._trap_array 

121 

122 @property 

123 def trapped_charge(self): 

124 return self._trapped_charge 

125 

126 def initialize(self, ny, nx, prescan_width): 

127 """Initialize trapping arrays for simulated readout. 

128 

129 Parameters 

130 ---------- 

131 ny : `int` 

132 Number of rows to simulate. 

133 nx : `int` 

134 Number of columns to simulate. 

135 prescan_width : `int` 

136 Additional transfers due to prescan. 

137 

138 Raises 

139 ------ 

140 ValueError 

141 Raised if the trap falls outside of the image. 

142 """ 

143 if self.pixel > nx+prescan_width: 

144 raise ValueError('Trap location {0} must be less than {1}'.format(self.pixel, 

145 nx+prescan_width)) 

146 

147 self._trap_array = np.zeros((ny, nx+prescan_width)) 

148 self._trap_array[:, self.pixel] = self.size 

149 self._trapped_charge = np.zeros((ny, nx+prescan_width)) 

150 

151 def release_charge(self): 

152 """Release charge through exponential decay. 

153 

154 Returns 

155 ------- 

156 released_charge : `float` 

157 Charge released. 

158 """ 

159 released_charge = self._trapped_charge*(1-np.exp(-1./self.emission_time)) 

160 self._trapped_charge -= released_charge 

161 

162 return released_charge 

163 

164 def trap_charge(self, free_charge): 

165 """Perform charge capture using a logistic function. 

166 

167 Parameters 

168 ---------- 

169 free_charge : `float` 

170 Charge available to be trapped. 

171 

172 Returns 

173 ------- 

174 captured_charge : `float` 

175 Amount of charge actually trapped. 

176 """ 

177 captured_charge = (np.clip(self.capture(free_charge), self.trapped_charge, self._trap_array) 

178 - self.trapped_charge) 

179 self._trapped_charge += captured_charge 

180 

181 return captured_charge 

182 

183 def capture(self, pixel_signals): 

184 """Trap capture function. 

185 

186 Parameters 

187 ---------- 

188 pixel_signals : `list` [`float`] 

189 Input pixel values. 

190 

191 Returns 

192 ------- 

193 captured_charge : `list` [`float`] 

194 Amount of charge captured from each pixel. 

195 

196 Raises 

197 ------ 

198 RuntimeError 

199 Raised if the trap type is invalid. 

200 """ 

201 if self.trap_type == 'linear': 

202 scaling = self.coeffs[0] 

203 return np.minimum(self.size, pixel_signals*scaling) 

204 elif self.trap_type == 'logistic': 

205 f0, k = (self.coeffs[0], self.coeffs[1]) 

206 return self.size/(1.+np.exp(-k*(pixel_signals-f0))) 

207 elif self.trap_type == 'spline': 

208 return self.interp(pixel_signals) 

209 else: 

210 raise RuntimeError(f"Invalid trap capture type: {self.trap_type}.") 

211 

212 

213class DeferredChargeCalib(IsrCalib): 

214 r"""Calibration containing deferred charge/CTI parameters. 

215 

216 Parameters 

217 ---------- 

218 **kwargs : 

219 Additional parameters to pass to parent constructor. 

220 

221 Notes 

222 ----- 

223 The charge transfer inefficiency attributes stored are: 

224 

225 driftScale : `dict` [`str`, `float`] 

226 A dictionary, keyed by amplifier name, of the local electronic 

227 offset drift scale parameter, A_L in Snyder+2021. 

228 decayTime : `dict` [`str`, `float`] 

229 A dictionary, keyed by amplifier name, of the local electronic 

230 offset decay time, \tau_L in Snyder+2021. 

231 globalCti : `dict` [`str`, `float`] 

232 A dictionary, keyed by amplifier name, of the mean global CTI 

233 paramter, b in Snyder+2021. 

234 serialTraps : `dict` [`str`, `lsst.ip.isr.SerialTrap`] 

235 A dictionary, keyed by amplifier name, containing a single 

236 serial trap for each amplifier. 

237 """ 

238 _OBSTYPE = 'CTI' 

239 _SCHEMA = 'Deferred Charge' 

240 _VERSION = 1.0 

241 

242 def __init__(self, **kwargs): 

243 self.driftScale = {} 

244 self.decayTime = {} 

245 self.globalCti = {} 

246 self.serialTraps = {} 

247 

248 super().__init__(**kwargs) 

249 self.requiredAttributes.update(['driftScale', 'decayTime', 'globalCti', 'serialTraps']) 

250 

251 def fromDetector(self, detector): 

252 """Read metadata parameters from a detector. 

253 

254 Parameters 

255 ---------- 

256 detector : `lsst.afw.cameraGeom.detector` 

257 Input detector with parameters to use. 

258 

259 Returns 

260 ------- 

261 calib : `lsst.ip.isr.Linearizer` 

262 The calibration constructed from the detector. 

263 """ 

264 

265 pass 

266 

267 @classmethod 

268 def fromDict(cls, dictionary): 

269 """Construct a calibration from a dictionary of properties. 

270 

271 Parameters 

272 ---------- 

273 dictionary : `dict` 

274 Dictionary of properties. 

275 

276 Returns 

277 ------- 

278 calib : `lsst.ip.isr.CalibType` 

279 Constructed calibration. 

280 

281 Raises 

282 ------ 

283 RuntimeError 

284 Raised if the supplied dictionary is for a different 

285 calibration. 

286 """ 

287 calib = cls() 

288 

289 if calib._OBSTYPE != dictionary['metadata']['OBSTYPE']: 

290 raise RuntimeError(f"Incorrect CTI supplied. Expected {calib._OBSTYPE}, " 

291 f"found {dictionary['metadata']['OBSTYPE']}") 

292 

293 calib.setMetadata(dictionary['metadata']) 

294 

295 calib.driftScale = dictionary['driftScale'] 

296 calib.decayTime = dictionary['decayTime'] 

297 calib.globalCti = dictionary['globalCti'] 

298 

299 for ampName in dictionary['serialTraps']: 

300 ampTraps = dictionary['serialTraps'][ampName] 

301 calib.serialTraps[ampName] = SerialTrap(ampTraps['size'], ampTraps['emissionTime'], 

302 ampTraps['pixel'], ampTraps['trap_type'], 

303 ampTraps['coeffs']) 

304 calib.updateMetadata() 

305 return calib 

306 

307 def toDict(self): 

308 """Return a dictionary containing the calibration properties. 

309 The dictionary should be able to be round-tripped through 

310 ``fromDict``. 

311 

312 Returns 

313 ------- 

314 dictionary : `dict` 

315 Dictionary of properties. 

316 """ 

317 self.updateMetadata() 

318 outDict = {} 

319 outDict['metadata'] = self.getMetadata() 

320 

321 outDict['driftScale'] = self.driftScale 

322 outDict['decayTime'] = self.decayTime 

323 outDict['globalCti'] = self.globalCti 

324 

325 outDict['serialTraps'] = {} 

326 for ampName in self.serialTraps: 

327 ampTrap = {'size': self.serialTraps[ampName].size, 

328 'emissionTime': self.serialTraps[ampName].emission_time, 

329 'pixel': self.serialTraps[ampName].pixel, 

330 'trap_type': self.serialTraps[ampName].trap_type, 

331 'coeffs': self.serialTraps[ampName].coeffs} 

332 outDict['serialTraps'][ampName] = ampTrap 

333 

334 return outDict 

335 

336 @classmethod 

337 def fromTable(cls, tableList): 

338 """Construct calibration from a list of tables. 

339 

340 This method uses the ``fromDict`` method to create the 

341 calibration, after constructing an appropriate dictionary from 

342 the input tables. 

343 

344 Parameters 

345 ---------- 

346 tableList : `list` [`lsst.afw.table.Table`] 

347 List of tables to use to construct the crosstalk 

348 calibration. Two tables are expected in this list, the 

349 first containing the per-amplifier CTI parameters, and the 

350 second containing the parameters for serial traps. 

351 

352 Returns 

353 ------- 

354 calib : `lsst.ip.isr.DeferredChargeCalib` 

355 The calibration defined in the tables. 

356 

357 Raises 

358 ------ 

359 ValueError 

360 Raised if the trap type or trap coefficients are not 

361 defined properly. 

362 """ 

363 ampTable = tableList[0] 

364 

365 inDict = {} 

366 inDict['metadata'] = ampTable.meta 

367 

368 amps = ampTable['AMPLIFIER'] 

369 driftScale = ampTable['DRIFT_SCALE'] 

370 decayTime = ampTable['DECAY_TIME'] 

371 globalCti = ampTable['GLOBAL_CTI'] 

372 

373 inDict['driftScale'] = {amp: value for amp, value in zip(amps, driftScale)} 

374 inDict['decayTime'] = {amp: value for amp, value in zip(amps, decayTime)} 

375 inDict['globalCti'] = {amp: value for amp, value in zip(amps, globalCti)} 

376 

377 inDict['serialTraps'] = {} 

378 trapTable = tableList[1] 

379 

380 amps = trapTable['AMPLIFIER'] 

381 sizes = trapTable['SIZE'] 

382 emissionTimes = trapTable['EMISSION_TIME'] 

383 pixels = trapTable['PIXEL'] 

384 trap_type = trapTable['TYPE'] 

385 coeffs = trapTable['COEFFS'] 

386 

387 for index, amp in enumerate(amps): 

388 ampTrap = {} 

389 ampTrap['size'] = sizes[index] 

390 ampTrap['emissionTime'] = emissionTimes[index] 

391 ampTrap['pixel'] = pixels[index] 

392 ampTrap['trap_type'] = trap_type[index] 

393 

394 # Unpad any trailing NaN values: find the continuous array 

395 # of NaNs at the end of the coefficients, and remove them. 

396 inCoeffs = coeffs[index] 

397 breakIndex = 1 

398 nanValues = np.where(np.isnan(inCoeffs))[0] 

399 if nanValues is not None: 

400 coeffLength = len(inCoeffs) 

401 while breakIndex < coeffLength: 

402 if coeffLength - breakIndex in nanValues: 

403 breakIndex += 1 

404 else: 

405 break 

406 breakIndex -= 1 # Remove the fixed offset. 

407 if breakIndex != 0: 

408 outCoeffs = inCoeffs[0: coeffLength - breakIndex] 

409 else: 

410 outCoeffs = inCoeffs 

411 ampTrap['coeffs'] = outCoeffs.tolist() 

412 

413 if ampTrap['trap_type'] == 'linear': 

414 if len(ampTrap['coeffs']) < 1: 

415 raise ValueError("CTI Amplifier %s coefficients for trap has illegal length %d.", 

416 amp, len(ampTrap['coeffs'])) 

417 elif ampTrap['trap_type'] == 'logistic': 

418 if len(ampTrap['coeffs']) < 2: 

419 raise ValueError("CTI Amplifier %s coefficients for trap has illegal length %d.", 

420 amp, len(ampTrap['coeffs'])) 

421 elif ampTrap['trap_type'] == 'spline': 

422 if len(ampTrap['coeffs']) % 2 != 0: 

423 raise ValueError("CTI Amplifier %s coefficients for trap has illegal length %d.", 

424 amp, len(ampTrap['coeffs'])) 

425 else: 

426 raise ValueError('Unknown trap type: %s', ampTrap['trap_type']) 

427 

428 inDict['serialTraps'][amp] = ampTrap 

429 

430 return cls.fromDict(inDict) 

431 

432 def toTable(self): 

433 """Construct a list of tables containing the information in this 

434 calibration. 

435 

436 The list of tables should create an identical calibration 

437 after being passed to this class's fromTable method. 

438 

439 Returns 

440 ------- 

441 tableList : `list` [`lsst.afw.table.Table`] 

442 List of tables containing the crosstalk calibration 

443 information. Two tables are generated for this list, the 

444 first containing the per-amplifier CTI parameters, and the 

445 second containing the parameters for serial traps. 

446 """ 

447 tableList = [] 

448 self.updateMetadata() 

449 

450 ampList = [] 

451 driftScale = [] 

452 decayTime = [] 

453 globalCti = [] 

454 

455 for amp in self.driftScale.keys(): 

456 ampList.append(amp) 

457 driftScale.append(self.driftScale[amp]) 

458 decayTime.append(self.decayTime[amp]) 

459 globalCti.append(self.globalCti[amp]) 

460 

461 ampTable = Table({'AMPLIFIER': ampList, 

462 'DRIFT_SCALE': driftScale, 

463 'DECAY_TIME': decayTime, 

464 'GLOBAL_CTI': globalCti, 

465 }) 

466 

467 ampTable.meta = self.getMetadata().toDict() 

468 tableList.append(ampTable) 

469 

470 ampList = [] 

471 sizeList = [] 

472 timeList = [] 

473 pixelList = [] 

474 typeList = [] 

475 coeffList = [] 

476 

477 # Get maximum coeff length 

478 maxCoeffLength = 0 

479 for trap in self.serialTraps.values(): 

480 maxCoeffLength = np.maximum(maxCoeffLength, len(trap.coeffs)) 

481 

482 # Pack and pad the end of the coefficients with NaN values. 

483 for amp, trap in self.serialTraps.items(): 

484 ampList.append(amp) 

485 sizeList.append(trap.size) 

486 timeList.append(trap.emission_time) 

487 pixelList.append(trap.pixel) 

488 typeList.append(trap.trap_type) 

489 

490 coeffs = trap.coeffs 

491 if len(coeffs) != maxCoeffLength: 

492 coeffs = np.pad(coeffs, (0, maxCoeffLength - len(coeffs)), 

493 constant_values=np.nan).tolist() 

494 coeffList.append(coeffs) 

495 

496 trapTable = Table({'AMPLIFIER': ampList, 

497 'SIZE': sizeList, 

498 'EMISSION_TIME': timeList, 

499 'PIXEL': pixelList, 

500 'TYPE': typeList, 

501 'COEFFS': coeffList}) 

502 

503 tableList.append(trapTable) 

504 

505 return tableList 

506 

507 

508class DeferredChargeConfig(Config): 

509 """Settings for deferred charge correction. 

510 """ 

511 nPixelOffsetCorrection = Field( 

512 dtype=int, 

513 doc="Number of prior pixels to use for local offset correction.", 

514 default=15, 

515 ) 

516 nPixelTrapCorrection = Field( 

517 dtype=int, 

518 doc="Number of prior pixels to use for trap correction.", 

519 default=6, 

520 ) 

521 useGains = Field( 

522 dtype=bool, 

523 doc="If true, scale by the gain.", 

524 default=False, 

525 ) 

526 zeroUnusedPixels = Field( 

527 dtype=bool, 

528 doc="If true, set serial prescan and parallel overscan to zero before correction.", 

529 default=False, 

530 ) 

531 

532 

533class DeferredChargeTask(Task): 

534 """Task to correct an exposure for charge transfer inefficiency. 

535 

536 This uses the methods described by Snyder et al. 2021, Journal of 

537 Astronimcal Telescopes, Instruments, and Systems, 7, 

538 048002. doi:10.1117/1.JATIS.7.4.048002 (Snyder+21). 

539 """ 

540 ConfigClass = DeferredChargeConfig 

541 _DefaultName = 'isrDeferredCharge' 

542 

543 def run(self, exposure, ctiCalib, gains=None): 

544 """Correct deferred charge/CTI issues. 

545 

546 Parameters 

547 ---------- 

548 exposure : `lsst.afw.image.Exposure` 

549 Exposure to correct the deferred charge on. 

550 ctiCalib : `lsst.ip.isr.DeferredChargeCalib` 

551 Calibration object containing the charge transfer 

552 inefficiency model. 

553 gains : `dict` [`str`, `float`] 

554 A dictionary, keyed by amplifier name, of the gains to 

555 use. If gains is None, the nominal gains in the amplifier 

556 object are used. 

557 

558 Returns 

559 ------- 

560 exposure : `lsst.afw.image.Exposure` 

561 The corrected exposure. 

562 """ 

563 image = exposure.getMaskedImage().image 

564 detector = exposure.getDetector() 

565 

566 # If gains were supplied, they should be used. If useGains is 

567 # true, but no external gains were supplied, use the nominal 

568 # gains listed in the detector. Finally, if useGains is 

569 # false, fake a dictionary of unit gains for ``gainContext``. 

570 if self.config.useGains: 

571 if gains is None: 

572 gains = {amp.getName(): amp.getGain() for amp in detector.getAmplifiers()} 

573 

574 with gainContext(exposure, image, self.config.useGains, gains): 

575 for amp in detector.getAmplifiers(): 

576 ampName = amp.getName() 

577 

578 ampImage = image[amp.getRawBBox()] 

579 if self.config.zeroUnusedPixels: 

580 # We don't apply overscan subtraction, so zero these 

581 # out for now. 

582 ampImage[amp.getRawParallelOverscanBBox()].array[:, :] = 0.0 

583 ampImage[amp.getRawSerialPrescanBBox()].array[:, :] = 0.0 

584 

585 # The algorithm expects that the readout corner is in 

586 # the lower left corner. Flip it to be so: 

587 

588 ampData = self.flipData(ampImage.array, amp) 

589 

590 if ctiCalib.driftScale[ampName] > 0.0: 

591 correctedAmpData = self.local_offset_inverse(ampData, 

592 ctiCalib.driftScale[ampName], 

593 ctiCalib.decayTime[ampName], 

594 self.config.nPixelOffsetCorrection) 

595 else: 

596 correctedAmpData = ampData.copy() 

597 

598 correctedAmpData = self.local_trap_inverse(correctedAmpData, 

599 ctiCalib.serialTraps[ampName], 

600 ctiCalib.globalCti[ampName], 

601 self.config.nPixelTrapCorrection) 

602 

603 # Undo flips here. The method is symmetric. 

604 correctedAmpData = self.flipData(correctedAmpData, amp) 

605 image[amp.getRawBBox()].array[:, :] = correctedAmpData[:, :] 

606 

607 return exposure 

608 

609 @staticmethod 

610 def flipData(ampData, amp): 

611 """Flip data array such that readout corner is at lower-left. 

612 

613 Parameters 

614 ---------- 

615 ampData : `numpy.ndarray`, (nx, ny) 

616 Image data to flip. 

617 amp : `lsst.afw.cameraGeom.Amplifier` 

618 Amplifier to get readout corner information. 

619 

620 Returns 

621 ------- 

622 ampData : `numpy.ndarray`, (nx, ny) 

623 Flipped image data. 

624 """ 

625 X_FLIP = {ReadoutCorner.LL: False, 

626 ReadoutCorner.LR: True, 

627 ReadoutCorner.UL: False, 

628 ReadoutCorner.UR: True} 

629 Y_FLIP = {ReadoutCorner.LL: False, 

630 ReadoutCorner.LR: False, 

631 ReadoutCorner.UL: True, 

632 ReadoutCorner.UR: True} 

633 

634 if X_FLIP[amp.getReadoutCorner()]: 

635 ampData = np.fliplr(ampData) 

636 if Y_FLIP[amp.getReadoutCorner()]: 

637 ampData = np.flipud(ampData) 

638 

639 return ampData 

640 

641 @staticmethod 

642 def local_offset_inverse(inputArr, drift_scale, decay_time, num_previous_pixels=15): 

643 r"""Remove CTI effects from local offsets. 

644 

645 This implements equation 10 of Snyder+21. For an image with 

646 CTI, s'(m, n), the correction factor is equal to the maximum 

647 value of the set of: 

648 

649 .. code-block:: 

650 

651 {A_L s'(m, n - j) exp(-j t / \tau_L)}_j=0^jmax 

652 

653 Parameters 

654 ---------- 

655 inputArr : `numpy.ndarray`, (nx, ny) 

656 Input image data to correct. 

657 drift_scale : `float` 

658 Drift scale (Snyder+21 A_L value) to use in correction. 

659 decay_time : `float` 

660 Decay time (Snyder+21 \tau_L) of the correction. 

661 num_previous_pixels : `int`, optional 

662 Number of previous pixels to use for correction. As the 

663 CTI has an exponential decay, this essentially truncates 

664 the correction where that decay scales the input charge to 

665 near zero. 

666 

667 Returns 

668 ------- 

669 outputArr : `numpy.ndarray`, (nx, ny) 

670 Corrected image data. 

671 """ 

672 r = np.exp(-1/decay_time) 

673 Ny, Nx = inputArr.shape 

674 

675 # j = 0 term: 

676 offset = np.zeros((num_previous_pixels, Ny, Nx)) 

677 offset[0, :, :] = drift_scale*np.maximum(0, inputArr) 

678 

679 # j = 1..jmax terms: 

680 for n in range(1, num_previous_pixels): 

681 offset[n, :, n:] = drift_scale*np.maximum(0, inputArr[:, :-n])*(r**n) 

682 

683 Linv = np.amax(offset, axis=0) 

684 outputArr = inputArr - Linv 

685 

686 return outputArr 

687 

688 @staticmethod 

689 def local_trap_inverse(inputArr, trap, global_cti=0.0, num_previous_pixels=6): 

690 r"""Apply localized trapping inverse operator to pixel signals. 

691 

692 This implements equation 13 of Snyder+21. For an image with 

693 CTI, s'(m, n), the correction factor is equal to the maximum 

694 value of the set of: 

695 

696 .. code-block:: 

697 

698 {A_L s'(m, n - j) exp(-j t / \tau_L)}_j=0^jmax 

699 

700 Parameters 

701 ---------- 

702 inputArr : `numpy.ndarray`, (nx, ny) 

703 Input image data to correct. 

704 trap : `lsst.ip.isr.SerialTrap` 

705 Serial trap describing the capture and release of charge. 

706 global_cti: `float` 

707 Mean charge transfer inefficiency, b from Snyder+21. 

708 num_previous_pixels : `int`, optional 

709 Number of previous pixels to use for correction. 

710 

711 Returns 

712 ------- 

713 outputArr : `numpy.ndarray`, (nx, ny) 

714 Corrected image data. 

715 

716 """ 

717 Ny, Nx = inputArr.shape 

718 a = 1 - global_cti 

719 r = np.exp(-1/trap.emission_time) 

720 

721 # Estimate trap occupancies during readout 

722 trap_occupancy = np.zeros((num_previous_pixels, Ny, Nx)) 

723 for n in range(num_previous_pixels): 

724 trap_occupancy[n, :, n+1:] = trap.capture(np.maximum(0, inputArr))[:, :-(n+1)]*(r**n) 

725 trap_occupancy = np.amax(trap_occupancy, axis=0) 

726 

727 # Estimate captured charge 

728 C = trap.capture(np.maximum(0, inputArr)) - trap_occupancy*r 

729 C[C < 0] = 0. 

730 

731 # Estimate released charge 

732 R = np.zeros(inputArr.shape) 

733 R[:, 1:] = trap_occupancy[:, 1:]*(1-r) 

734 T = R - C 

735 

736 outputArr = inputArr - a*T 

737 

738 return outputArr