Coverage for tests / test_variance_plane.py: 8%

328 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-30 09:00 +0000

1# This file is part of ip_isr. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21# 

22"""The utility function under test is part of the `meas_algorithms` package. 

23This unit test has been relocated here to avoid circular dependencies. 

24""" 

25 

26import re 

27import unittest 

28from contextlib import nullcontext 

29 

30import galsim 

31import lsst.utils.tests 

32import matplotlib.colors as mcolors 

33import matplotlib.pyplot as plt 

34import numpy as np 

35from lsst.ip.isr.isrMock import IsrMock 

36from lsst.meas.algorithms import remove_signal_from_variance 

37from lsst.utils.tests import methodParametersProduct 

38from matplotlib.legend_handler import HandlerTuple 

39from matplotlib.patheffects import withStroke 

40from matplotlib.ticker import FixedLocator, FuncFormatter 

41 

42# Set to True to save the plot of the variance plane before and after 

43# correction for a representative test case. 

44SAVE_PLOT = False 

45 

46 

47def outline_effect(lw, alpha=0.8): 

48 """Generate a path effect for enhanced text visibility. 

49 

50 Parameters 

51 ---------- 

52 lw : `float` 

53 Line width of the outline. 

54 alpha : `float`, optional 

55 Transparency of the outline. 

56 

57 Returns 

58 ------- 

59 `list` of `matplotlib.patheffects.withStroke` 

60 A list containing the path effect. 

61 """ 

62 return [withStroke(linewidth=lw, foreground="white", alpha=alpha)] 

63 

64 

65class CustomHandler(HandlerTuple): 

66 """Custom handler for handling grouped items in the legend.""" 

67 

68 def create_artists(self, *args): 

69 artists = super().create_artists(*args) 

70 for a in artists: 

71 a.set_transform(args[-1]) 

72 return artists 

73 

74 

75def get_valid_color(handle): 

76 """Extracts a valid color from a Matplotlib handle. 

77 

78 Parameters 

79 ---------- 

80 handle : `matplotlib.artist.Artist` 

81 The handle from which to extract the color. 

82 

83 Returns 

84 ------- 

85 color : `str` or `tuple` 

86 The color extracted from the handle, or 'default' if no valid color is 

87 found. 

88 """ 

89 for attr in ["get_facecolor", "get_edgecolor", "get_color"]: 

90 if hasattr(handle, attr): 

91 color = getattr(handle, attr)() 

92 # If the handle is a collection, use the first color. 

93 if isinstance(color, np.ndarray) and color.shape[0] > 0: 

94 color = color[0] 

95 # If the color is RGBA with alpha = 0, continue the search. 

96 if len(color) == 4 and color[3] == 0: 

97 continue 

98 return color 

99 return "default" # If no valid color is found 

100 

101 

102def get_emptier_side(ax): 

103 """Analyze a matplotlib Axes object to determine which side (left or right) 

104 has more whitespace, considering cases where artists' bounding boxes span 

105 both sides of the midpoint. 

106 

107 Parameters 

108 ---------- 

109 ax : `~matplotlib.axes.Axes` 

110 The Axes object to analyze. 

111 

112 Returns 

113 ------- 

114 more_whitespace_side : `str` 

115 'left' if the left side has more whitespace, or 'right' if the right 

116 side does. 

117 """ 

118 # Get the total plotting area's midpoint on the x-axis. 

119 xlim = ax.get_xlim() 

120 midpoint = sum(xlim) / 2 

121 

122 # Initialize areas as zero. 

123 left_area, right_area = 0, 0 

124 

125 # Loop through all children (artists) in the Axes. 

126 for artist in ax.get_children(): 

127 # Skip if artist is invisible or lacks a bounding box. 

128 if not artist.get_visible() or not hasattr(artist, "get_window_extent"): 

129 continue 

130 bbox = artist.get_window_extent().transformed(ax.figure.dpi_scale_trans.inverted()) 

131 # Check if the artist's bounding box spans the midpoint. 

132 if bbox.x0 < midpoint < bbox.x1: 

133 # Calculate the proportion of the bbox on each side of the 

134 # midpoint. 

135 left_proportion = (midpoint - bbox.x0) / bbox.width 

136 right_proportion = 1 - left_proportion 

137 # Adjust area calculations for both sides. 

138 left_area += bbox.width * bbox.height * left_proportion 

139 right_area += bbox.width * bbox.height * right_proportion 

140 elif bbox.x0 + bbox.width / 2 < midpoint: 

141 # Entirely on the left. 

142 left_area += bbox.width * bbox.height 

143 else: 

144 # Entirely on the right. 

145 right_area += bbox.width * bbox.height 

146 

147 # Determine which side has more whitespace by comparing occupied areas. 

148 return "left" if left_area <= right_area else "right" 

149 

150 

151def adjust_legend_with_groups(ax, combine_groups, colors="default", yloc="upper", **kwargs): 

