Coverage for python / lsst / meas / algorithms / computeRoughPsfShapelets.py: 0%

260 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-24 08:25 +0000

1# This file is part of meas_algorithms. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ComputeRoughPsfShapeletsTask", "ComputeRoughPsfShapeletsConfig"] 

25 

26import itertools 

27from typing import TYPE_CHECKING, Any, Literal 

28 

29import numpy as np 

30import scipy.signal 

31import scipy.stats 

32from sklearn.covariance import MinCovDet 

33from sklearn.neighbors import KernelDensity 

34 

35from lsst.afw.geom import ellipses 

36from lsst.afw.image import ImageD, ImageF, MaskedImageF 

37from lsst.afw.table import Point2DKey, QuadrupoleKey, Schema, SourceCatalog 

38from lsst.geom import Box2D 

39from lsst.pex.config import Config, Field, ListField 

40from lsst.pipe.base import AlgorithmError, Struct, Task 

41from lsst.shapelet import LAGUERRE, ShapeletFunction, computeOffset 

42 

43from ._algorithmsLib import SpanSetMoments 

44 

45if TYPE_CHECKING: 

46 import matplotlib.axes 

47 import matplotlib.figure 

48 import matplotlib.image 

49 import matplotlib.patches 

50 

51 

52class NoStarsForShapeletsError(AlgorithmError): 

53 """Exception raised when ComputeRawPsfMomentsTask fails to find any usable 

54 stars. 

55 """ 

56 

57 @property 

58 def metadata(self) -> dict[str, Any]: 

59 return {} 

60 

61 

62class ComputeRoughPsfShapeletsConfig(Config): 

63 bad_mask_planes = ListField( 

64 "Mask planes to identify pixels to drop from the calculation.", 

65 dtype=str, 

66 default=["SAT", "SUSPECT", "INTRP"], 

67 ) 

68 bad_pixel_max_fraction = Field( 

69 "Maximum fraction of a footprint's pixels that can be bad (according " 

70 "to bad_mask_planes) before that footprint is fully dropped.", 

71 dtype=float, 

72 default=0.25, 

73 ) 

74 bad_pixel_exclusion_radius = Field( 

75 "If a bad pixel (according to bad_mask_planes) falls within this " 

76 "radius of the unweighted centroid of a footprint, drop that footprint " 

77 "entirely.", 

78 dtype=float, 

79 default=2.0, 

80 ) 

81 max_footprint_area = Field( 

82 "Footprints with a pixel count larger than this threshold are dropped before computing moments.", 

83 dtype=int, 

84 default=10000, 

85 ) 

86 min_snr = Field( 

87 "Mininum flux S/N for inclusion in the star sample.", 

88 dtype=float, 

89 default=50.0, 

90 ) 

91 max_radius_factor = Field( 

92 "Maximium multiple of the mode of the radius distribution for inclusion in the star sample.", 

93 dtype=float, 

94 default=2.0, 

95 ) 

96 max_shape_distance = Field( 

97 "Maximum Mahalanobis distance (distance from the center of the shape " 

98 "distribution in elliptical sigma units) to select a star from the candidate " 

99 "sample, comparing the shape of that source vs. a robust estimate of the " 

100 "distribution of the shapes of all sources.", 

101 dtype=float, 

102 default=3.0, 

103 ) 

104 min_n_stars = Field( 

105 "Minimum number of stars to select. The S/N, radius, and shape distance thresholds are " 

106 "relaxed as needed to meet this target.", 

107 dtype=int, 

108 default=10, 

109 ) 

110 max_n_stars = Field( 

111 "Maximum number of stars to select. High shape-distance sources " 

112 "are dropped as needed to meet this target.", 

113 dtype=int, 

114 default=20, 

115 ) 

116 logarithmic_shapes = Field( 

117 "If True, transform the (xx, yy, xy) moments to conformal shear " 

118 "and log trace radius when selecting stars in order to map the ellipse " 

119 "parameter to a space with (-inf, inf) bounds on all quantities (but " 

120 "a less-linear relationship to the pixel data).", 

121 dtype=bool, 

122 default=False, 

123 ) 

