Coverage for tests/test_variance_plane.py: 8%

329 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-11 04:06 -0700

1# This file is part of ip_isr. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21# 

22"""The utility function under test is part of the `meas_algorithms` package. 

23This unit test has been relocated here to avoid circular dependencies. 

24""" 

25 

26import re 

27import unittest 

28from contextlib import nullcontext 

29 

30import galsim 

31import lsst.utils.tests 

32import matplotlib.colors as mcolors 

33import matplotlib.pyplot as plt 

34import numpy as np 

35from lsst.ip.isr.isrMock import IsrMock 

36from lsst.meas.algorithms import remove_signal_from_variance 

37from lsst.utils.tests import methodParametersProduct 

38from matplotlib.legend_handler import HandlerTuple 

39from matplotlib.patheffects import withStroke 

40from matplotlib.ticker import FixedLocator, FuncFormatter 

41 

42# Set to True to save the plot of the variance plane before and after 

43# correction for a representative test case. 

44SAVE_PLOT = False 

45 

46 

47def outline_effect(lw, alpha=0.8): 

48 """Generate a path effect for enhanced text visibility. 

49 

50 Parameters 

51 ---------- 

52 lw : `float` 

53 Line width of the outline. 

54 alpha : `float`, optional 

55 Transparency of the outline. 

56 

57 Returns 

58 ------- 

59 `list` of `matplotlib.patheffects.withStroke` 

60 A list containing the path effect. 

61 """ 

62 return [withStroke(linewidth=lw, foreground="white", alpha=alpha)] 

63 

64 

65class CustomHandler(HandlerTuple): 

66 """Custom handler for handling grouped items in the legend.""" 

67 

68 def create_artists(self, *args): 

69 artists = super().create_artists(*args) 

70 for a in artists: 

71 a.set_transform(args[-1]) 

72 return artists 

73 

74 

75def get_valid_color(handle): 

76 """Extracts a valid color from a Matplotlib handle. 

77 

78 Parameters 

79 ---------- 

80 handle : `matplotlib.artist.Artist` 

81 The handle from which to extract the color. 

82 

83 Returns 

84 ------- 

85 color : `str` or `tuple` 

86 The color extracted from the handle, or 'default' if no valid color is 

87 found. 

88 """ 

89 for attr in ["get_facecolor", "get_edgecolor", "get_color"]: 

90 if hasattr(handle, attr): 

91 color = getattr(handle, attr)() 

92 # If the handle is a collection, use the first color. 

93 if isinstance(color, np.ndarray) and color.shape[0] > 0: 

94 color = color[0] 

95 # If the color is RGBA with alpha = 0, continue the search. 

96 if len(color) == 4 and color[3] == 0: 

97 continue 

98 return color 

99 return "default" # If no valid color is found 

100 

101 

102def get_emptier_side(ax): 

103 """Analyze a matplotlib Axes object to determine which side (left or right) 

104 has more whitespace, considering cases where artists' bounding boxes span 

105 both sides of the midpoint. 

106 

107 Parameters 

108 ---------- 

109 ax : `~matplotlib.axes.Axes` 

110 The Axes object to analyze. 

111 

112 Returns 

113 ------- 

114 more_whitespace_side : `str` 

115 'left' if the left side has more whitespace, or 'right' if the right 

116 side does. 

117 """ 

118 # Get the total plotting area's midpoint on the x-axis. 

119 xlim = ax.get_xlim() 

120 midpoint = sum(xlim) / 2 

121 

122 # Initialize areas as zero. 

123 left_area, right_area = 0, 0 

124 

125 # Loop through all children (artists) in the Axes. 

126 for artist in ax.get_children(): 

127 # Skip if artist is invisible or lacks a bounding box. 

128 if not artist.get_visible() or not hasattr(artist, "get_window_extent"): 

129 continue 

130 bbox = artist.get_window_extent().transformed(ax.figure.dpi_scale_trans.inverted()) 

131 # Check if the artist's bounding box spans the midpoint. 

132 if bbox.x0 < midpoint < bbox.x1: 

133 # Calculate the proportion of the bbox on each side of the 

134 # midpoint. 

135 left_proportion = (midpoint - bbox.x0) / bbox.width 

136 right_proportion = 1 - left_proportion 

137 # Adjust area calculations for both sides. 

138 left_area += bbox.width * bbox.height * left_proportion 

139 right_area += bbox.width * bbox.height * right_proportion 

140 elif bbox.x0 + bbox.width / 2 < midpoint: 

141 # Entirely on the left. 

142 left_area += bbox.width * bbox.height 

143 else: 

144 # Entirely on the right. 

145 right_area += bbox.width * bbox.height 

146 

147 # Determine which side has more whitespace by comparing occupied areas. 

148 return "left" if left_area <= right_area else "right" 

149 

150 

151def adjust_legend_with_groups(ax, combine_groups, colors="default", yloc="upper", **kwargs): 

