Coverage for python/lsst/meas/algorithms/accumulator_mean_stack.py: 8%
72 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-05 18:13 -0800
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-05 18:13 -0800
1# This file is part of meas_algorithms.
2#
3# LSST Data Management System
4# This product includes software developed by the
5# LSST Project (http://www.lsst.org/).
6# See COPYRIGHT file at the top of the source tree.
7#
8# This program is free software: you can redistribute it and/or modify
9# it under the terms of the GNU General Public License as published by
10# the Free Software Foundation, either version 3 of the License, or
11# (at your option) any later version.
12#
13# This program is distributed in the hope that it will be useful,
14# but WITHOUT ANY WARRANTY; without even the implied warranty of
15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16# GNU General Public License for more details.
17#
18# You should have received a copy of the LSST License Statement and
19# the GNU General Public License along with this program. If not,
20# see <https://www.lsstcorp.org/LegalNotices/>.
21#
22import numpy as np
25__all__ = ['AccumulatorMeanStack']
28class AccumulatorMeanStack(object):
29 """Stack masked images.
31 Parameters
32 ----------
33 shape : `tuple`
34 Shape of the input and output images.
35 bit_mask_value : `int`
36 Bit mask to flag for "bad" inputs that should not be stacked.
37 mask_threshold_dict : `dict` [`int`: `float`], optional
38 Dictionary of mapping from bit number to threshold for flagging.
39 Only bad bits (in bit_mask_value) which mask fractional weight
40 greater than this threshold will be flagged in the output image.
41 mask_map : `list` [`tuple`], optional
42 Mapping from input image bits to aggregated coadd bits.
43 no_good_pixels_mask : `int`, optional
44 Bit mask to set when there are no good pixels in the stack.
45 If not set then will set coadd masked image 'NO_DATA' bit.
46 calc_error_from_input_variance : `bool`, optional
47 Calculate the error from the input variance?
48 compute_n_image : `bool`, optional
49 Calculate the n_image map as well as stack?
50 """
51 def __init__(self, shape,
52 bit_mask_value, mask_threshold_dict={},
53 mask_map=[], no_good_pixels_mask=None,
54 calc_error_from_input_variance=True,
55 compute_n_image=False):
56 self.shape = shape
57 self.bit_mask_value = bit_mask_value
58 self.mask_map = mask_map
59 self.no_good_pixels_mask = no_good_pixels_mask
60 self.calc_error_from_input_variance = calc_error_from_input_variance
61 self.compute_n_image = compute_n_image
63 # Only track threshold bits that are in the bad bit_mask_value.
64 self.mask_threshold_dict = {}
65 for bit in mask_threshold_dict:
66 if (self.bit_mask_value & 2**bit) > 0:
67 self.mask_threshold_dict[bit] = mask_threshold_dict[bit]
69 # sum_weight holds the sum of weights for each pixel.
70 self.sum_weight = np.zeros(shape, dtype=np.float64)
71 # sum_wdata holds the sum of weight*data for each pixel.
72 self.sum_wdata = np.zeros(shape, dtype=np.float64)
74 if calc_error_from_input_variance:
75 # sum_w2var holds the sum of weight**2 * variance for each pixel.
76 self.sum_w2var = np.zeros(shape, dtype=np.float64)
77 else:
78 # sum_weight2 holds the sum of weight**2 for each pixel.
79 self.sum_weight2 = np.zeros(shape, dtype=np.float64)
80 # sum_wdata2 holds the sum of weight * data**2 for each pixel.
81 self.sum_wdata2 = np.zeros(shape, dtype=np.float64)
83 self.or_mask = np.zeros(shape, dtype=np.int64)
84 self.rejected_weights_by_bit = {}
85 for bit in self.mask_threshold_dict:
86 self.rejected_weights_by_bit[bit] = np.zeros(shape, dtype=np.float64)
88 self.masked_pixels_mask = np.zeros(shape, dtype=np.int64)
90 if self.compute_n_image:
91 self.n_image = np.zeros(shape, dtype=np.int32)
93 def add_masked_image(self, masked_image, weight=1.0):
94 """Add a masked image to the stack.
96 Parameters
97 ----------
98 masked_image : `lsst.afw.image.MaskedImage`
99 Masked image to add to the stack.
100 """
101 good_pixels = np.where(((masked_image.mask.array & self.bit_mask_value) == 0)
102 & np.isfinite(masked_image.mask.array))
104 self.sum_weight[good_pixels] += weight
105 self.sum_wdata[good_pixels] += weight*masked_image.image.array[good_pixels]
107 if self.compute_n_image:
108 self.n_image[good_pixels] += 1
110 if self.calc_error_from_input_variance:
111 self.sum_w2var[good_pixels] += (weight**2.)*masked_image.variance.array[good_pixels]
112 else:
113 self.sum_weight2[good_pixels] += weight**2.
114 self.sum_wdata2[good_pixels] += weight*(masked_image.image.array[good_pixels]**2.)
116 # Mask bits are propagated for good pixels
117 self.or_mask[good_pixels] |= masked_image.mask.array[good_pixels]
119 # Bad pixels are only tracked if they cross a threshold
120 for bit in self.mask_threshold_dict:
121 bad_pixels = ((masked_image.mask.array & 2**bit) > 0)
122 self.rejected_weights_by_bit[bit][bad_pixels] += weight
123 self.masked_pixels_mask[bad_pixels] |= 2**bit
125 def fill_stacked_masked_image(self, stacked_masked_image):
126 """Fill the stacked mask image after accumulation.
128 Parameters
129 ----------
130 stacked_masked_image : `lsst.afw.image.MaskedImage`
131 Total masked image.
132 """
133 with np.warnings.catch_warnings():
134 # Let the NaNs through and flag bad pixels below
135 np.warnings.simplefilter("ignore")
137 # The image plane is sum(weight*data)/sum(weight)
138 stacked_masked_image.image.array[:, :] = self.sum_wdata/self.sum_weight
140 if self.calc_error_from_input_variance:
141 mean_var = self.sum_w2var/(self.sum_weight**2.)
142 else:
143 # Compute the biased estimator
144 variance = self.sum_wdata2/self.sum_weight - stacked_masked_image.image.array[:, :]**2.
145 # De-bias
146 variance *= (self.sum_weight**2.)/(self.sum_weight**2. - self.sum_weight2)
148 # Compute the mean variance
149 mean_var = variance*self.sum_weight2/(self.sum_weight**2.)
151 stacked_masked_image.variance.array[:, :] = mean_var
153 # Propagate bits when they cross the threshold
154 for bit in self.mask_threshold_dict:
155 hypothetical_total_weight = self.sum_weight + self.rejected_weights_by_bit[bit]
156 self.rejected_weights_by_bit[bit] /= hypothetical_total_weight
157 propagate = np.where(self.rejected_weights_by_bit[bit] > self.mask_threshold_dict[bit])
158 self.or_mask[propagate] |= 2**bit
160 # Map mask planes to new bits for pixels that had at least one
161 # bad input rejected and are in the mask_map.
162 for mask_tuple in self.mask_map:
163 self.or_mask[(self.masked_pixels_mask & mask_tuple[0]) > 0] |= mask_tuple[1]
165 stacked_masked_image.mask.array[:, :] = self.or_mask
167 if self.no_good_pixels_mask is None:
168 mask_dict = stacked_masked_image.maskedImage().getMask().getMaskPlaneDict()
169 no_good_pixels_mask = 2**(mask_dict['NO_DATA'])
170 else:
171 no_good_pixels_mask = self.no_good_pixels_mask
173 bad_pixels = (self.sum_weight <= 0.0)
174 stacked_masked_image.mask.array[bad_pixels] |= no_good_pixels_mask
176 @staticmethod
177 def stats_ctrl_to_threshold_dict(stats_ctrl):
178 """Convert stats control to threshold dict.
180 Parameters
181 ----------
182 stats_ctrl : `lsst.afw.math.StatisticsControl`
184 Returns
185 -------
186 threshold_dict : `dict`
187 Dict mapping from bit to propagation threshold.
188 """
189 threshold_dict = {}
190 for bit in range(64):
191 threshold_dict[bit] = stats_ctrl.getMaskPropagationThreshold(bit)
193 return threshold_dict