Coverage for tests/test_variance_plane.py: 8%
329 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-15 02:10 -0700
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-15 02:10 -0700
1# This file is part of ip_isr.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21#
22"""The utility function under test is part of the `meas_algorithms` package.
23This unit test has been relocated here to avoid circular dependencies.
24"""
26import re
27import unittest
28from contextlib import nullcontext
30import galsim
31import lsst.utils.tests
32import matplotlib.colors as mcolors
33import matplotlib.pyplot as plt
34import numpy as np
35from lsst.ip.isr.isrMock import IsrMock
36from lsst.meas.algorithms import remove_signal_from_variance
37from lsst.utils.tests import methodParametersProduct
38from matplotlib.legend_handler import HandlerTuple
39from matplotlib.patheffects import withStroke
40from matplotlib.ticker import FixedLocator, FuncFormatter
42# Set to True to save the plot of the variance plane before and after
43# correction for a representative test case.
44SAVE_PLOT = False
47def outline_effect(lw, alpha=0.8):
48 """Generate a path effect for enhanced text visibility.
50 Parameters
51 ----------
52 lw : `float`
53 Line width of the outline.
54 alpha : `float`, optional
55 Transparency of the outline.
57 Returns
58 -------
59 `list` of `matplotlib.patheffects.withStroke`
60 A list containing the path effect.
61 """
62 return [withStroke(linewidth=lw, foreground="white", alpha=alpha)]
65class CustomHandler(HandlerTuple):
66 """Custom handler for handling grouped items in the legend."""
68 def create_artists(self, *args):
69 artists = super().create_artists(*args)
70 for a in artists:
71 a.set_transform(args[-1])
72 return artists
75def get_valid_color(handle):
76 """Extracts a valid color from a Matplotlib handle.
78 Parameters
79 ----------
80 handle : `matplotlib.artist.Artist`
81 The handle from which to extract the color.
83 Returns
84 -------
85 color : `str` or `tuple`
86 The color extracted from the handle, or 'default' if no valid color is
87 found.
88 """
89 for attr in ["get_facecolor", "get_edgecolor", "get_color"]:
90 if hasattr(handle, attr):
91 color = getattr(handle, attr)()
92 # If the handle is a collection, use the first color.
93 if isinstance(color, np.ndarray) and color.shape[0] > 0:
94 color = color[0]
95 # If the color is RGBA with alpha = 0, continue the search.
96 if len(color) == 4 and color[3] == 0:
97 continue
98 return color
99 return "default" # If no valid color is found
102def get_emptier_side(ax):
103 """Analyze a matplotlib Axes object to determine which side (left or right)
104 has more whitespace, considering cases where artists' bounding boxes span
105 both sides of the midpoint.
107 Parameters
108 ----------
109 ax : `~matplotlib.axes.Axes`
110 The Axes object to analyze.
112 Returns
113 -------
114 more_whitespace_side : `str`
115 'left' if the left side has more whitespace, or 'right' if the right
116 side does.
117 """
118 # Get the total plotting area's midpoint on the x-axis.
119 xlim = ax.get_xlim()
120 midpoint = sum(xlim) / 2
122 # Initialize areas as zero.
123 left_area, right_area = 0, 0
125 # Loop through all children (artists) in the Axes.
126 for artist in ax.get_children():
127 # Skip if artist is invisible or lacks a bounding box.
128 if not artist.get_visible() or not hasattr(artist, "get_window_extent"):
129 continue
130 bbox = artist.get_window_extent().transformed(ax.figure.dpi_scale_trans.inverted())
131 # Check if the artist's bounding box spans the midpoint.
132 if bbox.x0 < midpoint < bbox.x1:
133 # Calculate the proportion of the bbox on each side of the
134 # midpoint.
135 left_proportion = (midpoint - bbox.x0) / bbox.width
136 right_proportion = 1 - left_proportion
137 # Adjust area calculations for both sides.
138 left_area += bbox.width * bbox.height * left_proportion
139 right_area += bbox.width * bbox.height * right_proportion
140 elif bbox.x0 + bbox.width / 2 < midpoint:
141 # Entirely on the left.
142 left_area += bbox.width * bbox.height
143 else:
144 # Entirely on the right.
145 right_area += bbox.width * bbox.height
147 # Determine which side has more whitespace by comparing occupied areas.
148 return "left" if left_area <= right_area else "right"
151def adjust_legend_with_groups(ax, combine_groups, colors="default", yloc="upper", **kwargs):
152 """Adjusts the legend of a given Axes object by combining specified handles
153 based on provided groups, setting the marker location and text alignment
154 based on the inferable emptier side, and optionally setting the text color
155 of legend entries to a provided list or inferring colors from the handles.
156 Additionally, allows specifying the vertical location of the legend within
157 the plot.
159 Parameters
160 ----------
161 ax : `~matplotlib.axes.Axes`
162 The Axes object for which to adjust the legend.
163 combine_groups : `list` of `int` or iterable of `int`
164 A list that can contain a mix of individual integers and/or iterables
165 (lists, tuples, or sets) of integers. An individual integer specifies
166 the index of a single legend entry. An iterable of integers specifies
167 a group of indices to be combined into a single legend entry.
168 colors : `list` of `str` or `tuple`, or `str`, optional
169 Specifies the colors for the legend entries. This parameter can be:
170 - A list of color specifications, where each element is a string (for
171 named colors or hex values) or a tuple (for RGB or RGBA values). This
172 list explicitly assigns colors to each legend entry post-combination.
173 - A single string value:
174 - "match": Colors are inferred from the properties of the first
175 handle in each group that corresponds to a non-white-space label.
176 This aims to match the legend text color with the color of the
177 plotted data.
178 - "default": The function does not alter the default colors
179 assigned by Matplotlib, preserving the automatic color assignment
180 for all legend entries.
181 yloc : `str`, optional
182 The vertical location of the legend within the Axes. Valid options are
183 'upper', 'lower', or 'middle'. This parameter is combined with the
184 inferable emptier side ('left' or 'right') to determine the legend's
185 placement. For example, 'upper right' or 'lower left'.
186 **kwargs :
187 Keyword arguments forwarded to the ``ax.legend`` function.
188 """
190 handles, labels = ax.get_legend_handles_labels()
191 new_handles = []
192 new_labels = []
194 if colors == "match":
195 colors = []
196 infer_colors = True
197 else:
198 infer_colors = False
200 for group in combine_groups:
201 # Assume the first non-white-space label represents the group. If no
202 # such label is found, just use the first label in the group which is
203 # in fact empty.
204 if isinstance(group, (list, tuple, set)):
205 group = list(group) # Just in case
206 label_index = next((i for i in group if labels[i].strip()), group[0])
207 combined_handle = tuple(handles[i] for i in group)
208 combined_label = labels[label_index]
209 elif isinstance(group, int):
210 label_index = group
211 combined_handle = handles[group]
212 combined_label = labels[group]
213 else:
214 raise ValueError("Invalid value in 'combine_groups'")
215 new_handles.append(combined_handle)
216 new_labels.append(combined_label)
217 if infer_colors:
218 # Attempt to infer color from the representative handle in the
219 # group.
220 handle = handles[label_index]
221 color = get_valid_color(handle)
222 colors.append(color)
224 # Determine the emptier side to decide legend and text alignment.
225 emptier_side = get_emptier_side(ax)
226 markerfirst = emptier_side != "right"
228 # Create the legend with custom adjustments.
229 legend = ax.legend(
230 new_handles,
231 new_labels,
232 handler_map={tuple: CustomHandler()},
233 loc=f"{yloc} {emptier_side}",
234 fontsize=8,
235 frameon=False,
236 markerfirst=markerfirst,
237 **kwargs,
238 )
240 # Right- or left-align the legend text based on the emptier side.
241 for text in legend.get_texts():
242 text.set_ha(emptier_side)
244 # Set legend text colors if necessary.
245 if colors != "default":
246 for text, color in zip(legend.get_texts(), colors):
247 if not (isinstance(color, str) and color == "default"):
248 text.set_color(color)
251def adjust_tick_scale(ax, axis_label_templates):
252 """Scales down tick labels to make them more readable and updates axis
253 labels accordingly.
255 Calculates a power of 10 scale factor (common divisor) to reduce the
256 magnitude of tick labels. It automatically determines which axes to adjust
257 based on the provided axis label templates, which should include `{scale}`
258 for inserting the scale factor dynamically.
260 Parameters
261 ----------
262 ax : `~matplotlib.axes.Axes`
263 The Axes object to modify.
264 axis_label_templates : `dict`
265 Templates for axis labels, including '{scale}' for scale factor
266 insertion. Keys should be one or more of axes names ('x', 'y', 'z') and
267 values should be the corresponding label templates.
268 """
270 def trailing_zeros(n):
271 """Determines the number of trailing zeros in a number."""
272 return len(n := str(int(float(n))) if float(n).is_integer() else n) - len(n.rstrip("0"))
274 def format_tick(val, pos, divisor):
275 """Formats tick labels using the determined divisor."""
276 return str(int(val / divisor)) if (val / divisor).is_integer() else str(val / divisor)
278 # Iterate through the specified axes and adjust their tick labels and axis
279 # labels.
280 for axis in axis_label_templates.keys():
281 # Gather current tick labels.
282 labels = [label.get_text() for label in getattr(ax, f"get_{axis}ticklabels")()]
284 # Calculate the power of 10 divisor based on the minimum number of
285 # trailing zeros in the tick labels.
286 divisor = 10 ** min(trailing_zeros(label) for label in labels if float(label) != 0)
288 # Set a formatter for the axis ticks that scales them according to the
289 # common divisor.
290 getattr(ax, f"{axis}axis").set_major_formatter(
291 FuncFormatter(lambda val, pos: format_tick(val, pos, divisor))
292 )
294 # Ensure the tick positions remain unchanged despite the new
295 # formatting.
296 getattr(ax, f"{axis}axis").set_major_locator(FixedLocator(getattr(ax, f"get_{axis}ticks")()))
298 # Prepare 'scale', empty if divisor <= 1.
299 scale = f"{int(divisor)}" if divisor > 1 else ""
301 # Fetch the corresponding label template for the axis.
302 label_template = axis_label_templates[axis]
304 # If 'scale' is empty, remove whitespace around "{scale}" in the
305 # template. Also remove any trailing "/{scale}".
306 if scale == "":
307 label_template = re.sub(r"\s*{\s*scale\s*}\s*", "{scale}", label_template)
308 label_template = label_template.replace("/{scale}", "")
310 # Always strip remaining whitespace from the template.
311 label_template = label_template.strip()
313 # Set the formatted axis label.
314 label_text = label_template.format(scale=scale)
315 getattr(ax, f"set_{axis}label")(label_text, labelpad=8)
318class VariancePlaneTestCase(lsst.utils.tests.TestCase):
319 def setUp(self):
320 # Testing with a single detector that has 8 amplifiers in a 4x2
321 # configuration. Each amplifier measures 100x51 in dimensions.
322 config = IsrMock.ConfigClass()
323 config.isLsstLike = True
324 config.doAddBias = False
325 config.doAddDark = False
326 config.doAddFlat = False
327 config.doAddFringe = False
328 config.doGenerateImage = True
329 config.doGenerateData = True
330 config.doGenerateAmpDict = True
331 self.mock = IsrMock(config=config)
332 self.pixel_scale = 0.2 # arcsec/pixel
334 # Set the random seed for reproducibility.
335 random_seed = galsim.BaseDeviate(1905).raw() + 1
336 np.random.seed(random_seed)
337 self.rng = galsim.BaseDeviate(random_seed)
339 # Get the exposure, detector, and amps from the mock.
340 exposure = self.mock.getExposure()
341 detector = exposure.getDetector()
342 amps = detector.getAmplifiers()
344 # Set amp-related attributes for use in the test cases.
345 self.num_amps = len(amps)
346 self.amp_name_list = [amp.getName() for amp in amps]
347 table = str.maketrans("", "", ":,") # Remove ':' and ',' from names
348 self.amp_name_list_simplified = [name.translate(table) for name in self.amp_name_list]
350 # Get the bounding boxes for the exposure and amplifiers and convert
351 # them to galsim bounds.
352 exp_bbox = exposure.getBBox()
353 image_bounds = galsim.BoundsI(exp_bbox.minX, exp_bbox.maxX, exp_bbox.minY, exp_bbox.maxY)
354 self.amp_bbox_list = [amp.getBBox() for amp in amps]
355 self.amp_bounds_list = [galsim.BoundsI(b.minX, b.maxX, b.minY, b.maxY) for b in self.amp_bbox_list]
357 # Create a raw galsim image to potentially draw the sources onto. The
358 # exposure image that is passed to this method will be modified in
359 # place but it won't be used.
360 self.signal_free_raw_image = galsim.ImageF(exposure.image.array, bounds=image_bounds)
361 self.raw_image = self.signal_free_raw_image.copy()
363 # Define parameters for a mix of source types, including extended
364 # sources with assorted profiles as well as point sources simulated
365 # with minimal half-light radii to resemble hot pixels
366 # post-deconvolution. All flux values are given in electrons and
367 # half-light radii in pixels. The goal is for each amplifier to
368 # predominantly contain at least one source, enhancing the
369 # representativeness of test conditions.
370 source_params = [
371 {"type": "Sersic", "n": 3, "flux": 1.6e5, "half_light_radius": 3.5, "g1": -0.3, "g2": 0.2},
372 {"type": "Sersic", "n": 1, "flux": 9.3e5, "half_light_radius": 2.1, "g1": 0.25, "g2": 0.12},
373 {"type": "Sersic", "n": 4, "flux": 1.0e5, "half_light_radius": 1.1, "g1": 0.0, "g2": 0.0},
374 {"type": "Sersic", "n": 3, "flux": 1.1e6, "half_light_radius": 4.2, "g1": 0.0, "g2": 0.2},
375 {"type": "Sersic", "n": 5, "flux": 1.1e5, "half_light_radius": 3.6, "g1": 0.22, "g2": -0.05},
376 {"type": "Sersic", "n": 2, "flux": 4.3e5, "half_light_radius": 2.0, "g1": 0.0, "g2": 0.0},
377 {"type": "Sersic", "n": 6, "flux": 1.2e6, "half_light_radius": 11.0, "g1": -0.16, "g2": 0.7},
378 {"type": "Exponential", "flux": 1.3e6, "half_light_radius": 1.9, "g1": 0.3, "g2": -0.1},
379 {"type": "Exponential", "flux": 1.8e6, "half_light_radius": 5.0, "g1": 0.0, "g2": 0.14},
380 {"type": "Exponential", "flux": 6.6e6, "half_light_radius": 4.8, "g1": 0.26, "g2": 0.5},
381 {"type": "Exponential", "flux": 7.0e5, "half_light_radius": 3.1, "g1": -0.3, "g2": 0.0},
382 {"type": "DeVaucouleurs", "flux": 1.6e5, "half_light_radius": 3.5, "g1": 0.2, "g2": 0.4},
383 {"type": "DeVaucouleurs", "flux": 2.0e5, "half_light_radius": 1.6, "g1": -0.06, "g2": -0.2},
384 {"type": "DeVaucouleurs", "flux": 8.3e5, "half_light_radius": 5.1, "g1": 0.29, "g2": 0.0},
385 {"type": "DeVaucouleurs", "flux": 4.5e5, "half_light_radius": 2.5, "g1": 0.4, "g2": 0.3},
386 {"type": "DeVaucouleurs", "flux": 6.2e5, "half_light_radius": 4.9, "g1": -0.08, "g2": -0.01},
387 {"type": "Gaussian", "flux": 4.7e6, "half_light_radius": 2.5, "g1": 0.07, "g2": -0.35},
388 {"type": "Gaussian", "flux": 5.8e6, "half_light_radius": 3.1, "g1": 0.03, "g2": 0.4},
389 {"type": "Gaussian", "flux": 2.3e5, "half_light_radius": 0.5, "g1": 0.0, "g2": 0.0},
390 {"type": "Gaussian", "flux": 1.6e6, "half_light_radius": 3.0, "g1": 0.18, "g2": -0.29},
391 {"type": "Gaussian", "flux": 3.5e5, "half_light_radius": 4.6, "g1": 0.5, "g2": 0.35},
392 {"type": "Gaussian", "flux": 5.9e5, "half_light_radius": 9.5, "g1": 0.1, "g2": 0.55},
393 {"type": "Gaussian", "flux": 4.0e5, "half_light_radius": 1.0, "g1": 0.0, "g2": 0.0},
394 ]
396 # Mapping of profile types to their galsim constructors.
397 profile_constructors = {
398 "Sersic": galsim.Sersic,
399 "Exponential": galsim.Exponential,
400 "DeVaucouleurs": galsim.DeVaucouleurs,
401 "Gaussian": galsim.Gaussian,
402 }
404 # Generate random positions within exposure bounds, avoiding edges by a
405 # margin.
406 margin_x, margin_y = 0.05 * exp_bbox.width, 0.05 * exp_bbox.height
407 self.positions = np.random.uniform(
408 [exp_bbox.minX + margin_x, exp_bbox.minY + margin_y],
409 [exp_bbox.maxX - margin_x, exp_bbox.maxY - margin_y],
410 (len(source_params), 2),
411 ).tolist()
413 # Loop over the sources and draw them onto the image cutout by cutout.
414 for i, params in enumerate(source_params):
415 # Dynamically get constructor and remove type from params.
416 constructor = profile_constructors[params.pop("type")]
418 # Get shear parameters and remove them from params.
419 g1, g2 = params.pop("g1"), params.pop("g2")
421 # The extent of the cutout should be large enough to contain the
422 # entire object above the background level. Some empirical factor
423 # is used to mitigate artifacts.
424 half_extent = 10 * params["half_light_radius"] * (1 + 2 * np.sqrt(g1**2 + g2**2))
426 # Pass the remaining params to the constructor and apply shear.
427 galsim_object = constructor(**params).shear(galsim.Shear(g1=g1, g2=g2))
429 # Retrieve the position of the object.
430 x, y = self.positions[i]
431 pos = galsim.PositionD(x, y)
433 # Get the bounds of the sub-image based on the object position.
434 sub_image_bounds = galsim.BoundsI(
435 *map(int, [x - half_extent, x + half_extent, y - half_extent, y + half_extent])
436 )
438 # Identify the overlap region, which could be partially outside
439 # the image bounds.
440 sub_image_bounds = sub_image_bounds & self.raw_image.bounds
442 # Check that there is some overlap.
443 assert sub_image_bounds.isDefined(), "No overlap with image bounds"
445 # Get the sub-image cutout.
446 sub_image = self.raw_image[sub_image_bounds]
448 # Draw the object onto the image within the the sub-image bounds.
449 galsim_object.drawImage(
450 image=sub_image,
451 offset=pos - sub_image.true_center,
452 method="real_space", # Saves memory, usable w/o convolution
453 add_to_image=True, # Add flux to existing image
454 scale=self.pixel_scale,
455 )
457 def tearDown(self):
458 del self.mock
459 del self.raw_image
460 del self.signal_free_raw_image
462 def buildExposure(
463 self,
464 average_gain,
465 gain_sigma_factor,
466 sky_level,
467 add_signal=True,
468 ):
469 """Build and return an exposure with different types of simulated
470 source profiles and a background sky level. It's intended for testing
471 and analysis, providing a way to generate exposures with controlled
472 conditions.
474 Parameters
475 ----------
476 average_gain : `float`
477 The average gain value of amplifiers in e-/ADU.
478 gain_sigma_factor : float
479 The standard deviation of the gain values as a factor of the
480 ``average_gain``.
481 sky_level : `float`
482 The background sky level in e-/arcsec^2.
483 add_signal : `bool`, optional
484 Whether to add sources to the exposure. If set to False, the
485 exposure will only contain background noise.
487 Returns
488 -------
489 exposure : `~lsst.afw.image.Exposure`
490 An exposure object with simulated sources and background. The units
491 are in detector counts (ADU).
492 """
494 # Get the exposure from the mock.
495 exposure = self.mock.getExposure()
497 # Convert the background sky level from e-/arcsec^2 to e-/pixel.
498 self.background = sky_level * self.pixel_scale**2
500 # Generate random deviations from the average gain across amplifiers
501 # and adjust them to ensure their sum equals zero. This reflects
502 # real-world detectors, with amplifier gains normally distributed due
503 # to manufacturing and operational variations.
504 deviations = np.random.normal(average_gain, gain_sigma_factor * average_gain, size=self.num_amps)
505 deviations -= np.mean(deviations)
507 # Set the gain for amplifiers to be slightly different from each other
508 # while averaging to `average_gain`. This is to test the
509 # `average_across_amps` option in the `remove_signal_from_variance`
510 # function.
511 self.amp_gain_list = [average_gain + deviation for deviation in deviations]
513 # Add a constant background to the entire image (both in e-/pixel).
514 if add_signal:
515 image = self.raw_image + self.background
516 else:
517 image = self.signal_free_raw_image + self.background
519 # Add noise to the image which is in electrons. Note that we won't
520 # specify a `sky_level` here to avoid double-counting it, as it's
521 # already included in the image as the background.
522 image.addNoise(galsim.PoissonNoise(self.rng))
524 # Subtract off the background to get the sky-subtracted image.
525 image -= self.background
527 # Adjust each amplifier's image segment by its respective gain. After
528 # this step, the image will be in ADUs.
529 for bounds, gain in zip(self.amp_bounds_list, self.amp_gain_list):
530 image[bounds] /= gain
532 # We know that the exposure has already been modified in place, but
533 # just to be extra sure, we'll set the exposure image explicitly.
534 exposure.image.array = image.array
536 # Create a variance plane for the exposure while including signal as a
537 # pollutant. Note that the exposure image is pre-adjusted for gain,
538 # unlike 'self.background'. Thus, we divide the background by the
539 # corresponding gain before adding it to the image. This leads to the
540 # variance plane being in units of ADU^2.
541 for bbox, gain in zip(self.amp_bbox_list, self.amp_gain_list):
542 exposure.variance[bbox].array = (exposure.image[bbox].array + self.background / gain) / gain
544 return exposure
546 def test_no_signal_handling(self):
547 """Test that the function does nearly nothing when given an image with
548 no signal.
549 """
550 # Create an exposure with no signal.
551 exposure = self.buildExposure(
552 average_gain=1.4, gain_sigma_factor=0.01, sky_level=4e6, add_signal=False
553 )
554 # Remove the signal from the variance plane, if any.
555 updated_variance = remove_signal_from_variance(exposure, in_place=False)
556 # Check that the variance plane is nearly the same as the original.
557 self.assertFloatsAlmostEqual(exposure.variance.array, updated_variance.array, rtol=0.013)
559 def test_in_place_handling(self):
560 """Make sure the function is tested to handle in-place operations."""
561 # Create an exposure with signal.
562 exposure = self.buildExposure(
563 average_gain=1.4, gain_sigma_factor=0.01, sky_level=4e6, add_signal=True
564 )
565 # Remove the signal from the variance plane.
566 updated_variance = remove_signal_from_variance(exposure, in_place=True)
567 # Retrieve the variance plane from the exposure and check that it is
568 # identical to the returned variance plane.
569 self.assertFloatsEqual(exposure.variance.array, updated_variance.array)
571 @methodParametersProduct(
572 average_gain=[1.4, 1.7],
573 predefined_gain_type=["average", "per-amp", None],
574 gain_sigma_factor=[0, 0.008],
575 sky_level=[2e6, 4e6],
576 average_across_amps=[False, True],
577 )
578 def test_variance_signal_removal(
579 self, average_gain, predefined_gain_type, gain_sigma_factor, sky_level, average_across_amps
580 ):
581 exposure = self.buildExposure(
582 average_gain=average_gain,
583 gain_sigma_factor=gain_sigma_factor,
584 sky_level=sky_level,
585 add_signal=True,
586 )
588 # Save the original variance plane for comparison, assuming it has
589 # Poisson contribution from the source signal.
590 signal_polluted_variance = exposure.variance.clone()
592 # Check that the variance plane has no negative values.
593 self.assertTrue(
594 np.all(signal_polluted_variance.array >= 0),
595 "Variance plane has negative values (pre correction)",
596 )
598 if predefined_gain_type == "average":
599 predefined_gain = average_gain
600 predefined_gains = None
601 elif predefined_gain_type == "per-amp":
602 predefined_gain = None
603 predefined_gains = {name: gain for name, gain in zip(self.amp_name_list, self.amp_gain_list)}
604 elif predefined_gain_type is None:
605 # Allow the 'remove_signal_from_variance' function to estimate the
606 # gain itself before it attempts to remove the signal from the
607 # variance plane.
608 predefined_gain = None
609 predefined_gains = None
611 # Set the relative tolerance for the variance plane checks.
612 if predefined_gain_type == "average" or (predefined_gain_type is None and average_across_amps):
613 # Relax the tolerance if we are simply averaging across amps to
614 # roughly estimate the overall gain.
615 rtol = 0.015
616 estimate_average_gain = True
617 else:
618 # Tighten tolerance for the 'predefined_gain_type' of 'per-amp' or
619 # for a more accurate per-amp gain estimation strategy.
620 rtol = 3e-7
621 estimate_average_gain = False
623 # Remove the signal from the variance plane.
624 signal_free_variance = remove_signal_from_variance(
625 exposure,
626 gain=predefined_gain,
627 gains=predefined_gains,
628 average_across_amps=average_across_amps,
629 in_place=False,
630 )
632 # Check that the variance plane has been modified.
633 self.assertFloatsNotEqual(signal_polluted_variance.array, signal_free_variance.array)
635 # Check that the corrected variance plane has no negative values.
636 self.assertTrue(
637 np.all(signal_free_variance.array >= 0), "Variance plane has negative values (post correction)"
638 )
640 for bbox, gain in zip(self.amp_bbox_list, self.amp_gain_list):
641 # Calculate the true variance in theoretical terms.
642 true_var_amp = self.background / gain**2
643 # Pair each variance with the appropriate context manager before
644 # looping through them.
645 var_context_pairs = [
646 # For the signal-free variance, directly execute the checks.
647 (signal_free_variance, nullcontext()),
648 # For the signal-polluted variance, expect AssertionError
649 # unless we are averaging across amps.
650 (
651 signal_polluted_variance,
652 nullcontext() if estimate_average_gain else self.assertRaises(AssertionError),
653 ),
654 ]
655 for var, context_manager in var_context_pairs:
656 # Extract the segment of the variance plane for the amplifier.
657 var_amp = var[bbox]
658 with context_manager:
659 if var is signal_polluted_variance and estimate_average_gain:
660 # Skip rigorous checks on the signal-polluted variance,
661 # if we are averaging across amps.
662 pass
663 else:
664 # Get the variance value at the first pixel of the
665 # segment to compare with the rest of the pixels and
666 # the true variance.
667 v00 = var_amp.array[0, 0]
668 # Assert that the variance plane is almost uniform
669 # across the segment because the signal has been
670 # removed from it and the background is constant.
671 self.assertFloatsAlmostEqual(var_amp.array, v00, rtol=rtol)
672 # Assert that the variance plane is almost equal to the
673 # true variance across the segment.
674 self.assertFloatsAlmostEqual(v00, true_var_amp, rtol=rtol)
676 if (
677 SAVE_PLOT
678 and not average_across_amps
679 and gain_sigma_factor in (0, 0.008)
680 and sky_level == 4e6
681 and average_gain == 1.7
682 and predefined_gain_type is None
683 ):
684 fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(10, 8.5))
685 plt.subplots_adjust(wspace=0.17, hspace=0.17)
686 colorbar_aspect = 12
688 amp_background_variance_ADU_list = [self.background / gain**2 for gain in self.amp_gain_list]
689 amp_background_image_ADU_list = [self.background / gain for gain in self.amp_gain_list]
690 # Calculate the mean value that corresponds to the background for
691 # the variance plane, adjusting for the gain.
692 background_mean_variance_ADU = np.mean(
693 [self.background / gain**2 for gain in self.amp_gain_list]
694 )
696 # Extract the variance planes and the image from the exposure.
697 arr1 = signal_polluted_variance.array # Variance with signal
698 arr2 = signal_free_variance.array # Variance without signal
699 exp_im = exposure.image.clone() # Clone of the image plane
701 # Incorporate the gain-adjusted background into the image plane to
702 # enable combined visualization of sources with the background.
703 for gain, bbox in zip(self.amp_gain_list, self.amp_bbox_list):
704 exp_im[bbox].array += self.background / gain
705 arr3 = exp_im.array
707 # Define colors visually distinct from each other for the subplots.
708 original_variance_color = "#8A2BE2" # Periwinkle
709 corrected_variance_color = "#618B3C" # Lush Forest Green
710 sky_variance_color = "#c3423f" # Crimson Red
711 amp_colors = [
712 "#1f77b4", # Muted Blue
713 "#ff7f0e", # Vivid Orange
714 "#2ca02c", # Kelly Green
715 "#d62728", # Brick Red
716 "#9467bd", # Soft Purple
717 "#8B4513", # Saddle Brown
718 "#e377c2", # Pale Violet Red
719 "#202020", # Onyx
720 ]
721 arrowheads_lr = ["$\u25C0$", "$\u25B6$"] # Left- & right-pointing
722 arrowheads_ud = ["$\u25B2$", "$\u25BC$"] # Up- & down-pointing
724 # Set titles for the subplots.
725 ax1.set_title("Original variance plane", color=original_variance_color)
726 ax2.set_title("Corrected variance plane", color=corrected_variance_color)
727 ax3.set_title("Image + background ($\\mathit{uniform}$)")
728 ax4.set_title("Histogram of variances")
730 # Collect all vertical and horizontal line positions to find the
731 # amp boundaries.
732 vlines, hlines = set(), set()
733 for bbox in self.amp_bbox_list:
734 # Adjst by 0.5 for merging of lines at the boundaries.
735 vlines.update({bbox.minX - 0.5, bbox.maxX + 0.5})
736 hlines.update({bbox.minY - 0.5, bbox.maxY + 0.5})
738 # Filter lines at the edges of the overall image bbox.
739 image_bbox = exposure.getBBox()
740 vlines = {x for x in vlines if image_bbox.minX < x < image_bbox.maxX}
741 hlines = {y for y in hlines if image_bbox.minY < y < image_bbox.maxY}
743 # Plot image and variance planes.
744 for plane, arr, ax in zip(
745 ("variance", "variance_corrected", "image"), (arr1, arr2, arr3), (ax1, ax2, ax3)
746 ):
747 # We skip 'variance_corrected' in the loop below because we use
748 # the same normalization and colormap as 'variance' for it.
749 if plane in ["variance", "image"]:
750 # Get the normalization.
751 vmin, vmax = arr.min(), arr.max()
752 norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
754 # Get the thresholds corresponding to per-amp backgrounds
755 # and their positions in the normalized color scale.
756 thresholds = (
757 amp_background_variance_ADU_list
758 if plane.startswith("variance")
759 else amp_background_image_ADU_list
760 )
761 threshold_positions = [norm(t) for t in thresholds]
762 threshold = np.mean(thresholds)
763 threshold_position = np.mean(threshold_positions)
765 # Create a custom colormap with two distinct colors for the
766 # sky and source contributions.
767 border = (threshold - vmin) / (vmax - vmin)
768 colors1 = plt.cm.Purples_r(np.linspace(0, 1, int(border * 256)))
769 colors2 = plt.cm.Greens(np.linspace(0, 1, int((1 - border) * 256)))
770 colors = np.vstack((colors1, colors2))
771 cmap = mcolors.LinearSegmentedColormap.from_list("cmap", colors)
773 # Plot the array with the custom colormap and normalization.
774 im = ax.imshow(arr, cmap=cmap, norm=norm)
776 # Add colorbars to the plot.
777 cbar = fig.colorbar(im, aspect=colorbar_aspect, pad=0)
779 # Change the number of ticks on the colorbar for better
780 # spacing. Needs to be done before modifying the tick labels.
781 cbar.ax.locator_params(nbins=7)
783 # Enhance readability by scaling down colorbar tick labels.
784 unit = "ADU$^2$" if plane.startswith("variance") else "ADU"
785 adjust_tick_scale(cbar.ax, {"y": f"Value [{{scale}} {unit}]"})
787 # Mark min and max per-amp thresholds with dotted lines on the
788 # colorbar.
789 for tp in [min(thresholds), max(thresholds)]:
790 cbar.ax.axhline(tp, color="white", linestyle="-", linewidth=1.5, alpha=0.4)
791 cbar.ax.axhline(tp, color=sky_variance_color, linestyle=":", linewidth=1.5, alpha=0.9)
793 # Mark mean threshold with facing arrowheads on the colorbar.
794 cbar.ax.annotate(
795 arrowheads_lr[1], # Right-pointing arrowhead
796 xy=(0, threshold_position),
797 xycoords="axes fraction",
798 textcoords="offset points",
799 xytext=(0, 0),
800 ha="left",
801 va="center",
802 fontsize=6,
803 color=sky_variance_color,
804 clip_on=False,
805 alpha=0.9,
806 )
807 cbar.ax.annotate(
808 arrowheads_lr[0], # Left-pointing arrowhead
809 xy=(1, threshold_position),
810 xycoords="axes fraction",
811 textcoords="offset points",
812 xytext=(0, 0),
813 ha="right",
814 va="center",
815 fontsize=6,
816 color=sky_variance_color,
817 clip_on=False,
818 alpha=0.9,
819 )
821 # Add text inside the colorbar to label the average threshold
822 # position.
823 sky_level_text = "$\u27E8$" + "Sky" + "$\u27E9$" # <Sky>
824 sky_level_text_artist = cbar.ax.text(
825 0.5,
826 threshold_position,
827 sky_level_text,
828 va="center",
829 ha="center",
830 transform=cbar.ax.transAxes,
831 fontsize=8,
832 color=sky_variance_color,
833 rotation="vertical",
834 alpha=0.9,
835 path_effects=outline_effect(2),
836 )
838 # Setup renderer and transformation.
839 renderer = fig.canvas.get_renderer()
840 transform = cbar.ax.transAxes.inverted()
842 # Transform the bounding box and calculate adjustment for
843 # 'sky_level_text_artist' for when it goes beyond the colorbar.
844 sky_level_text_bbox = sky_level_text_artist.get_window_extent(renderer).transformed(transform)
845 adjustment = 1.4 * sky_level_text_bbox.height / 2
847 if sky_level_text_bbox.ymin < 0:
848 sky_level_text_artist.set_y(adjustment)
849 elif sky_level_text_bbox.ymax > 1:
850 sky_level_text_artist.set_y(1 - adjustment)
852 # Draw amp boundaries as vertical and/or horizontal lines.
853 line_color = "white" if np.mean(norm(arr)) > 0.5 else "#808080"
854 for x in vlines:
855 ax.axvline(x=x, color=line_color, linestyle="--", linewidth=1, alpha=0.7)
856 for y in hlines:
857 ax.axhline(y=y, color=line_color, linestyle="--", linewidth=1, alpha=0.7)
858 # Hide all x and y tick marks.
859 ax.tick_params(axis="both", which="both", bottom=False, top=False, left=False, right=False)
860 # Hide all x and y tick labels.
861 ax.set_xticklabels([])
862 ax.set_yticklabels([])
864 # Additional ax2 annotations:
865 # Labels amplifiers with their respective gains for a visual check.
866 for bbox, name, gain, color in zip(
867 self.amp_bbox_list, self.amp_name_list_simplified, self.amp_gain_list, amp_colors
868 ):
869 # Get the center of the bbox to label the gain value.
870 bbox_center = (bbox.minX + bbox.maxX) / 2, (bbox.minY + bbox.maxY) / 2
871 # Label the gain value at the center of each amplifier segment.
872 ax2.text(
873 *bbox_center,
874 f"gain$_{{\\rm \\, {name} \\,}}$: {gain:.3f}",
875 fontsize=9,
876 color=color,
877 alpha=0.95,
878 ha="center",
879 va="center",
880 path_effects=outline_effect(2),
881 )
883 # Additional ax3 annotations:
884 # Label sources with numbers on the image plane.
885 for i, pos in enumerate(self.positions, start=1):
886 ax3.text(
887 *pos,
888 f"{i}",
889 fontsize=7,
890 color=sky_variance_color,
891 path_effects=outline_effect(1.5),
892 alpha=0.9,
893 )
895 # Now we use ax4 to plot the histograms of the variance planes for
896 # comparison.
897 # Plot the histogram of the original variance plane.
898 hist_values, bins, _ = ax4.hist(
899 arr1.flatten(),
900 bins=80,
901 histtype="step",
902 color=original_variance_color,
903 alpha=0.9,
904 label="Original variance",
905 )
906 # Fill the area under the step.
907 ax4.fill_between(
908 bins[:-1],
909 hist_values,
910 step="post",
911 color=original_variance_color,
912 alpha=0.09,
913 hatch="/////",
914 label=" ",
915 )
916 # Plot the histogram of the corrected variance plane.
917 ax4.hist(
918 arr2.flatten(),
919 bins=80,
920 histtype="bar",
921 color=corrected_variance_color,
922 alpha=0.9,
923 label="Corrected variance",
924 )
925 adjust_tick_scale(ax4, {"x": "Variance [{scale} ADU$^2$]", "y": "Number of pixels / {scale}"})
926 ax4.yaxis.set_label_position("right")
927 ax4.yaxis.tick_right()
928 ax4.axvline(
929 background_mean_variance_ADU,
930 color=sky_variance_color,
931 linestyle="--",
932 linewidth=1,
933 alpha=0.9,
934 label="Average sky variance\nacross all amps",
935 )
937 # Use colored arrowheads to mark true amp variances.
938 sorted_vars = sorted(amp_background_variance_ADU_list)
939 count = {v: 0 for v in sorted_vars}
940 for i, (x, name, gain, color) in enumerate(
941 zip(
942 amp_background_variance_ADU_list,
943 self.amp_name_list_simplified,
944 self.amp_gain_list,
945 amp_colors,
946 )
947 ):
948 arrowhead = arrowheads_ud[int(gain < average_gain)]
949 arrowhead_text = ax4.annotate(
950 arrowhead,
951 xy=(x, 0),
952 xycoords=("data", "axes fraction"),
953 textcoords="offset points",
954 xytext=(0, 0),
955 ha="center",
956 va="bottom",
957 fontsize=6.5,
958 color=color,
959 clip_on=False,
960 alpha=0.85,
961 path_effects=outline_effect(1.5),
962 )
963 if i == 0:
964 # Draw the canvas once to make sure the renderer is active.
965 fig.canvas.draw()
966 # Get the bounding box of the text annotation in axes
967 # fraction.
968 bbox_axes = arrowhead_text.get_window_extent().transformed(ax4.transAxes.inverted())
969 # Get the height of the text annotation in axes fraction.
970 height = bbox_axes.height
971 # Increment the arrowhead y positions to avoid overlap.
972 var_idxs = np.where(sorted_vars == x)[0]
973 if len(var_idxs) > 1:
974 q = count[x]
975 count[x] += 1
976 else:
977 q = 0
978 arrowhead_text.xy = (x, var_idxs[q] * height)
980 # Create a proxy artist for the legend since annotations are
981 # not shown in the legend.
982 label = "True variance of" if i == 0 else "$\u21AA$"
983 ax4.scatter(
984 [],
985 [],
986 color=color,
987 marker=arrowhead,
988 s=15,
989 label=f"{label} {name}",
990 alpha=0.85,
991 path_effects=outline_effect(1.5),
992 )
994 # Group the legend handles and label them.
995 adjust_legend_with_groups(
996 ax4,
997 [(0, 1), 2, 3, *range(4, 4 + len(self.amp_name_list_simplified))],
998 colors="match",
999 handlelength=1.9,
1000 )
1002 # Align the histogram (bottom right panel) with the colorbar of the
1003 # corrected variance plane (top right panel) for aesthetic reasons.
1004 pos2 = ax2.get_position()
1005 pos4 = ax4.get_position()
1006 fig.canvas.draw() # Render to ensure accurate colorbar width
1007 cbar_width = cbar.ax.get_position().width
1008 ax4.set_position([pos2.x0, pos4.y0, pos2.width + cbar_width, pos4.height])
1010 # Increase all axes spines' linewidth by 20% for a bolder look.
1011 for ax in fig.get_axes():
1012 for spine in ax.spines.values():
1013 spine.set_linewidth(spine.get_linewidth() * 1.2)
1015 # Save the figure.
1016 filename = f"variance_plane_gain{average_gain}_sigma{gain_sigma_factor}_sky{sky_level}.png"
1017 fig.savefig(filename, dpi=300)
1018 print(f"Saved plot of variance plane before and after correction in {filename}")
1021class TestMemory(lsst.utils.tests.MemoryTestCase):
1022 pass
1025def setup_module(module):
1026 lsst.utils.tests.init()
1029if __name__ == "__main__": 1029 ↛ 1030line 1029 didn't jump to line 1030, because the condition on line 1029 was never true
1030 lsst.utils.tests.init()
1031 unittest.main()