Coverage for tests/test_variance_plane.py: 8%

324 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 03:29 -0700

1# This file is part of ip_isr. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21# 

22"""The utility function under test is part of the `meas_algorithms` package. 

23This unit test has been relocated here to avoid circular dependencies. 

24""" 

25 

26import re 

27import unittest 

28from contextlib import nullcontext 

29 

30import galsim 

31import lsst.utils.tests 

32import matplotlib.colors as mcolors 

33import matplotlib.pyplot as plt 

34import numpy as np 

35from lsst.ip.isr.isrMock import IsrMock 

36from lsst.meas.algorithms import remove_signal_from_variance 

37from lsst.utils.tests import methodParametersProduct 

38from matplotlib.legend_handler import HandlerTuple 

39from matplotlib.patheffects import withStroke 

40from matplotlib.ticker import FixedLocator, FuncFormatter 

41 

42# Set to True to save the plot of the variance plane before and after 

43# correction for a representative test case. 

44SAVE_PLOT = False 

45 

46 

47def outline_effect(lw, alpha=0.8): 

48 """Generate a path effect for enhanced text visibility. 

49 

50 Parameters 

51 ---------- 

52 lw : `float` 

53 Line width of the outline. 

54 alpha : `float`, optional 

55 Transparency of the outline. 

56 

57 Returns 

58 ------- 

59 `list` of `matplotlib.patheffects.withStroke` 

60 A list containing the path effect. 

61 """ 

62 return [withStroke(linewidth=lw, foreground="white", alpha=alpha)] 

63 

64 

65class CustomHandler(HandlerTuple): 

66 """Custom handler for handling grouped items in the legend.""" 

67 

68 def create_artists(self, *args): 

69 artists = super().create_artists(*args) 

70 for a in artists: 

71 a.set_transform(args[-1]) 

72 return artists 

73 

74 

75def get_valid_color(handle): 

76 """Extracts a valid color from a Matplotlib handle. 

77 

78 Parameters 

79 ---------- 

80 handle : `matplotlib.artist.Artist` 

81 The handle from which to extract the color. 

82 

83 Returns 

84 ------- 

85 color : `str` or `tuple` 

86 The color extracted from the handle, or 'default' if no valid color is 

87 found. 

88 """ 

89 for attr in ["get_facecolor", "get_edgecolor", "get_color"]: 

90 if hasattr(handle, attr): 

91 color = getattr(handle, attr)() 

92 # If the handle is a collection, use the first color. 

93 if isinstance(color, np.ndarray) and color.shape[0] > 0: 

94 color = color[0] 

95 # If the color is RGBA with alpha = 0, continue the search. 

96 if len(color) == 4 and color[3] == 0: 

97 continue 

98 return color 

99 return "default" # If no valid color is found 

100 

101 

102def get_emptier_side(ax): 

103 """Analyze a matplotlib Axes object to determine which side (left or right) 

104 has more whitespace, considering cases where artists' bounding boxes span 

105 both sides of the midpoint. 

106 

107 Parameters 

108 ---------- 

109 ax : `~matplotlib.axes.Axes` 

110 The Axes object to analyze. 

111 

112 Returns 

113 ------- 

114 more_whitespace_side : `str` 

115 'left' if the left side has more whitespace, or 'right' if the right 

116 side does. 

117 """ 

118 # Get the total plotting area's midpoint on the x-axis. 

119 xlim = ax.get_xlim() 

120 midpoint = sum(xlim) / 2 

121 

122 # Initialize areas as zero. 

123 left_area, right_area = 0, 0 

124 

125 # Loop through all children (artists) in the Axes. 

126 for artist in ax.get_children(): 

127 # Skip if artist is invisible or lacks a bounding box. 

128 if not artist.get_visible() or not hasattr(artist, "get_window_extent"): 

129 continue 

130 bbox = artist.get_window_extent().transformed(ax.figure.dpi_scale_trans.inverted()) 

131 # Check if the artist's bounding box spans the midpoint. 

132 if bbox.x0 < midpoint < bbox.x1: 

133 # Calculate the proportion of the bbox on each side of the 

134 # midpoint. 

135 left_proportion = (midpoint - bbox.x0) / bbox.width 

136 right_proportion = 1 - left_proportion 

137 # Adjust area calculations for both sides. 

138 left_area += bbox.width * bbox.height * left_proportion 

139 right_area += bbox.width * bbox.height * right_proportion 

140 elif bbox.x0 + bbox.width / 2 < midpoint: 

141 # Entirely on the left. 

142 left_area += bbox.width * bbox.height 

143 else: 

144 # Entirely on the right. 

145 right_area += bbox.width * bbox.height 

146 

147 # Determine which side has more whitespace by comparing occupied areas. 

148 return "left" if left_area <= right_area else "right" 

149 

150 

151def adjust_legend_with_groups(ax, combine_groups, colors="default", yloc="upper", **kwargs): 