152 """Adjusts the legend of a given Axes object by combining specified handles 

153 based on provided groups, setting the marker location and text alignment 

154 based on the inferable emptier side, and optionally setting the text color 

155 of legend entries to a provided list or inferring colors from the handles. 

156 Additionally, allows specifying the vertical location of the legend within 

157 the plot. 

158 

159 Parameters 

160 ---------- 

161 ax : `~matplotlib.axes.Axes` 

162 The Axes object for which to adjust the legend. 

163 combine_groups : `list` of `int` or iterable of `int` 

164 A list that can contain a mix of individual integers and/or iterables 

165 (lists, tuples, or sets) of integers. An individual integer specifies 

166 the index of a single legend entry. An iterable of integers specifies 

167 a group of indices to be combined into a single legend entry. 

168 colors : `list` of `str` or `tuple`, or `str`, optional 

169 Specifies the colors for the legend entries. This parameter can be: 

170 - A list of color specifications, where each element is a string (for 

171 named colors or hex values) or a tuple (for RGB or RGBA values). This 

172 list explicitly assigns colors to each legend entry post-combination. 

173 - A single string value: 

174 - "match": Colors are inferred from the properties of the first 

175 handle in each group that corresponds to a non-white-space label. 

176 This aims to match the legend text color with the color of the 

177 plotted data. 

178 - "default": The function does not alter the default colors 

179 assigned by Matplotlib, preserving the automatic color assignment 

180 for all legend entries. 

181 yloc : `str`, optional 

182 The vertical location of the legend within the Axes. Valid options are 

183 'upper', 'lower', or 'middle'. This parameter is combined with the 

184 inferable emptier side ('left' or 'right') to determine the legend's 

185 placement. For example, 'upper right' or 'lower left'. 

186 **kwargs : 

187 Keyword arguments forwarded to the ``ax.legend`` function. 

188 """ 

189 

190 handles, labels = ax.get_legend_handles_labels() 

191 new_handles = [] 

192 new_labels = [] 

193 

194 if colors == "match": 

195 colors = [] 

196 infer_colors = True 

197 else: 

198 infer_colors = False 

199 

200 for group in combine_groups: 

201 # Assume the first non-white-space label represents the group. If no 

202 # such label is found, just use the first label in the group which is 

203 # in fact empty. 

204 if isinstance(group, (list, tuple, set)): 

205 group = list(group) # Just in case 

206 label_index = next((i for i in group if labels[i].strip()), group[0]) 

207 combined_handle = tuple(handles[i] for i in group) 

208 combined_label = labels[label_index] 

209 elif isinstance(group, int): 

210 label_index = group 

211 combined_handle = handles[group] 

212 combined_label = labels[group] 

213 else: 

214 raise ValueError("Invalid value in 'combine_groups'") 

215 new_handles.append(combined_handle) 

216 new_labels.append(combined_label) 

217 if infer_colors: 

218 # Attempt to infer color from the representative handle in the 

219 # group. 

220 handle = handles[label_index] 

221 color = get_valid_color(handle) 

222 colors.append(color) 

223 

224 # Determine the emptier side to decide legend and text alignment. 

225 emptier_side = get_emptier_side(ax) 

226 markerfirst = emptier_side != "right" 

227 

228 # Create the legend with custom adjustments. 

229 legend = ax.legend( 

230 new_handles, 

231 new_labels, 

232 handler_map={tuple: CustomHandler()}, 

233 loc=f"{yloc} {emptier_side}", 

234 fontsize=8, 

235 frameon=False, 

236 markerfirst=markerfirst, 

237 **kwargs, 

238 ) 

239 

240 # Right- or left-align the legend text based on the emptier side. 

241 for text in legend.get_texts(): 

242 text.set_ha(emptier_side) 

243 

244 # Set legend text colors if necessary. 

245 if colors != "default": 

246 for text, color in zip(legend.get_texts(), colors): 

247 if not (isinstance(color, str) and color == "default"): 

248 text.set_color(color) 

249 

250 

251def adjust_tick_scale(ax, axis_label_templates): 

252 """Scales down tick labels to make them more readable and updates axis 

253 labels accordingly. 

254 

255 Calculates a power of 10 scale factor (common divisor) to reduce the 

256 magnitude of tick labels. It automatically determines which axes to adjust 

257 based on the provided axis label templates, which should include `{scale}` 

258 for inserting the scale factor dynamically. 

259 

260 Parameters 

261 ---------- 

262 ax : `~matplotlib.axes.Axes` 

263 The Axes object to modify. 

264 axis_label_templates : `dict` 

265 Templates for axis labels, including '{scale}' for scale factor 

266 insertion. Keys should be one or more of axes names ('x', 'y', 'z') and 

267 values should be the corresponding label templates. 

268 """ 

269 

270 def trailing_zeros(n): 

271 """Determines the number of trailing zeros in a number.""" 

272 return len(n := str(int(float(n))) if float(n).is_integer() else n) - len(n.rstrip("0")) 

273 

274 def format_tick(val, pos, divisor): 

275 """Formats tick labels using the determined divisor.""" 

276 return str(int(val / divisor)) if (val / divisor).is_integer() else str(val / divisor) 

277 

278 # Iterate through the specified axes and adjust their tick labels and axis 

279 # labels. 

280 for axis in axis_label_templates.keys(): 

281 # Gather current tick labels. 

282 labels = [label.get_text() for label in getattr(ax, f"get_{axis}ticklabels")()] 

283 

284 # Calculate the power of 10 divisor based on the minimum number of 

285 # trailing zeros in the tick labels. 

286 divisor = 10 ** min(trailing_zeros(label) for label in labels if float(label) != 0) 

287 

288 # Set a formatter for the axis ticks that scales them according to the 

289 # common divisor. 

290 getattr(ax, f"{axis}axis").set_major_formatter( 

291 FuncFormatter(lambda val, pos: format_tick(val, pos, divisor)) 

292 ) 

293 

294 # Ensure the tick positions remain unchanged despite the new 

295 # formatting. 

296 getattr(ax, f"{axis}axis").set_major_locator(FixedLocator(getattr(ax, f"get_{axis}ticks")())) 

297 

298 # Prepare 'scale', empty if divisor <= 1. 

299 scale = f"{int(divisor)}" if divisor > 1 else "" 

300 

301 # Fetch the corresponding label template for the axis. 

302 label_template = axis_label_templates[axis] 

303 

304 # If 'scale' is empty, remove whitespace around "{scale}" in the 

305 # template. Also remove any trailing "/{scale}". 