124 radius_mode_min = Field( 

125 "Minimum radius in pixels at which to start searching for the first mode. " 

126 "This just needs be large enough to avoid unphysically small garbage (e.g. CRs).", 

127 dtype=float, 

128 default=1.0, 

129 ) 

130 radius_kde_bandwidth = Field( 

131 "Bandwidth of the Gaussian kernel density estimator used to find the " 

132 "first mode of the radius distribution.", 

133 dtype=float, 

134 default=1.0, 

135 ) 

136 shapelet_order = Field( 

137 "Order of the shapelet expansion fit to the stars.", 

138 dtype=int, 

139 default=4, 

140 ) 

141 shapelet_scale_factor = Field( 

142 "Scale factor to apply to the moments ellipse when computing the ellipse for the shapelet basis.", 

143 dtype=float, 

144 default=1.0, 

145 ) 

146 

147 def validate(self) -> None: 

148 if self.min_n_stars > self.max_n_stars: 

149 raise ValueError( 

150 f"min_n_stars={self.min_n_stars} is greater than max_n_stars={self.max_n_stars}." 

151 ) 

152 if self.shapelet_order < 0: 

153 raise ValueError(f"shapelet order {self.shapelet_order} must be nonnegative.") 

154 if self.shapelet_scale_factor <= 0.0: 

155 raise ValueError(f"shapelet scale factor {self.shapelet_scale_factor} must be positive.") 

156 

157 

158class ComputeRoughPsfShapeletsTask(Task): 

159 """A task that computes a rough shapelet expansion of the PSF from a set 

160 of high S/N detections. 

161 

162 Notes 

163 ----- 

164 This task is expected to be run early in single-epoch processing - just 

165 after background subtraction and an initial high S/N detection phase, and 

166 before any deblending or measurement - in order to identify out-of-focus 

167 or otherwise bad PSFs. 

168 

169 Given a background-subtracted `lsst.afw.image.MaskedImage`, an 

170 `lsst.afw.table.SourceCatalog` with footprints attached, and a random 

171 number generator seed, the `run` method will: 

172 

173 - Compute the *unweighted* 0th-2nd moments of every non-child source over 

174 the footprint (except certain configurable masked pixels). This is 

175 delegated to the `compute_raw_moments` method (which uses the C++ 

176 `SpanSetMoments` class for the pixel-level processing). Unweighted 

177 moments are used to avoid "latching onto" a small piece of PSF 

178 substructure, but can be much noiser than the Gaussian-weighed moments 

179 we usually use. 

180 

181 - Select a "candidate" sample of sources with successfully measured 

182 moments that satisfy a S/N cut and a radius cut (determined from the 

183 first mode of the radius distribution, via kernel density estimation), 

184 and then use a robust covariance estimator (`scikit_learn.MinCovDet`) 

185 to select presumed isolated stars that are close to the center of that 

186 distribution, in 3-parameter shape space. This is delegated to the 

187 `select_stars` method. 

188 

189 - Fit a single shapelet expansion to the selected stars. This is 

190 mostly delegated to the `SpanSetMoments.fit_shapelets` method. 

191 

192 The radial shapelet terms at 0th, 2nd, and 4th order are expected to form 

193 a space in which donut-shaped PSFs are well-separated from those with 

194 monotonic profiles. Other terms *may* be useful in identifying other 

195 kinds of undesirable PSF structure. 

196 """ 

197 

198 ConfigClass = ComputeRoughPsfShapeletsConfig 

199 _DefaultName = "computeRoughPsfShapelets" 

200 config: ComputeRoughPsfShapeletsConfig 

201 

202 def __init__( 

203 self, 

204 config: ComputeRoughPsfShapeletsConfig | None = None, 

205 *, 

206 schema: Schema, 

207 **kwargs: Any, 

208 ): 

209 super().__init__(config=config, **kwargs) 

210 self.schema = schema 

211 self._flux_key = schema.addField("RawPsfMoments_flux", type=float, doc="Unweighted zeroth moment.") 

212 self._flux_err_key = schema.addField( 

213 "RoughPsfShapelets_fluxErr", 

214 type=float, 

215 doc="Uncertainty on the unweighted zeroth moment.", 

216 ) 