152 """Adjusts the legend of a given Axes object by combining specified handles 

153 based on provided groups, setting the marker location and text alignment 

154 based on the inferable emptier side, and optionally setting the text color 

155 of legend entries to a provided list or inferring colors from the handles. 

156 Additionally, allows specifying the vertical location of the legend within 

157 the plot. 

158 

159 Parameters 

160 ---------- 

161 ax : `~matplotlib.axes.Axes` 

162 The Axes object for which to adjust the legend. 

163 combine_groups : `list` of `int` or iterable of `int` 

164 A list that can contain a mix of individual integers and/or iterables 

165 (lists, tuples, or sets) of integers. An individual integer specifies 

166 the index of a single legend entry. An iterable of integers specifies 

167 a group of indices to be combined into a single legend entry. 

168 colors : `list` of `str` or `tuple`, or `str`, optional 

169 Specifies the colors for the legend entries. This parameter can be: 

170 - A list of color specifications, where each element is a string (for 

171 named colors or hex values) or a tuple (for RGB or RGBA values). This 

172 list explicitly assigns colors to each legend entry post-combination. 

173 - A single string value: 

174 - "match": Colors are inferred from the properties of the first 

175 handle in each group that corresponds to a non-white-space label. 

176 This aims to match the legend text color with the color of the 

177 plotted data. 

178 - "default": The function does not alter the default colors 

179 assigned by Matplotlib, preserving the automatic color assignment 

180 for all legend entries. 

181 yloc : `str`, optional 

182 The vertical location of the legend within the Axes. Valid options are 

183 'upper', 'lower', or 'middle'. This parameter is combined with the 

184 inferable emptier side ('left' or 'right') to determine the legend's 

185 placement. For example, 'upper right' or 'lower left'. 

186 **kwargs : 

187 Keyword arguments forwarded to the ``ax.legend`` function. 

188 """ 

189 

190 handles, labels = ax.get_legend_handles_labels() 

191 new_handles = [] 

192 new_labels = [] 

193 

194 if colors == "match": 

195 colors = [] 

196 infer_colors = True 

197 else: 

198 infer_colors = False 

199 

200 for group in combine_groups: 

201 # Assume the first non-white-space label represents the group. If no 

202 # such label is found, just use the first label in the group which is 

203 # in fact empty. 

204 if isinstance(group, (list, tuple, set)): 

205 group = list(group) # Just in case 

206 label_index = next((i for i in group if labels[i].strip()), group[0]) 

207 combined_handle = tuple(handles[i] for i in group) 

208 combined_label = labels[label_index] 

209 elif isinstance(group, int): 

210 label_index = group 

211 combined_handle = handles[group] 

212 combined_label = labels[group] 

213 else: 

214 raise ValueError("Invalid value in 'combine_groups'") 

215 new_handles.append(combined_handle) 

216 new_labels.append(combined_label) 

217 if infer_colors: 

218 # Attempt to infer color from the representative handle in the 

219 # group. 

220 handle = handles[label_index] 

221 color = get_valid_color(handle) 

222 colors.append(color) 

223 

224 # Determine the emptier side to decide legend and text alignment. 

225 emptier_side = get_emptier_side(ax) 

226 markerfirst = emptier_side != "right" 

227 

228 # Create the legend with custom adjustments. 

229 legend = ax.legend( 

230 new_handles, 

231 new_labels, 

232 handler_map={tuple: CustomHandler()}, 

233 loc=f"{yloc} {emptier_side}", 

234 fontsize=8, 

235 frameon=False, 

236 markerfirst=markerfirst, 

237 **kwargs, 

238 ) 

239 

240 # Right- or left-align the legend text based on the emptier side. 

241 for text in legend.get_texts(): 

242 text.set_ha(emptier_side) 

243 

244 # Set legend text colors if necessary. 

245 if colors != "default": 

246 for text, color in zip(legend.get_texts(), colors): 

247 if not (isinstance(color, str) and color == "default"): 

248 text.set_color(color) 

249 

250 

251def adjust_tick_scale(ax, axis_label_templates): 

252 """Scales down tick labels to make them more readable and updates axis 

253 labels accordingly. 

254 

255 Calculates a power of 10 scale factor (common divisor) to reduce the 

256 magnitude of tick labels. It automatically determines which axes to adjust 

257 based on the provided axis label templates, which should include `{scale}` 

258 for inserting the scale factor dynamically. 

259 

260 Parameters 

261 ---------- 

262 ax : `~matplotlib.axes.Axes` 

263 The Axes object to modify. 

264 axis_label_templates : `dict` 

265 Templates for axis labels, including '{scale}' for scale factor 

266 insertion. Keys should be one or more of axes names ('x', 'y', 'z') and 

267 values should be the corresponding label templates. 

268 """ 

269 

270 def trailing_zeros(n): 

271 """Determines the number of trailing zeros in a number.""" 

272 return len(n := str(int(float(n))) if float(n).is_integer() else n) - len(n.rstrip("0")) 

273 

274 def format_tick(val, pos, divisor): 

275 """Formats tick labels using the determined divisor.""" 

276 return str(int(val / divisor)) if (val / divisor).is_integer() else str(val / divisor) 

277 

278 # Iterate through the specified axes and adjust their tick labels and axis 

279 # labels. 

280 for axis in axis_label_templates.keys(): 

281 # Gather current tick labels. 

282 labels = [label.get_text() for label in getattr(ax, f"get_{axis}ticklabels")()] 

283 

284 # Calculate the power of 10 divisor based on the minimum number of 

285 # trailing zeros in the tick labels. 

286 divisor = 10 ** min(trailing_zeros(label) for label in labels if float(label) != 0) 

287 

288 # Set a formatter for the axis ticks that scales them according to the 

289 # common divisor. 

290 getattr(ax, f"{axis}axis").set_major_formatter( 

291 FuncFormatter(lambda val, pos: format_tick(val, pos, divisor)) 

292 ) 

293 

294 # Ensure the tick positions remain unchanged despite the new 

295 # formatting. 

296 getattr(ax, f"{axis}axis").set_major_locator(FixedLocator(getattr(ax, f"get_{axis}ticks")())) 

297 

298 # Prepare 'scale', empty if divisor <= 1. 

299 scale = f"{int(divisor)}" if divisor > 1 else "" 

300 

301 # Fetch the corresponding label template for the axis. 