306 if scale == "": 

307 label_template = re.sub(r"\s*{\s*scale\s*}\s*", "{scale}", label_template) 

308 label_template = label_template.replace("/{scale}", "") 

309 

310 # Always strip remaining whitespace from the template. 

311 label_template = label_template.strip() 

312 

313 # Set the formatted axis label. 

314 label_text = label_template.format(scale=scale) 

315 getattr(ax, f"set_{axis}label")(label_text, labelpad=8) 

316 

317 

318class VariancePlaneTestCase(lsst.utils.tests.TestCase): 

319 def setUp(self): 

320 # Testing with a single detector that has 8 amplifiers in a 4x2 

321 # configuration. Each amplifier measures 100x51 in dimensions. 

322 config = IsrMock.ConfigClass() 

323 config.isLsstLike = True 

324 config.doAddBias = False 

325 config.doAddDark = False 

326 config.doAddFlat = False 

327 config.doAddFringe = False 

328 config.doGenerateImage = True 

329 config.doGenerateData = True 

330 config.doGenerateAmpDict = True 

331 self.mock = IsrMock(config=config) 

332 self.pixel_scale = 0.2 # arcsec/pixel 

333 

334 # Set the random seed for reproducibility. 

335 random_seed = galsim.BaseDeviate(1905).raw() + 1 

336 np.random.seed(random_seed) 

337 self.rng = galsim.BaseDeviate(random_seed) 

338 

339 # Get the exposure, detector, and amps from the mock. 

340 exposure = self.mock.getExposure() 

341 detector = exposure.getDetector() 

342 amps = detector.getAmplifiers() 

343 

344 # Set amp-related attributes for use in the test cases. 

345 self.num_amps = len(amps) 

346 self.amp_name_list = [amp.getName() for amp in amps] 

347 table = str.maketrans("", "", ":,") # Remove ':' and ',' from names 

348 self.amp_name_list_simplified = [name.translate(table) for name in self.amp_name_list] 

349 

350 # Get the bounding boxes for the exposure and amplifiers and convert 

351 # them to galsim bounds. 

352 exp_bbox = exposure.getBBox() 

353 image_bounds = galsim.BoundsI(exp_bbox.minX, exp_bbox.maxX, exp_bbox.minY, exp_bbox.maxY) 

354 self.amp_bbox_list = [amp.getBBox() for amp in amps] 

355 self.amp_bounds_list = [galsim.BoundsI(b.minX, b.maxX, b.minY, b.maxY) for b in self.amp_bbox_list] 

356 

357 # Create a raw galsim image to potentially draw the sources onto. The 

358 # exposure image that is passed to this method will be modified in 

359 # place but it won't be used. 

360 self.signal_free_raw_image = galsim.ImageF(exposure.image.array, bounds=image_bounds) 

361 self.raw_image = self.signal_free_raw_image.copy() 

362 

363 # Define parameters for a mix of source types, including extended 

364 # sources with assorted profiles as well as point sources simulated 

365 # with minimal half-light radii to resemble hot pixels 

366 # post-deconvolution. All flux values are given in electrons and 

367 # half-light radii in pixels. The goal is for each amplifier to 

368 # predominantly contain at least one source, enhancing the 

369 # representativeness of test conditions. 

370 source_params = [ 

371 {"type": "Sersic", "n": 3, "flux": 1.6e5, "half_light_radius": 3.5, "g1": -0.3, "g2": 0.2}, 

372 {"type": "Sersic", "n": 1, "flux": 9.3e5, "half_light_radius": 2.1, "g1": 0.25, "g2": 0.12}, 

373 {"type": "Sersic", "n": 4, "flux": 1.0e5, "half_light_radius": 1.1, "g1": 0.0, "g2": 0.0}, 

374 {"type": "Sersic", "n": 3, "flux": 1.1e6, "half_light_radius": 4.2, "g1": 0.0, "g2": 0.2}, 

375 {"type": "Sersic", "n": 5, "flux": 1.1e5, "half_light_radius": 3.6, "g1": 0.22, "g2": -0.05}, 

376 {"type": "Sersic", "n": 2, "flux": 4.3e5, "half_light_radius": 2.0, "g1": 0.0, "g2": 0.0}, 

377 {"type": "Sersic", "n": 6, "flux": 1.2e6, "half_light_radius": 11.0, "g1": -0.16, "g2": 0.7}, 

378 {"type": "Exponential", "flux": 1.3e6, "half_light_radius": 1.9, "g1": 0.3, "g2": -0.1}, 

379 {"type": "Exponential", "flux": 1.8e6, "half_light_radius": 5.0, "g1": 0.0, "g2": 0.14}, 

380 {"type": "Exponential", "flux": 6.6e6, "half_light_radius": 4.8, "g1": 0.26, "g2": 0.5}, 

381 {"type": "Exponential", "flux": 7.0e5, "half_light_radius": 3.1, "g1": -0.3, "g2": 0.0}, 

382 {"type": "DeVaucouleurs", "flux": 1.6e5, "half_light_radius": 3.5, "g1": 0.2, "g2": 0.4}, 

383 {"type": "DeVaucouleurs", "flux": 2.0e5, "half_light_radius": 1.6, "g1": -0.06, "g2": -0.2}, 

384 {"type": "DeVaucouleurs", "flux": 8.3e5, "half_light_radius": 5.1, "g1": 0.29, "g2": 0.0}, 

385 {"type": "DeVaucouleurs", "flux": 4.5e5, "half_light_radius": 2.5, "g1": 0.4, "g2": 0.3}, 

386 {"type": "DeVaucouleurs", "flux": 6.2e5, "half_light_radius": 4.9, "g1": -0.08, "g2": -0.01}, 

387 {"type": "Gaussian", "flux": 4.7e6, "half_light_radius": 2.5, "g1": 0.07, "g2": -0.35}, 

388 {"type": "Gaussian", "flux": 5.8e6, "half_light_radius": 3.1, "g1": 0.03, "g2": 0.4}, 

389 {"type": "Gaussian", "flux": 2.3e5, "half_light_radius": 0.5, "g1": 0.0, "g2": 0.0}, 

390 {"type": "Gaussian", "flux": 1.6e6, "half_light_radius": 3.0, "g1": 0.18, "g2": -0.29}, 

391 {"type": "Gaussian", "flux": 3.5e5, "half_light_radius": 4.6, "g1": 0.5, "g2": 0.35}, 

392 {"type": "Gaussian", "flux": 5.9e5, "half_light_radius": 9.5, "g1": 0.1, "g2": 0.55}, 

393 {"type": "Gaussian", "flux": 4.0e5, "half_light_radius": 1.0, "g1": 0.0, "g2": 0.0}, 

394 ] 