152 """Adjusts the legend of a given Axes object by combining specified handles 

153 based on provided groups, setting the marker location and text alignment 

154 based on the inferable emptier side, and optionally setting the text color 

155 of legend entries to a provided list or inferring colors from the handles. 

156 Additionally, allows specifying the vertical location of the legend within 

157 the plot. 

158 

159 Parameters 

160 ---------- 

161 ax : `~matplotlib.axes.Axes` 

162 The Axes object for which to adjust the legend. 

163 combine_groups : `list` of `int` or iterable of `int` 

164 A list that can contain a mix of individual integers and/or iterables 

165 (lists, tuples, or sets) of integers. An individual integer specifies 

166 the index of a single legend entry. An iterable of integers specifies 

167 a group of indices to be combined into a single legend entry. 

168 colors : `list` of `str` or `tuple`, or `str`, optional 

169 Specifies the colors for the legend entries. This parameter can be: 

170 - A list of color specifications, where each element is a string (for 

171 named colors or hex values) or a tuple (for RGB or RGBA values). This 

172 list explicitly assigns colors to each legend entry post-combination. 

173 - A single string value: 

174 - "match": Colors are inferred from the properties of the first 

175 handle in each group that corresponds to a non-white-space label. 

176 This aims to match the legend text color with the color of the 

177 plotted data. 

178 - "default": The function does not alter the default colors 

179 assigned by Matplotlib, preserving the automatic color assignment 

180 for all legend entries. 

181 yloc : `str`, optional 

182 The vertical location of the legend within the Axes. Valid options are 

183 'upper', 'lower', or 'middle'. This parameter is combined with the 

184 inferable emptier side ('left' or 'right') to determine the legend's 

185 placement. For example, 'upper right' or 'lower left'. 

186 **kwargs : 

187 Keyword arguments forwarded to the ``ax.legend`` function. 

188 """ 

189 

190 handles, labels = ax.get_legend_handles_labels() 

191 new_handles = [] 

192 new_labels = [] 

193 

194 if colors == "match": 

195 colors = [] 

196 infer_colors = True 

197 else: 

198 infer_colors = False 

199 

200 for group in combine_groups: 

201 # Assume the first non-white-space label represents the group. If no 

202 # such label is found, just use the first label in the group which is 

203 # in fact empty. 

204 if isinstance(group, (list, tuple, set)): 

205 group = list(group) # Just in case 

206 label_index = next((i for i in group if labels[i].strip()), group[0]) 

207 combined_handle = tuple(handles[i] for i in group) 

208 combined_label = labels[label_index] 

209 elif isinstance(group, int): 

210 label_index = group 

211 combined_handle = handles[group] 

212 combined_label = labels[group] 

213 else: 

214 raise ValueError("Invalid value in 'combine_groups'") 

215 new_handles.append(combined_handle) 

216 new_labels.append(combined_label) 

217 if infer_colors: 

218 # Attempt to infer color from the representative handle in the 

219 # group. 

220 handle = handles[label_index] 

221 color = get_valid_color(handle) 

222 colors.append(color) 

223 

224 # Determine the emptier side to decide legend and text alignment. 

225 emptier_side = get_emptier_side(ax) 

226 markerfirst = emptier_side != "right" 

227 

228 # Create the legend with custom adjustments. 

229 legend = ax.legend( 

230 new_handles, 

231 new_labels, 

232 handler_map={tuple: CustomHandler()}, 

233 loc=f"{yloc} {emptier_side}", 

234 fontsize=8, 

235 frameon=False, 

236 markerfirst=markerfirst, 

237 **kwargs, 

238 ) 

239 

240 # Right- or left-align the legend text based on the emptier side. 

241 for text in legend.get_texts(): 

242 text.set_ha(emptier_side) 

243 

244 # Set legend text colors if necessary. 

245 if colors != "default": 

246 for text, color in zip(legend.get_texts(), colors): 

247 if not (isinstance(color, str) and color == "default"): 

248 text.set_color(color) 

249 

250 

251def adjust_tick_scale(ax, axis_label_templates): 

252 """Scales down tick labels to make them more readable and updates axis 

253 labels accordingly. 

254 

255 Calculates a power of 10 scale factor (common divisor) to reduce the 

256 magnitude of tick labels. It automatically determines which axes to adjust 

257 based on the provided axis label templates, which should include `{scale}` 

258 for inserting the scale factor dynamically. 

259 

260 Parameters 

261 ---------- 

262 ax : `~matplotlib.axes.Axes` 

263 The Axes object to modify. 

264 axis_label_templates : `dict` 

265 Templates for axis labels, including '{scale}' for scale factor 

266 insertion. Keys should be one or more of axes names ('x', 'y', 'z') and 

267 values should be the corresponding label templates. 

268 """ 

269 

270 def trailing_zeros(n): 

271 """Determines the number of trailing zeros in a number.""" 

272 return len(n := str(int(float(n))) if float(n).is_integer() else n) - len(n.rstrip("0")) 

273 

274 def format_tick(val, pos, divisor): 

275 """Formats tick labels using the determined divisor.""" 

276 return str(int(val / divisor)) if (val / divisor).is_integer() else str(val / divisor) 

277 

278 # Iterate through the specified axes and adjust their tick labels and axis 

279 # labels. 

280 for axis in axis_label_templates.keys(): 

281 # Gather current tick labels. 

282 labels = [label.get_text() for label in getattr(ax, f"get_{axis}ticklabels")()] 

283 

284 # Calculate the power of 10 divisor based on the minimum number of 

285 # trailing zeros in the tick labels. 

286 divisor = 10 ** min(trailing_zeros(label) for label in labels if float(label) != 0) 

287 

288 # Set a formatter for the axis ticks that scales them according to the 

289 # common divisor. 

290 getattr(ax, f"{axis}axis").set_major_formatter( 

291 FuncFormatter(lambda val, pos: format_tick(val, pos, divisor)) 

292 ) 

293 

294 # Ensure the tick positions remain unchanged despite the new 

295 # formatting. 

296 getattr(ax, f"{axis}axis").set_major_locator(FixedLocator(getattr(ax, f"get_{axis}ticks")())) 

297 

298 # Prepare 'scale', empty if divisor <= 1. 

299 scale = f"{int(divisor)}" if divisor > 1 else "" 

300 

301 # Fetch the corresponding label template for the axis. 

302 label_template = axis_label_templates[axis] 

303 

304 # If 'scale' is empty, remove whitespace around "{scale}" in the 

305 # template. Also remove any trailing "/{scale}". 