302 label_template = axis_label_templates[axis] 

303 

304 # If 'scale' is empty, remove whitespace around "{scale}" in the 

305 # template. Also remove any trailing "/{scale}". 

306 if scale == "": 

307 label_template = re.sub(r"\s*{\s*scale\s*}\s*", "{scale}", label_template) 

308 label_template = label_template.replace("/{scale}", "") 

309 

310 # Always strip remaining whitespace from the template. 

311 label_template = label_template.strip() 

312 

313 # Set the formatted axis label. 

314 label_text = label_template.format(scale=scale) 

315 getattr(ax, f"set_{axis}label")(label_text, labelpad=8) 

316 

317 

318class VariancePlaneTestCase(lsst.utils.tests.TestCase): 

319 def setUp(self): 

320 # Testing with a single detector that has 8 amplifiers in a 4x2 

321 # configuration. Each amplifier measures 100x51 in dimensions. 

322 config = IsrMock.ConfigClass() 

323 config.isLsstLike = True 

324 config.doAddBias = False 

325 config.doAddDark = False 

326 config.doAddFlat = False 

327 config.doAddFringe = False 

328 config.doGenerateImage = True 

329 config.doGenerateData = True 

330 config.doGenerateAmpDict = True 

331 self.mock = IsrMock(config=config) 

332 

333 def tearDown(self): 

334 del self.mock 

335 

336 def buildExposure( 

337 self, 

338 average_gain, 

339 gain_sigma_factor, 

340 sky_level, 

341 add_signal=True, 

342 ): 

343 """Build and return an exposure with different types of simulated 

344 source profiles and a background sky level. It's intended for testing 

345 and analysis, providing a way to generate exposures with controlled 

346 conditions. 

347 

348 Parameters 

349 ---------- 

350 average_gain : `float` 

351 The average gain value of amplifiers in e-/ADU. 

352 gain_sigma_factor : float 

353 The standard deviation of the gain values as a factor of the 

354 ``average_gain``. 

355 sky_level : `float` 

356 The background sky level in e-/arcsec^2. 

357 

358 Returns 

359 ------- 

360 exposure : `~lsst.afw.image.Exposure` 

361 An exposure object with simulated sources and background. The units 

362 are in detector counts (ADU). 

363 """ 

364 

365 # Set the random seed for reproducibility. 

366 random_seed = galsim.BaseDeviate(1905).raw() + 1 

367 np.random.seed(random_seed) 

368 rng = galsim.BaseDeviate(random_seed) 

369 

370 # Get the exposure, detector, and amps from the mock. 

371 exposure = self.mock.getExposure() 

372 detector = exposure.getDetector() 

373 amps = detector.getAmplifiers() 

374 num_amps = len(amps) 

375 self.amp_name_list = [amp.getName() for amp in amps] 

376 table = str.maketrans("", "", ":,") # Remove ':' and ',' from names 

377 self.amp_name_list_simplified = [name.translate(table) for name in self.amp_name_list] 

378 

379 # Adjust instrument and observation parameters to some nominal values. 

380 pixel_scale = 0.2 # arcsec/pixel 

381 self.background = sky_level * pixel_scale**2 # e-/pixel 

382 

383 # Get the bounding boxes for the exposure and amplifiers and convert 

384 # them to galsim bounds. 

385 exp_bbox = exposure.getBBox() 

386 image_bounds = galsim.BoundsI(exp_bbox.minX, exp_bbox.maxX, exp_bbox.minY, exp_bbox.maxY) 

387 self.amp_bbox_list = [amp.getBBox() for amp in amps] 

388 amp_bounds_list = [galsim.BoundsI(b.minX, b.maxX, b.minY, b.maxY) for b in self.amp_bbox_list] 

389 

390 # Generate random deviations from the average gain across amplifiers 

391 # and adjust them to ensure their sum equals zero. This reflects 

392 # real-world detectors, with amplifier gains normally distributed due 

393 # to manufacturing and operational variations. 

394 deviations = np.random.normal(average_gain, gain_sigma_factor * average_gain, size=num_amps) 

395 deviations -= np.mean(deviations) 

396 

397 # Set the gain for amplifiers to be slightly different from each other 

398 # while averaging to `average_gain`. This is to test the 

399 # `average_across_amps` option in the `remove_signal_from_variance` 

400 # function. 

401 self.amp_gain_list = [average_gain + deviation for deviation in deviations] 

402 

403 # Create a galsim image to potentially draw the sources onto. The 

404 # exposure image that is passed to this method will be modified in 

405 # place. 

406 image = galsim.ImageF(exposure.image.array, bounds=image_bounds) 

407 

408 if add_signal: 

409 # Define parameters for a mix of source types, including extended 

410 # sources with assorted profiles as well as point sources simulated 

411 # with minimal half-light radii to resemble hot pixels 

412 # post-deconvolution. All flux values are given in electrons and 

413 # half-light radii in pixels. The goal is for each amplifier to 

414 # predominantly contain at least one source, enhancing the 

415 # representativeness of test conditions. 