395 

396 # Mapping of profile types to their galsim constructors. 

397 profile_constructors = { 

398 "Sersic": galsim.Sersic, 

399 "Exponential": galsim.Exponential, 

400 "DeVaucouleurs": galsim.DeVaucouleurs, 

401 "Gaussian": galsim.Gaussian, 

402 } 

403 

404 # Generate random positions within exposure bounds, avoiding edges by a 

405 # margin. 

406 margin_x, margin_y = 0.05 * exp_bbox.width, 0.05 * exp_bbox.height 

407 self.positions = np.random.uniform( 

408 [exp_bbox.minX + margin_x, exp_bbox.minY + margin_y], 

409 [exp_bbox.maxX - margin_x, exp_bbox.maxY - margin_y], 

410 (len(source_params), 2), 

411 ).tolist() 

412 

413 # Loop over the sources and draw them onto the image cutout by cutout. 

414 for i, params in enumerate(source_params): 

415 # Dynamically get constructor and remove type from params. 

416 constructor = profile_constructors[params.pop("type")] 

417 

418 # Get shear parameters and remove them from params. 

419 g1, g2 = params.pop("g1"), params.pop("g2") 

420 

421 # The extent of the cutout should be large enough to contain the 

422 # entire object above the background level. Some empirical factor 

423 # is used to mitigate artifacts. 

424 half_extent = 10 * params["half_light_radius"] * (1 + 2 * np.sqrt(g1**2 + g2**2)) 

425 

426 # Pass the remaining params to the constructor and apply shear. 

427 galsim_object = constructor(**params).shear(galsim.Shear(g1=g1, g2=g2)) 

428 

429 # Retrieve the position of the object. 

430 x, y = self.positions[i] 

431 pos = galsim.PositionD(x, y) 

432 

433 # Get the bounds of the sub-image based on the object position. 

434 sub_image_bounds = galsim.BoundsI( 

435 *map(int, [x - half_extent, x + half_extent, y - half_extent, y + half_extent]) 

436 ) 

437 

438 # Identify the overlap region, which could be partially outside 

439 # the image bounds. 

440 sub_image_bounds = sub_image_bounds & self.raw_image.bounds 

441 

442 # Check that there is some overlap. 

443 assert sub_image_bounds.isDefined(), "No overlap with image bounds" 

444 

445 # Get the sub-image cutout. 

446 sub_image = self.raw_image[sub_image_bounds] 

447 

448 # Draw the object onto the image within the the sub-image bounds. 

449 galsim_object.drawImage( 

450 image=sub_image, 

451 offset=pos - sub_image.true_center, 

452 method="real_space", # Saves memory, usable w/o convolution 

453 add_to_image=True, # Add flux to existing image 

454 scale=self.pixel_scale, 

455 ) 

456 

457 def tearDown(self): 

458 del self.mock 

459 del self.raw_image 

460 del self.signal_free_raw_image 

461 

462 def buildExposure( 

463 self, 

464 average_gain, 

465 gain_sigma_factor, 

466 sky_level, 

467 add_signal=True, 

468 ): 

469 """Build and return an exposure with different types of simulated 

470 source profiles and a background sky level. It's intended for testing 

471 and analysis, providing a way to generate exposures with controlled 

472 conditions. 

473 

474 Parameters 

475 ---------- 

476 average_gain : `float` 

477 The average gain value of amplifiers in e-/ADU. 

478 gain_sigma_factor : float 

479 The standard deviation of the gain values as a factor of the 

480 ``average_gain``. 

481 sky_level : `float` 

482 The background sky level in e-/arcsec^2. 

483 add_signal : `bool`, optional 

484 Whether to add sources to the exposure. If set to False, the 

485 exposure will only contain background noise. 

486 

487 Returns 

488 ------- 

489 exposure : `~lsst.afw.image.Exposure` 

490 An exposure object with simulated sources and background. The units 

491 are in detector counts (ADU). 

492 """ 

493 

494 # Get the exposure from the mock. 

495 exposure = self.mock.getExposure() 

496 

497 # Convert the background sky level from e-/arcsec^2 to e-/pixel. 

498 self.background = sky_level * self.pixel_scale**2 

499 

500 # Generate random deviations from the average gain across amplifiers 

501 # and adjust them to ensure their sum equals zero. This reflects 

502 # real-world detectors, with amplifier gains normally distributed due 

503 # to manufacturing and operational variations. 

504 deviations = np.random.normal(average_gain, gain_sigma_factor * average_gain, size=self.num_amps) 

505 deviations -= np.mean(deviations) 

506 

507 # Set the gain for amplifiers to be slightly different from each other 

508 # while averaging to `average_gain`. This is to test the 

