Coverage for python / lsst / scarlet / lite / operators.py: 13%
162 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:28 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:28 +0000
1from __future__ import annotations
3from typing import Any, Callable, Sequence, cast
5import numpy as np
6import numpy.typing as npt
7from lsst.scarlet.lite.detect_pybind11 import get_connected_pixels # type: ignore
8from lsst.scarlet.lite.operators_pybind11 import new_monotonicity # type: ignore
10from .bbox import Box
13def prox_connected(morph: np.ndarray, centers: Sequence[Sequence[int]]) -> np.ndarray:
14 """Remove all pixels not connected to the center of a source.
16 Parameters
17 ----------
18 morph:
19 The morphology that is being constrained.
20 centers:
21 The `(cy, cx)` center of any sources that all pixels must be
22 connected to.
24 Returns
25 -------
26 result:
27 The morphology with all pixels that are not connected to a center
28 postion set to zero.
29 """
30 result = np.zeros(morph.shape, dtype=bool)
32 for center in centers:
33 unchecked = np.ones(morph.shape, dtype=bool)
34 cy, cx = center
35 cy = int(cy)
36 cx = int(cx)
37 bounds = np.array([cy, cy, cx, cx]).astype(np.int32)
38 # Update the result in place with the pixels connected to this center
39 get_connected_pixels(cy, cx, morph, unchecked, result, bounds, 0)
41 return result * morph
44class Monotonicity:
45 """Class to implement Monotonicity
47 Callable class that applies monotonicity as a pseudo proximal
48 operator (actually a projection operator) to *a* radially
49 monotonic solution.
51 Notes
52 -----
53 This differs from monotonicity in the main scarlet branch because
54 this stores a single monotonicity operator to set the weights for all
55 of the pixels up to the size of the largest shape expected,
56 and only needs to be created once _per blend_, as opposed to
57 once _per source_..
58 This class is then called with the source morphology
59 to make monotonic and the location of the "center" of the image,
60 and the full weight matrix is sliced accordingly.
62 Parameters
63 ----------
64 shape:
65 The shape of the full operator.
66 This must be larger than the largest possible object size
67 in the blend.
68 dtype:
69 The numpy ``dtype`` of the output image.
70 auto_update:
71 If ``True`` the operator will update its shape if a image is
72 too big to fit in the current operator.
73 fit_radius:
74 Pixels within `fit_radius` of the center of the array to make
75 monotonic are checked to see if they have more flux than the center
76 pixel. If they do, the pixel with larger flux is used as the center.
77 """
79 def __init__(
80 self,
81 shape: tuple[int, int],
82 dtype: npt.DTypeLike = float,
83 auto_update: bool = True,
84 fit_radius: int = 1,
85 ):
86 # Initialize defined variables
87 self.weights: np.ndarray | None = None
88 self.distance: np.ndarray | None = None
89 self.sizes: tuple[int, int, int, int] | None = None
90 self.dtype = dtype
91 self.auto_update = auto_update
92 self.fit_radius = fit_radius
93 self.update(shape)
95 @property
96 def shape(self) -> tuple[int, int]:
97 """The 2D shape of the largest component that can be made monotonic
99 Returns
100 -------
101 result:
102 The shape of the oeprator.
103 """
104 return cast(tuple[int, int], cast(np.ndarray, self.weights).shape[1:])
106 @property
107 def center(self) -> tuple[int, int]:
108 """The center of the full operator
110 Returns
111 -------
112 result:
113 The center of the full operator.
114 """
115 shape = self.shape
116 cx = (shape[1] - 1) // 2
117 cy = (shape[0] - 1) // 2
118 return cy, cx
120 def update(self, shape: tuple[int, int]):
121 """Update the operator with a new shape
123 Parameters
124 ----------
125 shape:
126 The new shape
127 """
128 if len(shape) != 2:
129 msg = f"Monotonicity is a 2D operator but received shape with {len(shape)} dimensions"
130 raise ValueError(msg)
131 if shape[0] % 2 == 0 or shape[1] % 2 == 0:
132 raise ValueError(f"The shape must be odd, got {shape}")
133 # Use the center of the operator as the center
134 # and calculate the distance to each pixel from the center
135 cx = (shape[1] - 1) // 2
136 cy = (shape[0] - 1) // 2
137 _x = np.arange(shape[1], dtype=self.dtype) - cx
138 _y = np.arange(shape[0], dtype=self.dtype) - cy
139 x, y = np.meshgrid(_x, _y)
140 distance = np.sqrt(x**2 + y**2)
142 # Calculate the distance from each pixel to its 8 nearest neighbors
143 neighbor_dist = np.zeros((9,) + distance.shape, dtype=self.dtype)
144 neighbor_dist[0, 1:, 1:] = distance[1:, 1:] - distance[:-1, :-1]
145 neighbor_dist[1, 1:, :] = distance[1:, :] - distance[:-1, :]
146 neighbor_dist[2, 1:, :-1] = distance[1:, :-1] - distance[:-1, 1:]
147 neighbor_dist[3, :, 1:] = distance[:, 1:] - distance[:, :-1]
149 # For the center pixel, set the distance to 1 just so that it is
150 # non-zero
151 neighbor_dist[4, cy, cx] = 1
152 neighbor_dist[5, :, :-1] = distance[:, :-1] - distance[:, 1:]
153 neighbor_dist[6, :-1, 1:] = distance[:-1, 1:] - distance[1:, :-1]
154 neighbor_dist[7, :-1, :] = distance[:-1, :] - distance[1:, :]
155 neighbor_dist[8, :-1, :-1] = distance[:-1, :-1] - distance[1:, 1:]
157 # Calculate the difference in angle to the center
158 # from each pixel to its 8 nearest neighbors
159 angles = np.arctan2(y, x)
160 angle_diff = np.zeros((9,) + angles.shape, dtype=self.dtype)
161 angle_diff[0, 1:, 1:] = angles[1:, 1:] - angles[:-1, :-1]
162 angle_diff[1, 1:, :] = angles[1:, :] - angles[:-1, :]
163 angle_diff[2, 1:, :-1] = angles[1:, :-1] - angles[:-1, 1:]
164 angle_diff[3, :, 1:] = angles[:, 1:] - angles[:, :-1]
165 # For the center pixel, on the center will have a non-zero cosine,
166 # which is used as the weight.
167 angle_diff[4] = 1
168 angle_diff[4, cy, cx] = 0
169 angle_diff[5, :, :-1] = angles[:, :-1] - angles[:, 1:]
170 angle_diff[6, :-1, 1:] = angles[:-1, 1:] - angles[1:, :-1]
171 angle_diff[7, :-1, :] = angles[:-1, :] - angles[1:, :]
172 angle_diff[8, :-1, :-1] = angles[:-1, :-1] - angles[1:, 1:]
174 # Use cos(theta) to set the weights, then normalize
175 # This gives more weight to neighboring pixels that are more closely
176 # aligned with the vector pointing toward the center.
177 weights = np.cos(angle_diff)
178 weights[neighbor_dist <= 0] = 0
179 # Adjust for the discontinuity at theta = 2pi
180 weights[weights < 0] = -weights[weights < 0]
181 weights = weights / np.sum(weights, axis=0)[None, :, :]
183 # Store the parameters needed later
184 self.weights = weights
185 self.distance = distance
186 self.sizes = (cy, cx, shape[0] - cy, shape[1] - cx)
188 def check_size(self, shape: tuple[int, int], center: tuple[int, int], update: bool = True):
189 """Check to see if the operator can be applied
191 Parameters
192 ----------
193 shape:
194 The shape of the image to apply monotonicity.
195 center:
196 The location (in `shape`) of the point where the monotonicity will
197 be taken from.
198 update:
199 When ``True`` the operator will update itself so that an image
200 with shape `shape` can be made monotonic about the `center`.
202 Raises
203 ------
204 ValueError:
205 Raised when an array with shape `shape` does not fit in the
206 current operator and `update` is `False`.
207 """
208 sizes = np.array(tuple(center) + (shape[0] - center[0], shape[1] - center[1]))
209 if np.any(sizes > self.sizes):
210 if update:
211 size = 2 * np.max(sizes) + 1
212 self.update((size, size))
213 else:
214 raise ValueError(f"Cannot apply monotonicity to image with shape {shape} at {center}")
216 def __call__(self, image: np.ndarray, center: tuple[int, int]) -> np.ndarray:
217 """Make an input image monotonic about a center pixel
219 Parameters
220 ----------
221 image:
222 The image to make monotonic.
223 center:
224 The ``(y, x)`` location _in image coordinates_ to make the
225 center of the monotonic region.
227 Returns
228 -------
229 result:
230 The input image is updated in place, but also returned from this
231 method.
232 """
233 # Check for a better center
234 center = get_peak(image, center, self.fit_radius)
236 # Check that the operator can fit the image
237 self.check_size(cast(tuple[int, int], image.shape), center, self.auto_update)
239 # Create the bounding box to slice the weights and distance as needed
240 cy, cx = self.center
241 py, px = center
242 bbox = Box((9,) + image.shape, origin=(0, cy - py, cx - px))
243 weights = cast(np.ndarray, self.weights)[bbox.slices]
244 indices = np.argsort(cast(np.ndarray, self.distance)[bbox.slices[1:]].flatten())
245 coords = np.unravel_index(indices, image.shape)
247 # Pad the image by 1 so that we don't have to worry about
248 # weights on the edges.
249 result_shape = (image.shape[0] + 2, image.shape[1] + 2)
250 result = np.zeros(result_shape, dtype=image.dtype)
251 result[1:-1, 1:-1] = image
252 new_monotonicity(coords[0], coords[1], weights, result)
253 image[:] = result[1:-1, 1:-1]
254 return image
256 def __copy__(self) -> Monotonicity:
257 """Create a shallow copy of the operator
259 Returns
260 -------
261 result:
262 A copy of the operator.
263 """
264 new = Monotonicity(self.shape, self.dtype, self.auto_update, self.fit_radius)
265 return new
267 def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Monotonicity:
268 """Create a deep copy of the operator
270 Parameters
271 ----------
272 memo:
273 The memoization dictionary for deep copies.
275 Returns
276 -------
277 result:
278 A copy of the operator.
279 """
280 return self.__copy__()
283def get_peak(image: np.ndarray, center: tuple[int, int], radius: int = 1) -> tuple[int, int]:
284 """Search around a location for the maximum flux
286 For monotonicity it is important to start at the brightest pixel
287 in the center of the source. This may be off by a pixel or two,
288 so we search for the correct center before applying
289 monotonic_tree.
291 Parameters
292 ----------
293 image:
294 The image of the source.
295 center:
296 The suggested center of the source.
297 radius:
298 The number of pixels around the `center` to search
299 for a higher flux value.
301 Returns
302 -------
303 new_center:
304 The true center of the source.
305 """
306 cy, cx = int(round(center[0])), int(round(center[1]))
307 y0 = np.max([cy - radius, 0])
308 x0 = np.max([cx - radius, 0])
309 y_slice = slice(y0, cy + radius + 1)
310 x_slice = slice(x0, cx + radius + 1)
311 subset = image[y_slice, x_slice]
312 center = cast(tuple[int, int], np.unravel_index(np.argmax(subset), subset.shape))
313 return center[0] + y0, center[1] + x0
316def prox_monotonic_mask(
317 x: np.ndarray,
318 center: tuple[int, int],
319 center_radius: int = 1,
320 variance: float = 0.0,
321 max_iter: int = 3,
322) -> tuple[np.ndarray, np.ndarray, tuple[int, int, int, int]]:
323 """Apply monotonicity from any path from the center
325 Parameters
326 ----------
327 x:
328 The input image that the mask is created for.
329 center:
330 The location of the center of the mask.
331 center_radius:
332 Radius from the center pixel to search for a better center
333 (ie. a pixel in `X` with higher flux than the pixel given by
334 `center`).
335 If `center_radius == 0` then the `center` pixel is assumed
336 to be correct.
337 variance:
338 The average variance in the image.
339 This is used to allow pixels to be non-monotonic up to `variance`,
340 so setting `variance=0` will force strict monotonicity in the mask.
341 max_iter:
342 Maximum number of iterations to interpolate non-monotonic pixels.
344 Returns
345 -------
346 valid:
347 Boolean array of pixels that are monotonic.
348 model:
349 The model with invalid pixels masked out.
350 bounds:
351 The bounds of the valid monotonic pixels.
352 """
353 from lsst.scarlet.lite.operators_pybind11 import (
354 get_valid_monotonic_pixels,
355 linear_interpolate_invalid_pixels,
356 )
358 if center_radius > 0:
359 i, j = get_peak(x, center, center_radius)
360 else:
361 i, j = int(np.round(center[0])), int(np.round(center[1]))
362 unchecked = np.ones(x.shape, dtype=bool)
363 unchecked[i, j] = False
364 orphans = np.zeros(x.shape, dtype=bool)
365 # This is the bounding box of the result
366 bounds = np.array([i, i, j, j], dtype=np.int32)
367 # Get all of the monotonic pixels
368 get_valid_monotonic_pixels(i, j, x, unchecked, orphans, variance, bounds, 0)
369 # Set the initial model to the exact input in the valid pixels
370 model = x.copy()
372 it = 0
374 while np.sum(orphans & unchecked) > 0 and it < max_iter:
375 it += 1
376 all_i, all_j = np.where(orphans)
377 linear_interpolate_invalid_pixels(all_i, all_j, unchecked, model, orphans, variance, True, bounds)
378 valid = ~unchecked & ~orphans
379 # Clear all of the invalid pixels from the input image
380 model = model * valid
381 return valid, model, tuple(bounds) # type: ignore
384def uncentered_operator(
385 x: np.ndarray,
386 func: Callable,
387 center: tuple[int, int] | None = None,
388 fill: float | None = None,
389 **kwargs,
390) -> np.ndarray:
391 """Only apply the operator on a centered patch
393 In some cases, for example symmetry, an operator might not make
394 sense outside of a centered box. This operator only updates
395 the portion of `X` inside the centered region.
397 Parameters
398 ----------
399 x:
400 The parameter to update.
401 func:
402 The function (or operator) to apply to `x`.
403 center:
404 The location of the center of the sub-region to
405 apply `func` to `x`.
406 fill:
407 The value to fill the region outside of centered
408 `sub-region`, for example `0`. If `fill` is `None`
409 then only the subregion is updated and the rest of
410 `x` remains unchanged.
412 Returns
413 -------
414 result:
415 `x`, with an operator applied based on the shifted center.
416 """
417 if center is None:
418 py, px = cast(tuple[int, int], np.unravel_index(np.argmax(x), x.shape))
419 else:
420 py, px = center
421 cy, cx = np.array(x.shape) // 2
423 if py == cy and px == cx:
424 return func(x, **kwargs)
426 dy = int(round(2 * (py - cy)))
427 dx = int(round(2 * (px - cx)))
428 if not x.shape[0] % 2:
429 dy += 1
430 if not x.shape[1] % 2:
431 dx += 1
432 if dx < 0:
433 xslice = slice(None, dx)
434 else:
435 xslice = slice(dx, None)
436 if dy < 0:
437 yslice = slice(None, dy)
438 else:
439 yslice = slice(dy, None)
441 if fill is not None:
442 _x = np.ones(x.shape, x.dtype) * fill
443 _x[yslice, xslice] = func(x[yslice, xslice], **kwargs)
444 x[:] = _x
445 else:
446 x[yslice, xslice] = func(x[yslice, xslice], **kwargs)
448 return x
451def prox_sdss_symmetry(x: np.ndarray):
452 """SDSS/HSC symmetry operator
454 This function uses the *minimum* of the two
455 symmetric pixels in the update.
457 Parameters
458 ----------
459 x:
460 The array to make symmetric.
462 Returns
463 -------
464 result:
465 The updated `x`.
466 """
467 symmetric = np.fliplr(np.flipud(x))
468 x[:] = np.min([x, symmetric], axis=0)
469 return x
472def prox_uncentered_symmetry(
473 x: np.ndarray,
474 center: tuple[int, int] | None = None,
475 fill: float | None = None,
476) -> np.ndarray:
477 """Symmetry with off-center peak
479 Symmetrize X for all pixels with a symmetric partner.
481 Parameters
482 ----------
483 x:
484 The parameter to update.
485 center:
486 The center pixel coordinates to apply the symmetry operator.
487 fill:
488 The value to fill the region that cannot be made symmetric.
489 When `fill` is `None` then the region of `X` that is not symmetric
490 is not constrained.
492 Returns
493 -------
494 result:
495 The update function based on the specified parameters.
496 """
497 return uncentered_operator(x, prox_sdss_symmetry, center, fill=fill)