Coverage for python / lsst / meas / algorithms / computeRoughPsfShapelets.py: 0%
260 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 09:07 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 09:07 +0000
1# This file is part of meas_algorithms.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["ComputeRoughPsfShapeletsTask", "ComputeRoughPsfShapeletsConfig"]
26import itertools
27from typing import TYPE_CHECKING, Any, Literal
29import numpy as np
30import scipy.signal
31import scipy.stats
32from sklearn.covariance import MinCovDet
33from sklearn.neighbors import KernelDensity
35from lsst.afw.geom import ellipses
36from lsst.afw.image import ImageD, ImageF, MaskedImageF
37from lsst.afw.table import Point2DKey, QuadrupoleKey, Schema, SourceCatalog
38from lsst.geom import Box2D
39from lsst.pex.config import Config, Field, ListField
40from lsst.pipe.base import AlgorithmError, Struct, Task
41from lsst.shapelet import LAGUERRE, ShapeletFunction, computeOffset
43from ._algorithmsLib import SpanSetMoments
45if TYPE_CHECKING:
46 import matplotlib.axes
47 import matplotlib.figure
48 import matplotlib.image
49 import matplotlib.patches
52class NoStarsForShapeletsError(AlgorithmError):
53 """Exception raised when ComputeRawPsfMomentsTask fails to find any usable
54 stars.
55 """
57 @property
58 def metadata(self) -> dict[str, Any]:
59 return {}
62class ComputeRoughPsfShapeletsConfig(Config):
63 bad_mask_planes = ListField(
64 "Mask planes to identify pixels to drop from the calculation.",
65 dtype=str,
66 default=["SAT", "SUSPECT", "INTRP"],
67 )
68 bad_pixel_max_fraction = Field(
69 "Maximum fraction of a footprint's pixels that can be bad (according "
70 "to bad_mask_planes) before that footprint is fully dropped.",
71 dtype=float,
72 default=0.25,
73 )
74 bad_pixel_exclusion_radius = Field(
75 "If a bad pixel (according to bad_mask_planes) falls within this "
76 "radius of the unweighted centroid of a footprint, drop that footprint "
77 "entirely.",
78 dtype=float,
79 default=2.0,
80 )
81 max_footprint_area = Field(
82 "Footprints with a pixel count larger than this threshold are dropped before computing moments.",
83 dtype=int,
84 default=10000,
85 )
86 min_snr = Field(
87 "Mininum flux S/N for inclusion in the star sample.",
88 dtype=float,
89 default=50.0,
90 )
91 max_radius_factor = Field(
92 "Maximium multiple of the mode of the radius distribution for inclusion in the star sample.",
93 dtype=float,
94 default=2.0,
95 )
96 max_shape_distance = Field(
97 "Maximum Mahalanobis distance (distance from the center of the shape "
98 "distribution in elliptical sigma units) to select a star from the candidate "
99 "sample, comparing the shape of that source vs. a robust estimate of the "
100 "distribution of the shapes of all sources.",
101 dtype=float,
102 default=3.0,
103 )
104 min_n_stars = Field(
105 "Minimum number of stars to select. The S/N, radius, and shape distance thresholds are "
106 "relaxed as needed to meet this target.",
107 dtype=int,
108 default=10,
109 )
110 max_n_stars = Field(
111 "Maximum number of stars to select. High shape-distance sources "
112 "are dropped as needed to meet this target.",
113 dtype=int,
114 default=20,
115 )
116 logarithmic_shapes = Field(
117 "If True, transform the (xx, yy, xy) moments to conformal shear "
118 "and log trace radius when selecting stars in order to map the ellipse "
119 "parameter to a space with (-inf, inf) bounds on all quantities (but "
120 "a less-linear relationship to the pixel data).",
121 dtype=bool,
122 default=False,
123 )
124 radius_mode_min = Field(
125 "Minimum radius in pixels at which to start searching for the first mode. "
126 "This just needs be large enough to avoid unphysically small garbage (e.g. CRs).",
127 dtype=float,
128 default=1.0,
129 )
130 radius_kde_bandwidth = Field(
131 "Bandwidth of the Gaussian kernel density estimator used to find the "
132 "first mode of the radius distribution.",
133 dtype=float,
134 default=1.0,
135 )
136 shapelet_order = Field(
137 "Order of the shapelet expansion fit to the stars.",
138 dtype=int,
139 default=4,
140 )
141 shapelet_scale_factor = Field(
142 "Scale factor to apply to the moments ellipse when computing the ellipse for the shapelet basis.",
143 dtype=float,
144 default=1.0,
145 )
147 def validate(self) -> None:
148 if self.min_n_stars > self.max_n_stars:
149 raise ValueError(
150 f"min_n_stars={self.min_n_stars} is greater than max_n_stars={self.max_n_stars}."
151 )
152 if self.shapelet_order < 0:
153 raise ValueError(f"shapelet order {self.shapelet_order} must be nonnegative.")
154 if self.shapelet_scale_factor <= 0.0:
155 raise ValueError(f"shapelet scale factor {self.shapelet_scale_factor} must be positive.")
158class ComputeRoughPsfShapeletsTask(Task):
159 """A task that computes a rough shapelet expansion of the PSF from a set
160 of high S/N detections.
162 Notes
163 -----
164 This task is expected to be run early in single-epoch processing - just
165 after background subtraction and an initial high S/N detection phase, and
166 before any deblending or measurement - in order to identify out-of-focus
167 or otherwise bad PSFs.
169 Given a background-subtracted `lsst.afw.image.MaskedImage`, an
170 `lsst.afw.table.SourceCatalog` with footprints attached, and a random
171 number generator seed, the `run` method will:
173 - Compute the *unweighted* 0th-2nd moments of every non-child source over
174 the footprint (except certain configurable masked pixels). This is
175 delegated to the `compute_raw_moments` method (which uses the C++
176 `SpanSetMoments` class for the pixel-level processing). Unweighted
177 moments are used to avoid "latching onto" a small piece of PSF
178 substructure, but can be much noiser than the Gaussian-weighed moments
179 we usually use.
181 - Select a "candidate" sample of sources with successfully measured
182 moments that satisfy a S/N cut and a radius cut (determined from the
183 first mode of the radius distribution, via kernel density estimation),
184 and then use a robust covariance estimator (`scikit_learn.MinCovDet`)
185 to select presumed isolated stars that are close to the center of that
186 distribution, in 3-parameter shape space. This is delegated to the
187 `select_stars` method.
189 - Fit a single shapelet expansion to the selected stars. This is
190 mostly delegated to the `SpanSetMoments.fit_shapelets` method.
192 The radial shapelet terms at 0th, 2nd, and 4th order are expected to form
193 a space in which donut-shaped PSFs are well-separated from those with
194 monotonic profiles. Other terms *may* be useful in identifying other
195 kinds of undesirable PSF structure.
196 """
198 ConfigClass = ComputeRoughPsfShapeletsConfig
199 _DefaultName = "computeRoughPsfShapelets"
200 config: ComputeRoughPsfShapeletsConfig
202 def __init__(
203 self,
204 config: ComputeRoughPsfShapeletsConfig | None = None,
205 *,
206 schema: Schema,
207 **kwargs: Any,
208 ):
209 super().__init__(config=config, **kwargs)
210 self.schema = schema
211 self._flux_key = schema.addField("RawPsfMoments_flux", type=float, doc="Unweighted zeroth moment.")
212 self._flux_err_key = schema.addField(
213 "RoughPsfShapelets_fluxErr",
214 type=float,
215 doc="Uncertainty on the unweighted zeroth moment.",
216 )
217 self._center_key = Point2DKey.addFields(
218 schema, "RoughPsfShapelets", "Center from unweighted first moments.", "pixels"
219 )
220 self._shape_key = QuadrupoleKey.addFields(
221 schema, "RoughPsfShapelets", "Shape from unweighted second moments."
222 )
223 self._flag_key = schema.addField(
224 "RoughPsfShapelets_flag",
225 type="Flag",
226 doc="Flag set if the raw PSF moments were not computed.",
227 )
228 self._candidate_key = schema.addField(
229 "RoughPsfShapelets_candidate",
230 type="Flag",
231 doc="Flag set if this source passed the radius_fraction cut (see configuration).",
232 )
233 self._used_key = schema.addField(
234 "RoughPsfShapelets_used",
235 type="Flag",
236 doc=(
237 "Flag set if this source passed the radius_fraction and shape_distance cuts "
238 "(see configuration) and was used to fit the shapelet expansion."
239 ),
240 )
242 def run(self, *, masked_image: MaskedImageF, catalog: SourceCatalog, seed: int) -> Struct:
243 """Compute raw moments, select stars, and fit a shapelet expansion to
244 them.
246 Parameters
247 ----------
248 masked_image
249 Masked image to measure on. Must be background-subtracted.
250 catalog
251 Catalog of detections to extract footprints from and fill output
252 columns of. Its schema must be a superset of ``self.schema``.
253 seed
254 A random-number generator seed, used for the robust covariance
255 estimator.
257 Returns
258 -------
259 `lsst.pipe.base.Struct`
260 A struct of results containing:
262 - ``shapelet`` (`lsst.shapelet.ShapeletFunction`): A
263 Gauss-Laguerre (polar shaplet) expansion of the PSF.
264 - ``radial`` (`list` [`float`]): the purely radial coefficients
265 of the shapelet expansion.
266 - all attributes returned by the `select_stars` method.
267 """
268 moments = self.compute_raw_moments(masked_image=masked_image, catalog=catalog)
269 result = self.select_stars(catalog, seed=seed)
270 star_moments = [moments[star_id] for star_id in result.star_ids]
271 result.shapelet = SpanSetMoments.fit_shapelets(
272 masked_image,
273 star_moments,
274 self.config.shapelet_order,
275 self.config.shapelet_scale_factor,
276 )
277 result.shapelet.getEllipse().setCore(result.mean_shape)
278 result.shapelet.changeBasisType(LAGUERRE)
279 result.radial = result.shapelet.getCoefficients()[
280 [computeOffset(i) for i in range(0, self.config.shapelet_order + 1, 2)]
281 ]
282 self.log.info("Rough PSF shapelet radial terms: %s.", result.radial)
283 return result
285 def compute_raw_moments(
286 self, *, masked_image: MaskedImageF, catalog: SourceCatalog
287 ) -> dict[int, SpanSetMoments]:
288 """Compute the unweighted moments of the footprints in a catalog.
290 Parameters
291 ----------
292 masked_image
293 Masked image to measure on. Must be background-subtracted.
294 catalog
295 Catalog of detections to extract footprints from and fill output
296 columns of. Its schema must be a superset of ``self.schema``.
298 Returns
299 -------
300 `dict` [`int`, `SpanSetMoments`]
301 Objects used to construct and hold the unweighted moments and the
302 pixel region used to computed them, keyed by source ID.
303 """
304 bitmask = masked_image.mask.getPlaneBitMask(self.config.bad_mask_planes)
305 all_moments: dict[int, SpanSetMoments] = {}
306 for record in catalog:
307 if record.getParent() != 0:
308 record.set(self._flag_key, True)
309 self.log.debug("Skipping child source %s", record.getId())
310 continue
311 footprint_spans = record.getFootprint().getSpans()
312 if footprint_spans.getArea() > self.config.max_footprint_area:
313 record.set(self._flag_key, True)
314 self.log.debug(
315 "Skipping source %s with footprint area %d > %d.",
316 record.getId(),
317 footprint_spans.getArea(),
318 self.config.max_footprint_area,
319 )
320 continue
321 moments = SpanSetMoments.compute(
322 record.getFootprint().getSpans(),
323 masked_image=masked_image,
324 bad_bitmask=bitmask,
325 bad_pixel_max_fraction=self.config.bad_pixel_max_fraction,
326 bad_pixel_exclusion_radius=self.config.bad_pixel_exclusion_radius,
327 )
328 record.set(self._flux_key, moments.flux)
329 record.set(self._flux_err_key, moments.variance**0.5)
330 record.set(self._center_key, moments.center)
331 record.set(self._shape_key, moments.shape)
332 record.set(self._flag_key, moments.any_flags_set)
333 if not moments.any_flags_set:
334 all_moments[record.getId()] = moments
335 if all_moments:
336 self.log.verbose(
337 "Successfully measured raw moments for %d of %d sources.",
338 len(all_moments),
339 len(catalog),
340 )
341 else:
342 raise NoStarsForShapeletsError("No raw moments could be measured.")
343 return all_moments
345 def select_stars(self, catalog: SourceCatalog, seed: int) -> Struct:
346 """Select probable stars from the distribution of second moments.
348 Parameters
349 ----------
350 catalog
351 Catalog of detections to extract footprints from and fill output
352 columns of. Its schema must be a superset of ``self.schema``.
353 seed
354 A random-number generator seed, used for the robust covariance
355 estimator.
357 Returns
358 -------
359 `lsst.pipe.base.Struct`
360 A struct of results containing:
362 - ``star_ids`` (`numpy.ndarray`): the source IDs that are expected
363 to be stars.
364 - ``mean_shape`` (`lsst.afw.geom.ellipses.BaseCore`): the mean of
365 the shape distribution.
366 - ``shape_covariance`` (`numpy.ndarray`): the covariance of the
367 distribution of shapes; a 3x3 matrix. This uses the same
368 parameterization of the shapes as ``mean_shape``.
369 - ``radius_cut`` (`float`): the indended radius cut (i.e. the
370 mode of the radius distribution multipled by the
371 ``radius_factor`` configuration option).
372 - ``radius_kde`` (`sklearn.neighbors.KernelDensity`): kernel
373 density estimator on the radius distribution, used to determine
374 the radius cut.
375 """
376 # Cut on flags and SNR first.
377 indices = np.arange(len(catalog), dtype=int)[np.logical_not(catalog[self._flag_key])]
378 indices = indices[
379 self._threshold_with_bounds(
380 catalog[self._flux_key][indices] / catalog[self._flux_err_key][indices],
381 threshold=self.config.min_snr,
382 min_count=self.config.min_n_stars,
383 max_count=len(catalog),
384 name="S/N",
385 kind=">",
386 )
387 ]
388 # Cut on radius next.
389 radii = np.zeros(indices.size, dtype=np.float64)
390 for n, index in enumerate(indices):
391 record = catalog[index]
392 shape = record.get(self._shape_key)
393 radii[n] = shape.getTraceRadius()
394 radius_mode, radius_kde = self._find_first_radius_mode(radii)
395 radius_cut = self.config.max_radius_factor * radius_mode
396 indices = indices[
397 self._threshold_with_bounds(
398 radii,
399 threshold=radius_cut,
400 min_count=self.config.min_n_stars,
401 max_count=len(catalog),
402 name="radius",
403 kind="<",
404 )
405 ]
406 shape_data = np.zeros((len(indices), 3), dtype=np.float64)
407 for n, index in enumerate(indices):
408 record = catalog[index]
409 record.set(self._candidate_key, True)
410 shape = record.get(self._shape_key)
411 if self.config.logarithmic_shapes:
412 shape = ellipses.SeparableConformalShearLogTraceRadius(shape)
413 shape_data[n, :] = shape.getParameterVector()
414 shape_dist = MinCovDet(random_state=seed).fit(shape_data)
415 m_distances = shape_dist.mahalanobis(shape_data)
416 indices = indices[
417 self._threshold_with_bounds(
418 m_distances,
419 threshold=self.config.max_shape_distance,
420 min_count=self.config.min_n_stars,
421 max_count=self.config.max_n_stars,
422 name="Mahalanobis distance",
423 kind="<",
424 )
425 ]
426 for index in indices:
427 catalog[index].set(self._used_key, True)
428 star_ids = catalog["id"][indices]
429 if self.config.logarithmic_shapes:
430 mean_shape = ellipses.SeparableConformalShearLogTraceRadius(shape_dist.location_)
431 else:
432 mean_shape = ellipses.Quadrupole(shape_dist.location_)
433 return Struct(
434 star_ids=star_ids,
435 mean_shape=mean_shape,
436 shape_covariance=shape_dist.covariance_,
437 radius_cut=radius_cut,
438 radius_kde=radius_kde,
439 )
441 def plot_selection(
442 self, figure: matplotlib.figure.Figure, *, catalog: SourceCatalog, results: Struct
443 ) -> None:
444 """Create plots of the shape distribution space used to select stars.
446 Parameters
447 ----------
448 figure
449 Matplotlib figure to plot to.
450 catalog
451 Catalog of sources with columns populated by the `run` method (at
452 least through the `select_stars` step).
453 results
454 Result struct returned by `run` or `select_stars`.
455 """
456 from matplotlib.lines import Line2D
458 shape_data = np.zeros((len(catalog), 3), dtype=np.float64)
459 radii = np.zeros(len(catalog), dtype=np.float64)
460 for n, record in enumerate(catalog):
461 if record[self._flag_key]:
462 continue
463 shape = record[self._shape_key]
464 if self.config.logarithmic_shapes:
465 shape = ellipses.SeparableConformalShearLogTraceRadius(shape)
466 shape_data[n, :] = shape.getParameterVector()
467 radii[n] = shape.getTraceRadius()
468 used_mask = catalog[self._used_key]
469 candidate_mask = np.logical_and(catalog[self._candidate_key], np.logical_not(used_mask))
470 measured_mask = np.logical_and(
471 np.logical_not(catalog[self._flag_key]), np.logical_not(catalog[self._candidate_key])
472 )
473 # Set up the axes.
474 axes = figure.subplot_mosaic(
475 [
476 ["radius", "radius", "radius"],
477 ["hist0", ".", "."],
478 ["scatter01", "hist1", "."],
479 ["scatter02", "scatter12", "hist2"],
480 ],
481 gridspec_kw=dict(bottom=0.2, top=1.0),
482 )
483 axes["scatter01"].sharex(axes["hist0"])
484 axes["scatter02"].sharex(axes["hist0"])
485 axes["scatter12"].sharex(axes["hist1"])
486 axes["scatter12"].sharey(axes["scatter02"])
487 for tk in itertools.chain(
488 axes["hist0"].get_xticklabels(),
489 axes["scatter01"].get_xticklabels(),
490 axes["hist1"].get_xticklabels(),
491 axes["scatter12"].get_yticklabels(),
492 ):
493 tk.set_visible(False)
494 # Move hist y axes to the outside.
495 for ax in [axes["hist0"], axes["hist1"], axes["hist2"]]:
496 ax.yaxis.set_label_position("right")
497 ax.yaxis.tick_right()
498 # Add labels to outer axes.
499 names = ["Ixx", "Iyy", "Ixy"] if not self.config.logarithmic_shapes else ["𝜂1", "𝜂2", "ln(r)"]
500 axes["scatter02"].set_xlabel(names[0])
501 axes["scatter12"].set_xlabel(names[1])
502 axes["hist2"].set_xlabel(names[2])
503 axes["scatter01"].set_ylabel(names[1])
504 axes["scatter02"].set_ylabel(names[2])
505 # Make the plots.
506 mu = results.mean_shape.getParameterVector()
507 sigma = np.diagonal(results.shape_covariance) ** 0.5
508 lower_bounds = [max(mu[i] - 3 * sigma[i], min(shape_data[:, i])) for i in range(3)]
509 upper_bounds = [min(mu[i] + 3 * sigma[i], max(shape_data[:, i])) for i in range(3)]
510 grids = [np.linspace(lower_bounds[i], upper_bounds[i], 50) for i in range(3)]
511 axes["radius"].axvline(results.radius_cut, color="k")
512 for color, mask, alpha in [
513 ("grey", measured_mask, 0.5),
514 ("blue", candidate_mask, 0.75),
515 ("green", used_mask, 1.0),
516 ]:
517 axes["radius"].hist(
518 radii[mask],
519 color=color,
520 alpha=alpha,
521 bins=16,
522 histtype="step",
523 range=(radii.min(), 15.0),
524 density=True,
525 )
526 for i in range(3):
527 axes[f"hist{i}"].hist(
528 shape_data[mask, i],
529 bins=16,
530 range=(lower_bounds[i], upper_bounds[i]),
531 density=True,
532 color=color,
533 histtype="step",
534 alpha=alpha,
535 )
536 for j in range(3):
537 if (ax := axes.get(f"scatter{i}{j}")) is not None:
538 ax.scatter(
539 shape_data[mask, i],
540 shape_data[mask, j],
541 c=color,
542 s=4,
543 alpha=alpha,
544 edgecolors=None,
545 )
546 for i in range(3):
547 axes[f"hist{i}"].plot(
548 grids[i], scipy.stats.norm.pdf(grids[i], loc=mu[i], scale=sigma[i]), "k", alpha=0.5
549 )
550 axes[f"hist{i}"].set_xlim(lower_bounds[i], upper_bounds[i])
551 for j in range(3):
552 if (ax := axes.get(f"scatter{i}{j}")) is not None:
553 sigma_ellipse = ellipses.Quadrupole(
554 results.shape_covariance[i, i],
555 results.shape_covariance[j, j],
556 results.shape_covariance[i, j],
557 )
558 for factor in [1, 2, 3]:
559 self._draw_ellipse(
560 ax,
561 sigma_ellipse,
562 x=mu[i],
563 y=mu[j],
564 scale=factor,
565 fill=False,
566 edgecolor="k",
567 alpha=0.5,
568 )
569 ax.set_xlim(lower_bounds[i], upper_bounds[i])
570 ax.set_ylim(lower_bounds[j], upper_bounds[j])
571 figure.legend(
572 [
573 Line2D([], [], color="green", alpha=1.0),
574 Line2D([], [], color="blue", alpha=0.75),
575 Line2D([], [], color="gray", alpha=0.5),
576 ],
577 [
578 f"RoughPsfShapelets_used ({used_mask.sum()})",
579 f"RoughPsfShapelets_candidate & ~RoughPsfShapelets_used ({candidate_mask.sum()})",
580 f"~RoughPsfShapelets_flag & ~RoughPsfShapelets_candidate ({measured_mask.sum()})",
581 ],
582 loc="lower center",
583 )
584 return figure
586 def plot_shapelets(
587 self,
588 figure: matplotlib.figure.Figure,
589 *,
590 image: ImageF,
591 catalog: SourceCatalog,
592 results: Struct,
593 n_stars: int = 3,
594 stamp_size: float = 2.0,
595 ) -> None:
596 """Create data/model/residual plots of stars and the shapelet model.
598 Parameters
599 ----------
600 figure
601 Matplotlib figure to plot to.
602 image
603 The image the stars were measured on.
604 catalog
605 Catalog of sources with columns populated by the `run` method .
606 results
607 Result struct returned by `run`.
608 n_stars
609 Number of stars to include.
610 stamp_size
611 Stamp size in inches.
612 """
613 from matplotlib.colors import Normalize
615 width = stamp_size * 3 + 1.5
616 figure.set_size_inches(w=width, h=stamp_size * n_stars)
617 axes = figure.subplot_mosaic(
618 [
619 ["image_cbar", f"d{star_id}", f"m{star_id}", f"r{star_id}", "res_cbar"]
620 for star_id in results.star_ids[:n_stars]
621 ],
622 gridspec_kw=dict(
623 wspace=0.01, hspace=0.01, left=0.5 / width, right=1.0 - 0.5 / width, bottom=0.01, top=0.99
624 ),
625 width_ratios=[0.25, stamp_size, stamp_size, stamp_size, 0.25],
626 )
627 for name, ax in axes.items():
628 if not name.endswith("cbar"):
629 ax.axis("off")
630 norm: Normalize | None = None
631 res_norm: Normalize | None = None
632 for star_id in results.star_ids[:n_stars]:
633 record = catalog.find(star_id)
634 star_bbox = record.getFootprint().getBBox()
635 star_model = ImageD(star_bbox)
636 star_center = record[self._center_key]
637 star_ellipse = ellipses.Ellipse(ellipses.Axes(record[self._shape_key]), star_center)
638 star_shapelet = ShapeletFunction(results.shapelet)
639 star_shapelet.setEllipse(star_ellipse)
640 star_shapelet.evaluate().addToImage(star_model)
641 if norm is None:
642 norm = Normalize(vmin=star_model.array.min(), vmax=star_model.array.max())
643 star_image = image[star_bbox].clone()
644 star_image /= record[self._flux_key]
645 self._draw_image(axes[f"d{star_id}"], star_image, norm=norm, cmap="YlGnBu")
646 self._draw_ellipse(axes[f"d{star_id}"], star_ellipse, fill=False, edgecolor="blue", alpha=0.5)
647 image_plot = self._draw_image(axes[f"m{star_id}"], star_model, norm=norm, cmap="YlGnBu")
648 self._draw_ellipse(axes[f"m{star_id}"], star_ellipse, fill=False, edgecolor="blue", alpha=0.5)
649 star_image -= star_model.convertF()
650 amax = np.abs(star_image.array).max()
651 if res_norm is None:
652 res_norm = Normalize(vmin=-amax, vmax=amax)
653 res_plot = self._draw_image(axes[f"r{star_id}"], star_image, norm=res_norm, cmap="RdBu")
654 self._draw_ellipse(axes[f"r{star_id}"], star_ellipse, fill=False, edgecolor="blue", alpha=0.5)
655 figure.colorbar(image_plot, cax=axes["image_cbar"], location="left")
656 figure.colorbar(res_plot, cax=axes["res_cbar"], location="right")
657 return figure
659 def _threshold_with_bounds(
660 self,
661 values: np.ndarray,
662 threshold: float,
663 min_count: int,
664 max_count: int,
665 name: str,
666 kind: Literal["<", ">"],
667 ) -> np.ndarray:
668 """Return the indices of an array that satisfy an inequality
669 and/or lower and upper bounds on the number of indices returned.
671 Parameters
672 ----------
673 values
674 Array of values to threshold on.
675 threshold
676 Threshold value that selected elements must be above or below.
677 min_count
678 The minimum number of indices returned. When thresholding would
679 yield fewer than this number, the threshold is ignored. Note that
680 the number of indices may still be less than this if the size of
681 ``values`` is less than this.
682 max_count
683 The maximum number of indices returned.
684 name
685 Name of the quantity being thresholded, for log messages.
686 kind
687 Whether the threshold is a upper bound (``<``) or lower bound
688 (``>``). This also sets how values are ranked when they are added
689 or dropped to satisfy the count constraints.
691 Returns
692 -------
693 indices
694 Indices into ``values``.
695 """
696 if min_count > len(values):
697 raise NoStarsForShapeletsError(
698 f"Not enough sources ({len(values)}) for {name} cut that must yield at least {min_count}."
699 )
700 sorter = values.argsort()
701 n = np.searchsorted(values[sorter], threshold)
702 if kind == ">":
703 sorter = sorter[::-1]
704 n = len(sorter) - n
705 if n < min_count:
706 self.log.verbose(
707 "Applying a %s %s %f cut yields only %d sources; keeping the top %d (%s %s %f) instead.",
708 name,
709 kind,
710 threshold,
711 n,
712 min_count,
713 name,
714 kind,
715 values[sorter[min_count - 1]],
716 )
717 n = min_count
718 elif n > max_count:
719 self.log.verbose(
720 "%d sources have %s %s %f; keeping only the top %d (%s %s %f) instead.",
721 n,
722 name,
723 kind,
724 threshold,
725 max_count,
726 name,
727 kind,
728 values[sorter[max_count - 1]],
729 )
730 n = max_count
731 else:
732 self.log.verbose("Keeping %d sources with %s %s %f.", n, name, kind, threshold)
733 return sorter[:n]
735 def _find_first_radius_mode(self, radii: np.ndarray) -> tuple[float, KernelDensity]:
736 """Find the first peak in a 1-d distribution of radii."""
737 kde = KernelDensity(bandwidth=self.config.radius_kde_bandwidth).fit(radii.reshape(-1, 1))
738 sorted_radii = radii.copy()
739 sorted_radii.sort()
740 sorted_radii = sorted_radii[sorted_radii.searchsorted(self.config.radius_mode_min):]
741 scores = kde.score_samples(sorted_radii.reshape(-1, 1))
742 peaks, _ = scipy.signal.find_peaks(scores)
743 if not peaks.size:
744 raise NoStarsForShapeletsError("Radius distribute has no mode.")
745 return sorted_radii[peaks.min()], kde
747 @staticmethod
748 def _draw_ellipse(
749 axes: matplotlib.axes.Axes,
750 ellipse: ellipses.BaseCore | ellipses.Ellipse,
751 *,
752 x: float | None = None,
753 y: float | None = None,
754 scale: float = 1.0,
755 **kwargs: Any,
756 ) -> matplotlib.patches.Ellipse:
757 from matplotlib.patches import Ellipse as EllipsePatch
759 if isinstance(ellipse, ellipses.Ellipse):
760 if x is None:
761 x = ellipse.getCenter().getX()
762 if y is None:
763 y = ellipse.getCenter().getY()
764 ellipse = ellipse.getCore()
765 else:
766 if x is None:
767 x = 0.0
768 if y is None:
769 y = 0.0
770 ellipse = ellipses.Axes(ellipse)
771 patch = EllipsePatch(
772 (x, y),
773 ellipse.getA() * 2 * scale, # factor of 2 for radius->diameter
774 ellipse.getB() * 2 * scale,
775 angle=ellipse.getTheta() * 180.0 / np.pi,
776 **kwargs,
777 )
778 axes.add_patch(patch)
779 return patch
781 @staticmethod
782 def _draw_image(axes: matplotlib.axes.Axes, image: ImageF, **kwargs: Any) -> matplotlib.image.AxesImage:
783 fp_bbox = Box2D(image.getBBox())
784 return axes.imshow(
785 image.array,
786 interpolation="nearest",
787 origin="lower",
788 aspect="equal",
789 extent=(fp_bbox.x.min, fp_bbox.x.max, fp_bbox.y.min, fp_bbox.y.max),
790 **kwargs,
791 )