509 # `average_across_amps` option in the `remove_signal_from_variance` 

510 # function. 

511 self.amp_gain_list = [average_gain + deviation for deviation in deviations] 

512 

513 # Add a constant background to the entire image (both in e-/pixel). 

514 if add_signal: 

515 image = self.raw_image + self.background 

516 else: 

517 image = self.signal_free_raw_image + self.background 

518 

519 # Add noise to the image which is in electrons. Note that we won't 

520 # specify a `sky_level` here to avoid double-counting it, as it's 

521 # already included in the image as the background. 

522 image.addNoise(galsim.PoissonNoise(self.rng)) 

523 

524 # Subtract off the background to get the sky-subtracted image. 

525 image -= self.background 

526 

527 # Adjust each amplifier's image segment by its respective gain. After 

528 # this step, the image will be in ADUs. 

529 for bounds, gain in zip(self.amp_bounds_list, self.amp_gain_list): 

530 image[bounds] /= gain 

531 

532 # We know that the exposure has already been modified in place, but 

533 # just to be extra sure, we'll set the exposure image explicitly. 

534 exposure.image.array = image.array 

535 

536 # Create a variance plane for the exposure while including signal as a 

537 # pollutant. Note that the exposure image is pre-adjusted for gain, 

538 # unlike 'self.background'. Thus, we divide the background by the 

539 # corresponding gain before adding it to the image. This leads to the 

540 # variance plane being in units of ADU^2. 

541 for bbox, gain in zip(self.amp_bbox_list, self.amp_gain_list): 

542 exposure.variance[bbox].array = (exposure.image[bbox].array + self.background / gain) / gain 

543 

544 return exposure 

545 

546 def test_no_signal_handling(self): 

547 """Test that the function does nearly nothing when given an image with 

548 no signal. 

549 """ 

550 # Create an exposure with no signal. 

551 exposure = self.buildExposure( 

552 average_gain=1.4, gain_sigma_factor=0.01, sky_level=4e6, add_signal=False 

553 ) 

554 # Remove the signal from the variance plane, if any. 

555 updated_variance = remove_signal_from_variance(exposure, in_place=False) 

556 # Check that the variance plane is nearly the same as the original. 

557 self.assertFloatsAlmostEqual(exposure.variance.array, updated_variance.array, rtol=0.013) 

558 

559 def test_in_place_handling(self): 

560 """Make sure the function is tested to handle in-place operations.""" 

561 # Create an exposure with signal. 

562 exposure = self.buildExposure( 

563 average_gain=1.4, gain_sigma_factor=0.01, sky_level=4e6, add_signal=True 

564 ) 

565 # Remove the signal from the variance plane. 

566 updated_variance = remove_signal_from_variance(exposure, in_place=True) 

567 # Retrieve the variance plane from the exposure and check that it is 

568 # identical to the returned variance plane. 

569 self.assertFloatsEqual(exposure.variance.array, updated_variance.array) 

570 

571 @methodParametersProduct( 

572 average_gain=[1.4, 1.7], 

573 predefined_gain_type=["average", "per-amp", None], 

574 gain_sigma_factor=[0, 0.008], 

575 sky_level=[2e6, 4e6], 

576 average_across_amps=[False, True], 

577 ) 

578 def test_variance_signal_removal( 

579 self, average_gain, predefined_gain_type, gain_sigma_factor, sky_level, average_across_amps 

580 ): 

581 exposure = self.buildExposure( 

582 average_gain=average_gain, 

583 gain_sigma_factor=gain_sigma_factor, 

584 sky_level=sky_level, 

585 add_signal=True, 

586 ) 

587 

588 # Save the original variance plane for comparison, assuming it has 

589 # Poisson contribution from the source signal. 

590 signal_polluted_variance = exposure.variance.clone() 

591 

592 # Check that the variance plane has no negative values. 

593 self.assertTrue( 

594 np.all(signal_polluted_variance.array >= 0), 

595 "Variance plane has negative values (pre correction)", 

596 ) 

597 

598 if predefined_gain_type == "average": 

599 predefined_gain = average_gain 

600 predefined_gains = None 

601 elif predefined_gain_type == "per-amp": 

602 predefined_gain = None 

603 predefined_gains = {name: gain for name, gain in zip(self.amp_name_list, self.amp_gain_list)} 

604 elif predefined_gain_type is None: 

605 # Allow the 'remove_signal_from_variance' function to estimate the 

606 # gain itself before it attempts to remove the signal from the 

607 # variance plane. 

608 predefined_gain = None 

609 predefined_gains = None 

610 

611 # Set the relative tolerance for the variance plane checks. 

612 if predefined_gain_type == "average" or (predefined_gain_type is None and average_across_amps): 

613 # Relax the tolerance if we are simply averaging across amps to 

614 # roughly estimate the overall gain. 

615 rtol = 0.015 

616 estimate_average_gain = True 

617 else: 

618 # Tighten tolerance for the 'predefined_gain_type' of 'per-amp' or 

619 # for a more accurate per-amp gain estimation strategy. 

620 rtol = 3e-7 

621 estimate_average_gain = False 

622 

623 # Remove the signal from the variance plane. 

624 signal_free_variance = remove_signal_from_variance( 

625 exposure, 

626 gain=predefined_gain, 

627 gains=predefined_gains, 

628 average_across_amps=average_across_amps, 

629 in_place=False, 

630 ) 

631 

632 # Check that the variance plane has been modified. 

633 self.assertFloatsNotEqual(signal_polluted_variance.array, signal_free_variance.array) 

634 

635 # Check that the corrected variance plane has no negative values. 

636 self.assertTrue( 

637 np.all(signal_free_variance.array >= 0), "Variance plane has negative values (post correction)" 

638 ) 