217 self._center_key = Point2DKey.addFields( 

218 schema, "RoughPsfShapelets", "Center from unweighted first moments.", "pixels" 

219 ) 

220 self._shape_key = QuadrupoleKey.addFields( 

221 schema, "RoughPsfShapelets", "Shape from unweighted second moments." 

222 ) 

223 self._flag_key = schema.addField( 

224 "RoughPsfShapelets_flag", 

225 type="Flag", 

226 doc="Flag set if the raw PSF moments were not computed.", 

227 ) 

228 self._candidate_key = schema.addField( 

229 "RoughPsfShapelets_candidate", 

230 type="Flag", 

231 doc="Flag set if this source passed the radius_fraction cut (see configuration).", 

232 ) 

233 self._used_key = schema.addField( 

234 "RoughPsfShapelets_used", 

235 type="Flag", 

236 doc=( 

237 "Flag set if this source passed the radius_fraction and shape_distance cuts " 

238 "(see configuration) and was used to fit the shapelet expansion." 

239 ), 

240 ) 

241 

242 def run(self, *, masked_image: MaskedImageF, catalog: SourceCatalog, seed: int) -> Struct: 

243 """Compute raw moments, select stars, and fit a shapelet expansion to 

244 them. 

245 

246 Parameters 

247 ---------- 

248 masked_image 

249 Masked image to measure on. Must be background-subtracted. 

250 catalog 

251 Catalog of detections to extract footprints from and fill output 

252 columns of. Its schema must be a superset of ``self.schema``. 

253 seed 

254 A random-number generator seed, used for the robust covariance 

255 estimator. 

256 

257 Returns 

258 ------- 

259 `lsst.pipe.base.Struct` 

260 A struct of results containing: 

261 

262 - ``shapelet`` (`lsst.shapelet.ShapeletFunction`): A 

263 Gauss-Laguerre (polar shaplet) expansion of the PSF. 

264 - ``radial`` (`list` [`float`]): the purely radial coefficients 

265 of the shapelet expansion. 

266 - all attributes returned by the `select_stars` method. 

267 """ 

268 moments = self.compute_raw_moments(masked_image=masked_image, catalog=catalog) 

269 result = self.select_stars(catalog, seed=seed) 

270 star_moments = [moments[star_id] for star_id in result.star_ids] 

271 result.shapelet = SpanSetMoments.fit_shapelets( 

272 masked_image, 

273 star_moments, 

274 self.config.shapelet_order, 

275 self.config.shapelet_scale_factor, 

276 ) 

277 result.shapelet.getEllipse().setCore(result.mean_shape) 

278 result.shapelet.changeBasisType(LAGUERRE) 

279 result.radial = result.shapelet.getCoefficients()[ 

280 [computeOffset(i) for i in range(0, self.config.shapelet_order + 1, 2)] 

281 ] 

282 self.log.info("Rough PSF shapelet radial terms: %s.", result.radial) 

283 return result 

284 

285 def compute_raw_moments( 

286 self, *, masked_image: MaskedImageF, catalog: SourceCatalog 

287 ) -> dict[int, SpanSetMoments]: 

288 """Compute the unweighted moments of the footprints in a catalog. 

289 

290 Parameters 

291 ---------- 

292 masked_image 

293 Masked image to measure on. Must be background-subtracted. 

294 catalog 

295 Catalog of detections to extract footprints from and fill output 

296 columns of. Its schema must be a superset of ``self.schema``. 

297 

298 Returns 

299 ------- 

300 `dict` [`int`, `SpanSetMoments`] 

301 Objects used to construct and hold the unweighted moments and the 

302 pixel region used to computed them, keyed by source ID. 

303 """ 

304 bitmask = masked_image.mask.getPlaneBitMask(self.config.bad_mask_planes) 

305 all_moments: dict[int, SpanSetMoments] = {} 

306 for record in catalog: 

307 if record.getParent() != 0: 

308 record.set(self._flag_key, True) 

309 self.log.debug("Skipping child source %s", record.getId()) 

310 continue 

311 footprint_spans = record.getFootprint().getSpans() 

