Coverage for python / lsst / ip / isr / deferredCharge.py: 11%

477 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-25 08:29 +0000

1# This file is part of ip_isr. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ('DeferredChargeConfig', 

23 'DeferredChargeTask', 

24 'SerialTrap', 

25 'OverscanModel', 

26 'SimpleModel', 

27 'SimulatedModel', 

28 'SegmentSimulator', 

29 'FloatingOutputAmplifier', 

30 'DeferredChargeCalib', 

31 ) 

32 

33import copy 

34import numpy as np 

35import warnings 

36from astropy.table import Table 

37 

38from lsst.afw.cameraGeom import ReadoutCorner 

39from lsst.pex.config import Config, Field 

40from lsst.pipe.base import Task 

41from .isrFunctions import gainContext 

42from .calibType import IsrCalib 

43 

44import scipy.interpolate as interp 

45 

46 

47class SerialTrap(): 

48 """Represents a serial register trap. 

49 

50 Parameters 

51 ---------- 

52 size : `float` 

53 Size of the charge trap, in electrons. 

54 emission_time : `float` 

55 Trap emission time constant, in inverse transfers. 

56 pixel : `int` 

57 Serial pixel location of the trap, including the prescan. 

58 trap_type : `str` 

59 Type of trap capture to use. Should be one of ``linear``, 

60 ``logistic``, or ``spline``. 

61 coeffs : `list` [`float`] 

62 Coefficients for the capture process. Linear traps need one 

63 coefficient, logistic traps need two, and spline based traps 

64 need to have an even number of coefficients that can be split 

65 into their spline locations and values. 

66 

67 Raises 

68 ------ 

69 ValueError 

70 Raised if the specified parameters are out of expected range. 

71 """ 

72 

73 def __init__(self, size, emission_time, pixel, trap_type, coeffs): 

74 if size < 0.0: 

75 raise ValueError('Trap size must be greater than or equal to 0.') 

76 self.size = size 

77 

78 if emission_time <= 0.0: 

79 raise ValueError('Emission time must be greater than 0.') 

80 if np.isnan(emission_time): 

81 raise ValueError('Emission time must be real-valued, not NaN') 

82 self.emission_time = emission_time 

83 

84 if int(pixel) != pixel: 

85 raise ValueError('Fraction value for pixel not allowed.') 

86 self.pixel = int(pixel) 

87 

88 self.trap_type = trap_type 

89 self.coeffs = coeffs 

90 

91 if self.trap_type not in ('linear', 'logistic', 'spline'): 

92 raise ValueError('Unknown trap type: %s', self.trap_type) 

93 

94 if self.trap_type == 'spline': 

95 # Note that ``spline`` is actually a piecewise linear interpolation 

96 # in the model and the application, and not a true spline. 

97 centers, values = np.split(np.array(self.coeffs, dtype=np.float64), 2) 

98 # Ensure all NaN values are stripped out 

99 values = values[~np.isnan(centers)] 

100 centers = centers[~np.isnan(centers)] 

101 centers = centers[~np.isnan(values)] 

102 values = values[~np.isnan(values)] 

103 self.interp = interp.interp1d( 

104 centers, 

105 values, 

106 bounds_error=False, 

107 fill_value=(values[0], values[-1]), 

108 ) 

109 

110 self._trap_array = None 

111 self._trapped_charge = None 

112 

113 def __eq__(self, other): 

114 # A trap is equal to another trap if all of the initialization 

115 # parameters are equal. All other properties are only filled 

116 # during use, and are not persisted into the calibration. 

117 if self.size != other.size: 

118 return False 

119 if self.emission_time != other.emission_time: 

120 return False 

121 if self.pixel != other.pixel: 

122 return False 

123 if self.trap_type != other.trap_type: 

124 return False 

125 if self.coeffs != other.coeffs: 

126 return False 

127 return True 

128 

129 @property 

130 def trap_array(self): 

131 return self._trap_array 

132 

133 @property 

134 def trapped_charge(self): 

135 return self._trapped_charge 

136 

137 def initialize(self, ny, nx, prescan_width): 

138 """Initialize trapping arrays for simulated readout. 

139 

140 Parameters 

141 ---------- 

142 ny : `int` 

143 Number of rows to simulate. 

144 nx : `int` 

145 Number of columns to simulate. 

146 prescan_width : `int` 

147 Additional transfers due to prescan. 

148 

149 Raises 

150 ------ 

151 ValueError 

152 Raised if the trap falls outside of the image. 

153 """ 

154 if self.pixel > nx+prescan_width: 

155 raise ValueError('Trap location {0} must be less than {1}'.format(self.pixel, 

156 nx+prescan_width)) 

157 

158 self._trap_array = np.zeros((ny, nx+prescan_width)) 

159 self._trap_array[:, self.pixel] = self.size 

160 self._trapped_charge = np.zeros((ny, nx+prescan_width)) 

161 

162 def release_charge(self): 

163 """Release charge through exponential decay. 

164 

165 Returns 

166 ------- 

167 released_charge : `float` 

168 Charge released. 

169 """ 

170 released_charge = self._trapped_charge*(1-np.exp(-1./self.emission_time)) 

171 self._trapped_charge -= released_charge 

172 

173 return released_charge 

174 

175 def trap_charge(self, free_charge): 

176 """Perform charge capture using a logistic function. 

177 

178 Parameters 

179 ---------- 

180 free_charge : `float` 

181 Charge available to be trapped. 

182 

183 Returns 

184 ------- 

185 captured_charge : `float` 

186 Amount of charge actually trapped. 

187 """ 

188 captured_charge = (np.clip(self.capture(free_charge), self.trapped_charge, self._trap_array) 

189 - self.trapped_charge) 

190 self._trapped_charge += captured_charge 

191 

192 return captured_charge 

193 

194 def capture(self, pixel_signals): 

195 """Trap capture function. 

196 

197 Parameters 

198 ---------- 

199 pixel_signals : `list` [`float`] 

200 Input pixel values. 

201 

202 Returns 

203 ------- 

204 captured_charge : `list` [`float`] 

205 Amount of charge captured from each pixel. 

206 

207 Raises 

208 ------ 

209 RuntimeError 

210 Raised if the trap type is invalid. 

211 """ 

212 if self.trap_type == 'linear': 

213 scaling = self.coeffs[0] 

214 return np.minimum(self.size, pixel_signals*scaling) 

215 elif self.trap_type == 'logistic': 

216 f0, k = (self.coeffs[0], self.coeffs[1]) 

217 return self.size/(1.+np.exp(-k*(pixel_signals-f0))) 

218 elif self.trap_type == 'spline': 

219 return self.interp(pixel_signals) 

220 else: 

221 raise RuntimeError(f"Invalid trap capture type: {self.trap_type}.") 

222 

223 

224class OverscanModel: 

225 """Base class for handling model/data fit comparisons. 

226 This handles all of the methods needed for the lmfit Minimizer to 

227 run. 

228 """ 

229 

230 @staticmethod 