639 

640 for bbox, gain in zip(self.amp_bbox_list, self.amp_gain_list): 

641 # Calculate the true variance in theoretical terms. 

642 true_var_amp = self.background / gain**2 

643 # Pair each variance with the appropriate context manager before 

644 # looping through them. 

645 var_context_pairs = [ 

646 # For the signal-free variance, directly execute the checks. 

647 (signal_free_variance, nullcontext()), 

648 # For the signal-polluted variance, expect AssertionError 

649 # unless we are averaging across amps. 

650 ( 

651 signal_polluted_variance, 

652 nullcontext() if estimate_average_gain else self.assertRaises(AssertionError), 

653 ), 

654 ] 

655 for var, context_manager in var_context_pairs: 

656 # Extract the segment of the variance plane for the amplifier. 

657 var_amp = var[bbox] 

658 with context_manager: 

659 if var is signal_polluted_variance and estimate_average_gain: 

660 # Skip rigorous checks on the signal-polluted variance, 

661 # if we are averaging across amps. 

662 pass 

663 else: 

664 # Get the variance value at the first pixel of the 

665 # segment to compare with the rest of the pixels and 

666 # the true variance. 

667 v00 = var_amp.array[0, 0] 

668 # Assert that the variance plane is almost uniform 

669 # across the segment because the signal has been 

670 # removed from it and the background is constant. 

671 self.assertFloatsAlmostEqual(var_amp.array, v00, rtol=rtol) 

672 # Assert that the variance plane is almost equal to the 

673 # true variance across the segment. 

674 self.assertFloatsAlmostEqual(v00, true_var_amp, rtol=rtol) 

675 

676 if ( 

677 SAVE_PLOT 

678 and not average_across_amps 

679 and gain_sigma_factor in (0, 0.008) 

680 and sky_level == 4e6 

681 and average_gain == 1.7 

682 and predefined_gain_type is None 

683 ): 

684 fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(10, 8.5)) 

685 plt.subplots_adjust(wspace=0.17, hspace=0.17) 

686 colorbar_aspect = 12 

687 

688 amp_background_variance_ADU_list = [self.background / gain**2 for gain in self.amp_gain_list] 

689 amp_background_image_ADU_list = [self.background / gain for gain in self.amp_gain_list] 

690 # Calculate the mean value that corresponds to the background for 

691 # the variance plane, adjusting for the gain. 

692 background_mean_variance_ADU = np.mean( 

693 [self.background / gain**2 for gain in self.amp_gain_list] 

694 ) 

695 

696 # Extract the variance planes and the image from the exposure. 

697 arr1 = signal_polluted_variance.array # Variance with signal 

698 arr2 = signal_free_variance.array # Variance without signal 

699 exp_im = exposure.image.clone() # Clone of the image plane 

700 

701 # Incorporate the gain-adjusted background into the image plane to 

702 # enable combined visualization of sources with the background. 

703 for gain, bbox in zip(self.amp_gain_list, self.amp_bbox_list): 

704 exp_im[bbox].array += self.background / gain 

705 arr3 = exp_im.array 

706 

707 # Define colors visually distinct from each other for the subplots. 

708 original_variance_color = "#8A2BE2" # Periwinkle 

709 corrected_variance_color = "#618B3C" # Lush Forest Green 

710 sky_variance_color = "#c3423f" # Crimson Red 

711 amp_colors = [ 

712 "#1f77b4", # Muted Blue 

713 "#ff7f0e", # Vivid Orange 

714 "#2ca02c", # Kelly Green 

715 "#d62728", # Brick Red 

716 "#9467bd", # Soft Purple 

717 "#8B4513", # Saddle Brown 

718 "#e377c2", # Pale Violet Red 

719 "#202020", # Onyx 

720 ] 

721 arrowheads_lr = ["$\u25C0$", "$\u25B6$"] # Left- & right-pointing 

722 arrowheads_ud = ["$\u25B2$", "$\u25BC$"] # Up- & down-pointing 

723 

724 # Set titles for the subplots. 

725 ax1.set_title("Original variance plane", color=original_variance_color) 

726 ax2.set_title("Corrected variance plane", color=corrected_variance_color) 

727 ax3.set_title("Image + background ($\\mathit{uniform}$)") 

728 ax4.set_title("Histogram of variances") 

729 

730 # Collect all vertical and horizontal line positions to find the 

731 # amp boundaries. 

732 vlines, hlines = set(), set() 

733 for bbox in self.amp_bbox_list: 

734 # Adjst by 0.5 for merging of lines at the boundaries. 

735 vlines.update({bbox.minX - 0.5, bbox.maxX + 0.5}) 

736 hlines.update({bbox.minY - 0.5, bbox.maxY + 0.5}) 

737 

738 # Filter lines at the edges of the overall image bbox. 

739 image_bbox = exposure.getBBox() 

740 vlines = {x for x in vlines if image_bbox.minX < x < image_bbox.maxX} 

741 hlines = {y for y in hlines if image_bbox.minY < y < image_bbox.maxY} 

742 

743 # Plot image and variance planes. 

744 for plane, arr, ax in zip( 

745 ("variance", "variance_corrected", "image"), (arr1, arr2, arr3), (ax1, ax2, ax3) 

746 ): 

747 # We skip 'variance_corrected' in the loop below because we use 

748 # the same normalization and colormap as 'variance' for it. 

749 if plane in ["variance", "image"]: 

750 # Get the normalization. 

751 vmin, vmax = arr.min(), arr.max() 

752 norm = mcolors.Normalize(vmin=vmin, vmax=vmax) 

753 

754 # Get the thresholds corresponding to per-amp backgrounds 

755 # and their positions in the normalized color scale. 