312 if footprint_spans.getArea() > self.config.max_footprint_area: 

313 record.set(self._flag_key, True) 

314 self.log.debug( 

315 "Skipping source %s with footprint area %d > %d.", 

316 record.getId(), 

317 footprint_spans.getArea(), 

318 self.config.max_footprint_area, 

319 ) 

320 continue 

321 moments = SpanSetMoments.compute( 

322 record.getFootprint().getSpans(), 

323 masked_image=masked_image, 

324 bad_bitmask=bitmask, 

325 bad_pixel_max_fraction=self.config.bad_pixel_max_fraction, 

326 bad_pixel_exclusion_radius=self.config.bad_pixel_exclusion_radius, 

327 ) 

328 record.set(self._flux_key, moments.flux) 

329 record.set(self._flux_err_key, moments.variance**0.5) 

330 record.set(self._center_key, moments.center) 

331 record.set(self._shape_key, moments.shape) 

332 record.set(self._flag_key, moments.any_flags_set) 

333 if not moments.any_flags_set: 

334 all_moments[record.getId()] = moments 

335 if all_moments: 

336 self.log.verbose( 

337 "Successfully measured raw moments for %d of %d sources.", 

338 len(all_moments), 

339 len(catalog), 

340 ) 

341 else: 

342 raise NoStarsForShapeletsError("No raw moments could be measured.") 

343 return all_moments 

344 

345 def select_stars(self, catalog: SourceCatalog, seed: int) -> Struct: 

346 """Select probable stars from the distribution of second moments. 

347 

348 Parameters 

349 ---------- 

350 catalog 

351 Catalog of detections to extract footprints from and fill output 

352 columns of. Its schema must be a superset of ``self.schema``. 

353 seed 

354 A random-number generator seed, used for the robust covariance 

355 estimator. 

356 

357 Returns 

358 ------- 

359 `lsst.pipe.base.Struct` 

360 A struct of results containing: 

361 

362 - ``star_ids`` (`numpy.ndarray`): the source IDs that are expected 

363 to be stars. 

364 - ``mean_shape`` (`lsst.afw.geom.ellipses.BaseCore`): the mean of 

365 the shape distribution. 

366 - ``shape_covariance`` (`numpy.ndarray`): the covariance of the 

367 distribution of shapes; a 3x3 matrix. This uses the same 

368 parameterization of the shapes as ``mean_shape``. 

369 - ``radius_cut`` (`float`): the indended radius cut (i.e. the 

370 mode of the radius distribution multipled by the 

371 ``radius_factor`` configuration option). 

372 - ``radius_kde`` (`sklearn.neighbors.KernelDensity`): kernel 

373 density estimator on the radius distribution, used to determine 

374 the radius cut. 

375 """ 

376 # Cut on flags and SNR first. 

377 indices = np.arange(len(catalog), dtype=int)[np.logical_not(catalog[self._flag_key])] 

378 indices = indices[ 

379 self._threshold_with_bounds( 

380 catalog[self._flux_key][indices] / catalog[self._flux_err_key][indices], 

381 threshold=self.config.min_snr, 

382 min_count=self.config.min_n_stars, 

383 max_count=len(catalog), 

384 name="S/N", 

385 kind=">", 

386 ) 

387 ] 

388 # Cut on radius next. 

389 radii = np.zeros(indices.size, dtype=np.float64) 

390 for n, index in enumerate(indices): 

391 record = catalog[index] 

392 shape = record.get(self._shape_key) 

393 radii[n] = shape.getTraceRadius() 

394 radius_mode, radius_kde = self._find_first_radius_mode(radii) 

395 radius_cut = self.config.max_radius_factor * radius_mode 

396 indices = indices[ 

397 self._threshold_with_bounds( 

398 radii, 

399 threshold=radius_cut, 

400 min_count=self.config.min_n_stars, 

401 max_count=len(catalog), 

402 name="radius", 

403 kind="<", 

404 ) 

405 ] 

406 shape_data = np.zeros((len(indices), 3), dtype=np.float64) 

407 for n, index in enumerate(indices): 

408 record = catalog[index] 

409 record.set(self._candidate_key, True) 

410 shape = record.get(self._shape_key) 