231 def model_results(params, signal, num_transfers, start=1, stop=10): 

232 """Generate a realization of the overscan model, using the specified 

233 fit parameters and input signal. 

234 

235 Parameters 

236 ---------- 

237 params : `lmfit.Parameters` 

238 Object containing the model parameters. 

239 signal : `np.ndarray`, (nMeasurements) 

240 Array of image means. 

241 num_transfers : `int` 

242 Number of serial transfers that the charge undergoes. 

243 start : `int`, optional 

244 First overscan column to fit. This number includes the 

245 last imaging column, and needs to be adjusted by one when 

246 using the overscan bounding box. 

247 stop : `int`, optional 

248 Last overscan column to fit. This number includes the 

249 last imaging column, and needs to be adjusted by one when 

250 using the overscan bounding box. 

251 

252 Returns 

253 ------- 

254 results : `np.ndarray`, (nMeasurements, nCols) 

255 Model results. 

256 """ 

257 raise NotImplementedError("Subclasses must implement the model calculation.") 

258 

259 def loglikelihood(self, params, signal, data, error, *args, **kwargs): 

260 """Calculate log likelihood of the model. 

261 

262 Parameters 

263 ---------- 

264 params : `lmfit.Parameters` 

265 Object containing the model parameters. 

266 signal : `np.ndarray`, (nMeasurements) 

267 Array of image means. 

268 data : `np.ndarray`, (nMeasurements, nCols) 

269 Array of overscan column means from each measurement. 

270 error : `float` 

271 Fixed error value. 

272 *args : 

273 Additional position arguments. 

274 **kwargs : 

275 Additional keyword arguments. 

276 

277 Returns 

278 ------- 

279 logL : `float` 

280 The log-likelihood of the observed data given the model 

281 parameters. 

282 """ 

283 model_results = self.model_results(params, signal, *args, **kwargs) 

284 

285 inv_sigma2 = 1.0/(error**2.0) 

286 diff = model_results - data 

287 

288 return -0.5*(np.sum(inv_sigma2*(diff)**2.)) 

289 

290 def negative_loglikelihood(self, params, signal, data, error, *args, **kwargs): 

291 """Calculate negative log likelihood of the model. 

292 

293 Parameters 

294 ---------- 

295 params : `lmfit.Parameters` 

296 Object containing the model parameters. 

297 signal : `np.ndarray`, (nMeasurements) 

298 Array of image means. 

299 data : `np.ndarray`, (nMeasurements, nCols) 

300 Array of overscan column means from each measurement. 

301 error : `float` 

302 Fixed error value. 

303 *args : 

304 Additional position arguments. 

305 **kwargs : 

306 Additional keyword arguments. 

307 

308 Returns 

309 ------- 

310 negativelogL : `float` 

311 The negative log-likelihood of the observed data given the 

312 model parameters. 

313 """ 

314 ll = self.loglikelihood(params, signal, data, error, *args, **kwargs) 

315 

316 return -ll 

317 

318 def rms_error(self, params, signal, data, error, *args, **kwargs): 

319 """Calculate RMS error between model and data. 

320 

321 Parameters 

322 ---------- 

323 params : `lmfit.Parameters` 

324 Object containing the model parameters. 

325 signal : `np.ndarray`, (nMeasurements) 

326 Array of image means. 

327 data : `np.ndarray`, (nMeasurements, nCols) 

328 Array of overscan column means from each measurement. 

329 error : `float` 

330 Fixed error value. 

331 *args : 

332 Additional position arguments. 

333 **kwargs : 

334 Additional keyword arguments. 

335 

336 Returns 

337 ------- 

338 rms : `float` 

339 The rms error between the model and input data. 

340 """ 

341 model_results = self.model_results(params, signal, *args, **kwargs) 

342 

343 diff = model_results - data 

344 rms = np.sqrt(np.mean(np.square(diff))) 

345 

346 return rms 

347 

348 def difference(self, params, signal, data, error, *args, **kwargs): 

349 """Calculate the flattened difference array between model and data. 

350 

351 Parameters 

352 ---------- 

353 params : `lmfit.Parameters` 

354 Object containing the model parameters. 

355 signal : `np.ndarray`, (nMeasurements) 

356 Array of image means. 

357 data : `np.ndarray`, (nMeasurements, nCols) 

358 Array of overscan column means from each measurement. 

359 error : `float` 

360 Fixed error value. 

361 *args : 

362 Additional position arguments. 

363 **kwargs : 

364 Additional keyword arguments. 

365 

366 Returns 

367 ------- 

368 difference : `np.ndarray`, (nMeasurements*nCols) 

369 The rms error between the model and input data. 

370 """ 

371 model_results = self.model_results(params, signal, *args, **kwargs) 

372 diff = (model_results-data).flatten() 

373 

374 return diff 

375 

376 

377class SimpleModel(OverscanModel): 

378 """Simple analytic overscan model.""" 

379 

380 @staticmethod 

381 def model_results(params, signal, num_transfers, start=1, stop=10): 

382 """Generate a realization of the overscan model, using the specified 

383 fit parameters and input signal. 

384 

385 Parameters 

386 ---------- 

387 params : `lmfit.Parameters` 

388 Object containing the model parameters. 

389 signal : `np.ndarray`, (nMeasurements) 

390 Array of image means. 

391 num_transfers : `int` 

392 Number of serial transfers that the charge undergoes. 

393 start : `int`, optional 

394 First overscan column to fit. This number includes the 

395 last imaging column, and needs to be adjusted by one when 

396 using the overscan bounding box. 

397 stop : `int`, optional 

398 Last overscan column to fit. This number includes the 

399 last imaging column, and needs to be adjusted by one when 

400 using the overscan bounding box. 

401 

402 Returns 

403 ------- 

404 res : `np.ndarray`, (nMeasurements, nCols) 

405 Model results. 

406 """ 

407 v = params.valuesdict() 

408 v['cti'] = 10**v['ctiexp'] 

409 

410 # Adjust column numbering to match DM overscan bbox. 

411 start += 1 

412 stop += 1 

413 

414 x = np.arange(start, stop+1) 

415 res = np.zeros((signal.shape[0], x.shape[0])) 

416 

417 for i, s in enumerate(signal): 

418 # This is largely equivalent to equation 2. The minimum 

419 # indicates that a trap cannot emit more charge than is 

420 # available, nor can it emit more charge than it can hold. 

421 # This scales the exponential release of charge from the 

422 # trap. The next term defines the contribution from the 

423 # global CTI at each pixel transfer, and the final term 

424 # includes the contribution from local CTI effects. 

425 res[i, :] = (np.minimum(v['trapsize'], s*v['scaling']) 

426 * (np.exp(1/v['emissiontime']) - 1.0) 

427 * np.exp(-x/v['emissiontime']) 

428 + s*num_transfers*v['cti']**x 

429 + v['driftscale']*s*np.exp(-x/float(v['decaytime']))) 

430 

431 return res 

432 

433 

