Coverage for python/lsst/scarlet/lite/operators.py: 13%

156 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-02-07 11:26 +0000

1from typing import Callable, Sequence, cast 

2 

3import numpy as np 

4import numpy.typing as npt 

5from lsst.scarlet.lite.detect_pybind11 import get_connected_pixels # type: ignore 

6from lsst.scarlet.lite.operators_pybind11 import new_monotonicity # type: ignore 

7 

8from .bbox import Box 

9 

10 

11def prox_connected(morph: np.ndarray, centers: Sequence[Sequence[int]]) -> np.ndarray: 

12 """Remove all pixels not connected to the center of a source. 

13 

14 Parameters 

15 ---------- 

16 morph: 

17 The morphology that is being constrained. 

18 centers: 

19 The `(cy, cx)` center of any sources that all pixels must be 

20 connected to. 

21 

22 Returns 

23 ------- 

24 result: 

25 The morphology with all pixels that are not connected to a center 

26 postion set to zero. 

27 """ 

28 result = np.zeros(morph.shape, dtype=bool) 

29 

30 for center in centers: 

31 unchecked = np.ones(morph.shape, dtype=bool) 

32 cy, cx = center 

33 cy = int(cy) 

34 cx = int(cx) 

35 bounds = np.array([cy, cy, cx, cx]).astype(np.int32) 

36 # Update the result in place with the pixels connected to this center 

37 get_connected_pixels(cy, cx, morph, unchecked, result, bounds, 0) 

38 

39 return result * morph 

40 

41 

42class Monotonicity: 

43 """Class to implement Monotonicity 

44 

45 Callable class that applies monotonicity as a pseudo proximal 

46 operator (actually a projection operator) to *a* radially 

47 monotonic solution. 

48 

49 Notes 

50 ----- 

51 This differs from monotonicity in the main scarlet branch because 

52 this stores a single monotonicity operator to set the weights for all 

53 of the pixels up to the size of the largest shape expected, 

54 and only needs to be created once _per blend_, as opposed to 

55 once _per source_.. 

56 This class is then called with the source morphology 

57 to make monotonic and the location of the "center" of the image, 

58 and the full weight matrix is sliced accordingly. 

59 

60 Parameters 

61 ---------- 

62 shape: 

63 The shape of the full operator. 

64 This must be larger than the largest possible object size 

65 in the blend. 

66 dtype: 

67 The numpy ``dtype`` of the output image. 

68 auto_update: 

69 If ``True`` the operator will update its shape if a image is 

70 too big to fit in the current operator. 

71 fit_radius: 

72 Pixels within `fit_radius` of the center of the array to make 

73 monotonic are checked to see if they have more flux than the center 

74 pixel. If they do, the pixel with larger flux is used as the center. 

75 """ 

76 

77 def __init__( 

78 self, 

79 shape: tuple[int, int], 

80 dtype: npt.DTypeLike = float, 

81 auto_update: bool = True, 

82 fit_radius: int = 1, 

83 ): 

84 # Initialize defined variables 

85 self.weights: np.ndarray | None = None 

86 self.distance: np.ndarray | None = None 

87 self.sizes: tuple[int, int, int, int] | None = None 

88 self.dtype = dtype 

89 self.auto_update = auto_update 

90 self.fit_radius = fit_radius 

91 self.update(shape) 

92 

93 @property 

94 def shape(self) -> tuple[int, int]: 

95 """The 2D shape of the largest component that can be made monotonic 

96 

97 Returns 

98 ------- 

99 result: 

100 The shape of the oeprator. 

101 """ 

102 return cast(tuple[int, int], cast(np.ndarray, self.weights).shape[1:]) 

103 

104 @property 

105 def center(self) -> tuple[int, int]: 

106 """The center of the full operator 

107 

108 Returns 

109 ------- 

110 result: 

111 The center of the full operator. 

112 """ 

113 shape = self.shape 

114 cx = (shape[1] - 1) // 2 

115 cy = (shape[0] - 1) // 2 

116 return cy, cx 

117 

118 def update(self, shape: tuple[int, int]): 

119 """Update the operator with a new shape 

120 

121 Parameters 

122 ---------- 

123 shape: 

124 The new shape 

125 """ 

126 if len(shape) != 2: 

127 msg = f"Monotonicity is a 2D operator but received shape with {len(shape)} dimensions" 

128 raise ValueError(msg) 

129 if shape[0] % 2 == 0 or shape[1] % 2 == 0: 

130 raise ValueError(f"The shape must be odd, got {shape}") 