306 if scale == "": 

307 label_template = re.sub(r"\s*{\s*scale\s*}\s*", "{scale}", label_template) 

308 label_template = label_template.replace("/{scale}", "") 

309 

310 # Always strip remaining whitespace from the template. 

311 label_template = label_template.strip() 

312 

313 # Set the formatted axis label. 

314 label_text = label_template.format(scale=scale) 

315 getattr(ax, f"set_{axis}label")(label_text, labelpad=8) 

316 

317 

318class VariancePlaneTestCase(lsst.utils.tests.TestCase): 

319 def setUp(self): 

320 # Testing with a single detector that has 8 amplifiers in a 4x2 

321 # configuration. Each amplifier measures 100x51 in dimensions. 

322 config = IsrMock.ConfigClass() 

323 config.isLsstLike = True 

324 config.doAddBias = False 

325 config.doAddDark = False 

326 config.doAddFlat = False 

327 config.doAddFringe = False 

328 config.doGenerateImage = True 

329 config.doGenerateData = True 

330 config.doGenerateAmpDict = True 

331 self.mock = IsrMock(config=config) 

332 self.pixel_scale = 0.2 # arcsec/pixel 

333 

334 # Set the random seed for reproducibility. 

335 random_seed = galsim.BaseDeviate(1905).raw() + 1 

336 self.np_rng = np.random.Generator(np.random.MT19937(4)) 

337 self.rng = galsim.BaseDeviate(random_seed) 

338 

339 # Get the exposure, detector, and amps from the mock. 

340 exposure = self.mock.getExposure() 

341 detector = exposure.getDetector() 

342 amps = detector.getAmplifiers() 

343 

344 # Set amp-related attributes for use in the test cases. 

345 self.num_amps = len(amps) 

346 self.amp_name_list = [amp.getName() for amp in amps] 

347 table = str.maketrans("", "", ":,") # Remove ':' and ',' from names 

348 self.amp_name_list_simplified = [name.translate(table) for name in self.amp_name_list] 

349 

350 # Get the bounding boxes for the exposure and amplifiers and convert 

351 # them to galsim bounds. 

352 exp_bbox = exposure.getBBox() 

353 image_bounds = galsim.BoundsI(exp_bbox.minX, exp_bbox.maxX, exp_bbox.minY, exp_bbox.maxY) 

354 self.amp_bbox_list = [amp.getBBox() for amp in amps] 

355 self.amp_bounds_list = [galsim.BoundsI(b.minX, b.maxX, b.minY, b.maxY) for b in self.amp_bbox_list] 

356 

357 # Create a raw galsim image to potentially draw the sources onto. The 

358 # exposure image that is passed to this method will be modified in 

359 # place but it won't be used. 

360 self.signal_free_raw_image = galsim.ImageF(exposure.image.array, bounds=image_bounds) 

361 self.raw_image = self.signal_free_raw_image.copy() 

362 

363 # Define parameters for a mix of source types, including extended 

364 # sources with assorted profiles as well as point sources simulated 

365 # with minimal half-light radii to resemble hot pixels 

366 # post-deconvolution. All flux values are given in electrons and 

367 # half-light radii in pixels. The goal is for each amplifier to 

368 # predominantly contain at least one source, enhancing the 

369 # representativeness of test conditions. 

370 source_params = [ 

371 {"type": "Sersic", "n": 3, "flux": 1.6e5, "half_light_radius": 3.5, "g1": -0.3, "g2": 0.2}, 

372 {"type": "Sersic", "n": 1, "flux": 9.3e5, "half_light_radius": 2.1, "g1": 0.25, "g2": 0.12}, 

373 {"type": "Sersic", "n": 4, "flux": 1.0e5, "half_light_radius": 1.1, "g1": 0.0, "g2": 0.0}, 

374 {"type": "Sersic", "n": 3, "flux": 1.1e6, "half_light_radius": 4.2, "g1": 0.0, "g2": 0.2}, 

375 {"type": "Sersic", "n": 5, "flux": 1.1e5, "half_light_radius": 3.6, "g1": 0.22, "g2": -0.05}, 

376 {"type": "Sersic", "n": 2, "flux": 4.3e5, "half_light_radius": 2.0, "g1": 0.0, "g2": 0.0}, 

377 {"type": "Sersic", "n": 6, "flux": 1.2e6, "half_light_radius": 11.0, "g1": -0.16, "g2": 0.7}, 

378 {"type": "Exponential", "flux": 1.3e6, "half_light_radius": 1.9, "g1": 0.3, "g2": -0.1}, 

379 {"type": "Exponential", "flux": 1.8e6, "half_light_radius": 5.0, "g1": 0.0, "g2": 0.14}, 

380 {"type": "Exponential", "flux": 6.6e6, "half_light_radius": 4.8, "g1": 0.26, "g2": 0.5}, 

381 {"type": "Exponential", "flux": 7.0e5, "half_light_radius": 3.1, "g1": -0.3, "g2": 0.0}, 

382 {"type": "DeVaucouleurs", "flux": 1.6e5, "half_light_radius": 3.5, "g1": 0.2, "g2": 0.4}, 

383 {"type": "DeVaucouleurs", "flux": 2.0e5, "half_light_radius": 1.6, "g1": -0.06, "g2": -0.2}, 

384 {"type": "DeVaucouleurs", "flux": 8.3e5, "half_light_radius": 5.1, "g1": 0.29, "g2": 0.0}, 

385 {"type": "DeVaucouleurs", "flux": 4.5e5, "half_light_radius": 2.5, "g1": 0.4, "g2": 0.3}, 

386 {"type": "DeVaucouleurs", "flux": 6.2e5, "half_light_radius": 4.9, "g1": -0.08, "g2": -0.01}, 

387 {"type": "Gaussian", "flux": 4.7e6, "half_light_radius": 2.5, "g1": 0.07, "g2": -0.35}, 

388 {"type": "Gaussian", "flux": 5.8e6, "half_light_radius": 3.1, "g1": 0.03, "g2": 0.4}, 

389 {"type": "Gaussian", "flux": 2.3e5, "half_light_radius": 0.5, "g1": 0.0, "g2": 0.0}, 

390 {"type": "Gaussian", "flux": 1.6e6, "half_light_radius": 3.0, "g1": 0.18, "g2": -0.29}, 

391 {"type": "Gaussian", "flux": 3.5e5, "half_light_radius": 4.6, "g1": 0.5, "g2": 0.35}, 

392 {"type": "Gaussian", "flux": 5.9e5, "half_light_radius": 9.5, "g1": 0.1, "g2": 0.55}, 

393 {"type": "Gaussian", "flux": 4.0e5, "half_light_radius": 1.0, "g1": 0.0, "g2": 0.0}, 

394 ] 