434class SimulatedModel(OverscanModel): 

435 """Simulated overscan model.""" 

436 

437 @staticmethod 

438 def model_results(params, signal, num_transfers, amp, start=1, stop=10, trap_type=None): 

439 """Generate a realization of the overscan model, using the specified 

440 fit parameters and input signal. 

441 

442 Parameters 

443 ---------- 

444 params : `lmfit.Parameters` 

445 Object containing the model parameters. 

446 signal : `np.ndarray`, (nMeasurements) 

447 Array of image means. 

448 num_transfers : `int` 

449 Number of serial transfers that the charge undergoes. 

450 amp : `lsst.afw.cameraGeom.Amplifier` 

451 Amplifier to use for geometry information. 

452 start : `int`, optional 

453 First overscan column to fit. This number includes the 

454 last imaging column, and needs to be adjusted by one when 

455 using the overscan bounding box. 

456 stop : `int`, optional 

457 Last overscan column to fit. This number includes the 

458 last imaging column, and needs to be adjusted by one when 

459 using the overscan bounding box. 

460 trap_type : `str`, optional 

461 Type of trap model to use. 

462 

463 Returns 

464 ------- 

465 results : `np.ndarray`, (nMeasurements, nCols) 

466 Model results. 

467 """ 

468 v = params.valuesdict() 

469 

470 # Adjust column numbering to match DM overscan bbox. 

471 start += 1 

472 stop += 1 

473 

474 # Electronics effect optimization 

475 output_amplifier = FloatingOutputAmplifier(1.0, v['driftscale'], v['decaytime']) 

476 

477 # CTI optimization 

478 v['cti'] = 10**v['ctiexp'] 

479 

480 # Trap type for optimization 

481 if trap_type is None: 

482 trap = None 

483 elif trap_type == 'linear': 

484 trap = SerialTrap(v['trapsize'], v['emissiontime'], 1, 'linear', 

485 [v['scaling']]) 

486 elif trap_type == 'logistic': 

487 trap = SerialTrap(v['trapsize'], v['emissiontime'], 1, 'logistic', 

488 [v['f0'], v['k']]) 

489 else: 

490 raise ValueError('Trap type must be linear or logistic or None') 

491 

492 # Simulate ramp readout 

493 imarr = np.zeros((signal.shape[0], amp.getRawDataBBox().getWidth())) 

494 ramp = SegmentSimulator(imarr, amp.getRawSerialPrescanBBox().getWidth(), output_amplifier, 

495 cti=v['cti'], traps=trap) 

496 ramp.ramp_exp(signal) 

497 model_results = ramp.readout(serial_overscan_width=amp.getRawSerialOverscanBBox().getWidth(), 

498 parallel_overscan_width=0) 

499 

500 ncols = amp.getRawSerialPrescanBBox().getWidth() + amp.getRawDataBBox().getWidth() 

501 

502 return model_results[:, ncols+start-1:ncols+stop] 

503 

504 

505class SegmentSimulator: 

506 """Controls the creation of simulated segment images. 

507 

508 Parameters 

509 ---------- 

510 imarr : `np.ndarray` (nx, ny) 

511 Image data array. 

512 prescan_width : `int` 

513 Number of serial prescan columns. 

514 output_amplifier : `lsst.cp.pipe.FloatingOutputAmplifier` 

515 An object holding some deferred charge parameters. 

516 cti : `float` 

517 Global CTI value. 

518 traps : `list` [`lsst.ip.isr.SerialTrap`] 

519 Serial traps to simulate. 

520 """ 

521 

522 def __init__(self, imarr, prescan_width, output_amplifier, cti=0.0, traps=None): 

523 # Image array geometry 

524 self.prescan_width = prescan_width 

525 self.ny, self.nx = imarr.shape 

526 

527 self.segarr = np.zeros((self.ny, self.nx+prescan_width)) 

528 self.segarr[:, prescan_width:] = imarr 

529 

530 # Serial readout information 

531 self.output_amplifier = output_amplifier 

532 if isinstance(cti, np.ndarray): 

533 raise ValueError("cti must be single value, not an array.") 

534 self.cti = cti 

535 

536 self.serial_traps = None 

537 self.do_trapping = False 

538 if traps is not None: 

539 if not isinstance(traps, list): 

540 traps = [traps] 

541 for trap in traps: 

542 self.add_trap(trap) 

543 

544 def add_trap(self, serial_trap): 

545 """Add a trap to the serial register. 

546 

547 Parameters 

548 ---------- 

549 serial_trap : `lsst.ip.isr.SerialTrap` 

550 The trap to add. 

551 """ 

552 try: 

553 self.serial_traps.append(serial_trap) 

554 except AttributeError: 

555 self.serial_traps = [serial_trap] 

556 self.do_trapping = True 

557 

558 def ramp_exp(self, signal_list): 

559 """Simulate an image with varying flux illumination per row. 

560 

561 This method simulates a segment image where the signal level 

562 increases along the horizontal direction, according to the 

563 provided list of signal levels. 

564 

565 Parameters 

566 ---------- 

567 signal_list : `list` [`float`] 

568 List of signal levels. 

569 

570 Raises 

571 ------ 

572 ValueError 

573 Raised if the length of the signal list does not equal the 

574 number of rows. 

575 """ 

576 if len(signal_list) != self.ny: 

577 raise ValueError("Signal list does not match row count.") 

578 

579 ramp = np.tile(signal_list, (self.nx, 1)).T 

580 self.segarr[:, self.prescan_width:] += ramp 

581 

582 def readout(self, serial_overscan_width=10, parallel_overscan_width=0): 

583 """Simulate serial readout of the segment image. 

584 

585 This method performs the serial readout of a segment image 

586 given the appropriate SerialRegister object and the properties 

587 of the ReadoutAmplifier. Additional arguments can be provided 

588 to account for the number of desired overscan transfers. The 

589 result is a simulated final segment image, in ADU. 

590 

591 Parameters 

592 ---------- 

593 serial_overscan_width : `int`, optional 

594 Number of serial overscan columns. 

595 parallel_overscan_width : `int`, optional 

596 Number of parallel overscan rows. 

597 

598 Returns 

599 ------- 

600 result : `np.ndarray` (nx, ny) 

601 Simulated image, including serial prescan, serial 

602 overscan, and parallel overscan regions. Result in electrons. 

603 """ 

604 # Create output array 

605 iy = int(self.ny + parallel_overscan_width) 

606 ix = int(self.nx + self.prescan_width + serial_overscan_width) 

607 

608 image = np.random.default_rng().normal( 

609 loc=self.output_amplifier.global_offset, 

610 scale=self.output_amplifier.noise, 

611 size=(iy, ix), 

612 ) 

613 

614 free_charge = copy.deepcopy(self.segarr) 

615 

616 # Set flow control parameters 

617 do_trapping = self.do_trapping 

618 cti = self.cti 

619 

620 offset = np.zeros(self.ny) 

621 cte = 1 - cti 

622 if do_trapping: 

