Coverage for python/lsst/ip/isr/deferredCharge.py: 16%
270 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-30 12:54 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-30 12:54 +0000
1# This file is part of ip_isr.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ('DeferredChargeConfig', 'DeferredChargeTask', 'SerialTrap', 'DeferredChargeCalib')
24import numpy as np
25from astropy.table import Table
27from lsst.afw.cameraGeom import ReadoutCorner
28from lsst.pex.config import Config, Field
29from lsst.pipe.base import Task
30from .isrFunctions import gainContext
31from .calibType import IsrCalib
33import scipy.interpolate as interp
36class SerialTrap():
37 """Represents a serial register trap.
39 Parameters
40 ----------
41 size : `float`
42 Size of the charge trap, in electrons.
43 emission_time : `float`
44 Trap emission time constant, in inverse transfers.
45 pixel : `int`
46 Serial pixel location of the trap, including the prescan.
47 trap_type : `str`
48 Type of trap capture to use. Should be one of ``linear``,
49 ``logistic``, or ``spline``.
50 coeffs : `list` [`float`]
51 Coefficients for the capture process. Linear traps need one
52 coefficient, logistic traps need two, and spline based traps
53 need to have an even number of coefficients that can be split
54 into their spline locations and values.
56 Raises
57 ------
58 ValueError
59 Raised if the specified parameters are out of expected range.
60 """
62 def __init__(self, size, emission_time, pixel, trap_type, coeffs):
63 if size < 0.0:
64 raise ValueError('Trap size must be greater than or equal to 0.')
65 self.size = size
67 if emission_time <= 0.0:
68 raise ValueError('Emission time must be greater than 0.')
69 if np.isnan(emission_time):
70 raise ValueError('Emission time must be real-valued, not NaN')
71 self.emission_time = emission_time
73 if int(pixel) != pixel:
74 raise ValueError('Fraction value for pixel not allowed.')
75 self.pixel = int(pixel)
77 self.trap_type = trap_type
78 self.coeffs = coeffs
80 if self.trap_type not in ('linear', 'logistic', 'spline'):
81 raise ValueError('Unknown trap type: %s', self.trap_type)
83 if self.trap_type == 'spline':
84 # Note that ``spline`` is actually a piecewise linear interpolation
85 # in the model and the application, and not a true spline.
86 centers, values = np.split(np.array(self.coeffs, dtype=np.float64), 2)
87 # Ensure all NaN values are stripped out
88 values = values[~np.isnan(centers)]
89 centers = centers[~np.isnan(centers)]
90 centers = centers[~np.isnan(values)]
91 values = values[~np.isnan(values)]
92 self.interp = interp.interp1d(
93 centers,
94 values,
95 bounds_error=False,
96 fill_value=(values[0], values[-1]),
97 )
99 self._trap_array = None
100 self._trapped_charge = None
102 def __eq__(self, other):
103 # A trap is equal to another trap if all of the initialization
104 # parameters are equal. All other properties are only filled
105 # during use, and are not persisted into the calibration.
106 if self.size != other.size:
107 return False
108 if self.emission_time != other.emission_time:
109 return False
110 if self.pixel != other.pixel:
111 return False
112 if self.trap_type != other.trap_type:
113 return False
114 if self.coeffs != other.coeffs:
115 return False
116 return True
118 @property
119 def trap_array(self):
120 return self._trap_array
122 @property
123 def trapped_charge(self):
124 return self._trapped_charge
126 def initialize(self, ny, nx, prescan_width):
127 """Initialize trapping arrays for simulated readout.
129 Parameters
130 ----------
131 ny : `int`
132 Number of rows to simulate.
133 nx : `int`
134 Number of columns to simulate.
135 prescan_width : `int`
136 Additional transfers due to prescan.
138 Raises
139 ------
140 ValueError
141 Raised if the trap falls outside of the image.
142 """
143 if self.pixel > nx+prescan_width:
144 raise ValueError('Trap location {0} must be less than {1}'.format(self.pixel,
145 nx+prescan_width))
147 self._trap_array = np.zeros((ny, nx+prescan_width))
148 self._trap_array[:, self.pixel] = self.size
149 self._trapped_charge = np.zeros((ny, nx+prescan_width))
151 def release_charge(self):
152 """Release charge through exponential decay.
154 Returns
155 -------
156 released_charge : `float`
157 Charge released.
158 """
159 released_charge = self._trapped_charge*(1-np.exp(-1./self.emission_time))
160 self._trapped_charge -= released_charge
162 return released_charge
164 def trap_charge(self, free_charge):
165 """Perform charge capture using a logistic function.
167 Parameters
168 ----------
169 free_charge : `float`
170 Charge available to be trapped.
172 Returns
173 -------
174 captured_charge : `float`
175 Amount of charge actually trapped.
176 """
177 captured_charge = (np.clip(self.capture(free_charge), self.trapped_charge, self._trap_array)
178 - self.trapped_charge)
179 self._trapped_charge += captured_charge
181 return captured_charge
183 def capture(self, pixel_signals):
184 """Trap capture function.
186 Parameters
187 ----------
188 pixel_signals : `list` [`float`]
189 Input pixel values.
191 Returns
192 -------
193 captured_charge : `list` [`float`]
194 Amount of charge captured from each pixel.
196 Raises
197 ------
198 RuntimeError
199 Raised if the trap type is invalid.
200 """
201 if self.trap_type == 'linear':
202 scaling = self.coeffs[0]
203 return np.minimum(self.size, pixel_signals*scaling)
204 elif self.trap_type == 'logistic':
205 f0, k = (self.coeffs[0], self.coeffs[1])
206 return self.size/(1.+np.exp(-k*(pixel_signals-f0)))
207 elif self.trap_type == 'spline':
208 return self.interp(pixel_signals)
209 else:
210 raise RuntimeError(f"Invalid trap capture type: {self.trap_type}.")
213class DeferredChargeCalib(IsrCalib):
214 r"""Calibration containing deferred charge/CTI parameters.
216 Parameters
217 ----------
218 **kwargs :
219 Additional parameters to pass to parent constructor.
221 Notes
222 -----
223 The charge transfer inefficiency attributes stored are:
225 driftScale : `dict` [`str`, `float`]
226 A dictionary, keyed by amplifier name, of the local electronic
227 offset drift scale parameter, A_L in Snyder+2021.
228 decayTime : `dict` [`str`, `float`]
229 A dictionary, keyed by amplifier name, of the local electronic
230 offset decay time, \tau_L in Snyder+2021.
231 globalCti : `dict` [`str`, `float`]
232 A dictionary, keyed by amplifier name, of the mean global CTI
233 paramter, b in Snyder+2021.
234 serialTraps : `dict` [`str`, `lsst.ip.isr.SerialTrap`]
235 A dictionary, keyed by amplifier name, containing a single
236 serial trap for each amplifier.
237 """
238 _OBSTYPE = 'CTI'
239 _SCHEMA = 'Deferred Charge'
240 _VERSION = 1.0
242 def __init__(self, **kwargs):
243 self.driftScale = {}
244 self.decayTime = {}
245 self.globalCti = {}
246 self.serialTraps = {}
248 super().__init__(**kwargs)
249 self.requiredAttributes.update(['driftScale', 'decayTime', 'globalCti', 'serialTraps'])
251 def fromDetector(self, detector):
252 """Read metadata parameters from a detector.
254 Parameters
255 ----------
256 detector : `lsst.afw.cameraGeom.detector`
257 Input detector with parameters to use.
259 Returns
260 -------
261 calib : `lsst.ip.isr.Linearizer`
262 The calibration constructed from the detector.
263 """
265 pass
267 @classmethod
268 def fromDict(cls, dictionary):
269 """Construct a calibration from a dictionary of properties.
271 Parameters
272 ----------
273 dictionary : `dict`
274 Dictionary of properties.
276 Returns
277 -------
278 calib : `lsst.ip.isr.CalibType`
279 Constructed calibration.
281 Raises
282 ------
283 RuntimeError
284 Raised if the supplied dictionary is for a different
285 calibration.
286 """
287 calib = cls()
289 if calib._OBSTYPE != dictionary['metadata']['OBSTYPE']:
290 raise RuntimeError(f"Incorrect CTI supplied. Expected {calib._OBSTYPE}, "
291 f"found {dictionary['metadata']['OBSTYPE']}")
293 calib.setMetadata(dictionary['metadata'])
295 calib.driftScale = dictionary['driftScale']
296 calib.decayTime = dictionary['decayTime']
297 calib.globalCti = dictionary['globalCti']
299 for ampName in dictionary['serialTraps']:
300 ampTraps = dictionary['serialTraps'][ampName]
301 calib.serialTraps[ampName] = SerialTrap(ampTraps['size'], ampTraps['emissionTime'],
302 ampTraps['pixel'], ampTraps['trap_type'],
303 ampTraps['coeffs'])
304 calib.updateMetadata()
305 return calib
307 def toDict(self):
308 """Return a dictionary containing the calibration properties.
309 The dictionary should be able to be round-tripped through
310 ``fromDict``.
312 Returns
313 -------
314 dictionary : `dict`
315 Dictionary of properties.
316 """
317 self.updateMetadata()
318 outDict = {}
319 outDict['metadata'] = self.getMetadata()
321 outDict['driftScale'] = self.driftScale
322 outDict['decayTime'] = self.decayTime
323 outDict['globalCti'] = self.globalCti
325 outDict['serialTraps'] = {}
326 for ampName in self.serialTraps:
327 ampTrap = {'size': self.serialTraps[ampName].size,
328 'emissionTime': self.serialTraps[ampName].emission_time,
329 'pixel': self.serialTraps[ampName].pixel,
330 'trap_type': self.serialTraps[ampName].trap_type,
331 'coeffs': self.serialTraps[ampName].coeffs}
332 outDict['serialTraps'][ampName] = ampTrap
334 return outDict
336 @classmethod
337 def fromTable(cls, tableList):
338 """Construct calibration from a list of tables.
340 This method uses the ``fromDict`` method to create the
341 calibration, after constructing an appropriate dictionary from
342 the input tables.
344 Parameters
345 ----------
346 tableList : `list` [`lsst.afw.table.Table`]
347 List of tables to use to construct the crosstalk
348 calibration. Two tables are expected in this list, the
349 first containing the per-amplifier CTI parameters, and the
350 second containing the parameters for serial traps.
352 Returns
353 -------
354 calib : `lsst.ip.isr.DeferredChargeCalib`
355 The calibration defined in the tables.
357 Raises
358 ------
359 ValueError
360 Raised if the trap type or trap coefficients are not
361 defined properly.
362 """
363 ampTable = tableList[0]
365 inDict = {}
366 inDict['metadata'] = ampTable.meta
368 amps = ampTable['AMPLIFIER']
369 driftScale = ampTable['DRIFT_SCALE']
370 decayTime = ampTable['DECAY_TIME']
371 globalCti = ampTable['GLOBAL_CTI']
373 inDict['driftScale'] = {amp: value for amp, value in zip(amps, driftScale)}
374 inDict['decayTime'] = {amp: value for amp, value in zip(amps, decayTime)}
375 inDict['globalCti'] = {amp: value for amp, value in zip(amps, globalCti)}
377 inDict['serialTraps'] = {}
378 trapTable = tableList[1]
380 amps = trapTable['AMPLIFIER']
381 sizes = trapTable['SIZE']
382 emissionTimes = trapTable['EMISSION_TIME']
383 pixels = trapTable['PIXEL']
384 trap_type = trapTable['TYPE']
385 coeffs = trapTable['COEFFS']
387 for index, amp in enumerate(amps):
388 ampTrap = {}
389 ampTrap['size'] = sizes[index]
390 ampTrap['emissionTime'] = emissionTimes[index]
391 ampTrap['pixel'] = pixels[index]
392 ampTrap['trap_type'] = trap_type[index]
394 # Unpad any trailing NaN values: find the continuous array
395 # of NaNs at the end of the coefficients, and remove them.
396 inCoeffs = coeffs[index]
397 breakIndex = 1
398 nanValues = np.where(np.isnan(inCoeffs))[0]
399 if nanValues is not None:
400 coeffLength = len(inCoeffs)
401 while breakIndex < coeffLength:
402 if coeffLength - breakIndex in nanValues:
403 breakIndex += 1
404 else:
405 break
406 breakIndex -= 1 # Remove the fixed offset.
407 if breakIndex != 0:
408 outCoeffs = inCoeffs[0: coeffLength - breakIndex]
409 else:
410 outCoeffs = inCoeffs
411 ampTrap['coeffs'] = outCoeffs.tolist()
413 if ampTrap['trap_type'] == 'linear':
414 if len(ampTrap['coeffs']) < 1:
415 raise ValueError("CTI Amplifier %s coefficients for trap has illegal length %d.",
416 amp, len(ampTrap['coeffs']))
417 elif ampTrap['trap_type'] == 'logistic':
418 if len(ampTrap['coeffs']) < 2:
419 raise ValueError("CTI Amplifier %s coefficients for trap has illegal length %d.",
420 amp, len(ampTrap['coeffs']))
421 elif ampTrap['trap_type'] == 'spline':
422 if len(ampTrap['coeffs']) % 2 != 0:
423 raise ValueError("CTI Amplifier %s coefficients for trap has illegal length %d.",
424 amp, len(ampTrap['coeffs']))
425 else:
426 raise ValueError('Unknown trap type: %s', ampTrap['trap_type'])
428 inDict['serialTraps'][amp] = ampTrap
430 return cls.fromDict(inDict)
432 def toTable(self):
433 """Construct a list of tables containing the information in this
434 calibration.
436 The list of tables should create an identical calibration
437 after being passed to this class's fromTable method.
439 Returns
440 -------
441 tableList : `list` [`lsst.afw.table.Table`]
442 List of tables containing the crosstalk calibration
443 information. Two tables are generated for this list, the
444 first containing the per-amplifier CTI parameters, and the
445 second containing the parameters for serial traps.
446 """
447 tableList = []
448 self.updateMetadata()
450 ampList = []
451 driftScale = []
452 decayTime = []
453 globalCti = []
455 for amp in self.driftScale.keys():
456 ampList.append(amp)
457 driftScale.append(self.driftScale[amp])
458 decayTime.append(self.decayTime[amp])
459 globalCti.append(self.globalCti[amp])
461 ampTable = Table({'AMPLIFIER': ampList,
462 'DRIFT_SCALE': driftScale,
463 'DECAY_TIME': decayTime,
464 'GLOBAL_CTI': globalCti,
465 })
467 ampTable.meta = self.getMetadata().toDict()
468 tableList.append(ampTable)
470 ampList = []
471 sizeList = []
472 timeList = []
473 pixelList = []
474 typeList = []
475 coeffList = []
477 # Get maximum coeff length
478 maxCoeffLength = 0
479 for trap in self.serialTraps.values():
480 maxCoeffLength = np.maximum(maxCoeffLength, len(trap.coeffs))
482 # Pack and pad the end of the coefficients with NaN values.
483 for amp, trap in self.serialTraps.items():
484 ampList.append(amp)
485 sizeList.append(trap.size)
486 timeList.append(trap.emission_time)
487 pixelList.append(trap.pixel)
488 typeList.append(trap.trap_type)
490 coeffs = trap.coeffs
491 if len(coeffs) != maxCoeffLength:
492 coeffs = np.pad(coeffs, (0, maxCoeffLength - len(coeffs)),
493 constant_values=np.nan).tolist()
494 coeffList.append(coeffs)
496 trapTable = Table({'AMPLIFIER': ampList,
497 'SIZE': sizeList,
498 'EMISSION_TIME': timeList,
499 'PIXEL': pixelList,
500 'TYPE': typeList,
501 'COEFFS': coeffList})
503 tableList.append(trapTable)
505 return tableList
508class DeferredChargeConfig(Config):
509 """Settings for deferred charge correction.
510 """
511 nPixelOffsetCorrection = Field(
512 dtype=int,
513 doc="Number of prior pixels to use for local offset correction.",
514 default=15,
515 )
516 nPixelTrapCorrection = Field(
517 dtype=int,
518 doc="Number of prior pixels to use for trap correction.",
519 default=6,
520 )
521 useGains = Field(
522 dtype=bool,
523 doc="If true, scale by the gain.",
524 default=False,
525 )
526 zeroUnusedPixels = Field(
527 dtype=bool,
528 doc="If true, set serial prescan and parallel overscan to zero before correction.",
529 default=False,
530 )
533class DeferredChargeTask(Task):
534 """Task to correct an exposure for charge transfer inefficiency.
536 This uses the methods described by Snyder et al. 2021, Journal of
537 Astronimcal Telescopes, Instruments, and Systems, 7,
538 048002. doi:10.1117/1.JATIS.7.4.048002 (Snyder+21).
539 """
540 ConfigClass = DeferredChargeConfig
541 _DefaultName = 'isrDeferredCharge'
543 def run(self, exposure, ctiCalib, gains=None):
544 """Correct deferred charge/CTI issues.
546 Parameters
547 ----------
548 exposure : `lsst.afw.image.Exposure`
549 Exposure to correct the deferred charge on.
550 ctiCalib : `lsst.ip.isr.DeferredChargeCalib`
551 Calibration object containing the charge transfer
552 inefficiency model.
553 gains : `dict` [`str`, `float`]
554 A dictionary, keyed by amplifier name, of the gains to
555 use. If gains is None, the nominal gains in the amplifier
556 object are used.
558 Returns
559 -------
560 exposure : `lsst.afw.image.Exposure`
561 The corrected exposure.
562 """
563 image = exposure.getMaskedImage().image
564 detector = exposure.getDetector()
566 # If gains were supplied, they should be used. If useGains is
567 # true, but no external gains were supplied, use the nominal
568 # gains listed in the detector. Finally, if useGains is
569 # false, fake a dictionary of unit gains for ``gainContext``.
570 if self.config.useGains:
571 if gains is None:
572 gains = {amp.getName(): amp.getGain() for amp in detector.getAmplifiers()}
574 with gainContext(exposure, image, self.config.useGains, gains):
575 for amp in detector.getAmplifiers():
576 ampName = amp.getName()
578 ampImage = image[amp.getRawBBox()]
579 if self.config.zeroUnusedPixels:
580 # We don't apply overscan subtraction, so zero these
581 # out for now.
582 ampImage[amp.getRawParallelOverscanBBox()].array[:, :] = 0.0
583 ampImage[amp.getRawSerialPrescanBBox()].array[:, :] = 0.0
585 # The algorithm expects that the readout corner is in
586 # the lower left corner. Flip it to be so:
588 ampData = self.flipData(ampImage.array, amp)
590 if ctiCalib.driftScale[ampName] > 0.0:
591 correctedAmpData = self.local_offset_inverse(ampData,
592 ctiCalib.driftScale[ampName],
593 ctiCalib.decayTime[ampName],
594 self.config.nPixelOffsetCorrection)
595 else:
596 correctedAmpData = ampData.copy()
598 correctedAmpData = self.local_trap_inverse(correctedAmpData,
599 ctiCalib.serialTraps[ampName],
600 ctiCalib.globalCti[ampName],
601 self.config.nPixelTrapCorrection)
603 # Undo flips here. The method is symmetric.
604 correctedAmpData = self.flipData(correctedAmpData, amp)
605 image[amp.getRawBBox()].array[:, :] = correctedAmpData[:, :]
607 return exposure
609 @staticmethod
610 def flipData(ampData, amp):
611 """Flip data array such that readout corner is at lower-left.
613 Parameters
614 ----------
615 ampData : `numpy.ndarray`, (nx, ny)
616 Image data to flip.
617 amp : `lsst.afw.cameraGeom.Amplifier`
618 Amplifier to get readout corner information.
620 Returns
621 -------
622 ampData : `numpy.ndarray`, (nx, ny)
623 Flipped image data.
624 """
625 X_FLIP = {ReadoutCorner.LL: False,
626 ReadoutCorner.LR: True,
627 ReadoutCorner.UL: False,
628 ReadoutCorner.UR: True}
629 Y_FLIP = {ReadoutCorner.LL: False,
630 ReadoutCorner.LR: False,
631 ReadoutCorner.UL: True,
632 ReadoutCorner.UR: True}
634 if X_FLIP[amp.getReadoutCorner()]:
635 ampData = np.fliplr(ampData)
636 if Y_FLIP[amp.getReadoutCorner()]:
637 ampData = np.flipud(ampData)
639 return ampData
641 @staticmethod
642 def local_offset_inverse(inputArr, drift_scale, decay_time, num_previous_pixels=15):
643 r"""Remove CTI effects from local offsets.
645 This implements equation 10 of Snyder+21. For an image with
646 CTI, s'(m, n), the correction factor is equal to the maximum
647 value of the set of:
649 .. code-block::
651 {A_L s'(m, n - j) exp(-j t / \tau_L)}_j=0^jmax
653 Parameters
654 ----------
655 inputArr : `numpy.ndarray`, (nx, ny)
656 Input image data to correct.
657 drift_scale : `float`
658 Drift scale (Snyder+21 A_L value) to use in correction.
659 decay_time : `float`
660 Decay time (Snyder+21 \tau_L) of the correction.
661 num_previous_pixels : `int`, optional
662 Number of previous pixels to use for correction. As the
663 CTI has an exponential decay, this essentially truncates
664 the correction where that decay scales the input charge to
665 near zero.
667 Returns
668 -------
669 outputArr : `numpy.ndarray`, (nx, ny)
670 Corrected image data.
671 """
672 r = np.exp(-1/decay_time)
673 Ny, Nx = inputArr.shape
675 # j = 0 term:
676 offset = np.zeros((num_previous_pixels, Ny, Nx))
677 offset[0, :, :] = drift_scale*np.maximum(0, inputArr)
679 # j = 1..jmax terms:
680 for n in range(1, num_previous_pixels):
681 offset[n, :, n:] = drift_scale*np.maximum(0, inputArr[:, :-n])*(r**n)
683 Linv = np.amax(offset, axis=0)
684 outputArr = inputArr - Linv
686 return outputArr
688 @staticmethod
689 def local_trap_inverse(inputArr, trap, global_cti=0.0, num_previous_pixels=6):
690 r"""Apply localized trapping inverse operator to pixel signals.
692 This implements equation 13 of Snyder+21. For an image with
693 CTI, s'(m, n), the correction factor is equal to the maximum
694 value of the set of:
696 .. code-block::
698 {A_L s'(m, n - j) exp(-j t / \tau_L)}_j=0^jmax
700 Parameters
701 ----------
702 inputArr : `numpy.ndarray`, (nx, ny)
703 Input image data to correct.
704 trap : `lsst.ip.isr.SerialTrap`
705 Serial trap describing the capture and release of charge.
706 global_cti: `float`
707 Mean charge transfer inefficiency, b from Snyder+21.
708 num_previous_pixels : `int`, optional
709 Number of previous pixels to use for correction.
711 Returns
712 -------
713 outputArr : `numpy.ndarray`, (nx, ny)
714 Corrected image data.
716 """
717 Ny, Nx = inputArr.shape
718 a = 1 - global_cti
719 r = np.exp(-1/trap.emission_time)
721 # Estimate trap occupancies during readout
722 trap_occupancy = np.zeros((num_previous_pixels, Ny, Nx))
723 for n in range(num_previous_pixels):
724 trap_occupancy[n, :, n+1:] = trap.capture(np.maximum(0, inputArr))[:, :-(n+1)]*(r**n)
725 trap_occupancy = np.amax(trap_occupancy, axis=0)
727 # Estimate captured charge
728 C = trap.capture(np.maximum(0, inputArr)) - trap_occupancy*r
729 C[C < 0] = 0.
731 # Estimate released charge
732 R = np.zeros(inputArr.shape)
733 R[:, 1:] = trap_occupancy[:, 1:]*(1-r)
734 T = R - C
736 outputArr = inputArr - a*T
738 return outputArr