416 source_params = [ 

417 {"type": "Sersic", "n": 3, "flux": 1.6e5, "half_light_radius": 3.5, "g1": -0.3, "g2": 0.2}, 

418 {"type": "Sersic", "n": 1, "flux": 9.3e5, "half_light_radius": 2.1, "g1": 0.25, "g2": 0.12}, 

419 {"type": "Sersic", "n": 4, "flux": 1.0e5, "half_light_radius": 1.1, "g1": 0.0, "g2": 0.0}, 

420 {"type": "Sersic", "n": 3, "flux": 1.1e6, "half_light_radius": 4.2, "g1": 0.0, "g2": 0.2}, 

421 {"type": "Sersic", "n": 5, "flux": 1.1e5, "half_light_radius": 3.6, "g1": 0.22, "g2": -0.05}, 

422 {"type": "Sersic", "n": 2, "flux": 4.3e5, "half_light_radius": 2.0, "g1": 0.0, "g2": 0.0}, 

423 {"type": "Sersic", "n": 6, "flux": 1.2e6, "half_light_radius": 11.0, "g1": -0.16, "g2": 0.7}, 

424 {"type": "Exponential", "flux": 1.3e6, "half_light_radius": 1.9, "g1": 0.3, "g2": -0.1}, 

425 {"type": "Exponential", "flux": 1.8e6, "half_light_radius": 5.0, "g1": 0.0, "g2": 0.14}, 

426 {"type": "Exponential", "flux": 6.6e6, "half_light_radius": 4.8, "g1": 0.26, "g2": 0.5}, 

427 {"type": "Exponential", "flux": 7.0e5, "half_light_radius": 3.1, "g1": -0.3, "g2": 0.0}, 

428 {"type": "DeVaucouleurs", "flux": 1.6e5, "half_light_radius": 3.5, "g1": 0.2, "g2": 0.4}, 

429 {"type": "DeVaucouleurs", "flux": 2.0e5, "half_light_radius": 1.6, "g1": -0.06, "g2": -0.2}, 

430 {"type": "DeVaucouleurs", "flux": 8.3e5, "half_light_radius": 5.1, "g1": 0.29, "g2": 0.0}, 

431 {"type": "DeVaucouleurs", "flux": 4.5e5, "half_light_radius": 2.5, "g1": 0.4, "g2": 0.3}, 

432 {"type": "DeVaucouleurs", "flux": 6.2e5, "half_light_radius": 4.9, "g1": -0.08, "g2": -0.01}, 

433 {"type": "Gaussian", "flux": 4.7e6, "half_light_radius": 2.5, "g1": 0.07, "g2": -0.35}, 

434 {"type": "Gaussian", "flux": 5.8e6, "half_light_radius": 3.1, "g1": 0.03, "g2": 0.4}, 

435 {"type": "Gaussian", "flux": 2.3e5, "half_light_radius": 0.5, "g1": 0.0, "g2": 0.0}, 

436 {"type": "Gaussian", "flux": 1.6e6, "half_light_radius": 3.0, "g1": 0.18, "g2": -0.29}, 

437 {"type": "Gaussian", "flux": 3.5e5, "half_light_radius": 4.6, "g1": 0.5, "g2": 0.35}, 

438 {"type": "Gaussian", "flux": 5.9e5, "half_light_radius": 9.5, "g1": 0.1, "g2": 0.55}, 

439 {"type": "Gaussian", "flux": 4.0e5, "half_light_radius": 1.0, "g1": 0.0, "g2": 0.0}, 

440 ] 

441 

442 # Mapping of profile types to their galsim constructors. 

443 profile_constructors = { 

444 "Sersic": galsim.Sersic, 

445 "Exponential": galsim.Exponential, 

446 "DeVaucouleurs": galsim.DeVaucouleurs, 

447 "Gaussian": galsim.Gaussian, 

448 } 

449 

450 # Generate random positions within exposure bounds, avoiding edges 

451 # by a margin. 

452 margin_x, margin_y = 0.05 * exp_bbox.width, 0.05 * exp_bbox.height 

453 self.positions = np.random.uniform( 

454 [exp_bbox.minX + margin_x, exp_bbox.minY + margin_y], 

455 [exp_bbox.maxX - margin_x, exp_bbox.maxY - margin_y], 

456 (len(source_params), 2), 

457 ).tolist() 

458 

459 # Loop over the sources and draw them onto the image cutout by 

460 # cutout. 

461 for i, params in enumerate(source_params): 

462 # Dynamically get constructor and remove type from params. 

463 constructor = profile_constructors[params.pop("type")] 

464 

465 # Get shear parameters and remove them from params. 

466 g1, g2 = params.pop("g1"), params.pop("g2") 

467 

468 # The extent of the cutout should be large enough to contain 

469 # the entire object above the background level. Some empirical 

470 # factor is used to mitigate artifacts. 

471 half_extent = 10 * params["half_light_radius"] * (1 + 2 * np.sqrt(g1**2 + g2**2)) 

472 

473 # Pass the remaining params to the constructor and apply shear. 

474 galsim_object = constructor(**params).shear(galsim.Shear(g1=g1, g2=g2)) 

475 

476 # Retrieve the position of the object. 

477 x, y = self.positions[i] 

478 pos = galsim.PositionD(x, y) 

479 

480 # Get the bounds of the sub-image based on the object position. 

481 sub_image_bounds = galsim.BoundsI( 

482 *map(int, [x - half_extent, x + half_extent, y - half_extent, y + half_extent]) 

483 ) 

484 

485 # Identify the overlap region, which could be partially outside 

486 # the image bounds. 

487 sub_image_bounds = sub_image_bounds & image.bounds 

488 

489 # Check that there is some overlap. 

490 assert sub_image_bounds.isDefined(), "No overlap with image bounds" 

491 

492 # Get the sub-image cutout. 

493 sub_image = image[sub_image_bounds] 

494 

495 # Draw the object onto the image within the the sub-image 

496 # bounds. 

497 galsim_object.drawImage( 

498 image=sub_image, 

499 offset=pos - sub_image.true_center, 

500 method="real_space", # Save memory, usable w/o convolution 

501 add_to_image=True, # Add flux to existing image 

502 scale=pixel_scale, 

503 ) 

504 