411 if self.config.logarithmic_shapes: 

412 shape = ellipses.SeparableConformalShearLogTraceRadius(shape) 

413 shape_data[n, :] = shape.getParameterVector() 

414 shape_dist = MinCovDet(random_state=seed).fit(shape_data) 

415 m_distances = shape_dist.mahalanobis(shape_data) 

416 indices = indices[ 

417 self._threshold_with_bounds( 

418 m_distances, 

419 threshold=self.config.max_shape_distance, 

420 min_count=self.config.min_n_stars, 

421 max_count=self.config.max_n_stars, 

422 name="Mahalanobis distance", 

423 kind="<", 

424 ) 

425 ] 

426 for index in indices: 

427 catalog[index].set(self._used_key, True) 

428 star_ids = catalog["id"][indices] 

429 if self.config.logarithmic_shapes: 

430 mean_shape = ellipses.SeparableConformalShearLogTraceRadius(shape_dist.location_) 

431 else: 

432 mean_shape = ellipses.Quadrupole(shape_dist.location_) 

433 return Struct( 

434 star_ids=star_ids, 

435 mean_shape=mean_shape, 

436 shape_covariance=shape_dist.covariance_, 

437 radius_cut=radius_cut, 

438 radius_kde=radius_kde, 

439 ) 

440 

441 def plot_selection( 

442 self, figure: matplotlib.figure.Figure, *, catalog: SourceCatalog, results: Struct 

443 ) -> None: 

444 """Create plots of the shape distribution space used to select stars. 

445 

446 Parameters 

447 ---------- 

448 figure 

449 Matplotlib figure to plot to. 

450 catalog 

451 Catalog of sources with columns populated by the `run` method (at 

452 least through the `select_stars` step). 

453 results 

454 Result struct returned by `run` or `select_stars`. 

455 """ 

456 from matplotlib.lines import Line2D 

457 

458 shape_data = np.zeros((len(catalog), 3), dtype=np.float64) 

459 radii = np.zeros(len(catalog), dtype=np.float64) 

460 for n, record in enumerate(catalog): 

461 if record[self._flag_key]: 

462 continue 

463 shape = record[self._shape_key] 

464 if self.config.logarithmic_shapes: 

465 shape = ellipses.SeparableConformalShearLogTraceRadius(shape) 

466 shape_data[n, :] = shape.getParameterVector() 

467 radii[n] = shape.getTraceRadius() 

468 used_mask = catalog[self._used_key] 

469 candidate_mask = np.logical_and(catalog[self._candidate_key], np.logical_not(used_mask)) 

470 measured_mask = np.logical_and( 

471 np.logical_not(catalog[self._flag_key]), np.logical_not(catalog[self._candidate_key]) 

472 ) 

473 # Set up the axes. 

474 axes = figure.subplot_mosaic( 

475 [ 

476 ["radius", "radius", "radius"], 

477 ["hist0", ".", "."], 

478 ["scatter01", "hist1", "."], 

479 ["scatter02", "scatter12", "hist2"], 

480 ], 

481 gridspec_kw=dict(bottom=0.2, top=1.0), 

482 ) 

483 axes["scatter01"].sharex(axes["hist0"]) 

484 axes["scatter02"].sharex(axes["hist0"]) 

485 axes["scatter12"].sharex(axes["hist1"]) 

486 axes["scatter12"].sharey(axes["scatter02"]) 

487 for tk in itertools.chain( 

488 axes["hist0"].get_xticklabels(), 

489 axes["scatter01"].get_xticklabels(), 

490 axes["hist1"].get_xticklabels(), 

491 axes["scatter12"].get_yticklabels(), 

492 ): 

493 tk.set_visible(False) 

494 # Move hist y axes to the outside. 

495 for ax in [axes["hist0"], axes["hist1"], axes["hist2"]]: 

496 ax.yaxis.set_label_position("right") 

497 ax.yaxis.tick_right() 

498 # Add labels to outer axes. 

499 names = ["Ixx", "Iyy", "Ixy"] if not self.config.logarithmic_shapes else ["𝜂1", "𝜂2", "ln(r)"] 