395 

396 # Mapping of profile types to their galsim constructors. 

397 profile_constructors = { 

398 "Sersic": galsim.Sersic, 

399 "Exponential": galsim.Exponential, 

400 "DeVaucouleurs": galsim.DeVaucouleurs, 

401 "Gaussian": galsim.Gaussian, 

402 } 

403 

404 # Generate random positions within exposure bounds, avoiding edges by a 

405 # margin. 

406 margin_x, margin_y = 0.05 * exp_bbox.width, 0.05 * exp_bbox.height 

407 self.positions = self.np_rng.uniform( 

408 [exp_bbox.minX + margin_x, exp_bbox.minY + margin_y], 

409 [exp_bbox.maxX - margin_x, exp_bbox.maxY - margin_y], 

410 (len(source_params), 2), 

411 ).tolist() 

412 

413 # Loop over the sources and draw them onto the image cutout by cutout. 

414 for i, params in enumerate(source_params): 

415 # Dynamically get constructor and remove type from params. 

416 constructor = profile_constructors[params.pop("type")] 

417 

418 # Get shear parameters and remove them from params. 

419 g1, g2 = params.pop("g1"), params.pop("g2") 

420 

421 # The extent of the cutout should be large enough to contain the 

422 # entire object above the background level. Some empirical factor 

423 # is used to mitigate artifacts. 

424 half_extent = 10 * params["half_light_radius"] * (1 + 2 * np.sqrt(g1**2 + g2**2)) 

425 

426 # Pass the remaining params to the constructor and apply shear. 

427 galsim_object = constructor(**params).shear(galsim.Shear(g1=g1, g2=g2)) 

428 

429 # Retrieve the position of the object. 

430 x, y = self.positions[i] 

431 pos = galsim.PositionD(x, y) 

432 

433 # Get the bounds of the sub-image based on the object position. 

434 sub_image_bounds = galsim.BoundsI( 

435 *map(int, [x - half_extent, x + half_extent, y - half_extent, y + half_extent]) 

436 ) 

437 

438 # Identify the overlap region, which could be partially outside 

439 # the image bounds. 

440 sub_image_bounds = sub_image_bounds & self.raw_image.bounds 

441 

442 # Check that there is some overlap. 

443 assert sub_image_bounds.isDefined(), "No overlap with image bounds" 

444 

445 # Get the sub-image cutout. 

446 sub_image = self.raw_image[sub_image_bounds] 

447 

448 # Draw the object onto the image within the the sub-image bounds. 

449 galsim_object.drawImage( 

450 image=sub_image, 

451 offset=pos - sub_image.true_center, 

452 method="real_space", # Saves memory, usable w/o convolution 

453 add_to_image=True, # Add flux to existing image 

454 scale=self.pixel_scale, 

455 ) 

456 

457 def tearDown(self): 

458 del self.mock 

459 del self.raw_image 

460 del self.signal_free_raw_image 

461 

462 def buildExposure( 

463 self, 

464 average_gain, 

465 gain_sigma_factor, 

466 sky_level, 

467 add_signal=True, 

468 ): 

469 """Build and return an exposure with different types of simulated 

470 source profiles and a background sky level. It's intended for testing 

471 and analysis, providing a way to generate exposures with controlled 

472 conditions. 

473 

474 Parameters 

475 ---------- 

476 average_gain : `float` 

477 The average gain value of amplifiers in e-/ADU. 

478 gain_sigma_factor : float 

479 The standard deviation of the gain values as a factor of the 

480 ``average_gain``. 

481 sky_level : `float` 

482 The background sky level in e-/arcsec^2. 

483 add_signal : `bool`, optional 

484 Whether to add sources to the exposure. If set to False, the 

485 exposure will only contain background noise. 

486 

487 Returns 

488 ------- 

489 exposure : `~lsst.afw.image.Exposure` 

490 An exposure object with simulated sources and background. The units 

491 are in detector counts (ADU). 

492 """ 

493 

494 # Get the exposure from the mock. 

495 exposure = self.mock.getExposure() 

496 

497 # Convert the background sky level from e-/arcsec^2 to e-/pixel. 

498 self.background = sky_level * self.pixel_scale**2 

499 

500 # Generate random deviations from the average gain across amplifiers 

501 # and adjust them to ensure their sum equals zero. This reflects 

502 # real-world detectors, with amplifier gains normally distributed due 

503 # to manufacturing and operational variations. 

504 deviations = self.np_rng.normal(average_gain, gain_sigma_factor * average_gain, size=self.num_amps) 

505 deviations -= np.mean(deviations) 

506 

507 # Set the gain for amplifiers to be slightly different from each other 

508 # while averaging to `average_gain`. This is to test the 