505 # Add a constant background to the entire image (both in e-/pixel). 

506 image += self.background 

507 

508 # Add noise to the image which is in electrons. Note that we won't 

509 # specify a `sky_level` here to avoid double-counting it, as it's 

510 # already included as the background. 

511 image.addNoise(galsim.PoissonNoise(rng)) 

512 

513 # Subtract off the background to get the sky-subtracted image. 

514 image -= self.background 

515 

516 # Adjust each amplifier's image segment by its respective gain. After 

517 # this step, the image will be in ADUs. 

518 for bounds, gain in zip(amp_bounds_list, self.amp_gain_list): 

519 image[bounds] /= gain 

520 

521 # We know that the exposure has already been modified in place, but 

522 # just to be extra sure, we'll set the exposure image explicitly. 

523 exposure.image.array = image.array 

524 

525 # Create a variance plane for the exposure while including signal as a 

526 # pollutant. Note that the exposure image is pre-adjusted for gain, 

527 # unlike 'self.background'. Thus, we divide the background by the 

528 # corresponding gain before adding it to the image. This leads to the 

529 # variance plane being in units of ADU^2. 

530 for bbox, gain in zip(self.amp_bbox_list, self.amp_gain_list): 

531 exposure.variance[bbox].array = (exposure.image[bbox].array + self.background / gain) / gain 

532 

533 return exposure 

534 

535 def test_no_signal_handling(self): 

536 """Test that the function does nearly nothing when given an image with 

537 no signal. 

538 """ 

539 # Create an exposure with no signal. 

540 exposure = self.buildExposure( 

541 average_gain=1.4, gain_sigma_factor=0.01, sky_level=4e6, add_signal=False 

542 ) 

543 # Remove the signal from the variance plane, if any. 

544 updated_variance = remove_signal_from_variance(exposure, in_place=False) 

545 # Check that the variance plane is nearly the same as the original. 

546 self.assertFloatsAlmostEqual(exposure.variance.array, updated_variance.array, rtol=0.013) 

547 

548 def test_in_place_handling(self): 

549 """Make sure the function is tested to handle in-place operations.""" 

550 # Create an exposure with signal. 

551 exposure = self.buildExposure( 

552 average_gain=1.4, gain_sigma_factor=0.01, sky_level=4e6, add_signal=True 

553 ) 

554 # Remove the signal from the variance plane. 

555 updated_variance = remove_signal_from_variance(exposure, in_place=True) 

556 # Retrieve the variance plane from the exposure and check that it is 

557 # identical to the returned variance plane. 

558 self.assertFloatsEqual(exposure.variance.array, updated_variance.array) 

559 

560 @methodParametersProduct( 

561 average_gain=[1.4, 1.7], 

562 predefined_gain_type=["average", "per-amp", None], 

563 gain_sigma_factor=[0, 0.008], 

564 sky_level=[2e6, 4e6], 

565 average_across_amps=[False, True], 

566 ) 

567 def test_variance_signal_removal( 

568 self, average_gain, predefined_gain_type, gain_sigma_factor, sky_level, average_across_amps 

569 ): 

570 exposure = self.buildExposure( 

571 average_gain=average_gain, 

572 gain_sigma_factor=gain_sigma_factor, 

573 sky_level=sky_level, 

574 add_signal=True, 

575 ) 

576 

577 # Save the original variance plane for comparison, assuming it has 

578 # Poisson contribution from the source signal. 

579 signal_polluted_variance = exposure.variance.clone() 

580 

581 # Check that the variance plane has no negative values. 

582 self.assertTrue( 

583 np.all(signal_polluted_variance.array >= 0), 

584 "Variance plane has negative values (pre correction)", 

585 ) 

586 

587 if predefined_gain_type == "average": 

588 predefined_gain = average_gain 

589 predefined_gains = None 

590 elif predefined_gain_type == "per-amp": 

591 predefined_gain = None 

592 predefined_gains = {name: gain for name, gain in zip(self.amp_name_list, self.amp_gain_list)} 

593 elif predefined_gain_type is None: 

594 # Allow the 'remove_signal_from_variance' function to estimate the 

595 # gain itself before it attempts to remove the signal from the 

596 # variance plane. 

597 predefined_gain = None 

598 predefined_gains = None 

599 

600 # Set the relative tolerance for the variance plane checks. 

601 if predefined_gain_type == "average" or (predefined_gain_type is None and average_across_amps): 

602 # Relax the tolerance if we are simply averaging across amps to 

603 # roughly estimate the overall gain. 

604 rtol = 0.015 

605 estimate_average_gain = True 

606 else: 

607 # Tighten tolerance for the 'predefined_gain_type' of 'per-amp' or 

608 # for a more accurate per-amp gain estimation strategy. 

609 rtol = 2e-7 

610 estimate_average_gain = False 

611 

612 # Remove the signal from the variance plane. 

613 signal_free_variance = remove_signal_from_variance( 

614 exposure, 

615 gain=predefined_gain, 

616 gains=predefined_gains, 

617 average_across_amps=average_across_amps, 

618 in_place=False, 

619 ) 

620 

621 # Check that the variance plane has been modified. 

622 self.assertFloatsNotEqual(signal_polluted_variance.array, signal_free_variance.array) 

623 

624 # Check that the corrected variance plane has no negative values. 

625 self.assertTrue( 

626 np.all(signal_free_variance.array >= 0), "Variance plane has negative values (post correction)" 

627 ) 

628 

629 for bbox, gain in zip(self.amp_bbox_list, self.amp_gain_list): 

630 # Calculate the true variance in theoretical terms. 

631 true_var_amp = self.background / gain**2 

