Coverage for tests / test_variance_plane.py: 8%
328 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 09:00 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 09:00 +0000
1# This file is part of ip_isr.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21#
22"""The utility function under test is part of the `meas_algorithms` package.
23This unit test has been relocated here to avoid circular dependencies.
24"""
26import re
27import unittest
28from contextlib import nullcontext
30import galsim
31import lsst.utils.tests
32import matplotlib.colors as mcolors
33import matplotlib.pyplot as plt
34import numpy as np
35from lsst.ip.isr.isrMock import IsrMock
36from lsst.meas.algorithms import remove_signal_from_variance
37from lsst.utils.tests import methodParametersProduct
38from matplotlib.legend_handler import HandlerTuple
39from matplotlib.patheffects import withStroke
40from matplotlib.ticker import FixedLocator, FuncFormatter
42# Set to True to save the plot of the variance plane before and after
43# correction for a representative test case.
44SAVE_PLOT = False
47def outline_effect(lw, alpha=0.8):
48 """Generate a path effect for enhanced text visibility.
50 Parameters
51 ----------
52 lw : `float`
53 Line width of the outline.
54 alpha : `float`, optional
55 Transparency of the outline.
57 Returns
58 -------
59 `list` of `matplotlib.patheffects.withStroke`
60 A list containing the path effect.
61 """
62 return [withStroke(linewidth=lw, foreground="white", alpha=alpha)]
65class CustomHandler(HandlerTuple):
66 """Custom handler for handling grouped items in the legend."""
68 def create_artists(self, *args):
69 artists = super().create_artists(*args)
70 for a in artists:
71 a.set_transform(args[-1])
72 return artists
75def get_valid_color(handle):
76 """Extracts a valid color from a Matplotlib handle.
78 Parameters
79 ----------
80 handle : `matplotlib.artist.Artist`
81 The handle from which to extract the color.
83 Returns
84 -------
85 color : `str` or `tuple`
86 The color extracted from the handle, or 'default' if no valid color is
87 found.
88 """
89 for attr in ["get_facecolor", "get_edgecolor", "get_color"]:
90 if hasattr(handle, attr):
91 color = getattr(handle, attr)()
92 # If the handle is a collection, use the first color.
93 if isinstance(color, np.ndarray) and color.shape[0] > 0:
94 color = color[0]
95 # If the color is RGBA with alpha = 0, continue the search.
96 if len(color) == 4 and color[3] == 0:
97 continue
98 return color
99 return "default" # If no valid color is found
102def get_emptier_side(ax):
103 """Analyze a matplotlib Axes object to determine which side (left or right)
104 has more whitespace, considering cases where artists' bounding boxes span
105 both sides of the midpoint.
107 Parameters
108 ----------
109 ax : `~matplotlib.axes.Axes`
110 The Axes object to analyze.
112 Returns
113 -------
114 more_whitespace_side : `str`
115 'left' if the left side has more whitespace, or 'right' if the right
116 side does.
117 """
118 # Get the total plotting area's midpoint on the x-axis.
119 xlim = ax.get_xlim()
120 midpoint = sum(xlim) / 2
122 # Initialize areas as zero.
123 left_area, right_area = 0, 0
125 # Loop through all children (artists) in the Axes.
126 for artist in ax.get_children():
127 # Skip if artist is invisible or lacks a bounding box.
128 if not artist.get_visible() or not hasattr(artist, "get_window_extent"):
129 continue
130 bbox = artist.get_window_extent().transformed(ax.figure.dpi_scale_trans.inverted())
131 # Check if the artist's bounding box spans the midpoint.
132 if bbox.x0 < midpoint < bbox.x1:
133 # Calculate the proportion of the bbox on each side of the
134 # midpoint.
135 left_proportion = (midpoint - bbox.x0) / bbox.width
136 right_proportion = 1 - left_proportion
137 # Adjust area calculations for both sides.
138 left_area += bbox.width * bbox.height * left_proportion
139 right_area += bbox.width * bbox.height * right_proportion
140 elif bbox.x0 + bbox.width / 2 < midpoint:
141 # Entirely on the left.
142 left_area += bbox.width * bbox.height
143 else:
144 # Entirely on the right.
145 right_area += bbox.width * bbox.height
147 # Determine which side has more whitespace by comparing occupied areas.
148 return "left" if left_area <= right_area else "right"
151def adjust_legend_with_groups(ax, combine_groups, colors="default", yloc="upper", **kwargs):
152 """Adjusts the legend of a given Axes object by combining specified handles
153 based on provided groups, setting the marker location and text alignment
154 based on the inferable emptier side, and optionally setting the text color
155 of legend entries to a provided list or inferring colors from the handles.
156 Additionally, allows specifying the vertical location of the legend within
157 the plot.
159 Parameters
160 ----------
161 ax : `~matplotlib.axes.Axes`
162 The Axes object for which to adjust the legend.
163 combine_groups : `list` of `int` or iterable of `int`
164 A list that can contain a mix of individual integers and/or iterables
165 (lists, tuples, or sets) of integers. An individual integer specifies
166 the index of a single legend entry. An iterable of integers specifies
167 a group of indices to be combined into a single legend entry.
168 colors : `list` of `str` or `tuple`, or `str`, optional
169 Specifies the colors for the legend entries. This parameter can be:
170 - A list of color specifications, where each element is a string (for
171 named colors or hex values) or a tuple (for RGB or RGBA values). This
172 list explicitly assigns colors to each legend entry post-combination.
173 - A single string value:
174 - "match": Colors are inferred from the properties of the first
175 handle in each group that corresponds to a non-white-space label.
176 This aims to match the legend text color with the color of the
177 plotted data.
178 - "default": The function does not alter the default colors
179 assigned by Matplotlib, preserving the automatic color assignment
180 for all legend entries.
181 yloc : `str`, optional
182 The vertical location of the legend within the Axes. Valid options are
183 'upper', 'lower', or 'middle'. This parameter is combined with the
184 inferable emptier side ('left' or 'right') to determine the legend's
185 placement. For example, 'upper right' or 'lower left'.
186 **kwargs :
187 Keyword arguments forwarded to the ``ax.legend`` function.
188 """
190 handles, labels = ax.get_legend_handles_labels()
191 new_handles = []
192 new_labels = []
194 if colors == "match":
195 colors = []
196 infer_colors = True
197 else:
198 infer_colors = False
200 for group in combine_groups:
201 # Assume the first non-white-space label represents the group. If no
202 # such label is found, just use the first label in the group which is
203 # in fact empty.
204 if isinstance(group, (list, tuple, set)):
205 group = list(group) # Just in case
206 label_index = next((i for i in group if labels[i].strip()), group[0])
207 combined_handle = tuple(handles[i] for i in group)
208 combined_label = labels[label_index]
209 elif isinstance(group, int):
210 label_index = group
211 combined_handle = handles[group]
212 combined_label = labels[group]
213 else:
214 raise ValueError("Invalid value in 'combine_groups'")
215 new_handles.append(combined_handle)
216 new_labels.append(combined_label)
217 if infer_colors:
218 # Attempt to infer color from the representative handle in the
219 # group.
220 handle = handles[label_index]
221 color = get_valid_color(handle)
222 colors.append(color)
224 # Determine the emptier side to decide legend and text alignment.
225 emptier_side = get_emptier_side(ax)
226 markerfirst = emptier_side != "right"
228 # Create the legend with custom adjustments.
229 legend = ax.legend(
230 new_handles,
231 new_labels,
232 handler_map={tuple: CustomHandler()},
233 loc=f"{yloc} {emptier_side}",
234 fontsize=8,
235 frameon=False,
236 markerfirst=markerfirst,
237 **kwargs,
238 )
240 # Right- or left-align the legend text based on the emptier side.
241 for text in legend.get_texts():
242 text.set_ha(emptier_side)
244 # Set legend text colors if necessary.
245 if colors != "default":
246 for text, color in zip(legend.get_texts(), colors):
247 if not (isinstance(color, str) and color == "default"):
248 text.set_color(color)
251def adjust_tick_scale(ax, axis_label_templates):
252 """Scales down tick labels to make them more readable and updates axis
253 labels accordingly.
255 Calculates a power of 10 scale factor (common divisor) to reduce the
256 magnitude of tick labels. It automatically determines which axes to adjust
257 based on the provided axis label templates, which should include `{scale}`
258 for inserting the scale factor dynamically.
260 Parameters
261 ----------
262 ax : `~matplotlib.axes.Axes`
263 The Axes object to modify.
264 axis_label_templates : `dict`
265 Templates for axis labels, including '{scale}' for scale factor
266 insertion. Keys should be one or more of axes names ('x', 'y', 'z') and
267 values should be the corresponding label templates.
268 """
270 def trailing_zeros(n):
271 """Determines the number of trailing zeros in a number."""
272 return len(n := str(int(float(n))) if float(n).is_integer() else n) - len(n.rstrip("0"))
274 def format_tick(val, pos, divisor):
275 """Formats tick labels using the determined divisor."""
276 return str(int(val / divisor)) if (val / divisor).is_integer() else str(val / divisor)
278 # Iterate through the specified axes and adjust their tick labels and axis
279 # labels.
280 for axis in axis_label_templates.keys():
281 # Gather current tick labels.
282 labels = [label.get_text() for label in getattr(ax, f"get_{axis}ticklabels")()]
284 # Calculate the power of 10 divisor based on the minimum number of
285 # trailing zeros in the tick labels.
286 divisor = 10 ** min(trailing_zeros(label) for label in labels if float(label) != 0)
288 # Set a formatter for the axis ticks that scales them according to the
289 # common divisor.
290 getattr(ax, f"{axis}axis").set_major_formatter(
291 FuncFormatter(lambda val, pos: format_tick(val, pos, divisor))
292 )
294 # Ensure the tick positions remain unchanged despite the new
295 # formatting.
296 getattr(ax, f"{axis}axis").set_major_locator(FixedLocator(getattr(ax, f"get_{axis}ticks")()))
298 # Prepare 'scale', empty if divisor <= 1.
299 scale = f"{int(divisor)}" if divisor > 1 else ""
301 # Fetch the corresponding label template for the axis.
302 label_template = axis_label_templates[axis]
304 # If 'scale' is empty, remove whitespace around "{scale}" in the
305 # template. Also remove any trailing "/{scale}".
306 if scale == "":
307 label_template = re.sub(r"\s*{\s*scale\s*}\s*", "{scale}", label_template)
308 label_template = label_template.replace("/{scale}", "")
310 # Always strip remaining whitespace from the template.
311 label_template = label_template.strip()
313 # Set the formatted axis label.
314 label_text = label_template.format(scale=scale)
315 getattr(ax, f"set_{axis}label")(label_text, labelpad=8)
318class VariancePlaneTestCase(lsst.utils.tests.TestCase):
319 def setUp(self):
320 # Testing with a single detector that has 8 amplifiers in a 4x2
321 # configuration. Each amplifier measures 100x51 in dimensions.
322 config = IsrMock.ConfigClass()
323 config.isLsstLike = True
324 config.doAddBias = False
325 config.doAddDark = False
326 config.doAddFlat = False
327 config.doAddFringe = False
328 config.doGenerateImage = True
329 config.doGenerateData = True
330 config.doGenerateAmpDict = True
331 self.mock = IsrMock(config=config)
332 self.pixel_scale = 0.2 # arcsec/pixel
334 # Set the random seed for reproducibility.
335 random_seed = galsim.BaseDeviate(1905).raw() + 1
336 self.np_rng = np.random.Generator(np.random.MT19937(4))
337 self.rng = galsim.BaseDeviate(random_seed)
339 # Get the exposure, detector, and amps from the mock.
340 exposure = self.mock.getExposure()
341 detector = exposure.getDetector()
342 amps = detector.getAmplifiers()
344 # Set amp-related attributes for use in the test cases.
345 self.num_amps = len(amps)
346 self.amp_name_list = [amp.getName() for amp in amps]
347 table = str.maketrans("", "", ":,") # Remove ':' and ',' from names
348 self.amp_name_list_simplified = [name.translate(table) for name in self.amp_name_list]
350 # Get the bounding boxes for the exposure and amplifiers and convert
351 # them to galsim bounds.
352 exp_bbox = exposure.getBBox()
353 image_bounds = galsim.BoundsI(exp_bbox.minX, exp_bbox.maxX, exp_bbox.minY, exp_bbox.maxY)
354 self.amp_bbox_list = [amp.getBBox() for amp in amps]
355 self.amp_bounds_list = [galsim.BoundsI(b.minX, b.maxX, b.minY, b.maxY) for b in self.amp_bbox_list]
357 # Create a raw galsim image to potentially draw the sources onto. The
358 # exposure image that is passed to this method will be modified in
359 # place but it won't be used.
360 self.signal_free_raw_image = galsim.ImageF(exposure.image.array, bounds=image_bounds)
361 self.raw_image = self.signal_free_raw_image.copy()
363 # Define parameters for a mix of source types, including extended
364 # sources with assorted profiles as well as point sources simulated
365 # with minimal half-light radii to resemble hot pixels
366 # post-deconvolution. All flux values are given in electrons and
367 # half-light radii in pixels. The goal is for each amplifier to
368 # predominantly contain at least one source, enhancing the
369 # representativeness of test conditions.
370 source_params = [
371 {"type": "Sersic", "n": 3, "flux": 1.6e5, "half_light_radius": 3.5, "g1": -0.3, "g2": 0.2},
372 {"type": "Sersic", "n": 1, "flux": 9.3e5, "half_light_radius": 2.1, "g1": 0.25, "g2": 0.12},
373 {"type": "Sersic", "n": 4, "flux": 1.0e5, "half_light_radius": 1.1, "g1": 0.0, "g2": 0.0},
374 {"type": "Sersic", "n": 3, "flux": 1.1e6, "half_light_radius": 4.2, "g1": 0.0, "g2": 0.2},
375 {"type": "Sersic", "n": 5, "flux": 1.1e5, "half_light_radius": 3.6, "g1": 0.22, "g2": -0.05},
376 {"type": "Sersic", "n": 2, "flux": 4.3e5, "half_light_radius": 2.0, "g1": 0.0, "g2": 0.0},
377 {"type": "Sersic", "n": 6, "flux": 1.2e6, "half_light_radius": 11.0, "g1": -0.16, "g2": 0.7},
378 {"type": "Exponential", "flux": 1.3e6, "half_light_radius": 1.9, "g1": 0.3, "g2": -0.1},
379 {"type": "Exponential", "flux": 1.8e6, "half_light_radius": 5.0, "g1": 0.0, "g2": 0.14},
380 {"type": "Exponential", "flux": 6.6e6, "half_light_radius": 4.8, "g1": 0.26, "g2": 0.5},
381 {"type": "Exponential", "flux": 7.0e5, "half_light_radius": 3.1, "g1": -0.3, "g2": 0.0},
382 {"type": "DeVaucouleurs", "flux": 1.6e5, "half_light_radius": 3.5, "g1": 0.2, "g2": 0.4},
383 {"type": "DeVaucouleurs", "flux": 2.0e5, "half_light_radius": 1.6, "g1": -0.06, "g2": -0.2},
384 {"type": "DeVaucouleurs", "flux": 8.3e5, "half_light_radius": 5.1, "g1": 0.29, "g2": 0.0},
385 {"type": "DeVaucouleurs", "flux": 4.5e5, "half_light_radius": 2.5, "g1": 0.4, "g2": 0.3},
386 {"type": "DeVaucouleurs", "flux": 6.2e5, "half_light_radius": 4.9, "g1": -0.08, "g2": -0.01},
387 {"type": "Gaussian", "flux": 4.7e6, "half_light_radius": 2.5, "g1": 0.07, "g2": -0.35},
388 {"type": "Gaussian", "flux": 5.8e6, "half_light_radius": 3.1, "g1": 0.03, "g2": 0.4},
389 {"type": "Gaussian", "flux": 2.3e5, "half_light_radius": 0.5, "g1": 0.0, "g2": 0.0},
390 {"type": "Gaussian", "flux": 1.6e6, "half_light_radius": 3.0, "g1": 0.18, "g2": -0.29},
391 {"type": "Gaussian", "flux": 3.5e5, "half_light_radius": 4.6, "g1": 0.5, "g2": 0.35},
392 {"type": "Gaussian", "flux": 5.9e5, "half_light_radius": 9.5, "g1": 0.1, "g2": 0.55},
393 {"type": "Gaussian", "flux": 4.0e5, "half_light_radius": 1.0, "g1": 0.0, "g2": 0.0},
394 ]
396 # Mapping of profile types to their galsim constructors.
397 profile_constructors = {
398 "Sersic": galsim.Sersic,
399 "Exponential": galsim.Exponential,
400 "DeVaucouleurs": galsim.DeVaucouleurs,
401 "Gaussian": galsim.Gaussian,
402 }
404 # Generate random positions within exposure bounds, avoiding edges by a
405 # margin.
406 margin_x, margin_y = 0.05 * exp_bbox.width, 0.05 * exp_bbox.height
407 self.positions = self.np_rng.uniform(
408 [exp_bbox.minX + margin_x, exp_bbox.minY + margin_y],
409 [exp_bbox.maxX - margin_x, exp_bbox.maxY - margin_y],
410 (len(source_params), 2),
411 ).tolist()
413 # Loop over the sources and draw them onto the image cutout by cutout.
414 for i, params in enumerate(source_params):
415 # Dynamically get constructor and remove type from params.
416 constructor = profile_constructors[params.pop("type")]
418 # Get shear parameters and remove them from params.
419 g1, g2 = params.pop("g1"), params.pop("g2")
421 # The extent of the cutout should be large enough to contain the
422 # entire object above the background level. Some empirical factor
423 # is used to mitigate artifacts.
424 half_extent = 10 * params["half_light_radius"] * (1 + 2 * np.sqrt(g1**2 + g2**2))
426 # Pass the remaining params to the constructor and apply shear.
427 galsim_object = constructor(**params).shear(galsim.Shear(g1=g1, g2=g2))
429 # Retrieve the position of the object.
430 x, y = self.positions[i]
431 pos = galsim.PositionD(x, y)
433 # Get the bounds of the sub-image based on the object position.
434 sub_image_bounds = galsim.BoundsI(
435 *map(int, [x - half_extent, x + half_extent, y - half_extent, y + half_extent])
436 )
438 # Identify the overlap region, which could be partially outside
439 # the image bounds.
440 sub_image_bounds = sub_image_bounds & self.raw_image.bounds
442 # Check that there is some overlap.
443 assert sub_image_bounds.isDefined(), "No overlap with image bounds"
445 # Get the sub-image cutout.
446 sub_image = self.raw_image[sub_image_bounds]
448 # Draw the object onto the image within the the sub-image bounds.
449 galsim_object.drawImage(
450 image=sub_image,
451 offset=pos - sub_image.true_center,
452 method="real_space", # Saves memory, usable w/o convolution
453 add_to_image=True, # Add flux to existing image
454 scale=self.pixel_scale,
455 )
457 def tearDown(self):
458 del self.mock
459 del self.raw_image
460 del self.signal_free_raw_image
462 def buildExposure(
463 self,
464 average_gain,
465 gain_sigma_factor,
466 sky_level,
467 add_signal=True,
468 ):
469 """Build and return an exposure with different types of simulated
470 source profiles and a background sky level. It's intended for testing
471 and analysis, providing a way to generate exposures with controlled
472 conditions.
474 Parameters
475 ----------
476 average_gain : `float`
477 The average gain value of amplifiers in e-/ADU.
478 gain_sigma_factor : float
479 The standard deviation of the gain values as a factor of the
480 ``average_gain``.
481 sky_level : `float`
482 The background sky level in e-/arcsec^2.
483 add_signal : `bool`, optional
484 Whether to add sources to the exposure. If set to False, the
485 exposure will only contain background noise.
487 Returns
488 -------
489 exposure : `~lsst.afw.image.Exposure`
490 An exposure object with simulated sources and background. The units
491 are in detector counts (ADU).
492 """
494 # Get the exposure from the mock.
495 exposure = self.mock.getExposure()
497 # Convert the background sky level from e-/arcsec^2 to e-/pixel.
498 self.background = sky_level * self.pixel_scale**2
500 # Generate random deviations from the average gain across amplifiers
501 # and adjust them to ensure their sum equals zero. This reflects
502 # real-world detectors, with amplifier gains normally distributed due
503 # to manufacturing and operational variations.
504 deviations = self.np_rng.normal(average_gain, gain_sigma_factor * average_gain, size=self.num_amps)
505 deviations -= np.mean(deviations)
507 # Set the gain for amplifiers to be slightly different from each other
508 # while averaging to `average_gain`. This is to test the
509 # `average_across_amps` option in the `remove_signal_from_variance`
510 # function.
511 self.amp_gain_list = [average_gain + deviation for deviation in deviations]
513 # Add a constant background to the entire image (both in e-/pixel).
514 if add_signal:
515 image = self.raw_image + self.background
516 else:
517 image = self.signal_free_raw_image + self.background
519 # Add noise to the image which is in electrons. Note that we won't
520 # specify a `sky_level` here to avoid double-counting it, as it's
521 # already included in the image as the background.
522 image.addNoise(galsim.PoissonNoise(self.rng))
524 # Subtract off the background to get the sky-subtracted image.
525 image -= self.background
527 # Adjust each amplifier's image segment by its respective gain. After
528 # this step, the image will be in ADUs.
529 for bounds, gain in zip(self.amp_bounds_list, self.amp_gain_list):
530 image[bounds] /= gain
532 # We know that the exposure has already been modified in place, but
533 # just to be extra sure, we'll set the exposure image explicitly.
534 exposure.image.array = image.array
536 # Create a variance plane for the exposure while including signal as a
537 # pollutant. Note that the exposure image is pre-adjusted for gain,
538 # unlike 'self.background'. Thus, we divide the background by the
539 # corresponding gain before adding it to the image. This leads to the
540 # variance plane being in units of ADU^2.
541 for bbox, gain in zip(self.amp_bbox_list, self.amp_gain_list):
542 exposure.variance[bbox].array = (
543 (exposure.image[bbox].array + self.background / gain) / gain
544 ).astype(np.float32)
546 return exposure
548 def test_no_signal_handling(self):
549 """Test that the function does nearly nothing when given an image with
550 no signal.
551 """
552 # Create an exposure with no signal.
553 exposure = self.buildExposure(
554 average_gain=1.4, gain_sigma_factor=0.01, sky_level=4e6, add_signal=False
555 )
556 # Remove the signal from the variance plane, if any.
557 updated_variance = remove_signal_from_variance(exposure, in_place=False)
558 # Check that the variance plane is nearly the same as the original.
559 self.assertFloatsAlmostEqual(exposure.variance.array, updated_variance.array, rtol=0.013)
561 def test_in_place_handling(self):
562 """Make sure the function is tested to handle in-place operations."""
563 # Create an exposure with signal.
564 exposure = self.buildExposure(
565 average_gain=1.4, gain_sigma_factor=0.01, sky_level=4e6, add_signal=True
566 )
567 # Remove the signal from the variance plane.
568 updated_variance = remove_signal_from_variance(exposure, in_place=True)
569 # Retrieve the variance plane from the exposure and check that it is
570 # identical to the returned variance plane.
571 self.assertFloatsEqual(exposure.variance.array, updated_variance.array)
573 @methodParametersProduct(
574 average_gain=[1.4, 1.7],
575 predefined_gain_type=["average", "per-amp", None],
576 gain_sigma_factor=[0, 0.008],
577 sky_level=[2e6, 4e6],
578 average_across_amps=[False, True],
579 )
580 def test_variance_signal_removal(
581 self, average_gain, predefined_gain_type, gain_sigma_factor, sky_level, average_across_amps
582 ):
583 exposure = self.buildExposure(
584 average_gain=average_gain,
585 gain_sigma_factor=gain_sigma_factor,
586 sky_level=sky_level,
587 add_signal=True,
588 )
590 # Save the original variance plane for comparison, assuming it has
591 # Poisson contribution from the source signal.
592 signal_polluted_variance = exposure.variance.clone()
594 # Check that the variance plane has no negative values.
595 self.assertTrue(
596 np.all(signal_polluted_variance.array >= 0),
597 "Variance plane has negative values (pre correction)",
598 )
600 if predefined_gain_type == "average":
601 predefined_gain = average_gain
602 predefined_gains = None
603 elif predefined_gain_type == "per-amp":
604 predefined_gain = None
605 predefined_gains = {name: gain for name, gain in zip(self.amp_name_list, self.amp_gain_list)}
606 elif predefined_gain_type is None:
607 # Allow the 'remove_signal_from_variance' function to estimate the
608 # gain itself before it attempts to remove the signal from the
609 # variance plane.
610 predefined_gain = None
611 predefined_gains = None
613 # Set the relative tolerance for the variance plane checks.
614 if predefined_gain_type == "average" or (predefined_gain_type is None and average_across_amps):
615 # Relax the tolerance if we are simply averaging across amps to
616 # roughly estimate the overall gain.
617 rtol = 0.018
618 estimate_average_gain = True
619 else:
620 # Tighten tolerance for the 'predefined_gain_type' of 'per-amp' or
621 # for a more accurate per-amp gain estimation strategy.
622 rtol = 3e-7
623 estimate_average_gain = False
625 # Remove the signal from the variance plane.
626 signal_free_variance = remove_signal_from_variance(
627 exposure,
628 gain=predefined_gain,
629 gains=predefined_gains,
630 average_across_amps=average_across_amps,
631 in_place=False,
632 )
634 # Check that the variance plane has been modified.
635 self.assertFloatsNotEqual(signal_polluted_variance.array, signal_free_variance.array)
637 # Check that the corrected variance plane has no negative values.
638 self.assertTrue(
639 np.all(signal_free_variance.array >= 0), "Variance plane has negative values (post correction)"
640 )
642 for bbox, gain in zip(self.amp_bbox_list, self.amp_gain_list):
643 # Calculate the true variance in theoretical terms.
644 true_var_amp = self.background / gain**2
645 # Pair each variance with the appropriate context manager before
646 # looping through them.
647 var_context_pairs = [
648 # For the signal-free variance, directly execute the checks.
649 (signal_free_variance, nullcontext()),
650 # For the signal-polluted variance, expect AssertionError
651 # unless we are averaging across amps.
652 (
653 signal_polluted_variance,
654 nullcontext() if estimate_average_gain else self.assertRaises(AssertionError),
655 ),
656 ]
657 for var, context_manager in var_context_pairs:
658 # Extract the segment of the variance plane for the amplifier.
659 var_amp = var[bbox]
660 with context_manager:
661 if var is signal_polluted_variance and estimate_average_gain:
662 # Skip rigorous checks on the signal-polluted variance,
663 # if we are averaging across amps.
664 pass
665 else:
666 # Get the variance value at the first pixel of the
667 # segment to compare with the rest of the pixels and
668 # the true variance.
669 v00 = var_amp.array[0, 0]
670 # Assert that the variance plane is almost uniform
671 # across the segment because the signal has been
672 # removed from it and the background is constant.
673 self.assertFloatsAlmostEqual(var_amp.array, v00, rtol=rtol)
674 # Assert that the variance plane is almost equal to the
675 # true variance across the segment.
676 self.assertFloatsAlmostEqual(v00, true_var_amp, rtol=rtol)
678 if (
679 SAVE_PLOT
680 and not average_across_amps
681 and gain_sigma_factor in (0, 0.008)
682 and sky_level == 4e6
683 and average_gain == 1.7
684 and predefined_gain_type is None
685 ):
686 fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(10, 8.5))
687 plt.subplots_adjust(wspace=0.17, hspace=0.17)
688 colorbar_aspect = 12
690 amp_background_variance_ADU_list = [self.background / gain**2 for gain in self.amp_gain_list]
691 amp_background_image_ADU_list = [self.background / gain for gain in self.amp_gain_list]
692 # Calculate the mean value that corresponds to the background for
693 # the variance plane, adjusting for the gain.
694 background_mean_variance_ADU = np.mean(
695 [self.background / gain**2 for gain in self.amp_gain_list]
696 )
698 # Extract the variance planes and the image from the exposure.
699 arr1 = signal_polluted_variance.array # Variance with signal
700 arr2 = signal_free_variance.array # Variance without signal
701 exp_im = exposure.image.clone() # Clone of the image plane
703 # Incorporate the gain-adjusted background into the image plane to
704 # enable combined visualization of sources with the background.
705 for gain, bbox in zip(self.amp_gain_list, self.amp_bbox_list):
706 exp_im[bbox].array += self.background / gain
707 arr3 = exp_im.array
709 # Define colors visually distinct from each other for the subplots.
710 original_variance_color = "#8A2BE2" # Periwinkle
711 corrected_variance_color = "#618B3C" # Lush Forest Green
712 sky_variance_color = "#c3423f" # Crimson Red
713 amp_colors = [
714 "#1f77b4", # Muted Blue
715 "#ff7f0e", # Vivid Orange
716 "#2ca02c", # Kelly Green
717 "#d62728", # Brick Red
718 "#9467bd", # Soft Purple
719 "#8B4513", # Saddle Brown
720 "#e377c2", # Pale Violet Red
721 "#202020", # Onyx
722 ]
723 arrowheads_lr = ["$\u25C0$", "$\u25B6$"] # Left- & right-pointing
724 arrowheads_ud = ["$\u25B2$", "$\u25BC$"] # Up- & down-pointing
726 # Set titles for the subplots.
727 ax1.set_title("Original variance plane", color=original_variance_color)
728 ax2.set_title("Corrected variance plane", color=corrected_variance_color)
729 ax3.set_title("Image + background ($\\mathit{uniform}$)")
730 ax4.set_title("Histogram of variances")
732 # Collect all vertical and horizontal line positions to find the
733 # amp boundaries.
734 vlines, hlines = set(), set()
735 for bbox in self.amp_bbox_list:
736 # Adjst by 0.5 for merging of lines at the boundaries.
737 vlines.update({bbox.minX - 0.5, bbox.maxX + 0.5})
738 hlines.update({bbox.minY - 0.5, bbox.maxY + 0.5})
740 # Filter lines at the edges of the overall image bbox.
741 image_bbox = exposure.getBBox()
742 vlines = {x for x in vlines if image_bbox.minX < x < image_bbox.maxX}
743 hlines = {y for y in hlines if image_bbox.minY < y < image_bbox.maxY}
745 # Plot image and variance planes.
746 for plane, arr, ax in zip(
747 ("variance", "variance_corrected", "image"), (arr1, arr2, arr3), (ax1, ax2, ax3)
748 ):
749 # We skip 'variance_corrected' in the loop below because we use
750 # the same normalization and colormap as 'variance' for it.
751 if plane in ["variance", "image"]:
752 # Get the normalization.
753 vmin, vmax = arr.min(), arr.max()
754 norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
756 # Get the thresholds corresponding to per-amp backgrounds
757 # and their positions in the normalized color scale.
758 thresholds = (
759 amp_background_variance_ADU_list
760 if plane.startswith("variance")
761 else amp_background_image_ADU_list
762 )
763 threshold_positions = [norm(t) for t in thresholds]
764 threshold = np.mean(thresholds)
765 threshold_position = np.mean(threshold_positions)
767 # Create a custom colormap with two distinct colors for the
768 # sky and source contributions.
769 border = (threshold - vmin) / (vmax - vmin)
770 colors1 = plt.cm.Purples_r(np.linspace(0, 1, int(border * 256)))
771 colors2 = plt.cm.Greens(np.linspace(0, 1, int((1 - border) * 256)))
772 colors = np.vstack((colors1, colors2))
773 cmap = mcolors.LinearSegmentedColormap.from_list("cmap", colors)
775 # Plot the array with the custom colormap and normalization.
776 im = ax.imshow(arr, cmap=cmap, norm=norm)
778 # Add colorbars to the plot.
779 cbar = fig.colorbar(im, aspect=colorbar_aspect, pad=0)
781 # Change the number of ticks on the colorbar for better
782 # spacing. Needs to be done before modifying the tick labels.
783 cbar.ax.locator_params(nbins=7)
785 # Enhance readability by scaling down colorbar tick labels.
786 unit = "ADU$^2$" if plane.startswith("variance") else "ADU"
787 adjust_tick_scale(cbar.ax, {"y": f"Value [{{scale}} {unit}]"})
789 # Mark min and max per-amp thresholds with dotted lines on the
790 # colorbar.
791 for tp in [min(thresholds), max(thresholds)]:
792 cbar.ax.axhline(tp, color="white", linestyle="-", linewidth=1.5, alpha=0.4)
793 cbar.ax.axhline(tp, color=sky_variance_color, linestyle=":", linewidth=1.5, alpha=0.9)
795 # Mark mean threshold with facing arrowheads on the colorbar.
796 cbar.ax.annotate(
797 arrowheads_lr[1], # Right-pointing arrowhead
798 xy=(0, threshold_position),
799 xycoords="axes fraction",
800 textcoords="offset points",
801 xytext=(0, 0),
802 ha="left",
803 va="center",
804 fontsize=6,
805 color=sky_variance_color,
806 clip_on=False,
807 alpha=0.9,
808 )
809 cbar.ax.annotate(
810 arrowheads_lr[0], # Left-pointing arrowhead
811 xy=(1, threshold_position),
812 xycoords="axes fraction",
813 textcoords="offset points",
814 xytext=(0, 0),
815 ha="right",
816 va="center",
817 fontsize=6,
818 color=sky_variance_color,
819 clip_on=False,
820 alpha=0.9,
821 )
823 # Add text inside the colorbar to label the average threshold
824 # position.
825 sky_level_text = "$\u27E8$" + "Sky" + "$\u27E9$" # <Sky>
826 sky_level_text_artist = cbar.ax.text(
827 0.5,
828 threshold_position,
829 sky_level_text,
830 va="center",
831 ha="center",
832 transform=cbar.ax.transAxes,
833 fontsize=8,
834 color=sky_variance_color,
835 rotation="vertical",
836 alpha=0.9,
837 path_effects=outline_effect(2),
838 )
840 # Setup renderer and transformation.
841 renderer = fig.canvas.get_renderer()
842 transform = cbar.ax.transAxes.inverted()
844 # Transform the bounding box and calculate adjustment for
845 # 'sky_level_text_artist' for when it goes beyond the colorbar.
846 sky_level_text_bbox = sky_level_text_artist.get_window_extent(renderer).transformed(transform)
847 adjustment = 1.4 * sky_level_text_bbox.height / 2
849 if sky_level_text_bbox.ymin < 0:
850 sky_level_text_artist.set_y(adjustment)
851 elif sky_level_text_bbox.ymax > 1:
852 sky_level_text_artist.set_y(1 - adjustment)
854 # Draw amp boundaries as vertical and/or horizontal lines.
855 line_color = "white" if np.mean(norm(arr)) > 0.5 else "#808080"
856 for x in vlines:
857 ax.axvline(x=x, color=line_color, linestyle="--", linewidth=1, alpha=0.7)
858 for y in hlines:
859 ax.axhline(y=y, color=line_color, linestyle="--", linewidth=1, alpha=0.7)
860 # Hide all x and y tick marks.
861 ax.tick_params(axis="both", which="both", bottom=False, top=False, left=False, right=False)
862 # Hide all x and y tick labels.
863 ax.set_xticklabels([])
864 ax.set_yticklabels([])
866 # Additional ax2 annotations:
867 # Labels amplifiers with their respective gains for a visual check.
868 for bbox, name, gain, color in zip(
869 self.amp_bbox_list, self.amp_name_list_simplified, self.amp_gain_list, amp_colors
870 ):
871 # Get the center of the bbox to label the gain value.
872 bbox_center = (bbox.minX + bbox.maxX) / 2, (bbox.minY + bbox.maxY) / 2
873 # Label the gain value at the center of each amplifier segment.
874 ax2.text(
875 *bbox_center,
876 f"gain$_{{\\rm \\, {name} \\,}}$: {gain:.3f}",
877 fontsize=9,
878 color=color,
879 alpha=0.95,
880 ha="center",
881 va="center",
882 path_effects=outline_effect(2),
883 )
885 # Additional ax3 annotations:
886 # Label sources with numbers on the image plane.
887 for i, pos in enumerate(self.positions, start=1):
888 ax3.text(
889 *pos,
890 f"{i}",
891 fontsize=7,
892 color=sky_variance_color,
893 path_effects=outline_effect(1.5),
894 alpha=0.9,
895 )
897 # Now we use ax4 to plot the histograms of the variance planes for
898 # comparison.
899 # Plot the histogram of the original variance plane.
900 hist_values, bins, _ = ax4.hist(
901 arr1.flatten(),
902 bins=80,
903 histtype="step",
904 color=original_variance_color,
905 alpha=0.9,
906 label="Original variance",
907 )
908 # Fill the area under the step.
909 ax4.fill_between(
910 bins[:-1],
911 hist_values,
912 step="post",
913 color=original_variance_color,
914 alpha=0.09,
915 hatch="/////",
916 label=" ",
917 )
918 # Plot the histogram of the corrected variance plane.
919 ax4.hist(
920 arr2.flatten(),
921 bins=80,
922 histtype="bar",
923 color=corrected_variance_color,
924 alpha=0.9,
925 label="Corrected variance",
926 )
927 adjust_tick_scale(ax4, {"x": "Variance [{scale} ADU$^2$]", "y": "Number of pixels / {scale}"})
928 ax4.yaxis.set_label_position("right")
929 ax4.yaxis.tick_right()
930 ax4.axvline(
931 background_mean_variance_ADU,
932 color=sky_variance_color,
933 linestyle="--",
934 linewidth=1,
935 alpha=0.9,
936 label="Average sky variance\nacross all amps",
937 )
939 # Use colored arrowheads to mark true amp variances.
940 sorted_vars = sorted(amp_background_variance_ADU_list)
941 count = {v: 0 for v in sorted_vars}
942 for i, (x, name, gain, color) in enumerate(
943 zip(
944 amp_background_variance_ADU_list,
945 self.amp_name_list_simplified,
946 self.amp_gain_list,
947 amp_colors,
948 )
949 ):
950 arrowhead = arrowheads_ud[int(gain < average_gain)]
951 arrowhead_text = ax4.annotate(
952 arrowhead,
953 xy=(x, 0),
954 xycoords=("data", "axes fraction"),
955 textcoords="offset points",
956 xytext=(0, 0),
957 ha="center",
958 va="bottom",
959 fontsize=6.5,
960 color=color,
961 clip_on=False,
962 alpha=0.85,
963 path_effects=outline_effect(1.5),
964 )
965 if i == 0:
966 # Draw the canvas once to make sure the renderer is active.
967 fig.canvas.draw()
968 # Get the bounding box of the text annotation in axes
969 # fraction.
970 bbox_axes = arrowhead_text.get_window_extent().transformed(ax4.transAxes.inverted())
971 # Get the height of the text annotation in axes fraction.
972 height = bbox_axes.height
973 # Increment the arrowhead y positions to avoid overlap.
974 var_idxs = np.where(sorted_vars == x)[0]
975 if len(var_idxs) > 1:
976 q = count[x]
977 count[x] += 1
978 else:
979 q = 0
980 arrowhead_text.xy = (x, var_idxs[q] * height)
982 # Create a proxy artist for the legend since annotations are
983 # not shown in the legend.
984 label = "True variance of" if i == 0 else "$\u21AA$"
985 ax4.scatter(
986 [],
987 [],
988 color=color,
989 marker=arrowhead,
990 s=15,
991 label=f"{label} {name}",
992 alpha=0.85,
993 path_effects=outline_effect(1.5),
994 )
996 # Group the legend handles and label them.
997 adjust_legend_with_groups(
998 ax4,
999 [(0, 1), 2, 3, *range(4, 4 + len(self.amp_name_list_simplified))],
1000 colors="match",
1001 handlelength=1.9,
1002 )
1004 # Align the histogram (bottom right panel) with the colorbar of the
1005 # corrected variance plane (top right panel) for aesthetic reasons.
1006 pos2 = ax2.get_position()
1007 pos4 = ax4.get_position()
1008 fig.canvas.draw() # Render to ensure accurate colorbar width
1009 cbar_width = cbar.ax.get_position().width
1010 ax4.set_position([pos2.x0, pos4.y0, pos2.width + cbar_width, pos4.height])
1012 # Increase all axes spines' linewidth by 20% for a bolder look.
1013 for ax in fig.get_axes():
1014 for spine in ax.spines.values():
1015 spine.set_linewidth(spine.get_linewidth() * 1.2)
1017 # Save the figure.
1018 filename = f"variance_plane_gain{average_gain}_sigma{gain_sigma_factor}_sky{sky_level}.png"
1019 fig.savefig(filename, dpi=300)
1020 print(f"Saved plot of variance plane before and after correction in {filename}")
1023class TestMemory(lsst.utils.tests.MemoryTestCase):
1024 pass
1027def setup_module(module):
1028 lsst.utils.tests.init()
1031if __name__ == "__main__": 1031 ↛ 1032line 1031 didn't jump to line 1032 because the condition on line 1031 was never true
1032 lsst.utils.tests.init()
1033 unittest.main()