509 # `average_across_amps` option in the `remove_signal_from_variance` 

510 # function. 

511 self.amp_gain_list = [average_gain + deviation for deviation in deviations] 

512 

513 # Add a constant background to the entire image (both in e-/pixel). 

514 if add_signal: 

515 image = self.raw_image + self.background 

516 else: 

517 image = self.signal_free_raw_image + self.background 

518 

519 # Add noise to the image which is in electrons. Note that we won't 

520 # specify a `sky_level` here to avoid double-counting it, as it's 

521 # already included in the image as the background. 

522 image.addNoise(galsim.PoissonNoise(self.rng)) 

523 

524 # Subtract off the background to get the sky-subtracted image. 

525 image -= self.background 

526 

527 # Adjust each amplifier's image segment by its respective gain. After 

528 # this step, the image will be in ADUs. 

529 for bounds, gain in zip(self.amp_bounds_list, self.amp_gain_list): 

530 image[bounds] /= gain 

531 

532 # We know that the exposure has already been modified in place, but 

533 # just to be extra sure, we'll set the exposure image explicitly. 

534 exposure.image.array = image.array 

535 

536 # Create a variance plane for the exposure while including signal as a 

537 # pollutant. Note that the exposure image is pre-adjusted for gain, 

538 # unlike 'self.background'. Thus, we divide the background by the 

539 # corresponding gain before adding it to the image. This leads to the 

540 # variance plane being in units of ADU^2. 

541 for bbox, gain in zip(self.amp_bbox_list, self.amp_gain_list): 

542 exposure.variance[bbox].array = ( 

543 (exposure.image[bbox].array + self.background / gain) / gain 

544 ).astype(np.float32) 

545 

546 return exposure 

547 

548 def test_no_signal_handling(self): 

549 """Test that the function does nearly nothing when given an image with 

550 no signal. 

551 """ 

552 # Create an exposure with no signal. 

553 exposure = self.buildExposure( 

554 average_gain=1.4, gain_sigma_factor=0.01, sky_level=4e6, add_signal=False 

555 ) 

556 # Remove the signal from the variance plane, if any. 

557 updated_variance = remove_signal_from_variance(exposure, in_place=False) 

558 # Check that the variance plane is nearly the same as the original. 

559 self.assertFloatsAlmostEqual(exposure.variance.array, updated_variance.array, rtol=0.013) 

560 

561 def test_in_place_handling(self): 

562 """Make sure the function is tested to handle in-place operations.""" 

563 # Create an exposure with signal. 

564 exposure = self.buildExposure( 

565 average_gain=1.4, gain_sigma_factor=0.01, sky_level=4e6, add_signal=True 

566 ) 

567 # Remove the signal from the variance plane. 

568 updated_variance = remove_signal_from_variance(exposure, in_place=True) 

569 # Retrieve the variance plane from the exposure and check that it is 

570 # identical to the returned variance plane. 

571 self.assertFloatsEqual(exposure.variance.array, updated_variance.array) 

572 

573 @methodParametersProduct( 

574 average_gain=[1.4, 1.7], 

575 predefined_gain_type=["average", "per-amp", None], 

576 gain_sigma_factor=[0, 0.008], 

577 sky_level=[2e6, 4e6], 

578 average_across_amps=[False, True], 

579 ) 

580 def test_variance_signal_removal( 

581 self, average_gain, predefined_gain_type, gain_sigma_factor, sky_level, average_across_amps 

582 ): 

583 exposure = self.buildExposure( 

584 average_gain=average_gain, 

585 gain_sigma_factor=gain_sigma_factor, 

586 sky_level=sky_level, 

587 add_signal=True, 

588 ) 

589 

590 # Save the original variance plane for comparison, assuming it has 

591 # Poisson contribution from the source signal. 

592 signal_polluted_variance = exposure.variance.clone() 

593 

594 # Check that the variance plane has no negative values. 

595 self.assertTrue( 

596 np.all(signal_polluted_variance.array >= 0), 

597 "Variance plane has negative values (pre correction)", 

598 ) 

599 

600 if predefined_gain_type == "average": 

601 predefined_gain = average_gain 

602 predefined_gains = None 

603 elif predefined_gain_type == "per-amp": 

604 predefined_gain = None 

605 predefined_gains = {name: gain for name, gain in zip(self.amp_name_list, self.amp_gain_list)} 

606 elif predefined_gain_type is None: 

607 # Allow the 'remove_signal_from_variance' function to estimate the 

608 # gain itself before it attempts to remove the signal from the 

609 # variance plane. 

610 predefined_gain = None 

611 predefined_gains = None 

612 

613 # Set the relative tolerance for the variance plane checks. 

614 if predefined_gain_type == "average" or (predefined_gain_type is None and average_across_amps): 

615 # Relax the tolerance if we are simply averaging across amps to 

616 # roughly estimate the overall gain. 

617 rtol = 0.018 

618 estimate_average_gain = True 

619 else: 

620 # Tighten tolerance for the 'predefined_gain_type' of 'per-amp' or 

621 # for a more accurate per-amp gain estimation strategy. 

622 rtol = 3e-7 

623 estimate_average_gain = False 

624 

625 # Remove the signal from the variance plane. 

626 signal_free_variance = remove_signal_from_variance( 

627 exposure, 

628 gain=predefined_gain, 

629 gains=predefined_gains, 

630 average_across_amps=average_across_amps, 

631 in_place=False, 

632 ) 

633 

634 # Check that the variance plane has been modified. 

635 self.assertFloatsNotEqual(signal_polluted_variance.array, signal_free_variance.array) 

636 

637 # Check that the corrected variance plane has no negative values. 

638 self.assertTrue( 

639 np.all(signal_free_variance.array >= 0), "Variance plane has negative values (post correction)" 

640 ) 