623 for trap in self.serial_traps: 

624 trap.initialize(self.ny, self.nx, self.prescan_width) 

625 

626 for i in range(ix): 

627 # Trap capture 

628 if do_trapping: 

629 for trap in self.serial_traps: 

630 captured_charge = trap.trap_charge(free_charge) 

631 free_charge -= captured_charge 

632 

633 # Pixel-to-pixel proportional loss 

634 transferred_charge = free_charge*cte 

635 deferred_charge = free_charge*cti 

636 

637 # Pixel transfer and readout 

638 offset = self.output_amplifier.local_offset(offset, 

639 transferred_charge[:, 0]) 

640 image[:iy-parallel_overscan_width, i] += transferred_charge[:, 0] + offset 

641 

642 free_charge = np.pad(transferred_charge, ((0, 0), (0, 1)), 

643 mode='constant')[:, 1:] + deferred_charge 

644 

645 # Trap emission 

646 if do_trapping: 

647 for trap in self.serial_traps: 

648 released_charge = trap.release_charge() 

649 free_charge += released_charge 

650 

651 return image 

652 

653 

654class FloatingOutputAmplifier: 

655 """Object representing the readout amplifier of a single channel. 

656 

657 Parameters 

658 ---------- 

659 gain : `float` 

660 Gain of the amplifier. Currently not used. 

661 scale : `float` 

662 Drift scale for the amplifier. 

663 decay_time : `float` 

664 Decay time for the bias drift. 

665 noise : `float`, optional 

666 Amplifier read noise. 

667 offset : `float`, optional 

668 Global CTI offset. 

669 """ 

670 

671 def __init__(self, gain, scale, decay_time, noise=0.0, offset=0.0): 

672 

673 self.gain = gain 

674 self.noise = noise 

675 self.global_offset = offset 

676 

677 self.update_parameters(scale, decay_time) 

678 

679 def local_offset(self, old, signal): 

680 """Calculate local offset hysteresis. 

681 

682 Parameters 

683 ---------- 

684 old : `np.ndarray`, (,) 

685 Previous iteration. 

686 signal : `np.ndarray`, (,) 

687 Current column measurements. 

688 Returns 

689 ------- 

690 offset : `np.ndarray` 

691 Local offset. 

692 """ 

693 new = self.scale*signal 

694 

695 return np.maximum(new, old*np.exp(-1/self.decay_time)) 

696 

697 def update_parameters(self, scale, decay_time): 

698 """Update parameter values, if within acceptable values. 

699 

700 Parameters 

701 ---------- 

702 scale : `float` 

703 Drift scale for the amplifier. 

704 decay_time : `float` 

705 Decay time for the bias drift. 

706 

707 Raises 

708 ------ 

709 ValueError 

710 Raised if the input parameters are out of range. 

711 """ 

712 if scale < 0.0: 

713 raise ValueError("Scale must be greater than or equal to 0.") 

714 if np.isnan(scale): 

715 raise ValueError("Scale must be real-valued number, not NaN.") 

716 self.scale = scale 

717 if decay_time <= 0.0: 

718 raise ValueError("Decay time must be greater than 0.") 

719 if np.isnan(decay_time): 

720 raise ValueError("Decay time must be real-valued number, not NaN.") 

721 self.decay_time = decay_time 

722 

723 

724class DeferredChargeCalib(IsrCalib): 

725 r"""Calibration containing deferred charge/CTI parameters. 

726 

727 This includes, parameters from Snyder+2021 and exstimates of 

728 the serial and parallel CTI using the extended pixel edge 

729 response (EPER) method (also defined in Snyder+2021). 

730 

731 Parameters 

732 ---------- 

733 **kwargs : 

734 Additional parameters to pass to parent constructor. 

735 

736 Notes 

737 ----- 

738 The charge transfer inefficiency attributes stored are: 

739 

740 driftScale : `dict` [`str`, `float`] 

741 A dictionary, keyed by amplifier name, of the local electronic 

742 offset drift scale parameter, A_L in Snyder+2021. 

743 decayTime : `dict` [`str`, `float`] 

744 A dictionary, keyed by amplifier name, of the local electronic 

745 offset decay time, \tau_L in Snyder+2021. 

746 globalCti : `dict` [`str`, `float`] 

747 A dictionary, keyed by amplifier name, of the mean global CTI 

748 paramter, b in Snyder+2021. 

749 serialTraps : `dict` [`str`, `lsst.ip.isr.SerialTrap`] 

750 A dictionary, keyed by amplifier name, containing a single 

751 serial trap for each amplifier. 

752 signals : `dict` [`str`, `np.ndarray`] 

753 A dictionary, keyed by amplifier name, of the mean signal 

754 level for each input measurement. 

755 inputGain : `dict` [`str`, `float`] 

756 A dictionary, keyed by amplifier name of the input gain used 

757 to calculate the overscan statistics and produce this calib. 

758 serialEper : `dict` [`str`, `np.ndarray`, `float`] 

759 A dictionary, keyed by amplifier name, of the serial EPER 

760 estimator of serial CTI, given in a list for each input 

761 measurement. 

762 parallelEper : `dict` [`str`, `np.ndarray`, `float`] 

763 A dictionary, keyed by amplifier name, of the parallel 

764 EPER estimator of parallel CTI, given in a list for each 

765 input measurement. 

766 serialCtiTurnoff : `dict` [`str`, `float`] 

767 A dictionary, keyed by amplifier name, of the serial CTI 

768 turnoff (unit: electrons). 

769 parallelCtiTurnoff : `dict` [`str`, `float`] 

770 A dictionary, keyed by amplifier name, of the parallel CTI 

771 turnoff (unit: electrons). 

772 serialCtiTurnoffSamplingErr : `dict` [`str`, `float`] 

773 A dictionary, keyed by amplifier name, of the serial CTI 

774 turnoff sampling error (unit: electrons). 

775 parallelCtiTurnoffSamplingErr : `dict` [`str`, `float`] 

776 A dictionary, keyed by amplifier name, of the parallel CTI 

777 turnoff sampling error (unit: electrons). 

778 

779 Also, the values contained in this calibration are all derived 

780 from and image and overscan in units of electron as these are 

781 the most natural units in which to compute deferred charge. 

782 However, this means the the user should supply a reliable set 

783 of gains when computing the CTI statistics during ISR. 

784 

785 Version 1.1 deprecates the USEGAINS attribute and standardizes 

786 everything to electron units. 

787 Version 1.2 adds the ``signal``, ``serialEper``, ``parallelEper``, 

788 ``serialCtiTurnoff``, ``parallelCtiTurnoff``, 

789 ``serialCtiTurnoffSamplingErr``, ``parallelCtiTurnoffSamplingErr`` 

790 attributes. 

791 Version 1.3 adds the `inputGain` attribute. 

792 """ 

793 _OBSTYPE = 'CTI' 

794 _SCHEMA = 'Deferred Charge' 

795 _VERSION = 1.3 

796 

797 def __init__(self, **kwargs): 