131 # Use the center of the operator as the center 

132 # and calculate the distance to each pixel from the center 

133 cx = (shape[1] - 1) // 2 

134 cy = (shape[0] - 1) // 2 

135 x = np.arange(shape[1], dtype=self.dtype) - cx 

136 y = np.arange(shape[0], dtype=self.dtype) - cy 

137 x, y = np.meshgrid(x, y) 

138 distance = np.sqrt(x**2 + y**2) 

139 

140 # Calculate the distance from each pixel to its 8 nearest neighbors 

141 neighbor_dist = np.zeros((9,) + distance.shape, dtype=self.dtype) 

142 neighbor_dist[0, 1:, 1:] = distance[1:, 1:] - distance[:-1, :-1] 

143 neighbor_dist[1, 1:, :] = distance[1:, :] - distance[:-1, :] 

144 neighbor_dist[2, 1:, :-1] = distance[1:, :-1] - distance[:-1, 1:] 

145 neighbor_dist[3, :, 1:] = distance[:, 1:] - distance[:, :-1] 

146 

147 # For the center pixel, set the distance to 1 just so that it is 

148 # non-zero 

149 neighbor_dist[4, cy, cx] = 1 

150 neighbor_dist[5, :, :-1] = distance[:, :-1] - distance[:, 1:] 

151 neighbor_dist[6, :-1, 1:] = distance[:-1, 1:] - distance[1:, :-1] 

152 neighbor_dist[7, :-1, :] = distance[:-1, :] - distance[1:, :] 

153 neighbor_dist[8, :-1, :-1] = distance[:-1, :-1] - distance[1:, 1:] 

154 

155 # Calculate the difference in angle to the center 

156 # from each pixel to its 8 nearest neighbors 

157 angles = np.arctan2(y, x) 

158 angle_diff = np.zeros((9,) + angles.shape, dtype=self.dtype) 

159 angle_diff[0, 1:, 1:] = angles[1:, 1:] - angles[:-1, :-1] 

160 angle_diff[1, 1:, :] = angles[1:, :] - angles[:-1, :] 

161 angle_diff[2, 1:, :-1] = angles[1:, :-1] - angles[:-1, 1:] 

162 angle_diff[3, :, 1:] = angles[:, 1:] - angles[:, :-1] 

163 # For the center pixel, on the center will have a non-zero cosine, 

164 # which is used as the weight. 

165 angle_diff[4] = 1 

166 angle_diff[4, cy, cx] = 0 

167 angle_diff[5, :, :-1] = angles[:, :-1] - angles[:, 1:] 

168 angle_diff[6, :-1, 1:] = angles[:-1, 1:] - angles[1:, :-1] 

169 angle_diff[7, :-1, :] = angles[:-1, :] - angles[1:, :] 

170 angle_diff[8, :-1, :-1] = angles[:-1, :-1] - angles[1:, 1:] 

171 

172 # Use cos(theta) to set the weights, then normalize 

173 # This gives more weight to neighboring pixels that are more closely 

174 # aligned with the vector pointing toward the center. 

175 weights = np.cos(angle_diff) 

176 weights[neighbor_dist <= 0] = 0 

177 # Adjust for the discontinuity at theta = 2pi 

178 weights[weights < 0] = -weights[weights < 0] 

179 weights = weights / np.sum(weights, axis=0)[None, :, :] 

180 

181 # Store the parameters needed later 

182 self.weights = weights 

183 self.distance = distance 

184 self.sizes = (cy, cx, shape[0] - cy, shape[1] - cx) 

185 

186 def check_size(self, shape: tuple[int, int], center: tuple[int, int], update: bool = True): 

187 """Check to see if the operator can be applied 

188 

189 Parameters 

190 ---------- 

191 shape: 

192 The shape of the image to apply monotonicity. 

193 center: 

194 The location (in `shape`) of the point where the monotonicity will 

195 be taken from. 

196 update: 

197 When ``True`` the operator will update itself so that an image 

198 with shape `shape` can be made monotonic about the `center`. 

199 

200 Raises 

201 ------ 

202 ValueError: 

203 Raised when an array with shape `shape` does not fit in the 

204 current operator and `update` is `False`. 

205 """ 

206 sizes = np.array(tuple(center) + (shape[0] - center[0], shape[1] - center[1])) 

207 if np.any(sizes > self.sizes): 

208 if update: 

209 size = 2 * np.max(sizes) + 1 

210 self.update((size, size)) 

211 else: 

212 raise ValueError(f"Cannot apply monotonicity to image with shape {shape} at {center}") 