632 # Pair each variance with the appropriate context manager before 

633 # looping through them. 

634 var_context_pairs = [ 

635 # For the signal-free variance, directly execute the checks. 

636 (signal_free_variance, nullcontext()), 

637 # For the signal-polluted variance, expect AssertionError 

638 # unless we are averaging across amps. 

639 ( 

640 signal_polluted_variance, 

641 nullcontext() if estimate_average_gain else self.assertRaises(AssertionError), 

642 ), 

643 ] 

644 for var, context_manager in var_context_pairs: 

645 # Extract the segment of the variance plane for the amplifier. 

646 var_amp = var[bbox] 

647 with context_manager: 

648 if var is signal_polluted_variance and estimate_average_gain: 

649 # Skip rigorous checks on the signal-polluted variance, 

650 # if we are averaging across amps. 

651 pass 

652 else: 

653 # Get the variance value at the first pixel of the 

654 # segment to compare with the rest of the pixels and 

655 # the true variance. 

656 v00 = var_amp.array[0, 0] 

657 # Assert that the variance plane is almost uniform 

658 # across the segment because the signal has been 

659 # removed from it and the background is constant. 

660 self.assertFloatsAlmostEqual(var_amp.array, v00, rtol=rtol) 

661 # Assert that the variance plane is almost equal to the 

662 # true variance across the segment. 

663 self.assertFloatsAlmostEqual(v00, true_var_amp, rtol=rtol) 

664 

665 if ( 

666 SAVE_PLOT 

667 and not average_across_amps 

668 and gain_sigma_factor in (0, 0.008) 

669 and sky_level == 4e6 

670 and average_gain == 1.7 

671 and predefined_gain_type is None 

672 ): 

673 fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(10, 8.5)) 

674 plt.subplots_adjust(wspace=0.17, hspace=0.17) 

675 colorbar_aspect = 12 

676 

677 amp_background_variance_ADU_list = [self.background / gain**2 for gain in self.amp_gain_list] 

678 amp_background_image_ADU_list = [self.background / gain for gain in self.amp_gain_list] 

679 # Calculate the mean value that corresponds to the background for 

680 # the variance plane, adjusting for the gain. 

681 background_mean_variance_ADU = np.mean( 

682 [self.background / gain**2 for gain in self.amp_gain_list] 

683 ) 

684 

685 # Extract the variance planes and the image from the exposure. 

686 arr1 = signal_polluted_variance.array # Variance with signal 

687 arr2 = signal_free_variance.array # Variance without signal 

688 exp_im = exposure.image.clone() # Clone of the image plane 

689 

690 # Incorporate the gain-adjusted background into the image plane to 

691 # enable combined visualization of sources with the background. 

692 for gain, bbox in zip(self.amp_gain_list, self.amp_bbox_list): 

693 exp_im[bbox].array += self.background / gain 

694 arr3 = exp_im.array 

695 

696 # Define colors visually distinct from each other for the subplots. 

697 original_variance_color = "#8A2BE2" # Periwinkle 

698 corrected_variance_color = "#618B3C" # Lush Forest Green 

699 sky_variance_color = "#c3423f" # Crimson Red 

700 amp_colors = [ 

701 "#1f77b4", # Muted Blue 

702 "#ff7f0e", # Vivid Orange 

703 "#2ca02c", # Kelly Green 

704 "#d62728", # Brick Red 

705 "#9467bd", # Soft Purple 

706 "#8B4513", # Saddle Brown 

707 "#e377c2", # Pale Violet Red 

708 "#202020", # Onyx 

709 ] 

710 arrowheads_lr = ["$\u25C0$", "$\u25B6$"] # Left- & right-pointing 

711 arrowheads_ud = ["$\u25B2$", "$\u25BC$"] # Up- & down-pointing 

712 

713 # Set titles for the subplots. 

714 ax1.set_title("Original variance plane", color=original_variance_color) 

715 ax2.set_title("Corrected variance plane", color=corrected_variance_color) 

716 ax3.set_title("Image + background ($\\mathit{uniform}$)") 

717 ax4.set_title("Histogram of variances") 

718 

719 # Collect all vertical and horizontal line positions to find the 

720 # amp boundaries. 

721 vlines, hlines = set(), set() 

722 for bbox in self.amp_bbox_list: 

723 # Adjst by 0.5 for merging of lines at the boundaries. 

724 vlines.update({bbox.minX - 0.5, bbox.maxX + 0.5}) 

725 hlines.update({bbox.minY - 0.5, bbox.maxY + 0.5}) 

726 

727 # Filter lines at the edges of the overall image bbox. 

728 image_bbox = exposure.getBBox() 

729 vlines = {x for x in vlines if image_bbox.minX < x < image_bbox.maxX} 

730 hlines = {y for y in hlines if image_bbox.minY < y < image_bbox.maxY} 

731 

732 # Plot image and variance planes. 

733 for plane, arr, ax in zip( 

734 ("variance", "variance_corrected", "image"), (arr1, arr2, arr3), (ax1, ax2, ax3) 

735 ): 

736 # We skip 'variance_corrected' in the loop below because we use 

737 # the same normalization and colormap as 'variance' for it. 

738 if plane in ["variance", "image"]: 

739 # Get the normalization. 

740 vmin, vmax = arr.min(), arr.max() 

741 norm = mcolors.Normalize(vmin=vmin, vmax=vmax) 

742 

743 # Get the thresholds corresponding to per-amp backgrounds 

744 # and their positions in the normalized color scale. 

745 thresholds = ( 

746 amp_background_variance_ADU_list 

747 if plane.startswith("variance") 

748 else amp_background_image_ADU_list 

749 ) 