798 self.driftScale = {} 

799 self.decayTime = {} 

800 self.globalCti = {} 

801 self.serialTraps = {} 

802 self.signals = {} 

803 self.inputGain = {} 

804 self.serialEper = {} 

805 self.parallelEper = {} 

806 self.serialCtiTurnoff = {} 

807 self.parallelCtiTurnoff = {} 

808 self.serialCtiTurnoffSamplingErr = {} 

809 self.parallelCtiTurnoffSamplingErr = {} 

810 

811 # Check for deprecated kwargs 

812 if kwargs.pop("useGains", None) is not None: 

813 warnings.warn("useGains is deprecated, and will be removed " 

814 "after v28.", FutureWarning) 

815 

816 super().__init__(**kwargs) 

817 

818 # Units are always in electron. 

819 self.updateMetadata(UNITS='electron') 

820 

821 self.requiredAttributes.update(['driftScale', 'decayTime', 'globalCti', 'serialTraps', 

822 'inputGain', 'signals', 'serialEper', 'parallelEper', 

823 'serialCtiTurnoff', 'parallelCtiTurnoff', 

824 'serialCtiTurnoffSamplingErr', 

825 'parallelCtiTurnoffSamplingErr']) 

826 

827 def fromDetector(self, detector): 

828 """Read metadata parameters from a detector. 

829 

830 Parameters 

831 ---------- 

832 detector : `lsst.afw.cameraGeom.detector` 

833 Input detector with parameters to use. 

834 

835 Returns 

836 ------- 

837 calib : `lsst.ip.isr.Linearizer` 

838 The calibration constructed from the detector. 

839 """ 

840 

841 pass 

842 

843 @classmethod 

844 def fromDict(cls, dictionary): 

845 """Construct a calibration from a dictionary of properties. 

846 

847 Parameters 

848 ---------- 

849 dictionary : `dict` 

850 Dictionary of properties. 

851 

852 Returns 

853 ------- 

854 calib : `lsst.ip.isr.CalibType` 

855 Constructed calibration. 

856 

857 Raises 

858 ------ 

859 RuntimeError 

860 Raised if the supplied dictionary is for a different 

861 calibration. 

862 """ 

863 calib = cls() 

864 

865 if calib._OBSTYPE != dictionary['metadata']['OBSTYPE']: 

866 raise RuntimeError(f"Incorrect CTI supplied. Expected {calib._OBSTYPE}, " 

867 f"found {dictionary['metadata']['OBSTYPE']}") 

868 

869 calib.setMetadata(dictionary['metadata']) 

870 

871 calib.inputGain = dictionary['inputGain'] 

872 calib.driftScale = dictionary['driftScale'] 

873 calib.decayTime = dictionary['decayTime'] 

874 calib.globalCti = dictionary['globalCti'] 

875 calib.serialCtiTurnoff = dictionary['serialCtiTurnoff'] 

876 calib.parallelCtiTurnoff = dictionary['parallelCtiTurnoff'] 

877 calib.serialCtiTurnoffSamplingErr = dictionary['serialCtiTurnoffSamplingErr'] 

878 calib.parallelCtiTurnoffSamplingErr = dictionary['parallelCtiTurnoffSamplingErr'] 

879 

880 allAmpNames = dictionary['driftScale'].keys() 

881 

882 # Some amps might not have a serial trap solution, so 

883 # dictionary['serialTraps'].keys() might not be equal 

884 # to dictionary['driftScale'].keys() 

885 for ampName in dictionary['serialTraps']: 

886 ampTraps = dictionary['serialTraps'][ampName] 

887 calib.serialTraps[ampName] = SerialTrap(ampTraps['size'], ampTraps['emissionTime'], 

888 ampTraps['pixel'], ampTraps['trap_type'], 

889 ampTraps['coeffs']) 

890 

891 for ampName in allAmpNames: 

892 calib.signals[ampName] = np.array(dictionary['signals'][ampName], dtype=np.float64) 

893 calib.serialEper[ampName] = np.array(dictionary['serialEper'][ampName], dtype=np.float64) 

894 calib.parallelEper[ampName] = np.array(dictionary['parallelEper'][ampName], dtype=np.float64) 

895 

896 calib.updateMetadata() 

897 return calib 

898 

899 def toDict(self): 

900 """Return a dictionary containing the calibration properties. 

901 The dictionary should be able to be round-tripped through 

902 ``fromDict``. 

903 

904 Returns 

905 ------- 

906 dictionary : `dict` 

907 Dictionary of properties. 

908 """ 

909 self.updateMetadata() 

910 outDict = {} 

911 outDict['metadata'] = self.getMetadata() 

912 

913 outDict['driftScale'] = self.driftScale 

914 outDict['decayTime'] = self.decayTime 

915 outDict['globalCti'] = self.globalCti 

916 outDict['signals'] = self.signals 

917 outDict['inputGain'] = self.inputGain 

918 outDict['serialEper'] = self.serialEper 

919 outDict['parallelEper'] = self.parallelEper 

920 outDict['serialCtiTurnoff'] = self.serialCtiTurnoff 

921 outDict['parallelCtiTurnoff'] = self.parallelCtiTurnoff 

922 outDict['serialCtiTurnoffSamplingErr'] = self.serialCtiTurnoffSamplingErr 

923 outDict['parallelCtiTurnoffSamplingErr'] = self.parallelCtiTurnoffSamplingErr 

924 

925 outDict['serialTraps'] = {} 

926 for ampName in self.serialTraps: 

927 ampTrap = {'size': self.serialTraps[ampName].size, 

928 'emissionTime': self.serialTraps[ampName].emission_time, 

929 'pixel': self.serialTraps[ampName].pixel, 

930 'trap_type': self.serialTraps[ampName].trap_type, 

931 'coeffs': self.serialTraps[ampName].coeffs} 

932 outDict['serialTraps'][ampName] = ampTrap 

933 

934 return outDict 

935 

936 @classmethod 

937 def fromTable(cls, tableList): 

938 """Construct calibration from a list of tables. 

939 

940 This method uses the ``fromDict`` method to create the 

941 calibration, after constructing an appropriate dictionary from 

942 the input tables. 

943 

944 Parameters 

945 ---------- 

946 tableList : `list` [`lsst.afw.table.Table`] 

947 List of tables to use to construct the CTI 

948 calibration. Two tables are expected in this list, the 

949 first containing the per-amplifier CTI parameters, and the 

950 second containing the parameters for serial traps. 

951 

952 Returns 

953 ------- 

954 calib : `lsst.ip.isr.DeferredChargeCalib` 

955 The calibration defined in the tables. 

956 

957 Raises 

958 ------ 

959 ValueError 

960 Raised if the trap type or trap coefficients are not 

961 defined properly. 

962 """ 

963 ampTable = tableList[0] 

964 

965 inDict = {} 

966 inDict['metadata'] = ampTable.meta 

967 calibVersion = inDict['metadata']['CTI_VERSION'] 

968 

