Coverage for tests/test_variance_plane.py: 8%
324 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 03:29 -0700
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 03:29 -0700
1# This file is part of ip_isr.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21#
22"""The utility function under test is part of the `meas_algorithms` package.
23This unit test has been relocated here to avoid circular dependencies.
24"""
26import re
27import unittest
28from contextlib import nullcontext
30import galsim
31import lsst.utils.tests
32import matplotlib.colors as mcolors
33import matplotlib.pyplot as plt
34import numpy as np
35from lsst.ip.isr.isrMock import IsrMock
36from lsst.meas.algorithms import remove_signal_from_variance
37from lsst.utils.tests import methodParametersProduct
38from matplotlib.legend_handler import HandlerTuple
39from matplotlib.patheffects import withStroke
40from matplotlib.ticker import FixedLocator, FuncFormatter
42# Set to True to save the plot of the variance plane before and after
43# correction for a representative test case.
44SAVE_PLOT = False
47def outline_effect(lw, alpha=0.8):
48 """Generate a path effect for enhanced text visibility.
50 Parameters
51 ----------
52 lw : `float`
53 Line width of the outline.
54 alpha : `float`, optional
55 Transparency of the outline.
57 Returns
58 -------
59 `list` of `matplotlib.patheffects.withStroke`
60 A list containing the path effect.
61 """
62 return [withStroke(linewidth=lw, foreground="white", alpha=alpha)]
65class CustomHandler(HandlerTuple):
66 """Custom handler for handling grouped items in the legend."""
68 def create_artists(self, *args):
69 artists = super().create_artists(*args)
70 for a in artists:
71 a.set_transform(args[-1])
72 return artists
75def get_valid_color(handle):
76 """Extracts a valid color from a Matplotlib handle.
78 Parameters
79 ----------
80 handle : `matplotlib.artist.Artist`
81 The handle from which to extract the color.
83 Returns
84 -------
85 color : `str` or `tuple`
86 The color extracted from the handle, or 'default' if no valid color is
87 found.
88 """
89 for attr in ["get_facecolor", "get_edgecolor", "get_color"]:
90 if hasattr(handle, attr):
91 color = getattr(handle, attr)()
92 # If the handle is a collection, use the first color.
93 if isinstance(color, np.ndarray) and color.shape[0] > 0:
94 color = color[0]
95 # If the color is RGBA with alpha = 0, continue the search.
96 if len(color) == 4 and color[3] == 0:
97 continue
98 return color
99 return "default" # If no valid color is found
102def get_emptier_side(ax):
103 """Analyze a matplotlib Axes object to determine which side (left or right)
104 has more whitespace, considering cases where artists' bounding boxes span
105 both sides of the midpoint.
107 Parameters
108 ----------
109 ax : `~matplotlib.axes.Axes`
110 The Axes object to analyze.
112 Returns
113 -------
114 more_whitespace_side : `str`
115 'left' if the left side has more whitespace, or 'right' if the right
116 side does.
117 """
118 # Get the total plotting area's midpoint on the x-axis.
119 xlim = ax.get_xlim()
120 midpoint = sum(xlim) / 2
122 # Initialize areas as zero.
123 left_area, right_area = 0, 0
125 # Loop through all children (artists) in the Axes.
126 for artist in ax.get_children():
127 # Skip if artist is invisible or lacks a bounding box.
128 if not artist.get_visible() or not hasattr(artist, "get_window_extent"):
129 continue
130 bbox = artist.get_window_extent().transformed(ax.figure.dpi_scale_trans.inverted())
131 # Check if the artist's bounding box spans the midpoint.
132 if bbox.x0 < midpoint < bbox.x1:
133 # Calculate the proportion of the bbox on each side of the
134 # midpoint.
135 left_proportion = (midpoint - bbox.x0) / bbox.width
136 right_proportion = 1 - left_proportion
137 # Adjust area calculations for both sides.
138 left_area += bbox.width * bbox.height * left_proportion
139 right_area += bbox.width * bbox.height * right_proportion
140 elif bbox.x0 + bbox.width / 2 < midpoint:
141 # Entirely on the left.
142 left_area += bbox.width * bbox.height
143 else:
144 # Entirely on the right.
145 right_area += bbox.width * bbox.height
147 # Determine which side has more whitespace by comparing occupied areas.
148 return "left" if left_area <= right_area else "right"
151def adjust_legend_with_groups(ax, combine_groups, colors="default", yloc="upper", **kwargs):
152 """Adjusts the legend of a given Axes object by combining specified handles
153 based on provided groups, setting the marker location and text alignment
154 based on the inferable emptier side, and optionally setting the text color
155 of legend entries to a provided list or inferring colors from the handles.
156 Additionally, allows specifying the vertical location of the legend within
157 the plot.
159 Parameters
160 ----------
161 ax : `~matplotlib.axes.Axes`
162 The Axes object for which to adjust the legend.
163 combine_groups : `list` of `int` or iterable of `int`
164 A list that can contain a mix of individual integers and/or iterables
165 (lists, tuples, or sets) of integers. An individual integer specifies
166 the index of a single legend entry. An iterable of integers specifies
167 a group of indices to be combined into a single legend entry.
168 colors : `list` of `str` or `tuple`, or `str`, optional
169 Specifies the colors for the legend entries. This parameter can be:
170 - A list of color specifications, where each element is a string (for
171 named colors or hex values) or a tuple (for RGB or RGBA values). This
172 list explicitly assigns colors to each legend entry post-combination.
173 - A single string value:
174 - "match": Colors are inferred from the properties of the first
175 handle in each group that corresponds to a non-white-space label.
176 This aims to match the legend text color with the color of the
177 plotted data.
178 - "default": The function does not alter the default colors
179 assigned by Matplotlib, preserving the automatic color assignment
180 for all legend entries.
181 yloc : `str`, optional
182 The vertical location of the legend within the Axes. Valid options are
183 'upper', 'lower', or 'middle'. This parameter is combined with the
184 inferable emptier side ('left' or 'right') to determine the legend's
185 placement. For example, 'upper right' or 'lower left'.
186 **kwargs :
187 Keyword arguments forwarded to the ``ax.legend`` function.
188 """
190 handles, labels = ax.get_legend_handles_labels()
191 new_handles = []
192 new_labels = []
194 if colors == "match":
195 colors = []
196 infer_colors = True
197 else:
198 infer_colors = False
200 for group in combine_groups:
201 # Assume the first non-white-space label represents the group. If no
202 # such label is found, just use the first label in the group which is
203 # in fact empty.
204 if isinstance(group, (list, tuple, set)):
205 group = list(group) # Just in case
206 label_index = next((i for i in group if labels[i].strip()), group[0])
207 combined_handle = tuple(handles[i] for i in group)
208 combined_label = labels[label_index]
209 elif isinstance(group, int):
210 label_index = group
211 combined_handle = handles[group]
212 combined_label = labels[group]
213 else:
214 raise ValueError("Invalid value in 'combine_groups'")
215 new_handles.append(combined_handle)
216 new_labels.append(combined_label)
217 if infer_colors:
218 # Attempt to infer color from the representative handle in the
219 # group.
220 handle = handles[label_index]
221 color = get_valid_color(handle)
222 colors.append(color)
224 # Determine the emptier side to decide legend and text alignment.
225 emptier_side = get_emptier_side(ax)
226 markerfirst = emptier_side != "right"
228 # Create the legend with custom adjustments.
229 legend = ax.legend(
230 new_handles,
231 new_labels,
232 handler_map={tuple: CustomHandler()},
233 loc=f"{yloc} {emptier_side}",
234 fontsize=8,
235 frameon=False,
236 markerfirst=markerfirst,
237 **kwargs,
238 )
240 # Right- or left-align the legend text based on the emptier side.
241 for text in legend.get_texts():
242 text.set_ha(emptier_side)
244 # Set legend text colors if necessary.
245 if colors != "default":
246 for text, color in zip(legend.get_texts(), colors):
247 if not (isinstance(color, str) and color == "default"):
248 text.set_color(color)
251def adjust_tick_scale(ax, axis_label_templates):
252 """Scales down tick labels to make them more readable and updates axis
253 labels accordingly.
255 Calculates a power of 10 scale factor (common divisor) to reduce the
256 magnitude of tick labels. It automatically determines which axes to adjust
257 based on the provided axis label templates, which should include `{scale}`
258 for inserting the scale factor dynamically.
260 Parameters
261 ----------
262 ax : `~matplotlib.axes.Axes`
263 The Axes object to modify.
264 axis_label_templates : `dict`
265 Templates for axis labels, including '{scale}' for scale factor
266 insertion. Keys should be one or more of axes names ('x', 'y', 'z') and
267 values should be the corresponding label templates.
268 """
270 def trailing_zeros(n):
271 """Determines the number of trailing zeros in a number."""
272 return len(n := str(int(float(n))) if float(n).is_integer() else n) - len(n.rstrip("0"))
274 def format_tick(val, pos, divisor):
275 """Formats tick labels using the determined divisor."""
276 return str(int(val / divisor)) if (val / divisor).is_integer() else str(val / divisor)
278 # Iterate through the specified axes and adjust their tick labels and axis
279 # labels.
280 for axis in axis_label_templates.keys():
281 # Gather current tick labels.
282 labels = [label.get_text() for label in getattr(ax, f"get_{axis}ticklabels")()]
284 # Calculate the power of 10 divisor based on the minimum number of
285 # trailing zeros in the tick labels.
286 divisor = 10 ** min(trailing_zeros(label) for label in labels if float(label) != 0)
288 # Set a formatter for the axis ticks that scales them according to the
289 # common divisor.
290 getattr(ax, f"{axis}axis").set_major_formatter(
291 FuncFormatter(lambda val, pos: format_tick(val, pos, divisor))
292 )
294 # Ensure the tick positions remain unchanged despite the new
295 # formatting.
296 getattr(ax, f"{axis}axis").set_major_locator(FixedLocator(getattr(ax, f"get_{axis}ticks")()))
298 # Prepare 'scale', empty if divisor <= 1.
299 scale = f"{int(divisor)}" if divisor > 1 else ""
301 # Fetch the corresponding label template for the axis.
302 label_template = axis_label_templates[axis]
304 # If 'scale' is empty, remove whitespace around "{scale}" in the
305 # template. Also remove any trailing "/{scale}".
306 if scale == "":
307 label_template = re.sub(r"\s*{\s*scale\s*}\s*", "{scale}", label_template)
308 label_template = label_template.replace("/{scale}", "")
310 # Always strip remaining whitespace from the template.
311 label_template = label_template.strip()
313 # Set the formatted axis label.
314 label_text = label_template.format(scale=scale)
315 getattr(ax, f"set_{axis}label")(label_text, labelpad=8)
318class VariancePlaneTestCase(lsst.utils.tests.TestCase):
319 def setUp(self):
320 # Testing with a single detector that has 8 amplifiers in a 4x2
321 # configuration. Each amplifier measures 100x51 in dimensions.
322 config = IsrMock.ConfigClass()
323 config.isLsstLike = True
324 config.doAddBias = False
325 config.doAddDark = False
326 config.doAddFlat = False
327 config.doAddFringe = False
328 config.doGenerateImage = True
329 config.doGenerateData = True
330 config.doGenerateAmpDict = True
331 self.mock = IsrMock(config=config)
333 def tearDown(self):
334 del self.mock
336 def buildExposure(
337 self,
338 average_gain,
339 gain_sigma_factor,
340 sky_level,
341 add_signal=True,
342 ):
343 """Build and return an exposure with different types of simulated
344 source profiles and a background sky level. It's intended for testing
345 and analysis, providing a way to generate exposures with controlled
346 conditions.
348 Parameters
349 ----------
350 average_gain : `float`
351 The average gain value of amplifiers in e-/ADU.
352 gain_sigma_factor : float
353 The standard deviation of the gain values as a factor of the
354 ``average_gain``.
355 sky_level : `float`
356 The background sky level in e-/arcsec^2.
358 Returns
359 -------
360 exposure : `~lsst.afw.image.Exposure`
361 An exposure object with simulated sources and background. The units
362 are in detector counts (ADU).
363 """
365 # Set the random seed for reproducibility.
366 random_seed = galsim.BaseDeviate(1905).raw() + 1
367 np.random.seed(random_seed)
368 rng = galsim.BaseDeviate(random_seed)
370 # Get the exposure, detector, and amps from the mock.
371 exposure = self.mock.getExposure()
372 detector = exposure.getDetector()
373 amps = detector.getAmplifiers()
374 num_amps = len(amps)
375 self.amp_name_list = [amp.getName() for amp in amps]
376 table = str.maketrans("", "", ":,") # Remove ':' and ',' from names
377 self.amp_name_list_simplified = [name.translate(table) for name in self.amp_name_list]
379 # Adjust instrument and observation parameters to some nominal values.
380 pixel_scale = 0.2 # arcsec/pixel
381 self.background = sky_level * pixel_scale**2 # e-/pixel
383 # Get the bounding boxes for the exposure and amplifiers and convert
384 # them to galsim bounds.
385 exp_bbox = exposure.getBBox()
386 image_bounds = galsim.BoundsI(exp_bbox.minX, exp_bbox.maxX, exp_bbox.minY, exp_bbox.maxY)
387 self.amp_bbox_list = [amp.getBBox() for amp in amps]
388 amp_bounds_list = [galsim.BoundsI(b.minX, b.maxX, b.minY, b.maxY) for b in self.amp_bbox_list]
390 # Generate random deviations from the average gain across amplifiers
391 # and adjust them to ensure their sum equals zero. This reflects
392 # real-world detectors, with amplifier gains normally distributed due
393 # to manufacturing and operational variations.
394 deviations = np.random.normal(average_gain, gain_sigma_factor * average_gain, size=num_amps)
395 deviations -= np.mean(deviations)
397 # Set the gain for amplifiers to be slightly different from each other
398 # while averaging to `average_gain`. This is to test the
399 # `average_across_amps` option in the `remove_signal_from_variance`
400 # function.
401 self.amp_gain_list = [average_gain + deviation for deviation in deviations]
403 # Create a galsim image to potentially draw the sources onto. The
404 # exposure image that is passed to this method will be modified in
405 # place.
406 image = galsim.ImageF(exposure.image.array, bounds=image_bounds)
408 if add_signal:
409 # Define parameters for a mix of source types, including extended
410 # sources with assorted profiles as well as point sources simulated
411 # with minimal half-light radii to resemble hot pixels
412 # post-deconvolution. All flux values are given in electrons and
413 # half-light radii in pixels. The goal is for each amplifier to
414 # predominantly contain at least one source, enhancing the
415 # representativeness of test conditions.
416 source_params = [
417 {"type": "Sersic", "n": 3, "flux": 1.6e5, "half_light_radius": 3.5, "g1": -0.3, "g2": 0.2},
418 {"type": "Sersic", "n": 1, "flux": 9.3e5, "half_light_radius": 2.1, "g1": 0.25, "g2": 0.12},
419 {"type": "Sersic", "n": 4, "flux": 1.0e5, "half_light_radius": 1.1, "g1": 0.0, "g2": 0.0},
420 {"type": "Sersic", "n": 3, "flux": 1.1e6, "half_light_radius": 4.2, "g1": 0.0, "g2": 0.2},
421 {"type": "Sersic", "n": 5, "flux": 1.1e5, "half_light_radius": 3.6, "g1": 0.22, "g2": -0.05},
422 {"type": "Sersic", "n": 2, "flux": 4.3e5, "half_light_radius": 2.0, "g1": 0.0, "g2": 0.0},
423 {"type": "Sersic", "n": 6, "flux": 1.2e6, "half_light_radius": 11.0, "g1": -0.16, "g2": 0.7},
424 {"type": "Exponential", "flux": 1.3e6, "half_light_radius": 1.9, "g1": 0.3, "g2": -0.1},
425 {"type": "Exponential", "flux": 1.8e6, "half_light_radius": 5.0, "g1": 0.0, "g2": 0.14},
426 {"type": "Exponential", "flux": 6.6e6, "half_light_radius": 4.8, "g1": 0.26, "g2": 0.5},
427 {"type": "Exponential", "flux": 7.0e5, "half_light_radius": 3.1, "g1": -0.3, "g2": 0.0},
428 {"type": "DeVaucouleurs", "flux": 1.6e5, "half_light_radius": 3.5, "g1": 0.2, "g2": 0.4},
429 {"type": "DeVaucouleurs", "flux": 2.0e5, "half_light_radius": 1.6, "g1": -0.06, "g2": -0.2},
430 {"type": "DeVaucouleurs", "flux": 8.3e5, "half_light_radius": 5.1, "g1": 0.29, "g2": 0.0},
431 {"type": "DeVaucouleurs", "flux": 4.5e5, "half_light_radius": 2.5, "g1": 0.4, "g2": 0.3},
432 {"type": "DeVaucouleurs", "flux": 6.2e5, "half_light_radius": 4.9, "g1": -0.08, "g2": -0.01},
433 {"type": "Gaussian", "flux": 4.7e6, "half_light_radius": 2.5, "g1": 0.07, "g2": -0.35},
434 {"type": "Gaussian", "flux": 5.8e6, "half_light_radius": 3.1, "g1": 0.03, "g2": 0.4},
435 {"type": "Gaussian", "flux": 2.3e5, "half_light_radius": 0.5, "g1": 0.0, "g2": 0.0},
436 {"type": "Gaussian", "flux": 1.6e6, "half_light_radius": 3.0, "g1": 0.18, "g2": -0.29},
437 {"type": "Gaussian", "flux": 3.5e5, "half_light_radius": 4.6, "g1": 0.5, "g2": 0.35},
438 {"type": "Gaussian", "flux": 5.9e5, "half_light_radius": 9.5, "g1": 0.1, "g2": 0.55},
439 {"type": "Gaussian", "flux": 4.0e5, "half_light_radius": 1.0, "g1": 0.0, "g2": 0.0},
440 ]
442 # Mapping of profile types to their galsim constructors.
443 profile_constructors = {
444 "Sersic": galsim.Sersic,
445 "Exponential": galsim.Exponential,
446 "DeVaucouleurs": galsim.DeVaucouleurs,
447 "Gaussian": galsim.Gaussian,
448 }
450 # Generate random positions within exposure bounds, avoiding edges
451 # by a margin.
452 margin_x, margin_y = 0.05 * exp_bbox.width, 0.05 * exp_bbox.height
453 self.positions = np.random.uniform(
454 [exp_bbox.minX + margin_x, exp_bbox.minY + margin_y],
455 [exp_bbox.maxX - margin_x, exp_bbox.maxY - margin_y],
456 (len(source_params), 2),
457 ).tolist()
459 # Loop over the sources and draw them onto the image cutout by
460 # cutout.
461 for i, params in enumerate(source_params):
462 # Dynamically get constructor and remove type from params.
463 constructor = profile_constructors[params.pop("type")]
465 # Get shear parameters and remove them from params.
466 g1, g2 = params.pop("g1"), params.pop("g2")
468 # The extent of the cutout should be large enough to contain
469 # the entire object above the background level. Some empirical
470 # factor is used to mitigate artifacts.
471 half_extent = 10 * params["half_light_radius"] * (1 + 2 * np.sqrt(g1**2 + g2**2))
473 # Pass the remaining params to the constructor and apply shear.
474 galsim_object = constructor(**params).shear(galsim.Shear(g1=g1, g2=g2))
476 # Retrieve the position of the object.
477 x, y = self.positions[i]
478 pos = galsim.PositionD(x, y)
480 # Get the bounds of the sub-image based on the object position.
481 sub_image_bounds = galsim.BoundsI(
482 *map(int, [x - half_extent, x + half_extent, y - half_extent, y + half_extent])
483 )
485 # Identify the overlap region, which could be partially outside
486 # the image bounds.
487 sub_image_bounds = sub_image_bounds & image.bounds
489 # Check that there is some overlap.
490 assert sub_image_bounds.isDefined(), "No overlap with image bounds"
492 # Get the sub-image cutout.
493 sub_image = image[sub_image_bounds]
495 # Draw the object onto the image within the the sub-image
496 # bounds.
497 galsim_object.drawImage(
498 image=sub_image,
499 offset=pos - sub_image.true_center,
500 method="real_space", # Save memory, usable w/o convolution
501 add_to_image=True, # Add flux to existing image
502 scale=pixel_scale,
503 )
505 # Add a constant background to the entire image (both in e-/pixel).
506 image += self.background
508 # Add noise to the image which is in electrons. Note that we won't
509 # specify a `sky_level` here to avoid double-counting it, as it's
510 # already included as the background.
511 image.addNoise(galsim.PoissonNoise(rng))
513 # Subtract off the background to get the sky-subtracted image.
514 image -= self.background
516 # Adjust each amplifier's image segment by its respective gain. After
517 # this step, the image will be in ADUs.
518 for bounds, gain in zip(amp_bounds_list, self.amp_gain_list):
519 image[bounds] /= gain
521 # We know that the exposure has already been modified in place, but
522 # just to be extra sure, we'll set the exposure image explicitly.
523 exposure.image.array = image.array
525 # Create a variance plane for the exposure while including signal as a
526 # pollutant. Note that the exposure image is pre-adjusted for gain,
527 # unlike 'self.background'. Thus, we divide the background by the
528 # corresponding gain before adding it to the image. This leads to the
529 # variance plane being in units of ADU^2.
530 for bbox, gain in zip(self.amp_bbox_list, self.amp_gain_list):
531 exposure.variance[bbox].array = (exposure.image[bbox].array + self.background / gain) / gain
533 return exposure
535 def test_no_signal_handling(self):
536 """Test that the function does nearly nothing when given an image with
537 no signal.
538 """
539 # Create an exposure with no signal.
540 exposure = self.buildExposure(
541 average_gain=1.4, gain_sigma_factor=0.01, sky_level=4e6, add_signal=False
542 )
543 # Remove the signal from the variance plane, if any.
544 updated_variance = remove_signal_from_variance(exposure, in_place=False)
545 # Check that the variance plane is nearly the same as the original.
546 self.assertFloatsAlmostEqual(exposure.variance.array, updated_variance.array, rtol=0.013)
548 def test_in_place_handling(self):
549 """Make sure the function is tested to handle in-place operations."""
550 # Create an exposure with signal.
551 exposure = self.buildExposure(
552 average_gain=1.4, gain_sigma_factor=0.01, sky_level=4e6, add_signal=True
553 )
554 # Remove the signal from the variance plane.
555 updated_variance = remove_signal_from_variance(exposure, in_place=True)
556 # Retrieve the variance plane from the exposure and check that it is
557 # identical to the returned variance plane.
558 self.assertFloatsEqual(exposure.variance.array, updated_variance.array)
560 @methodParametersProduct(
561 average_gain=[1.4, 1.7],
562 predefined_gain_type=["average", "per-amp", None],
563 gain_sigma_factor=[0, 0.008],
564 sky_level=[2e6, 4e6],
565 average_across_amps=[False, True],
566 )
567 def test_variance_signal_removal(
568 self, average_gain, predefined_gain_type, gain_sigma_factor, sky_level, average_across_amps
569 ):
570 exposure = self.buildExposure(
571 average_gain=average_gain,
572 gain_sigma_factor=gain_sigma_factor,
573 sky_level=sky_level,
574 add_signal=True,
575 )
577 # Save the original variance plane for comparison, assuming it has
578 # Poisson contribution from the source signal.
579 signal_polluted_variance = exposure.variance.clone()
581 # Check that the variance plane has no negative values.
582 self.assertTrue(
583 np.all(signal_polluted_variance.array >= 0),
584 "Variance plane has negative values (pre correction)",
585 )
587 if predefined_gain_type == "average":
588 predefined_gain = average_gain
589 predefined_gains = None
590 elif predefined_gain_type == "per-amp":
591 predefined_gain = None
592 predefined_gains = {name: gain for name, gain in zip(self.amp_name_list, self.amp_gain_list)}
593 elif predefined_gain_type is None:
594 # Allow the 'remove_signal_from_variance' function to estimate the
595 # gain itself before it attempts to remove the signal from the
596 # variance plane.
597 predefined_gain = None
598 predefined_gains = None
600 # Set the relative tolerance for the variance plane checks.
601 if predefined_gain_type == "average" or (predefined_gain_type is None and average_across_amps):
602 # Relax the tolerance if we are simply averaging across amps to
603 # roughly estimate the overall gain.
604 rtol = 0.015
605 estimate_average_gain = True
606 else:
607 # Tighten tolerance for the 'predefined_gain_type' of 'per-amp' or
608 # for a more accurate per-amp gain estimation strategy.
609 rtol = 2e-7
610 estimate_average_gain = False
612 # Remove the signal from the variance plane.
613 signal_free_variance = remove_signal_from_variance(
614 exposure,
615 gain=predefined_gain,
616 gains=predefined_gains,
617 average_across_amps=average_across_amps,
618 in_place=False,
619 )
621 # Check that the variance plane has been modified.
622 self.assertFloatsNotEqual(signal_polluted_variance.array, signal_free_variance.array)
624 # Check that the corrected variance plane has no negative values.
625 self.assertTrue(
626 np.all(signal_free_variance.array >= 0), "Variance plane has negative values (post correction)"
627 )
629 for bbox, gain in zip(self.amp_bbox_list, self.amp_gain_list):
630 # Calculate the true variance in theoretical terms.
631 true_var_amp = self.background / gain**2
632 # Pair each variance with the appropriate context manager before
633 # looping through them.
634 var_context_pairs = [
635 # For the signal-free variance, directly execute the checks.
636 (signal_free_variance, nullcontext()),
637 # For the signal-polluted variance, expect AssertionError
638 # unless we are averaging across amps.
639 (
640 signal_polluted_variance,
641 nullcontext() if estimate_average_gain else self.assertRaises(AssertionError),
642 ),
643 ]
644 for var, context_manager in var_context_pairs:
645 # Extract the segment of the variance plane for the amplifier.
646 var_amp = var[bbox]
647 with context_manager:
648 if var is signal_polluted_variance and estimate_average_gain:
649 # Skip rigorous checks on the signal-polluted variance,
650 # if we are averaging across amps.
651 pass
652 else:
653 # Get the variance value at the first pixel of the
654 # segment to compare with the rest of the pixels and
655 # the true variance.
656 v00 = var_amp.array[0, 0]
657 # Assert that the variance plane is almost uniform
658 # across the segment because the signal has been
659 # removed from it and the background is constant.
660 self.assertFloatsAlmostEqual(var_amp.array, v00, rtol=rtol)
661 # Assert that the variance plane is almost equal to the
662 # true variance across the segment.
663 self.assertFloatsAlmostEqual(v00, true_var_amp, rtol=rtol)
665 if (
666 SAVE_PLOT
667 and not average_across_amps
668 and gain_sigma_factor in (0, 0.008)
669 and sky_level == 4e6
670 and average_gain == 1.7
671 and predefined_gain_type is None
672 ):
673 fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(10, 8.5))
674 plt.subplots_adjust(wspace=0.17, hspace=0.17)
675 colorbar_aspect = 12
677 amp_background_variance_ADU_list = [self.background / gain**2 for gain in self.amp_gain_list]
678 amp_background_image_ADU_list = [self.background / gain for gain in self.amp_gain_list]
679 # Calculate the mean value that corresponds to the background for
680 # the variance plane, adjusting for the gain.
681 background_mean_variance_ADU = np.mean(
682 [self.background / gain**2 for gain in self.amp_gain_list]
683 )
685 # Extract the variance planes and the image from the exposure.
686 arr1 = signal_polluted_variance.array # Variance with signal
687 arr2 = signal_free_variance.array # Variance without signal
688 exp_im = exposure.image.clone() # Clone of the image plane
690 # Incorporate the gain-adjusted background into the image plane to
691 # enable combined visualization of sources with the background.
692 for gain, bbox in zip(self.amp_gain_list, self.amp_bbox_list):
693 exp_im[bbox].array += self.background / gain
694 arr3 = exp_im.array
696 # Define colors visually distinct from each other for the subplots.
697 original_variance_color = "#8A2BE2" # Periwinkle
698 corrected_variance_color = "#618B3C" # Lush Forest Green
699 sky_variance_color = "#c3423f" # Crimson Red
700 amp_colors = [
701 "#1f77b4", # Muted Blue
702 "#ff7f0e", # Vivid Orange
703 "#2ca02c", # Kelly Green
704 "#d62728", # Brick Red
705 "#9467bd", # Soft Purple
706 "#8B4513", # Saddle Brown
707 "#e377c2", # Pale Violet Red
708 "#202020", # Onyx
709 ]
710 arrowheads_lr = ["$\u25C0$", "$\u25B6$"] # Left- & right-pointing
711 arrowheads_ud = ["$\u25B2$", "$\u25BC$"] # Up- & down-pointing
713 # Set titles for the subplots.
714 ax1.set_title("Original variance plane", color=original_variance_color)
715 ax2.set_title("Corrected variance plane", color=corrected_variance_color)
716 ax3.set_title("Image + background ($\\mathit{uniform}$)")
717 ax4.set_title("Histogram of variances")
719 # Collect all vertical and horizontal line positions to find the
720 # amp boundaries.
721 vlines, hlines = set(), set()
722 for bbox in self.amp_bbox_list:
723 # Adjst by 0.5 for merging of lines at the boundaries.
724 vlines.update({bbox.minX - 0.5, bbox.maxX + 0.5})
725 hlines.update({bbox.minY - 0.5, bbox.maxY + 0.5})
727 # Filter lines at the edges of the overall image bbox.
728 image_bbox = exposure.getBBox()
729 vlines = {x for x in vlines if image_bbox.minX < x < image_bbox.maxX}
730 hlines = {y for y in hlines if image_bbox.minY < y < image_bbox.maxY}
732 # Plot image and variance planes.
733 for plane, arr, ax in zip(
734 ("variance", "variance_corrected", "image"), (arr1, arr2, arr3), (ax1, ax2, ax3)
735 ):
736 # We skip 'variance_corrected' in the loop below because we use
737 # the same normalization and colormap as 'variance' for it.
738 if plane in ["variance", "image"]:
739 # Get the normalization.
740 vmin, vmax = arr.min(), arr.max()
741 norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
743 # Get the thresholds corresponding to per-amp backgrounds
744 # and their positions in the normalized color scale.
745 thresholds = (
746 amp_background_variance_ADU_list
747 if plane.startswith("variance")
748 else amp_background_image_ADU_list
749 )
750 threshold_positions = [norm(t) for t in thresholds]
751 threshold = np.mean(thresholds)
752 threshold_position = np.mean(threshold_positions)
754 # Create a custom colormap with two distinct colors for the
755 # sky and source contributions.
756 border = (threshold - vmin) / (vmax - vmin)
757 colors1 = plt.cm.Purples_r(np.linspace(0, 1, int(border * 256)))
758 colors2 = plt.cm.Greens(np.linspace(0, 1, int((1 - border) * 256)))
759 colors = np.vstack((colors1, colors2))
760 cmap = mcolors.LinearSegmentedColormap.from_list("cmap", colors)
762 # Plot the array with the custom colormap and normalization.
763 im = ax.imshow(arr, cmap=cmap, norm=norm)
765 # Add colorbars to the plot.
766 cbar = fig.colorbar(im, aspect=colorbar_aspect, pad=0)
768 # Change the number of ticks on the colorbar for better
769 # spacing. Needs to be done before modifying the tick labels.
770 cbar.ax.locator_params(nbins=7)
772 # Enhance readability by scaling down colorbar tick labels.
773 unit = "ADU$^2$" if plane.startswith("variance") else "ADU"
774 adjust_tick_scale(cbar.ax, {"y": f"Value [{{scale}} {unit}]"})
776 # Mark min and max per-amp thresholds with dotted lines on the
777 # colorbar.
778 for tp in [min(thresholds), max(thresholds)]:
779 cbar.ax.axhline(tp, color="white", linestyle="-", linewidth=1.5, alpha=0.4)
780 cbar.ax.axhline(tp, color=sky_variance_color, linestyle=":", linewidth=1.5, alpha=0.9)
782 # Mark mean threshold with facing arrowheads on the colorbar.
783 cbar.ax.annotate(
784 arrowheads_lr[1], # Right-pointing arrowhead
785 xy=(0, threshold_position),
786 xycoords="axes fraction",
787 textcoords="offset points",
788 xytext=(0, 0),
789 ha="left",
790 va="center",
791 fontsize=6,
792 color=sky_variance_color,
793 clip_on=False,
794 alpha=0.9,
795 )
796 cbar.ax.annotate(
797 arrowheads_lr[0], # Left-pointing arrowhead
798 xy=(1, threshold_position),
799 xycoords="axes fraction",
800 textcoords="offset points",
801 xytext=(0, 0),
802 ha="right",
803 va="center",
804 fontsize=6,
805 color=sky_variance_color,
806 clip_on=False,
807 alpha=0.9,
808 )
810 # Add text inside the colorbar to label the average threshold
811 # position.
812 sky_level_text = "$\u27E8$" + "Sky" + "$\u27E9$" # <Sky>
813 sky_level_text_artist = cbar.ax.text(
814 0.5,
815 threshold_position,
816 sky_level_text,
817 va="center",
818 ha="center",
819 transform=cbar.ax.transAxes,
820 fontsize=8,
821 color=sky_variance_color,
822 rotation="vertical",
823 alpha=0.9,
824 path_effects=outline_effect(2),
825 )
827 # Setup renderer and transformation.
828 renderer = fig.canvas.get_renderer()
829 transform = cbar.ax.transAxes.inverted()
831 # Transform the bounding box and calculate adjustment for
832 # 'sky_level_text_artist' for when it goes beyond the colorbar.
833 sky_level_text_bbox = sky_level_text_artist.get_window_extent(renderer).transformed(transform)
834 adjustment = 1.4 * sky_level_text_bbox.height / 2
836 if sky_level_text_bbox.ymin < 0:
837 sky_level_text_artist.set_y(adjustment)
838 elif sky_level_text_bbox.ymax > 1:
839 sky_level_text_artist.set_y(1 - adjustment)
841 # Draw amp boundaries as vertical and/or horizontal lines.
842 line_color = "white" if np.mean(norm(arr)) > 0.5 else "#808080"
843 for x in vlines:
844 ax.axvline(x=x, color=line_color, linestyle="--", linewidth=1, alpha=0.7)
845 for y in hlines:
846 ax.axhline(y=y, color=line_color, linestyle="--", linewidth=1, alpha=0.7)
847 # Hide all x and y tick marks.
848 ax.tick_params(axis="both", which="both", bottom=False, top=False, left=False, right=False)
849 # Hide all x and y tick labels.
850 ax.set_xticklabels([])
851 ax.set_yticklabels([])
853 # Additional ax2 annotations:
854 # Labels amplifiers with their respective gains for a visual check.
855 for bbox, name, gain, color in zip(
856 self.amp_bbox_list, self.amp_name_list_simplified, self.amp_gain_list, amp_colors
857 ):
858 # Get the center of the bbox to label the gain value.
859 bbox_center = (bbox.minX + bbox.maxX) / 2, (bbox.minY + bbox.maxY) / 2
860 # Label the gain value at the center of each amplifier segment.
861 ax2.text(
862 *bbox_center,
863 f"gain$_{{\\rm \\, {name} \\,}}$: {gain:.3f}",
864 fontsize=9,
865 color=color,
866 alpha=0.95,
867 ha="center",
868 va="center",
869 path_effects=outline_effect(2),
870 )
872 # Additional ax3 annotations:
873 # Label sources with numbers on the image plane.
874 for i, pos in enumerate(self.positions, start=1):
875 ax3.text(
876 *pos,
877 f"{i}",
878 fontsize=7,
879 color=sky_variance_color,
880 path_effects=outline_effect(1.5),
881 alpha=0.9,
882 )
884 # Now we use ax4 to plot the histograms of the variance planes for
885 # comparison.
886 # Plot the histogram of the original variance plane.
887 hist_values, bins, _ = ax4.hist(
888 arr1.flatten(),
889 bins=80,
890 histtype="step",
891 color=original_variance_color,
892 alpha=0.9,
893 label="Original variance",
894 )
895 # Fill the area under the step.
896 ax4.fill_between(
897 bins[:-1],
898 hist_values,
899 step="post",
900 color=original_variance_color,
901 alpha=0.09,
902 hatch="/////",
903 label=" ",
904 )
905 # Plot the histogram of the corrected variance plane.
906 ax4.hist(
907 arr2.flatten(),
908 bins=80,
909 histtype="bar",
910 color=corrected_variance_color,
911 alpha=0.9,
912 label="Corrected variance",
913 )
914 adjust_tick_scale(ax4, {"x": "Variance [{scale} ADU$^2$]", "y": "Number of pixels / {scale}"})
915 ax4.yaxis.set_label_position("right")
916 ax4.yaxis.tick_right()
917 ax4.axvline(
918 background_mean_variance_ADU,
919 color=sky_variance_color,
920 linestyle="--",
921 linewidth=1,
922 alpha=0.9,
923 label="Average sky variance\nacross all amps",
924 )
926 # Use colored arrowheads to mark true amp variances.
927 sorted_vars = sorted(amp_background_variance_ADU_list)
928 count = {v: 0 for v in sorted_vars}
929 for i, (x, name, gain, color) in enumerate(
930 zip(
931 amp_background_variance_ADU_list,
932 self.amp_name_list_simplified,
933 self.amp_gain_list,
934 amp_colors,
935 )
936 ):
937 arrowhead = arrowheads_ud[int(gain < average_gain)]
938 arrowhead_text = ax4.annotate(
939 arrowhead,
940 xy=(x, 0),
941 xycoords=("data", "axes fraction"),
942 textcoords="offset points",
943 xytext=(0, 0),
944 ha="center",
945 va="bottom",
946 fontsize=6.5,
947 color=color,
948 clip_on=False,
949 alpha=0.85,
950 path_effects=outline_effect(1.5),
951 )
952 if i == 0:
953 # Draw the canvas once to make sure the renderer is active.
954 fig.canvas.draw()
955 # Get the bounding box of the text annotation in axes
956 # fraction.
957 bbox_axes = arrowhead_text.get_window_extent().transformed(ax4.transAxes.inverted())
958 # Get the height of the text annotation in axes fraction.
959 height = bbox_axes.height
960 # Increment the arrowhead y positions to avoid overlap.
961 var_idxs = np.where(sorted_vars == x)[0]
962 if len(var_idxs) > 1:
963 q = count[x]
964 count[x] += 1
965 else:
966 q = 0
967 arrowhead_text.xy = (x, var_idxs[q] * height)
969 # Create a proxy artist for the legend since annotations are
970 # not shown in the legend.
971 label = "True variance of" if i == 0 else "$\u21AA$"
972 ax4.scatter(
973 [],
974 [],
975 color=color,
976 marker=arrowhead,
977 s=15,
978 label=f"{label} {name}",
979 alpha=0.85,
980 path_effects=outline_effect(1.5),
981 )
983 # Group the legend handles and label them.
984 adjust_legend_with_groups(
985 ax4,
986 [(0, 1), 2, 3, *range(4, 4 + len(self.amp_name_list_simplified))],
987 colors="match",
988 handlelength=1.9,
989 )
991 # Align the histogram (bottom right panel) with the colorbar of the
992 # corrected variance plane (top right panel) for aesthetic reasons.
993 pos2 = ax2.get_position()
994 pos4 = ax4.get_position()
995 fig.canvas.draw() # Render to ensure accurate colorbar width
996 cbar_width = cbar.ax.get_position().width
997 ax4.set_position([pos2.x0, pos4.y0, pos2.width + cbar_width, pos4.height])
999 # Increase all axes spines' linewidth by 20% for a bolder look.
1000 for ax in fig.get_axes():
1001 for spine in ax.spines.values():
1002 spine.set_linewidth(spine.get_linewidth() * 1.2)
1004 # Save the figure.
1005 filename = f"variance_plane_gain{average_gain}_sigma{gain_sigma_factor}_sky{sky_level}.png"
1006 fig.savefig(filename, dpi=300)
1007 print(f"Saved plot of variance plane before and after correction in {filename}")
1010class TestMemory(lsst.utils.tests.MemoryTestCase):
1011 pass
1014def setup_module(module):
1015 lsst.utils.tests.init()
1018if __name__ == "__main__": 1018 ↛ 1019line 1018 didn't jump to line 1019, because the condition on line 1018 was never true
1019 lsst.utils.tests.init()
1020 unittest.main()