641 

642 for bbox, gain in zip(self.amp_bbox_list, self.amp_gain_list): 

643 # Calculate the true variance in theoretical terms. 

644 true_var_amp = self.background / gain**2 

645 # Pair each variance with the appropriate context manager before 

646 # looping through them. 

647 var_context_pairs = [ 

648 # For the signal-free variance, directly execute the checks. 

649 (signal_free_variance, nullcontext()), 

650 # For the signal-polluted variance, expect AssertionError 

651 # unless we are averaging across amps. 

652 ( 

653 signal_polluted_variance, 

654 nullcontext() if estimate_average_gain else self.assertRaises(AssertionError), 

655 ), 

656 ] 

657 for var, context_manager in var_context_pairs: 

658 # Extract the segment of the variance plane for the amplifier. 

659 var_amp = var[bbox] 

660 with context_manager: 

661 if var is signal_polluted_variance and estimate_average_gain: 

662 # Skip rigorous checks on the signal-polluted variance, 

663 # if we are averaging across amps. 

664 pass 

665 else: 

666 # Get the variance value at the first pixel of the 

667 # segment to compare with the rest of the pixels and 

668 # the true variance. 

669 v00 = var_amp.array[0, 0] 

670 # Assert that the variance plane is almost uniform 

671 # across the segment because the signal has been 

672 # removed from it and the background is constant. 

673 self.assertFloatsAlmostEqual(var_amp.array, v00, rtol=rtol) 

674 # Assert that the variance plane is almost equal to the 

675 # true variance across the segment. 

676 self.assertFloatsAlmostEqual(v00, true_var_amp, rtol=rtol) 

677 

678 if ( 

679 SAVE_PLOT 

680 and not average_across_amps 

681 and gain_sigma_factor in (0, 0.008) 

682 and sky_level == 4e6 

683 and average_gain == 1.7 

684 and predefined_gain_type is None 

685 ): 

686 fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(10, 8.5)) 

687 plt.subplots_adjust(wspace=0.17, hspace=0.17) 

688 colorbar_aspect = 12 

689 

690 amp_background_variance_ADU_list = [self.background / gain**2 for gain in self.amp_gain_list] 

691 amp_background_image_ADU_list = [self.background / gain for gain in self.amp_gain_list] 

692 # Calculate the mean value that corresponds to the background for 

693 # the variance plane, adjusting for the gain. 

694 background_mean_variance_ADU = np.mean( 

695 [self.background / gain**2 for gain in self.amp_gain_list] 

696 ) 

697 

698 # Extract the variance planes and the image from the exposure. 

699 arr1 = signal_polluted_variance.array # Variance with signal 

700 arr2 = signal_free_variance.array # Variance without signal 

701 exp_im = exposure.image.clone() # Clone of the image plane 

702 

703 # Incorporate the gain-adjusted background into the image plane to 

704 # enable combined visualization of sources with the background. 

705 for gain, bbox in zip(self.amp_gain_list, self.amp_bbox_list): 

706 exp_im[bbox].array += self.background / gain 

707 arr3 = exp_im.array 

708 

709 # Define colors visually distinct from each other for the subplots. 

710 original_variance_color = "#8A2BE2" # Periwinkle 

711 corrected_variance_color = "#618B3C" # Lush Forest Green 

712 sky_variance_color = "#c3423f" # Crimson Red 

713 amp_colors = [ 

714 "#1f77b4", # Muted Blue 

715 "#ff7f0e", # Vivid Orange 

716 "#2ca02c", # Kelly Green 

717 "#d62728", # Brick Red 

718 "#9467bd", # Soft Purple 

719 "#8B4513", # Saddle Brown 

720 "#e377c2", # Pale Violet Red 

721 "#202020", # Onyx 

722 ] 

723 arrowheads_lr = ["$\u25C0$", "$\u25B6$"] # Left- & right-pointing 

724 arrowheads_ud = ["$\u25B2$", "$\u25BC$"] # Up- & down-pointing 

725 

726 # Set titles for the subplots. 

727 ax1.set_title("Original variance plane", color=original_variance_color) 

728 ax2.set_title("Corrected variance plane", color=corrected_variance_color) 

729 ax3.set_title("Image + background ($\\mathit{uniform}$)") 

730 ax4.set_title("Histogram of variances") 

731 

732 # Collect all vertical and horizontal line positions to find the 

733 # amp boundaries. 

734 vlines, hlines = set(), set() 

735 for bbox in self.amp_bbox_list: 

736 # Adjst by 0.5 for merging of lines at the boundaries. 

737 vlines.update({bbox.minX - 0.5, bbox.maxX + 0.5}) 

738 hlines.update({bbox.minY - 0.5, bbox.maxY + 0.5}) 

739 

740 # Filter lines at the edges of the overall image bbox. 

741 image_bbox = exposure.getBBox() 

742 vlines = {x for x in vlines if image_bbox.minX < x < image_bbox.maxX} 

743 hlines = {y for y in hlines if image_bbox.minY < y < image_bbox.maxY} 

744 

745 # Plot image and variance planes. 

746 for plane, arr, ax in zip( 

747 ("variance", "variance_corrected", "image"), (arr1, arr2, arr3), (ax1, ax2, ax3) 

748 ): 

749 # We skip 'variance_corrected' in the loop below because we use 

750 # the same normalization and colormap as 'variance' for it. 

751 if plane in ["variance", "image"]: 

752 # Get the normalization. 

753 vmin, vmax = arr.min(), arr.max() 

754 norm = mcolors.Normalize(vmin=vmin, vmax=vmax) 

755 

756 # Get the thresholds corresponding to per-amp backgrounds 

757 # and their positions in the normalized color scale. 