969 amps = ampTable['AMPLIFIER'] 

970 driftScale = ampTable['DRIFT_SCALE'] 

971 decayTime = ampTable['DECAY_TIME'] 

972 globalCti = ampTable['GLOBAL_CTI'] 

973 

974 inDict['driftScale'] = {amp: value for amp, value in zip(amps, driftScale)} 

975 inDict['decayTime'] = {amp: value for amp, value in zip(amps, decayTime)} 

976 inDict['globalCti'] = {amp: value for amp, value in zip(amps, globalCti)} 

977 

978 # Version check 

979 if calibVersion < 1.1: 

980 # This version might be in the wrong units (not electron), 

981 # and does not contain the gain information to convert 

982 # into a new calibration version. 

983 raise RuntimeError(f"Using old version of CTI calibration (ver. {calibVersion} < 1.1), " 

984 "which is no longer supported.") 

985 elif calibVersion < 1.2: 

986 inDict['signals'] = {amp: np.array([np.nan]) for amp in amps} 

987 inDict['serialEper'] = {amp: np.array([np.nan]) for amp in amps} 

988 inDict['parallelEper'] = {amp: np.array([np.nan]) for amp in amps} 

989 inDict['serialCtiTurnoff'] = {amp: np.nan for amp in amps} 

990 inDict['parallelCtiTurnoff'] = {amp: np.nan for amp in amps} 

991 inDict['serialCtiTurnoffSamplingErr'] = {amp: np.nan for amp in amps} 

992 inDict['parallelCtiTurnoffSamplingErr'] = {amp: np.nan for amp in amps} 

993 else: 

994 signals = ampTable['SIGNALS'] 

995 serialEper = ampTable['SERIAL_EPER'] 

996 parallelEper = ampTable['PARALLEL_EPER'] 

997 serialCtiTurnoff = ampTable['SERIAL_CTI_TURNOFF'] 

998 parallelCtiTurnoff = ampTable['PARALLEL_CTI_TURNOFF'] 

999 serialCtiTurnoffSamplingErr = ampTable['SERIAL_CTI_TURNOFF_SAMPLING_ERR'] 

1000 parallelCtiTurnoffSamplingErr = ampTable['PARALLEL_CTI_TURNOFF_SAMPLING_ERR'] 

1001 inDict['signals'] = {amp: value for amp, value in zip(amps, signals)} 

1002 inDict['serialEper'] = {amp: value for amp, value in zip(amps, serialEper)} 

1003 inDict['parallelEper'] = {amp: value for amp, value in zip(amps, parallelEper)} 

1004 inDict['serialCtiTurnoff'] = {amp: value for amp, value in zip(amps, serialCtiTurnoff)} 

1005 inDict['parallelCtiTurnoff'] = {amp: value for amp, value in zip(amps, parallelCtiTurnoff)} 

1006 inDict['serialCtiTurnoffSamplingErr'] = { 

1007 amp: value for amp, value in zip(amps, serialCtiTurnoffSamplingErr) 

1008 } 

1009 inDict['parallelCtiTurnoffSamplingErr'] = { 

1010 amp: value for amp, value in zip(amps, parallelCtiTurnoffSamplingErr) 

1011 } 

1012 if calibVersion < 1.3: 

1013 inDict['inputGain'] = {amp: np.nan for amp in amps} 

1014 else: 

1015 inputGain = ampTable['INPUT_GAIN'] 

1016 inDict['inputGain'] = {amp: value for amp, value in zip(amps, inputGain)} 

1017 

1018 inDict['serialTraps'] = {} 

1019 trapTable = tableList[1] 

1020 

1021 amps = trapTable['AMPLIFIER'] 

1022 sizes = trapTable['SIZE'] 

1023 emissionTimes = trapTable['EMISSION_TIME'] 

1024 pixels = trapTable['PIXEL'] 

1025 trap_type = trapTable['TYPE'] 

1026 coeffs = trapTable['COEFFS'] 

1027 

1028 for index, amp in enumerate(amps): 

1029 ampTrap = {} 

1030 ampTrap['size'] = sizes[index] 

1031 ampTrap['emissionTime'] = emissionTimes[index] 

1032 ampTrap['pixel'] = pixels[index] 

1033 ampTrap['trap_type'] = trap_type[index] 

1034 

1035 # Unpad any trailing NaN values: find the continuous array 

1036 # of NaNs at the end of the coefficients, and remove them. 

1037 inCoeffs = coeffs[index] 

1038 breakIndex = 1 

1039 nanValues = np.where(np.isnan(inCoeffs))[0] 

1040 if nanValues is not None: 

1041 coeffLength = len(inCoeffs) 

1042 while breakIndex < coeffLength: 

1043 if coeffLength - breakIndex in nanValues: 

1044 breakIndex += 1 

1045 else: 

1046 break 

1047 breakIndex -= 1 # Remove the fixed offset. 

1048 if breakIndex != 0: 

1049 outCoeffs = inCoeffs[0: coeffLength - breakIndex] 

1050 else: 

1051 outCoeffs = inCoeffs 

1052 ampTrap['coeffs'] = outCoeffs.tolist() 

1053 

1054 if ampTrap['trap_type'] == 'linear': 

1055 if len(ampTrap['coeffs']) < 1: 

1056 raise ValueError("CTI Amplifier %s coefficients for trap has illegal length %d.", 

1057 amp, len(ampTrap['coeffs'])) 

1058 elif ampTrap['trap_type'] == 'logistic': 

1059 if len(ampTrap['coeffs']) < 2: 

1060 raise ValueError("CTI Amplifier %s coefficients for trap has illegal length %d.", 

1061 amp, len(ampTrap['coeffs'])) 

1062 elif ampTrap['trap_type'] == 'spline': 

1063 if len(ampTrap['coeffs']) % 2 != 0: 

1064 raise ValueError("CTI Amplifier %s coefficients for trap has illegal length %d.", 

1065 amp, len(ampTrap['coeffs'])) 

1066 else: 

1067 raise ValueError('Unknown trap type: %s', ampTrap['trap_type']) 

1068 

1069 inDict['serialTraps'][amp] = ampTrap 

1070 

1071 return cls.fromDict(inDict) 

1072 

1073 def toTable(self): 

1074 """Construct a list of tables containing the information in this 

1075 calibration. 

1076 

1077 The list of tables should create an identical calibration 

1078 after being passed to this class's fromTable method. 

1079 

1080 Returns 

1081 ------- 

1082 tableList : `list` [`lsst.afw.table.Table`] 

1083 List of tables containing the crosstalk calibration 

1084 information. Two tables are generated for this list, the 

1085 first containing the per-amplifier CTI parameters, and the 

1086 second containing the parameters for serial traps. 

1087 """ 

1088 tableList = [] 

1089 self.updateMetadata() 

1090 

1091 ampList = [] 

1092 driftScale = [] 

1093 decayTime = [] 

1094 globalCti = [] 

1095 signals = [] 

1096 inputGain = [] 

1097 serialEper = [] 