500 axes["scatter02"].set_xlabel(names[0]) 

501 axes["scatter12"].set_xlabel(names[1]) 

502 axes["hist2"].set_xlabel(names[2]) 

503 axes["scatter01"].set_ylabel(names[1]) 

504 axes["scatter02"].set_ylabel(names[2]) 

505 # Make the plots. 

506 mu = results.mean_shape.getParameterVector() 

507 sigma = np.diagonal(results.shape_covariance) ** 0.5 

508 lower_bounds = [max(mu[i] - 3 * sigma[i], min(shape_data[:, i])) for i in range(3)] 

509 upper_bounds = [min(mu[i] + 3 * sigma[i], max(shape_data[:, i])) for i in range(3)] 

510 grids = [np.linspace(lower_bounds[i], upper_bounds[i], 50) for i in range(3)] 

511 axes["radius"].axvline(results.radius_cut, color="k") 

512 for color, mask, alpha in [ 

513 ("grey", measured_mask, 0.5), 

514 ("blue", candidate_mask, 0.75), 

515 ("green", used_mask, 1.0), 

516 ]: 

517 axes["radius"].hist( 

518 radii[mask], 

519 color=color, 

520 alpha=alpha, 

521 bins=16, 

522 histtype="step", 

523 range=(radii.min(), 15.0), 

524 density=True, 

525 ) 

526 for i in range(3): 

527 axes[f"hist{i}"].hist( 

528 shape_data[mask, i], 

529 bins=16, 

530 range=(lower_bounds[i], upper_bounds[i]), 

531 density=True, 

532 color=color, 

533 histtype="step", 

534 alpha=alpha, 

535 ) 

536 for j in range(3): 

537 if (ax := axes.get(f"scatter{i}{j}")) is not None: 

538 ax.scatter( 

539 shape_data[mask, i], 

540 shape_data[mask, j], 

541 c=color, 

542 s=4, 

543 alpha=alpha, 

544 edgecolors=None, 

545 ) 

546 for i in range(3): 

547 axes[f"hist{i}"].plot( 

548 grids[i], scipy.stats.norm.pdf(grids[i], loc=mu[i], scale=sigma[i]), "k", alpha=0.5 

549 ) 

550 axes[f"hist{i}"].set_xlim(lower_bounds[i], upper_bounds[i]) 

551 for j in range(3): 

552 if (ax := axes.get(f"scatter{i}{j}")) is not None: 

553 sigma_ellipse = ellipses.Quadrupole( 

554 results.shape_covariance[i, i], 

555 results.shape_covariance[j, j], 

556 results.shape_covariance[i, j], 

557 ) 

558 for factor in [1, 2, 3]: 

559 self._draw_ellipse( 

560 ax, 

561 sigma_ellipse, 

562 x=mu[i], 

563 y=mu[j], 

564 scale=factor, 

565 fill=False, 

566 edgecolor="k", 

567 alpha=0.5, 

568 ) 

569 ax.set_xlim(lower_bounds[i], upper_bounds[i]) 

570 ax.set_ylim(lower_bounds[j], upper_bounds[j]) 

571 figure.legend( 

572 [ 

573 Line2D([], [], color="green", alpha=1.0), 

574 Line2D([], [], color="blue", alpha=0.75), 

575 Line2D([], [], color="gray", alpha=0.5), 

576 ], 

577 [ 

578 f"RoughPsfShapelets_used ({used_mask.sum()})", 

579 f"RoughPsfShapelets_candidate & ~RoughPsfShapelets_used ({candidate_mask.sum()})", 

580 f"~RoughPsfShapelets_flag & ~RoughPsfShapelets_candidate ({measured_mask.sum()})", 

581 ], 

582 loc="lower center", 

583 ) 

584 return figure 

585 

586 def plot_shapelets( 

587 self, 

588 figure: matplotlib.figure.Figure, 

589 *, 

590 image: ImageF, 

591 catalog: SourceCatalog, 

592 results: Struct, 

593 n_stars: int = 3, 

594 stamp_size: float = 2.0, 

595 ) -> None: 