750 threshold_positions = [norm(t) for t in thresholds] 

751 threshold = np.mean(thresholds) 

752 threshold_position = np.mean(threshold_positions) 

753 

754 # Create a custom colormap with two distinct colors for the 

755 # sky and source contributions. 

756 border = (threshold - vmin) / (vmax - vmin) 

757 colors1 = plt.cm.Purples_r(np.linspace(0, 1, int(border * 256))) 

758 colors2 = plt.cm.Greens(np.linspace(0, 1, int((1 - border) * 256))) 

759 colors = np.vstack((colors1, colors2)) 

760 cmap = mcolors.LinearSegmentedColormap.from_list("cmap", colors) 

761 

762 # Plot the array with the custom colormap and normalization. 

763 im = ax.imshow(arr, cmap=cmap, norm=norm) 

764 

765 # Add colorbars to the plot. 

766 cbar = fig.colorbar(im, aspect=colorbar_aspect, pad=0) 

767 

768 # Change the number of ticks on the colorbar for better 

769 # spacing. Needs to be done before modifying the tick labels. 

770 cbar.ax.locator_params(nbins=7) 

771 

772 # Enhance readability by scaling down colorbar tick labels. 

773 unit = "ADU$^2$" if plane.startswith("variance") else "ADU" 

774 adjust_tick_scale(cbar.ax, {"y": f"Value [{{scale}} {unit}]"}) 

775 

776 # Mark min and max per-amp thresholds with dotted lines on the 

777 # colorbar. 

778 for tp in [min(thresholds), max(thresholds)]: 

779 cbar.ax.axhline(tp, color="white", linestyle="-", linewidth=1.5, alpha=0.4) 

780 cbar.ax.axhline(tp, color=sky_variance_color, linestyle=":", linewidth=1.5, alpha=0.9) 

781 

782 # Mark mean threshold with facing arrowheads on the colorbar. 

783 cbar.ax.annotate( 

784 arrowheads_lr[1], # Right-pointing arrowhead 

785 xy=(0, threshold_position), 

786 xycoords="axes fraction", 

787 textcoords="offset points", 

788 xytext=(0, 0), 

789 ha="left", 

790 va="center", 

791 fontsize=6, 

792 color=sky_variance_color, 

793 clip_on=False, 

794 alpha=0.9, 

795 ) 

796 cbar.ax.annotate( 

797 arrowheads_lr[0], # Left-pointing arrowhead 

798 xy=(1, threshold_position), 

799 xycoords="axes fraction", 

800 textcoords="offset points", 

801 xytext=(0, 0), 

802 ha="right", 

803 va="center", 

804 fontsize=6, 

805 color=sky_variance_color, 

806 clip_on=False, 

807 alpha=0.9, 

808 ) 

809 

810 # Add text inside the colorbar to label the average threshold 

811 # position. 

812 sky_level_text = "$\u27E8$" + "Sky" + "$\u27E9$" # <Sky> 

813 sky_level_text_artist = cbar.ax.text( 

814 0.5, 

815 threshold_position, 

816 sky_level_text, 

817 va="center", 

818 ha="center", 

819 transform=cbar.ax.transAxes, 

820 fontsize=8, 

821 color=sky_variance_color, 

822 rotation="vertical", 

823 alpha=0.9, 

824 path_effects=outline_effect(2), 

825 ) 

826 

827 # Setup renderer and transformation. 

828 renderer = fig.canvas.get_renderer() 

829 transform = cbar.ax.transAxes.inverted() 

830 

831 # Transform the bounding box and calculate adjustment for 

832 # 'sky_level_text_artist' for when it goes beyond the colorbar. 

833 sky_level_text_bbox = sky_level_text_artist.get_window_extent(renderer).transformed(transform) 

834 adjustment = 1.4 * sky_level_text_bbox.height / 2 

835 

836 if sky_level_text_bbox.ymin < 0: 

837 sky_level_text_artist.set_y(adjustment) 

838 elif sky_level_text_bbox.ymax > 1: 

839 sky_level_text_artist.set_y(1 - adjustment) 

840 

841 # Draw amp boundaries as vertical and/or horizontal lines. 

842 line_color = "white" if np.mean(norm(arr)) > 0.5 else "#808080" 

843 for x in vlines: 

844 ax.axvline(x=x, color=line_color, linestyle="--", linewidth=1, alpha=0.7) 

845 for y in hlines: 

846 ax.axhline(y=y, color=line_color, linestyle="--", linewidth=1, alpha=0.7) 

847 # Hide all x and y tick marks. 

848 ax.tick_params(axis="both", which="both", bottom=False, top=False, left=False, right=False) 

849 # Hide all x and y tick labels. 

850 ax.set_xticklabels([]) 

851 ax.set_yticklabels([]) 

852 

853 # Additional ax2 annotations: 

854 # Labels amplifiers with their respective gains for a visual check. 

855 for bbox, name, gain, color in zip( 

856 self.amp_bbox_list, self.amp_name_list_simplified, self.amp_gain_list, amp_colors 

857 ): 

858 # Get the center of the bbox to label the gain value. 

859 bbox_center = (bbox.minX + bbox.maxX) / 2, (bbox.minY + bbox.maxY) / 2 

860 # Label the gain value at the center of each amplifier segment. 

861 ax2.text( 

862 *bbox_center, 

863 f"gain$_{{\\rm \\, {name} \\,}}$: {gain:.3f}", 

864 fontsize=9, 

865 color=color, 

866 alpha=0.95, 

867 ha="center", 

868 va="center", 

869 path_effects=outline_effect(2), 

870 ) 

871 