756 thresholds = ( 

757 amp_background_variance_ADU_list 

758 if plane.startswith("variance") 

759 else amp_background_image_ADU_list 

760 ) 

761 threshold_positions = [norm(t) for t in thresholds] 

762 threshold = np.mean(thresholds) 

763 threshold_position = np.mean(threshold_positions) 

764 

765 # Create a custom colormap with two distinct colors for the 

766 # sky and source contributions. 

767 border = (threshold - vmin) / (vmax - vmin) 

768 colors1 = plt.cm.Purples_r(np.linspace(0, 1, int(border * 256))) 

769 colors2 = plt.cm.Greens(np.linspace(0, 1, int((1 - border) * 256))) 

770 colors = np.vstack((colors1, colors2)) 

771 cmap = mcolors.LinearSegmentedColormap.from_list("cmap", colors) 

772 

773 # Plot the array with the custom colormap and normalization. 

774 im = ax.imshow(arr, cmap=cmap, norm=norm) 

775 

776 # Add colorbars to the plot. 

777 cbar = fig.colorbar(im, aspect=colorbar_aspect, pad=0) 

778 

779 # Change the number of ticks on the colorbar for better 

780 # spacing. Needs to be done before modifying the tick labels. 

781 cbar.ax.locator_params(nbins=7) 

782 

783 # Enhance readability by scaling down colorbar tick labels. 

784 unit = "ADU$^2$" if plane.startswith("variance") else "ADU" 

785 adjust_tick_scale(cbar.ax, {"y": f"Value [{{scale}} {unit}]"}) 

786 

787 # Mark min and max per-amp thresholds with dotted lines on the 

788 # colorbar. 

789 for tp in [min(thresholds), max(thresholds)]: 

790 cbar.ax.axhline(tp, color="white", linestyle="-", linewidth=1.5, alpha=0.4) 

791 cbar.ax.axhline(tp, color=sky_variance_color, linestyle=":", linewidth=1.5, alpha=0.9) 

792 

793 # Mark mean threshold with facing arrowheads on the colorbar. 

794 cbar.ax.annotate( 

795 arrowheads_lr[1], # Right-pointing arrowhead 

796 xy=(0, threshold_position), 

797 xycoords="axes fraction", 

798 textcoords="offset points", 

799 xytext=(0, 0), 

800 ha="left", 

801 va="center", 

802 fontsize=6, 

803 color=sky_variance_color, 

804 clip_on=False, 

805 alpha=0.9, 

806 ) 

807 cbar.ax.annotate( 

808 arrowheads_lr[0], # Left-pointing arrowhead 

809 xy=(1, threshold_position), 

810 xycoords="axes fraction", 

811 textcoords="offset points", 

812 xytext=(0, 0), 

813 ha="right", 

814 va="center", 

815 fontsize=6, 

816 color=sky_variance_color, 

817 clip_on=False, 

818 alpha=0.9, 

819 ) 

820 

821 # Add text inside the colorbar to label the average threshold 

822 # position. 

823 sky_level_text = "$\u27E8$" + "Sky" + "$\u27E9$" # <Sky> 

824 sky_level_text_artist = cbar.ax.text( 

825 0.5, 

826 threshold_position, 

827 sky_level_text, 

828 va="center", 

829 ha="center", 

830 transform=cbar.ax.transAxes, 

831 fontsize=8, 

832 color=sky_variance_color, 

833 rotation="vertical", 

834 alpha=0.9, 

835 path_effects=outline_effect(2), 

836 ) 

837 

838 # Setup renderer and transformation. 

839 renderer = fig.canvas.get_renderer() 

840 transform = cbar.ax.transAxes.inverted() 

841 

842 # Transform the bounding box and calculate adjustment for 

843 # 'sky_level_text_artist' for when it goes beyond the colorbar. 

844 sky_level_text_bbox = sky_level_text_artist.get_window_extent(renderer).transformed(transform) 

845 adjustment = 1.4 * sky_level_text_bbox.height / 2 

846 

847 if sky_level_text_bbox.ymin < 0: 

848 sky_level_text_artist.set_y(adjustment) 

849 elif sky_level_text_bbox.ymax > 1: 

850 sky_level_text_artist.set_y(1 - adjustment) 

851 

852 # Draw amp boundaries as vertical and/or horizontal lines. 

853 line_color = "white" if np.mean(norm(arr)) > 0.5 else "#808080" 

854 for x in vlines: 

855 ax.axvline(x=x, color=line_color, linestyle="--", linewidth=1, alpha=0.7) 

856 for y in hlines: 

857 ax.axhline(y=y, color=line_color, linestyle="--", linewidth=1, alpha=0.7) 

858 # Hide all x and y tick marks. 

859 ax.tick_params(axis="both", which="both", bottom=False, top=False, left=False, right=False) 

860 # Hide all x and y tick labels. 

861 ax.set_xticklabels([]) 

862 ax.set_yticklabels([]) 

863 

864 # Additional ax2 annotations: 

865 # Labels amplifiers with their respective gains for a visual check. 

866 for bbox, name, gain, color in zip( 

867 self.amp_bbox_list, self.amp_name_list_simplified, self.amp_gain_list, amp_colors 

868 ): 

869 # Get the center of the bbox to label the gain value. 

870 bbox_center = (bbox.minX + bbox.maxX) / 2, (bbox.minY + bbox.maxY) / 2 

871 # Label the gain value at the center of each amplifier segment. 

872 ax2.text( 

873 *bbox_center, 

874 f"gain$_{{\\rm \\, {name} \\,}}$: {gain:.3f}", 

875 fontsize=9, 

876 color=color, 

877 alpha=0.95, 

878 ha="center", 

879 va="center", 

880 path_effects=outline_effect(2), 

881 ) 