596 """Create data/model/residual plots of stars and the shapelet model. 

597 

598 Parameters 

599 ---------- 

600 figure 

601 Matplotlib figure to plot to. 

602 image 

603 The image the stars were measured on. 

604 catalog 

605 Catalog of sources with columns populated by the `run` method . 

606 results 

607 Result struct returned by `run`. 

608 n_stars 

609 Number of stars to include. 

610 stamp_size 

611 Stamp size in inches. 

612 """ 

613 from matplotlib.colors import Normalize 

614 

615 width = stamp_size * 3 + 1.5 

616 figure.set_size_inches(w=width, h=stamp_size * n_stars) 

617 axes = figure.subplot_mosaic( 

618 [ 

619 ["image_cbar", f"d{star_id}", f"m{star_id}", f"r{star_id}", "res_cbar"] 

620 for star_id in results.star_ids[:n_stars] 

621 ], 

622 gridspec_kw=dict( 

623 wspace=0.01, hspace=0.01, left=0.5 / width, right=1.0 - 0.5 / width, bottom=0.01, top=0.99 

624 ), 

625 width_ratios=[0.25, stamp_size, stamp_size, stamp_size, 0.25], 

626 ) 

627 for name, ax in axes.items(): 

628 if not name.endswith("cbar"): 

629 ax.axis("off") 

630 norm: Normalize | None = None 

631 res_norm: Normalize | None = None 

632 for star_id in results.star_ids[:n_stars]: 

633 record = catalog.find(star_id) 

634 star_bbox = record.getFootprint().getBBox() 

635 star_model = ImageD(star_bbox) 

636 star_center = record[self._center_key] 

637 star_ellipse = ellipses.Ellipse(ellipses.Axes(record[self._shape_key]), star_center) 

638 star_shapelet = ShapeletFunction(results.shapelet) 

639 star_shapelet.setEllipse(star_ellipse) 

640 star_shapelet.evaluate().addToImage(star_model) 

641 if norm is None: 

642 norm = Normalize(vmin=star_model.array.min(), vmax=star_model.array.max()) 

643 star_image = image[star_bbox].clone() 

644 star_image /= record[self._flux_key] 

645 self._draw_image(axes[f"d{star_id}"], star_image, norm=norm, cmap="YlGnBu") 

646 self._draw_ellipse(axes[f"d{star_id}"], star_ellipse, fill=False, edgecolor="blue", alpha=0.5) 

647 image_plot = self._draw_image(axes[f"m{star_id}"], star_model, norm=norm, cmap="YlGnBu") 

648 self._draw_ellipse(axes[f"m{star_id}"], star_ellipse, fill=False, edgecolor="blue", alpha=0.5) 

649 star_image -= star_model.convertF() 

650 amax = np.abs(star_image.array).max() 

651 if res_norm is None: 

652 res_norm = Normalize(vmin=-amax, vmax=amax) 

653 res_plot = self._draw_image(axes[f"r{star_id}"], star_image, norm=res_norm, cmap="RdBu") 

654 self._draw_ellipse(axes[f"r{star_id}"], star_ellipse, fill=False, edgecolor="blue", alpha=0.5) 

655 figure.colorbar(image_plot, cax=axes["image_cbar"], location="left") 

656 figure.colorbar(res_plot, cax=axes["res_cbar"], location="right") 

657 return figure 

658 

659 def _threshold_with_bounds( 

660 self, 

661 values: np.ndarray, 

662 threshold: float, 

663 min_count: int, 

664 max_count: int, 

665 name: str, 

666 kind: Literal["<", ">"], 

667 ) -> np.ndarray: 

668 """Return the indices of an array that satisfy an inequality 

669 and/or lower and upper bounds on the number of indices returned. 

670 

671 Parameters 

672 ---------- 

673 values 

674 Array of values to threshold on. 

675 threshold 

676 Threshold value that selected elements must be above or below. 

677 min_count 

678 The minimum number of indices returned. When thresholding would 

679 yield fewer than this number, the threshold is ignored. Note that 

680 the number of indices may still be less than this if the size of 

681 ``values`` is less than this. 

682 max_count 

683 The maximum number of indices returned. 

684 name 

685 Name of the quantity being thresholded, for log messages. 

686 kind 

687 Whether the threshold is a upper bound (``<``) or lower bound 

688 (``>``). This also sets how values are ranked when they are added 

689 or dropped to satisfy the count constraints. 

690 

691 Returns 

692 ------- 

693 indices 

694 Indices into ``values``. 

695 """ 