872 # Additional ax3 annotations: 

873 # Label sources with numbers on the image plane. 

874 for i, pos in enumerate(self.positions, start=1): 

875 ax3.text( 

876 *pos, 

877 f"{i}", 

878 fontsize=7, 

879 color=sky_variance_color, 

880 path_effects=outline_effect(1.5), 

881 alpha=0.9, 

882 ) 

883 

884 # Now we use ax4 to plot the histograms of the variance planes for 

885 # comparison. 

886 # Plot the histogram of the original variance plane. 

887 hist_values, bins, _ = ax4.hist( 

888 arr1.flatten(), 

889 bins=80, 

890 histtype="step", 

891 color=original_variance_color, 

892 alpha=0.9, 

893 label="Original variance", 

894 ) 

895 # Fill the area under the step. 

896 ax4.fill_between( 

897 bins[:-1], 

898 hist_values, 

899 step="post", 

900 color=original_variance_color, 

901 alpha=0.09, 

902 hatch="/////", 

903 label=" ", 

904 ) 

905 # Plot the histogram of the corrected variance plane. 

906 ax4.hist( 

907 arr2.flatten(), 

908 bins=80, 

909 histtype="bar", 

910 color=corrected_variance_color, 

911 alpha=0.9, 

912 label="Corrected variance", 

913 ) 

914 adjust_tick_scale(ax4, {"x": "Variance [{scale} ADU$^2$]", "y": "Number of pixels / {scale}"}) 

915 ax4.yaxis.set_label_position("right") 

916 ax4.yaxis.tick_right() 

917 ax4.axvline( 

918 background_mean_variance_ADU, 

919 color=sky_variance_color, 

920 linestyle="--", 

921 linewidth=1, 

922 alpha=0.9, 

923 label="Average sky variance\nacross all amps", 

924 ) 

925 

926 # Use colored arrowheads to mark true amp variances. 

927 sorted_vars = sorted(amp_background_variance_ADU_list) 

928 count = {v: 0 for v in sorted_vars} 

929 for i, (x, name, gain, color) in enumerate( 

930 zip( 

931 amp_background_variance_ADU_list, 

932 self.amp_name_list_simplified, 

933 self.amp_gain_list, 

934 amp_colors, 

935 ) 

936 ): 

937 arrowhead = arrowheads_ud[int(gain < average_gain)] 

938 arrowhead_text = ax4.annotate( 

939 arrowhead, 

940 xy=(x, 0), 

941 xycoords=("data", "axes fraction"), 

942 textcoords="offset points", 

943 xytext=(0, 0), 

944 ha="center", 

945 va="bottom", 

946 fontsize=6.5, 

947 color=color, 

948 clip_on=False, 

949 alpha=0.85, 

950 path_effects=outline_effect(1.5), 

951 ) 

952 if i == 0: 

953 # Draw the canvas once to make sure the renderer is active. 

954 fig.canvas.draw() 

955 # Get the bounding box of the text annotation in axes 

956 # fraction. 

957 bbox_axes = arrowhead_text.get_window_extent().transformed(ax4.transAxes.inverted()) 

958 # Get the height of the text annotation in axes fraction. 

959 height = bbox_axes.height 

960 # Increment the arrowhead y positions to avoid overlap. 

961 var_idxs = np.where(sorted_vars == x)[0] 

962 if len(var_idxs) > 1: 

963 q = count[x] 

964 count[x] += 1 

965 else: 

966 q = 0 

967 arrowhead_text.xy = (x, var_idxs[q] * height) 

968 

969 # Create a proxy artist for the legend since annotations are 

970 # not shown in the legend. 

971 label = "True variance of" if i == 0 else "$\u21AA$" 

972 ax4.scatter( 

973 [], 

974 [], 

975 color=color, 

976 marker=arrowhead, 

977 s=15, 

978 label=f"{label} {name}", 

979 alpha=0.85, 

980 path_effects=outline_effect(1.5), 

981 ) 

982 

983 # Group the legend handles and label them. 

984 adjust_legend_with_groups( 

985 ax4, 

986 [(0, 1), 2, 3, *range(4, 4 + len(self.amp_name_list_simplified))], 

987 colors="match", 

988 handlelength=1.9, 

989 ) 

990 

991 # Align the histogram (bottom right panel) with the colorbar of the 

992 # corrected variance plane (top right panel) for aesthetic reasons. 

993 pos2 = ax2.get_position() 

994 pos4 = ax4.get_position() 

995 fig.canvas.draw() # Render to ensure accurate colorbar width 

996 cbar_width = cbar.ax.get_position().width 

997 ax4.set_position([pos2.x0, pos4.y0, pos2.width + cbar_width, pos4.height]) 

998 

999 # Increase all axes spines' linewidth by 20% for a bolder look. 

1000 for ax in fig.get_axes(): 

1001 for spine in ax.spines.values(): 

1002 spine.set_linewidth(spine.get_linewidth() * 1.2) 

1003 

1004 # Save the figure. 

1005 filename = f"variance_plane_gain{average_gain}_sigma{gain_sigma_factor}_sky{sky_level}.png" 

1006 fig.savefig(filename, dpi=300) 

1007 print(f"Saved plot of variance plane before and after correction in {filename}") 

1008 

1009 

1010class TestMemory(lsst.utils.tests.MemoryTestCase): 

1011 pass 

1012 

1013 

1014def setup_module(module): 

1015 lsst.utils.tests.init() 

1016 

1017 

1018if __name__ == "__main__": 1018 ↛ 1019line 1018 didn't jump to line 1019, because the condition on line 1018 was never true

1019 lsst.utils.tests.init() 

1020 unittest.main()