758 thresholds = ( 

759 amp_background_variance_ADU_list 

760 if plane.startswith("variance") 

761 else amp_background_image_ADU_list 

762 ) 

763 threshold_positions = [norm(t) for t in thresholds] 

764 threshold = np.mean(thresholds) 

765 threshold_position = np.mean(threshold_positions) 

766 

767 # Create a custom colormap with two distinct colors for the 

768 # sky and source contributions. 

769 border = (threshold - vmin) / (vmax - vmin) 

770 colors1 = plt.cm.Purples_r(np.linspace(0, 1, int(border * 256))) 

771 colors2 = plt.cm.Greens(np.linspace(0, 1, int((1 - border) * 256))) 

772 colors = np.vstack((colors1, colors2)) 

773 cmap = mcolors.LinearSegmentedColormap.from_list("cmap", colors) 

774 

775 # Plot the array with the custom colormap and normalization. 

776 im = ax.imshow(arr, cmap=cmap, norm=norm) 

777 

778 # Add colorbars to the plot. 

779 cbar = fig.colorbar(im, aspect=colorbar_aspect, pad=0) 

780 

781 # Change the number of ticks on the colorbar for better 

782 # spacing. Needs to be done before modifying the tick labels. 

783 cbar.ax.locator_params(nbins=7) 

784 

785 # Enhance readability by scaling down colorbar tick labels. 

786 unit = "ADU$^2$" if plane.startswith("variance") else "ADU" 

787 adjust_tick_scale(cbar.ax, {"y": f"Value [{{scale}} {unit}]"}) 

788 

789 # Mark min and max per-amp thresholds with dotted lines on the 

790 # colorbar. 

791 for tp in [min(thresholds), max(thresholds)]: 

792 cbar.ax.axhline(tp, color="white", linestyle="-", linewidth=1.5, alpha=0.4) 

793 cbar.ax.axhline(tp, color=sky_variance_color, linestyle=":", linewidth=1.5, alpha=0.9) 

794 

795 # Mark mean threshold with facing arrowheads on the colorbar. 

796 cbar.ax.annotate( 

797 arrowheads_lr[1], # Right-pointing arrowhead 

798 xy=(0, threshold_position), 

799 xycoords="axes fraction", 

800 textcoords="offset points", 

801 xytext=(0, 0), 

802 ha="left", 

803 va="center", 

804 fontsize=6, 

805 color=sky_variance_color, 

806 clip_on=False, 

807 alpha=0.9, 

808 ) 

809 cbar.ax.annotate( 

810 arrowheads_lr[0], # Left-pointing arrowhead 

811 xy=(1, threshold_position), 

812 xycoords="axes fraction", 

813 textcoords="offset points", 

814 xytext=(0, 0), 

815 ha="right", 

816 va="center", 

817 fontsize=6, 

818 color=sky_variance_color, 

819 clip_on=False, 

820 alpha=0.9, 

821 ) 

822 

823 # Add text inside the colorbar to label the average threshold 

824 # position. 

825 sky_level_text = "$\u27E8$" + "Sky" + "$\u27E9$" # <Sky> 

826 sky_level_text_artist = cbar.ax.text( 

827 0.5, 

828 threshold_position, 

829 sky_level_text, 

830 va="center", 

831 ha="center", 

832 transform=cbar.ax.transAxes, 

833 fontsize=8, 

834 color=sky_variance_color, 

835 rotation="vertical", 

836 alpha=0.9, 

837 path_effects=outline_effect(2), 

838 ) 

839 

840 # Setup renderer and transformation. 

841 renderer = fig.canvas.get_renderer() 

842 transform = cbar.ax.transAxes.inverted() 

843 

844 # Transform the bounding box and calculate adjustment for 

845 # 'sky_level_text_artist' for when it goes beyond the colorbar. 

846 sky_level_text_bbox = sky_level_text_artist.get_window_extent(renderer).transformed(transform) 

847 adjustment = 1.4 * sky_level_text_bbox.height / 2 

848 

849 if sky_level_text_bbox.ymin < 0: 

850 sky_level_text_artist.set_y(adjustment) 

851 elif sky_level_text_bbox.ymax > 1: 

852 sky_level_text_artist.set_y(1 - adjustment) 

853 

854 # Draw amp boundaries as vertical and/or horizontal lines. 

855 line_color = "white" if np.mean(norm(arr)) > 0.5 else "#808080" 

856 for x in vlines: 

857 ax.axvline(x=x, color=line_color, linestyle="--", linewidth=1, alpha=0.7) 

858 for y in hlines: 

859 ax.axhline(y=y, color=line_color, linestyle="--", linewidth=1, alpha=0.7) 

860 # Hide all x and y tick marks. 

861 ax.tick_params(axis="both", which="both", bottom=False, top=False, left=False, right=False) 

862 # Hide all x and y tick labels. 

863 ax.set_xticklabels([]) 

864 ax.set_yticklabels([]) 

865 

866 # Additional ax2 annotations: 

867 # Labels amplifiers with their respective gains for a visual check. 

868 for bbox, name, gain, color in zip( 

869 self.amp_bbox_list, self.amp_name_list_simplified, self.amp_gain_list, amp_colors 

870 ): 

871 # Get the center of the bbox to label the gain value. 

872 bbox_center = (bbox.minX + bbox.maxX) / 2, (bbox.minY + bbox.maxY) / 2 

873 # Label the gain value at the center of each amplifier segment. 

874 ax2.text( 

875 *bbox_center, 

876 f"gain$_{{\\rm \\, {name} \\,}}$: {gain:.3f}", 

877 fontsize=9, 

878 color=color, 

879 alpha=0.95, 

880 ha="center", 

881 va="center", 

882 path_effects=outline_effect(2), 

883 ) 