213 

214 def __call__(self, image: np.ndarray, center: tuple[int, int]) -> np.ndarray: 

215 """Make an input image monotonic about a center pixel 

216 

217 Parameters 

218 ---------- 

219 image: 

220 The image to make monotonic. 

221 center: 

222 The ``(y, x)`` location _in image coordinates_ to make the 

223 center of the monotonic region. 

224 

225 Returns 

226 ------- 

227 result: 

228 The input image is updated in place, but also returned from this 

229 method. 

230 """ 

231 # Check for a better center 

232 center = get_peak(image, center, self.fit_radius) 

233 

234 # Check that the operator can fit the image 

235 self.check_size(cast(tuple[int, int], image.shape), center, self.auto_update) 

236 

237 # Create the bounding box to slice the weights and distance as needed 

238 cy, cx = self.center 

239 py, px = center 

240 bbox = Box((9,) + image.shape, origin=(0, cy - py, cx - px)) 

241 weights = cast(np.ndarray, self.weights)[bbox.slices] 

242 indices = np.argsort(cast(np.ndarray, self.distance)[bbox.slices[1:]].flatten()) 

243 coords = np.unravel_index(indices, image.shape) 

244 

245 # Pad the image by 1 so that we don't have to worry about 

246 # weights on the edges. 

247 result_shape = (image.shape[0] + 2, image.shape[1] + 2) 

248 result = np.zeros(result_shape, dtype=image.dtype) 

249 result[1:-1, 1:-1] = image 

250 new_monotonicity(coords[0], coords[1], [w for w in weights], result) 

251 image[:] = result[1:-1, 1:-1] 

252 return image 

253 

254 

255def get_peak(image: np.ndarray, center: tuple[int, int], radius: int = 1) -> tuple[int, int]: 

256 """Search around a location for the maximum flux 

257 

258 For monotonicity it is important to start at the brightest pixel 

259 in the center of the source. This may be off by a pixel or two, 

260 so we search for the correct center before applying 

261 monotonic_tree. 

262 

263 Parameters 

264 ---------- 

265 image: 

266 The image of the source. 

267 center: 

268 The suggested center of the source. 

269 radius: 

270 The number of pixels around the `center` to search 

271 for a higher flux value. 

272 

273 Returns 

274 ------- 

275 new_center: 

276 The true center of the source. 

277 """ 

278 cy, cx = int(center[0]), int(center[1]) 

279 y0 = np.max([cy - radius, 0]) 

280 x0 = np.max([cx - radius, 0]) 

281 y_slice = slice(y0, cy + radius + 1) 

282 x_slice = slice(x0, cx + radius + 1) 

283 subset = image[y_slice, x_slice] 

284 center = cast(tuple[int, int], np.unravel_index(np.argmax(subset), subset.shape)) 

285 return center[0] + y0, center[1] + x0 

286 

287 

288def prox_monotonic_mask( 

289 x: np.ndarray, 

290 center: tuple[int, int], 

291 center_radius: int = 1, 

292 variance: float = 0.0, 

293 max_iter: int = 3, 

294) -> tuple[np.ndarray, np.ndarray, tuple[int, int, int, int]]: 

295 """Apply monotonicity from any path from the center 

296 

297 Parameters 

298 ---------- 

299 x: 

300 The input image that the mask is created for. 

301 center: 

302 The location of the center of the mask. 

303 center_radius: 

304 Radius from the center pixel to search for a better center 

305 (ie. a pixel in `X` with higher flux than the pixel given by 

306 `center`). 

307 If `center_radius == 0` then the `center` pixel is assumed 

308 to be correct. 

309 variance: 

310 The average variance in the image. 

311 This is used to allow pixels to be non-monotonic up to `variance`, 

312 so setting `variance=0` will force strict monotonicity in the mask. 

313 max_iter: 

314 Maximum number of iterations to interpolate non-monotonic pixels. 

315 

316 Returns 

317 ------- 

318 valid: 

319 Boolean array of pixels that are monotonic. 

320 model: 

321 The model with invalid pixels masked out. 

322 bounds: 

323 The bounds of the valid monotonic pixels. 

324 """ 

325 from lsst.scarlet.lite.operators_pybind11 import ( 

326 get_valid_monotonic_pixels, 

327 linear_interpolate_invalid_pixels, 

328 ) 

329 

330 if center_radius > 0: 

331 i, j = get_peak(x, center, center_radius) 

332 else: 

333 i, j = int(np.round(center[0])), int(np.round(center[1])) 