696 if min_count > len(values): 

697 raise NoStarsForShapeletsError( 

698 f"Not enough sources ({len(values)}) for {name} cut that must yield at least {min_count}." 

699 ) 

700 sorter = values.argsort() 

701 n = np.searchsorted(values[sorter], threshold) 

702 if kind == ">": 

703 sorter = sorter[::-1] 

704 n = len(sorter) - n 

705 if n < min_count: 

706 self.log.verbose( 

707 "Applying a %s %s %f cut yields only %d sources; keeping the top %d (%s %s %f) instead.", 

708 name, 

709 kind, 

710 threshold, 

711 n, 

712 min_count, 

713 name, 

714 kind, 

715 values[sorter[min_count - 1]], 

716 ) 

717 n = min_count 

718 elif n > max_count: 

719 self.log.verbose( 

720 "%d sources have %s %s %f; keeping only the top %d (%s %s %f) instead.", 

721 n, 

722 name, 

723 kind, 

724 threshold, 

725 max_count, 

726 name, 

727 kind, 

728 values[sorter[max_count - 1]], 

729 ) 

730 n = max_count 

731 else: 

732 self.log.verbose("Keeping %d sources with %s %s %f.", n, name, kind, threshold) 

733 return sorter[:n] 

734 

735 def _find_first_radius_mode(self, radii: np.ndarray) -> tuple[float, KernelDensity]: 

736 """Find the first peak in a 1-d distribution of radii.""" 

737 kde = KernelDensity(bandwidth=self.config.radius_kde_bandwidth).fit(radii.reshape(-1, 1)) 

738 sorted_radii = radii.copy() 

739 sorted_radii.sort() 

740 sorted_radii = sorted_radii[sorted_radii.searchsorted(self.config.radius_mode_min):] 

741 scores = kde.score_samples(sorted_radii.reshape(-1, 1)) 

742 peaks, _ = scipy.signal.find_peaks(scores) 

743 if not peaks.size: 

744 raise NoStarsForShapeletsError("Radius distribute has no mode.") 

745 return sorted_radii[peaks.min()], kde 

746 

747 @staticmethod 

748 def _draw_ellipse( 

749 axes: matplotlib.axes.Axes, 

750 ellipse: ellipses.BaseCore | ellipses.Ellipse, 

751 *, 

752 x: float | None = None, 

753 y: float | None = None, 

754 scale: float = 1.0, 

755 **kwargs: Any, 

756 ) -> matplotlib.patches.Ellipse: 

757 from matplotlib.patches import Ellipse as EllipsePatch 

758 

759 if isinstance(ellipse, ellipses.Ellipse): 

760 if x is None: 

761 x = ellipse.getCenter().getX() 

762 if y is None: 

763 y = ellipse.getCenter().getY() 

764 ellipse = ellipse.getCore() 

765 else: 

766 if x is None: 

767 x = 0.0 

768 if y is None: 

769 y = 0.0 

770 ellipse = ellipses.Axes(ellipse) 

771 patch = EllipsePatch( 

772 (x, y), 

773 ellipse.getA() * 2 * scale, # factor of 2 for radius->diameter 

774 ellipse.getB() * 2 * scale, 

775 angle=ellipse.getTheta() * 180.0 / np.pi, 

776 **kwargs, 

777 ) 

778 axes.add_patch(patch) 

779 return patch 

780 

781 @staticmethod 

782 def _draw_image(axes: matplotlib.axes.Axes, image: ImageF, **kwargs: Any) -> matplotlib.image.AxesImage: 

783 fp_bbox = Box2D(image.getBBox()) 

784 return axes.imshow( 

785 image.array, 

786 interpolation="nearest", 

787 origin="lower", 

788 aspect="equal", 

789 extent=(fp_bbox.x.min, fp_bbox.x.max, fp_bbox.y.min, fp_bbox.y.max), 

790 **kwargs, 

791 )