884 

885 # Additional ax3 annotations: 

886 # Label sources with numbers on the image plane. 

887 for i, pos in enumerate(self.positions, start=1): 

888 ax3.text( 

889 *pos, 

890 f"{i}", 

891 fontsize=7, 

892 color=sky_variance_color, 

893 path_effects=outline_effect(1.5), 

894 alpha=0.9, 

895 ) 

896 

897 # Now we use ax4 to plot the histograms of the variance planes for 

898 # comparison. 

899 # Plot the histogram of the original variance plane. 

900 hist_values, bins, _ = ax4.hist( 

901 arr1.flatten(), 

902 bins=80, 

903 histtype="step", 

904 color=original_variance_color, 

905 alpha=0.9, 

906 label="Original variance", 

907 ) 

908 # Fill the area under the step. 

909 ax4.fill_between( 

910 bins[:-1], 

911 hist_values, 

912 step="post", 

913 color=original_variance_color, 

914 alpha=0.09, 

915 hatch="/////", 

916 label=" ", 

917 ) 

918 # Plot the histogram of the corrected variance plane. 

919 ax4.hist( 

920 arr2.flatten(), 

921 bins=80, 

922 histtype="bar", 

923 color=corrected_variance_color, 

924 alpha=0.9, 

925 label="Corrected variance", 

926 ) 

927 adjust_tick_scale(ax4, {"x": "Variance [{scale} ADU$^2$]", "y": "Number of pixels / {scale}"}) 

928 ax4.yaxis.set_label_position("right") 

929 ax4.yaxis.tick_right() 

930 ax4.axvline( 

931 background_mean_variance_ADU, 

932 color=sky_variance_color, 

933 linestyle="--", 

934 linewidth=1, 

935 alpha=0.9, 

936 label="Average sky variance\nacross all amps", 

937 ) 

938 

939 # Use colored arrowheads to mark true amp variances. 

940 sorted_vars = sorted(amp_background_variance_ADU_list) 

941 count = {v: 0 for v in sorted_vars} 

942 for i, (x, name, gain, color) in enumerate( 

943 zip( 

944 amp_background_variance_ADU_list, 

945 self.amp_name_list_simplified, 

946 self.amp_gain_list, 

947 amp_colors, 

948 ) 

949 ): 

950 arrowhead = arrowheads_ud[int(gain < average_gain)] 

951 arrowhead_text = ax4.annotate( 

952 arrowhead, 

953 xy=(x, 0), 

954 xycoords=("data", "axes fraction"), 

955 textcoords="offset points", 

956 xytext=(0, 0), 

957 ha="center", 

958 va="bottom", 

959 fontsize=6.5, 

960 color=color, 

961 clip_on=False, 

962 alpha=0.85, 

963 path_effects=outline_effect(1.5), 

964 ) 

965 if i == 0: 

966 # Draw the canvas once to make sure the renderer is active. 

967 fig.canvas.draw() 

968 # Get the bounding box of the text annotation in axes 

969 # fraction. 

970 bbox_axes = arrowhead_text.get_window_extent().transformed(ax4.transAxes.inverted()) 

971 # Get the height of the text annotation in axes fraction. 

972 height = bbox_axes.height 

973 # Increment the arrowhead y positions to avoid overlap. 

974 var_idxs = np.where(sorted_vars == x)[0] 

975 if len(var_idxs) > 1: 

976 q = count[x] 

977 count[x] += 1 

978 else: 

979 q = 0 

980 arrowhead_text.xy = (x, var_idxs[q] * height) 

981 

982 # Create a proxy artist for the legend since annotations are 

983 # not shown in the legend. 

984 label = "True variance of" if i == 0 else "$\u21AA$" 

985 ax4.scatter( 

986 [], 

987 [], 

988 color=color, 

989 marker=arrowhead, 

990 s=15, 

991 label=f"{label} {name}", 

992 alpha=0.85, 

993 path_effects=outline_effect(1.5), 

994 ) 

995 

996 # Group the legend handles and label them. 

997 adjust_legend_with_groups( 

998 ax4, 

999 [(0, 1), 2, 3, *range(4, 4 + len(self.amp_name_list_simplified))], 

1000 colors="match", 

1001 handlelength=1.9, 

1002 ) 

1003 

1004 # Align the histogram (bottom right panel) with the colorbar of the 

1005 # corrected variance plane (top right panel) for aesthetic reasons. 

1006 pos2 = ax2.get_position() 

1007 pos4 = ax4.get_position() 

1008 fig.canvas.draw() # Render to ensure accurate colorbar width 

1009 cbar_width = cbar.ax.get_position().width 

1010 ax4.set_position([pos2.x0, pos4.y0, pos2.width + cbar_width, pos4.height]) 

1011 

1012 # Increase all axes spines' linewidth by 20% for a bolder look. 

1013 for ax in fig.get_axes(): 

1014 for spine in ax.spines.values(): 

1015 spine.set_linewidth(spine.get_linewidth() * 1.2) 

1016 

1017 # Save the figure. 

1018 filename = f"variance_plane_gain{average_gain}_sigma{gain_sigma_factor}_sky{sky_level}.png" 

1019 fig.savefig(filename, dpi=300) 

1020 print(f"Saved plot of variance plane before and after correction in {filename}") 

1021 

1022 

1023class TestMemory(lsst.utils.tests.MemoryTestCase): 

1024 pass 

1025 

1026 

1027def setup_module(module): 

1028 lsst.utils.tests.init() 

1029 

1030 

1031if __name__ == "__main__": 1031 ↛ 1032line 1031 didn't jump to line 1032 because the condition on line 1031 was never true

1032 lsst.utils.tests.init() 

1033 unittest.main()