1098 parallelEper = [] 

1099 serialCtiTurnoff = [] 

1100 parallelCtiTurnoff = [] 

1101 serialCtiTurnoffSamplingErr = [] 

1102 parallelCtiTurnoffSamplingErr = [] 

1103 

1104 for amp in self.driftScale.keys(): 

1105 ampList.append(amp) 

1106 driftScale.append(self.driftScale[amp]) 

1107 decayTime.append(self.decayTime[amp]) 

1108 globalCti.append(self.globalCti[amp]) 

1109 signals.append(self.signals[amp]) 

1110 inputGain.append(self.inputGain[amp]) 

1111 serialEper.append(self.serialEper[amp]) 

1112 parallelEper.append(self.parallelEper[amp]) 

1113 serialCtiTurnoff.append(self.serialCtiTurnoff[amp]) 

1114 parallelCtiTurnoff.append(self.parallelCtiTurnoff[amp]) 

1115 serialCtiTurnoffSamplingErr.append( 

1116 self.serialCtiTurnoffSamplingErr[amp] 

1117 ) 

1118 parallelCtiTurnoffSamplingErr.append( 

1119 self.parallelCtiTurnoffSamplingErr[amp] 

1120 ) 

1121 

1122 ampTable = Table({ 

1123 'AMPLIFIER': ampList, 

1124 'DRIFT_SCALE': driftScale, 

1125 'DECAY_TIME': decayTime, 

1126 'GLOBAL_CTI': globalCti, 

1127 'SIGNALS': signals, 

1128 'INPUT_GAIN': inputGain, 

1129 'SERIAL_EPER': serialEper, 

1130 'PARALLEL_EPER': parallelEper, 

1131 'SERIAL_CTI_TURNOFF': serialCtiTurnoff, 

1132 'PARALLEL_CTI_TURNOFF': parallelCtiTurnoff, 

1133 'SERIAL_CTI_TURNOFF_SAMPLING_ERR': serialCtiTurnoffSamplingErr, 

1134 'PARALLEL_CTI_TURNOFF_SAMPLING_ERR': parallelCtiTurnoffSamplingErr, 

1135 }) 

1136 

1137 ampTable.meta = self.getMetadata().toDict() 

1138 tableList.append(ampTable) 

1139 

1140 ampList = [] 

1141 sizeList = [] 

1142 timeList = [] 

1143 pixelList = [] 

1144 typeList = [] 

1145 coeffList = [] 

1146 

1147 # Get maximum coeff length 

1148 maxCoeffLength = 0 

1149 for trap in self.serialTraps.values(): 

1150 maxCoeffLength = np.maximum(maxCoeffLength, len(trap.coeffs)) 

1151 

1152 # Pack and pad the end of the coefficients with NaN values. 

1153 for amp, trap in self.serialTraps.items(): 

1154 ampList.append(amp) 

1155 sizeList.append(trap.size) 

1156 timeList.append(trap.emission_time) 

1157 pixelList.append(trap.pixel) 

1158 typeList.append(trap.trap_type) 

1159 

1160 coeffs = trap.coeffs 

1161 if len(coeffs) != maxCoeffLength: 

1162 coeffs = np.pad(coeffs, (0, maxCoeffLength - len(coeffs)), 

1163 constant_values=np.nan).tolist() 

1164 coeffList.append(coeffs) 

1165 

1166 trapTable = Table({'AMPLIFIER': ampList, 

1167 'SIZE': sizeList, 

1168 'EMISSION_TIME': timeList, 

1169 'PIXEL': pixelList, 

1170 'TYPE': typeList, 

1171 'COEFFS': coeffList}) 

1172 

1173 tableList.append(trapTable) 

1174 

1175 return tableList 

1176 

1177 

1178class DeferredChargeConfig(Config): 

1179 """Settings for deferred charge correction. 

1180 """ 

1181 nPixelOffsetCorrection = Field( 

1182 dtype=int, 

1183 doc="Number of prior pixels to use for local offset correction.", 

1184 default=15, 

1185 ) 

1186 nPixelTrapCorrection = Field( 

1187 dtype=int, 

1188 doc="Number of prior pixels to use for trap correction.", 

1189 default=6, 

1190 ) 

1191 zeroUnusedPixels = Field( 

1192 dtype=bool, 

1193 doc="If true, set serial prescan and parallel overscan to zero before correction.", 

1194 default=True, 

1195 ) 

1196 

1197 

1198class DeferredChargeTask(Task): 

1199 """Task to correct an exposure for charge transfer inefficiency. 

1200 

1201 This uses the methods described by Snyder et al. 2021, Journal of 

1202 Astronimcal Telescopes, Instruments, and Systems, 7, 

1203 048002. doi:10.1117/1.JATIS.7.4.048002 (Snyder+21). 

1204 """ 

1205 ConfigClass = DeferredChargeConfig 

1206 _DefaultName = 'isrDeferredCharge' 

1207 

1208 def run(self, exposure, ctiCalib, gains=None): 

1209 """Correct deferred charge/CTI issues. 

1210 

1211 Parameters 

1212 ---------- 

1213 exposure : `lsst.afw.image.Exposure` 

1214 Exposure to correct the deferred charge on. 

1215 ctiCalib : `lsst.ip.isr.DeferredChargeCalib` 

1216 Calibration object containing the charge transfer 

1217 inefficiency model. 

1218 gains : `dict` [`str`, `float`] 

1219 A dictionary, keyed by amplifier name, of the gains to 

1220 use. If gains is None, the nominal gains in the amplifier 

1221 object are used. 

1222 

1223 Returns 

1224 ------- 

1225 exposure : `lsst.afw.image.Exposure` 

1226 The corrected exposure. 

1227 

1228 Notes 

1229 ------- 

1230 This task will read the exposure metadata and determine if 

1231 applying gains if necessary. The correction takes place in 

1232 units of electrons. If bootstrapping, the gains used 

1233 will just be 1.0. and the input/output units will stay in 

1234 adu. If the input image is in adu, the output image will be 

1235 in units of electrons. If the input image is in electron, 

1236 the output image will be in electron. 

1237 """ 

1238 image = exposure.getMaskedImage().image 

1239 detector = exposure.getDetector() 

1240 

1241 # Get the image and overscan units. 

1242 imageUnits = exposure.getMetadata().get("LSST ISR UNITS") 

1243 

1244 # The deferred charge correction assumes that everything is in 

1245 # electron units. Make it so: 

1246 applyGains = False 

1247 if imageUnits == "adu": 

1248 applyGains = True 

1249 

1250 # If we need to convert the image to electrons, check that gains 

1251 # were supplied. CTI should not be solved or corrected without 

1252 # supplied gains. 

1253 if applyGains and gains is None: 

1254 raise RuntimeError("No gains supplied for deferred charge correction.") 

1255 

1256 with gainContext(exposure, image, apply=applyGains, gains=gains, isTrimmed=False): 

1257 # Both the image and the overscan are in electron units. 