882 

883 # Additional ax3 annotations: 

884 # Label sources with numbers on the image plane. 

885 for i, pos in enumerate(self.positions, start=1): 

886 ax3.text( 

887 *pos, 

888 f"{i}", 

889 fontsize=7, 

890 color=sky_variance_color, 

891 path_effects=outline_effect(1.5), 

892 alpha=0.9, 

893 ) 

894 

895 # Now we use ax4 to plot the histograms of the variance planes for 

896 # comparison. 

897 # Plot the histogram of the original variance plane. 

898 hist_values, bins, _ = ax4.hist( 

899 arr1.flatten(), 

900 bins=80, 

901 histtype="step", 

902 color=original_variance_color, 

903 alpha=0.9, 

904 label="Original variance", 

905 ) 

906 # Fill the area under the step. 

907 ax4.fill_between( 

908 bins[:-1], 

909 hist_values, 

910 step="post", 

911 color=original_variance_color, 

912 alpha=0.09, 

913 hatch="/////", 

914 label=" ", 

915 ) 

916 # Plot the histogram of the corrected variance plane. 

917 ax4.hist( 

918 arr2.flatten(), 

919 bins=80, 

920 histtype="bar", 

921 color=corrected_variance_color, 

922 alpha=0.9, 

923 label="Corrected variance", 

924 ) 

925 adjust_tick_scale(ax4, {"x": "Variance [{scale} ADU$^2$]", "y": "Number of pixels / {scale}"}) 

926 ax4.yaxis.set_label_position("right") 

927 ax4.yaxis.tick_right() 

928 ax4.axvline( 

929 background_mean_variance_ADU, 

930 color=sky_variance_color, 

931 linestyle="--", 

932 linewidth=1, 

933 alpha=0.9, 

934 label="Average sky variance\nacross all amps", 

935 ) 

936 

937 # Use colored arrowheads to mark true amp variances. 

938 sorted_vars = sorted(amp_background_variance_ADU_list) 

939 count = {v: 0 for v in sorted_vars} 

940 for i, (x, name, gain, color) in enumerate( 

941 zip( 

942 amp_background_variance_ADU_list, 

943 self.amp_name_list_simplified, 

944 self.amp_gain_list, 

945 amp_colors, 

946 ) 

947 ): 

948 arrowhead = arrowheads_ud[int(gain < average_gain)] 

949 arrowhead_text = ax4.annotate( 

950 arrowhead, 

951 xy=(x, 0), 

952 xycoords=("data", "axes fraction"), 

953 textcoords="offset points", 

954 xytext=(0, 0), 

955 ha="center", 

956 va="bottom", 

957 fontsize=6.5, 

958 color=color, 

959 clip_on=False, 

960 alpha=0.85, 

961 path_effects=outline_effect(1.5), 

962 ) 

963 if i == 0: 

964 # Draw the canvas once to make sure the renderer is active. 

965 fig.canvas.draw() 

966 # Get the bounding box of the text annotation in axes 

967 # fraction. 

968 bbox_axes = arrowhead_text.get_window_extent().transformed(ax4.transAxes.inverted()) 

969 # Get the height of the text annotation in axes fraction. 

970 height = bbox_axes.height 

971 # Increment the arrowhead y positions to avoid overlap. 

972 var_idxs = np.where(sorted_vars == x)[0] 

973 if len(var_idxs) > 1: 

974 q = count[x] 

975 count[x] += 1 

976 else: 

977 q = 0 

978 arrowhead_text.xy = (x, var_idxs[q] * height) 

979 

980 # Create a proxy artist for the legend since annotations are 

981 # not shown in the legend. 

982 label = "True variance of" if i == 0 else "$\u21AA$" 

983 ax4.scatter( 

984 [], 

985 [], 

986 color=color, 

987 marker=arrowhead, 

988 s=15, 

989 label=f"{label} {name}", 

990 alpha=0.85, 

991 path_effects=outline_effect(1.5), 

992 ) 

993 

994 # Group the legend handles and label them. 

995 adjust_legend_with_groups( 

996 ax4, 

997 [(0, 1), 2, 3, *range(4, 4 + len(self.amp_name_list_simplified))], 

998 colors="match", 

999 handlelength=1.9, 

1000 ) 

1001 

1002 # Align the histogram (bottom right panel) with the colorbar of the 

1003 # corrected variance plane (top right panel) for aesthetic reasons. 

1004 pos2 = ax2.get_position() 

1005 pos4 = ax4.get_position() 

1006 fig.canvas.draw() # Render to ensure accurate colorbar width 

1007 cbar_width = cbar.ax.get_position().width 

1008 ax4.set_position([pos2.x0, pos4.y0, pos2.width + cbar_width, pos4.height]) 

1009 

1010 # Increase all axes spines' linewidth by 20% for a bolder look. 

1011 for ax in fig.get_axes(): 

1012 for spine in ax.spines.values(): 

1013 spine.set_linewidth(spine.get_linewidth() * 1.2) 

1014 

1015 # Save the figure. 

1016 filename = f"variance_plane_gain{average_gain}_sigma{gain_sigma_factor}_sky{sky_level}.png" 

1017 fig.savefig(filename, dpi=300) 

1018 print(f"Saved plot of variance plane before and after correction in {filename}") 

1019 

1020 

1021class TestMemory(lsst.utils.tests.MemoryTestCase): 

1022 pass 

1023 

1024 

1025def setup_module(module): 

1026 lsst.utils.tests.init() 

1027 

1028 

1029if __name__ == "__main__": 1029 ↛ 1030line 1029 didn't jump to line 1030, because the condition on line 1029 was never true

1030 lsst.utils.tests.init() 

1031 unittest.main()