334 unchecked = np.ones(x.shape, dtype=bool) 

335 unchecked[i, j] = False 

336 orphans = np.zeros(x.shape, dtype=bool) 

337 # This is the bounding box of the result 

338 bounds = np.array([i, i, j, j], dtype=np.int32) 

339 # Get all of the monotonic pixels 

340 get_valid_monotonic_pixels(i, j, x, unchecked, orphans, variance, bounds, 0) 

341 # Set the initial model to the exact input in the valid pixels 

342 model = x.copy() 

343 

344 it = 0 

345 

346 while np.sum(orphans & unchecked) > 0 and it < max_iter: 

347 it += 1 

348 all_i, all_j = np.where(orphans) 

349 linear_interpolate_invalid_pixels(all_i, all_j, unchecked, model, orphans, variance, True, bounds) 

350 valid = ~unchecked & ~orphans 

351 # Clear all of the invalid pixels from the input image 

352 model = model * valid 

353 return valid, model, tuple(bounds) # type: ignore 

354 

355 

356def uncentered_operator( 

357 x: np.ndarray, 

358 func: Callable, 

359 center: tuple[int, int] | None = None, 

360 fill: float | None = None, 

361 **kwargs, 

362) -> np.ndarray: 

363 """Only apply the operator on a centered patch 

364 

365 In some cases, for example symmetry, an operator might not make 

366 sense outside of a centered box. This operator only updates 

367 the portion of `X` inside the centered region. 

368 

369 Parameters 

370 ---------- 

371 x: 

372 The parameter to update. 

373 func: 

374 The function (or operator) to apply to `x`. 

375 center: 

376 The location of the center of the sub-region to 

377 apply `func` to `x`. 

378 fill: 

379 The value to fill the region outside of centered 

380 `sub-region`, for example `0`. If `fill` is `None` 

381 then only the subregion is updated and the rest of 

382 `x` remains unchanged. 

383 

384 Returns 

385 ------- 

386 result: 

387 `x`, with an operator applied based on the shifted center. 

388 """ 

389 if center is None: 

390 py, px = cast(tuple[int, int], np.unravel_index(np.argmax(x), x.shape)) 

391 else: 

392 py, px = center 

393 cy, cx = np.array(x.shape) // 2 

394 

395 if py == cy and px == cx: 

396 return func(x, **kwargs) 

397 

398 dy = int(2 * (py - cy)) 

399 dx = int(2 * (px - cx)) 

400 if not x.shape[0] % 2: 

401 dy += 1 

402 if not x.shape[1] % 2: 

403 dx += 1 

404 if dx < 0: 

405 xslice = slice(None, dx) 

406 else: 

407 xslice = slice(dx, None) 

408 if dy < 0: 

409 yslice = slice(None, dy) 

410 else: 

411 yslice = slice(dy, None) 

412 

413 if fill is not None: 

414 _x = np.ones(x.shape, x.dtype) * fill 

415 _x[yslice, xslice] = func(x[yslice, xslice], **kwargs) 

416 x[:] = _x 

417 else: 

418 x[yslice, xslice] = func(x[yslice, xslice], **kwargs) 

419 

420 return x 

421 

422 

423def prox_sdss_symmetry(x: np.ndarray): 

424 """SDSS/HSC symmetry operator 

425 

426 This function uses the *minimum* of the two 

427 symmetric pixels in the update. 

428 

429 Parameters 

430 ---------- 

431 x: 

432 The array to make symmetric. 

433 

434 Returns 

435 ------- 

436 result: 

437 The updated `x`. 

438 """ 

439 symmetric = np.fliplr(np.flipud(x)) 

440 x[:] = np.min([x, symmetric], axis=0) 

441 return x 

442 

443 

444def prox_uncentered_symmetry( 

445 x: np.ndarray, 

446 center: tuple[int, int] | None = None, 

447 fill: float | None = None, 

448) -> np.ndarray: 

449 """Symmetry with off-center peak 

450 

451 Symmetrize X for all pixels with a symmetric partner. 

452 

453 Parameters 

454 ---------- 

455 x: 

456 The parameter to update. 

457 center: 

458 The center pixel coordinates to apply the symmetry operator. 

459 fill: 

460 The value to fill the region that cannot be made symmetric. 

461 When `fill` is `None` then the region of `X` that is not symmetric 

462 is not constrained. 

463 

464 Returns 

465 ------- 

466 result: 

467 The update function based on the specified parameters. 

468 """ 

469 return uncentered_operator(x, prox_sdss_symmetry, center, fill=fill)