1258 for amp in detector.getAmplifiers(): 

1259 ampName = amp.getName() 

1260 

1261 ampImage = image[amp.getRawBBox()] 

1262 if self.config.zeroUnusedPixels: 

1263 # We don't apply overscan subtraction, so zero these 

1264 # out for now. 

1265 ampImage[amp.getRawParallelOverscanBBox()].array[:, :] = 0.0 

1266 ampImage[amp.getRawSerialPrescanBBox()].array[:, :] = 0.0 

1267 

1268 # The algorithm expects that the readout corner is in 

1269 # the lower left corner. Flip it to be so: 

1270 ampData = self.flipData(ampImage.array, amp) 

1271 

1272 if ctiCalib.driftScale[ampName] > 0.0: 

1273 correctedAmpData = self.local_offset_inverse(ampData, 

1274 ctiCalib.driftScale[ampName], 

1275 ctiCalib.decayTime[ampName], 

1276 self.config.nPixelOffsetCorrection) 

1277 else: 

1278 correctedAmpData = ampData.copy() 

1279 

1280 correctedAmpData = self.local_trap_inverse(correctedAmpData, 

1281 ctiCalib.serialTraps[ampName], 

1282 ctiCalib.globalCti[ampName], 

1283 self.config.nPixelTrapCorrection) 

1284 

1285 # Undo flips here. The method is symmetric. 

1286 correctedAmpData = self.flipData(correctedAmpData, amp) 

1287 image[amp.getRawBBox()].array[:, :] = correctedAmpData[:, :] 

1288 

1289 return exposure 

1290 

1291 @staticmethod 

1292 def flipData(ampData, amp): 

1293 """Flip data array such that readout corner is at lower-left. 

1294 

1295 Parameters 

1296 ---------- 

1297 ampData : `numpy.ndarray`, (nx, ny) 

1298 Image data to flip. 

1299 amp : `lsst.afw.cameraGeom.Amplifier` 

1300 Amplifier to get readout corner information. 

1301 

1302 Returns 

1303 ------- 

1304 ampData : `numpy.ndarray`, (nx, ny) 

1305 Flipped image data. 

1306 """ 

1307 X_FLIP = {ReadoutCorner.LL: False, 

1308 ReadoutCorner.LR: True, 

1309 ReadoutCorner.UL: False, 

1310 ReadoutCorner.UR: True} 

1311 Y_FLIP = {ReadoutCorner.LL: False, 

1312 ReadoutCorner.LR: False, 

1313 ReadoutCorner.UL: True, 

1314 ReadoutCorner.UR: True} 

1315 

1316 if X_FLIP[amp.getReadoutCorner()]: 

1317 ampData = np.fliplr(ampData) 

1318 if Y_FLIP[amp.getReadoutCorner()]: 

1319 ampData = np.flipud(ampData) 

1320 

1321 return ampData 

1322 

1323 @staticmethod 

1324 def local_offset_inverse(inputArr, drift_scale, decay_time, num_previous_pixels=15): 

1325 r"""Remove CTI effects from local offsets. 

1326 

1327 This implements equation 10 of Snyder+21. For an image with 

1328 CTI, s'(m, n), the correction factor is equal to the maximum 

1329 value of the set of: 

1330 

1331 .. code-block:: 

1332 

1333 {A_L s'(m, n - j) exp(-j t / \tau_L)}_j=0^jmax 

1334 

1335 Parameters 

1336 ---------- 

1337 inputArr : `numpy.ndarray`, (nx, ny) 

1338 Input image data to correct. 

1339 drift_scale : `float` 

1340 Drift scale (Snyder+21 A_L value) to use in correction. 

1341 decay_time : `float` 

1342 Decay time (Snyder+21 \tau_L) of the correction. 

1343 num_previous_pixels : `int`, optional 

1344 Number of previous pixels to use for correction. As the 

1345 CTI has an exponential decay, this essentially truncates 

1346 the correction where that decay scales the input charge to 

1347 near zero. 

1348 

1349 Returns 

1350 ------- 

1351 outputArr : `numpy.ndarray`, (nx, ny) 

1352 Corrected image data. 

1353 """ 

1354 r = np.exp(-1/decay_time) 

1355 Ny, Nx = inputArr.shape 

1356 

1357 # j = 0 term: 

1358 offset = np.zeros((num_previous_pixels, Ny, Nx)) 

1359 offset[0, :, :] = drift_scale*np.maximum(0, inputArr) 

1360 

1361 # j = 1..jmax terms: 

1362 for n in range(1, num_previous_pixels): 

1363 offset[n, :, n:] = drift_scale*np.maximum(0, inputArr[:, :-n])*(r**n) 

1364 

1365 Linv = np.amax(offset, axis=0) 

1366 outputArr = inputArr - Linv 

1367 

1368 return outputArr 

1369 

1370 @staticmethod 

1371 def local_trap_inverse(inputArr, trap, global_cti=0.0, num_previous_pixels=6): 

1372 r"""Apply localized trapping inverse operator to pixel signals. 

1373 

1374 This implements equation 13 of Snyder+21. For an image with 

1375 CTI, s'(m, n), the correction factor is equal to the maximum 

1376 value of the set of: 

1377 

1378 .. code-block:: 

1379 

1380 {A_L s'(m, n - j) exp(-j t / \tau_L)}_j=0^jmax 

1381 

1382 Parameters 

1383 ---------- 

1384 inputArr : `numpy.ndarray`, (nx, ny) 

1385 Input image data to correct. 

1386 trap : `lsst.ip.isr.SerialTrap` 

1387 Serial trap describing the capture and release of charge. 

1388 global_cti: `float` 

1389 Mean charge transfer inefficiency, b from Snyder+21. 

1390 num_previous_pixels : `int`, optional 

1391 Number of previous pixels to use for correction. 

1392 

1393 Returns 

1394 ------- 

1395 outputArr : `numpy.ndarray`, (nx, ny) 

1396 Corrected image data. 

1397 

1398 """ 

1399 Ny, Nx = inputArr.shape 

1400 a = 1 - global_cti 

1401 r = np.exp(-1/trap.emission_time) 

1402 

1403 # Estimate trap occupancies during readout 

1404 trap_occupancy = np.zeros((num_previous_pixels, Ny, Nx)) 

1405 for n in range(num_previous_pixels): 

1406 trap_occupancy[n, :, n+1:] = trap.capture(np.maximum(0, inputArr))[:, :-(n+1)]*(r**n) 

1407 trap_occupancy = np.amax(trap_occupancy, axis=0) 

1408 

1409 # Estimate captured charge 

1410 C = trap.capture(np.maximum(0, inputArr)) - trap_occupancy*r 

1411 C[C < 0] = 0. 

1412 

1413 # Estimate released charge 

1414 R = np.zeros(inputArr.shape) 

1415 R[:, 1:] = trap_occupancy[:, 1:]*(1-r) 

1416 T = R - C 

1417 

1418 outputArr = inputArr - a*T 

1419 

1420 return outputArr