Coverage for python / lsst / scarlet / lite / component.py: 31%

206 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:28 +0000

1# This file is part of scarlet_lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23from copy import deepcopy 

24 

25__all__ = [ 

26 "Component", 

27 "CubeComponent", 

28 "FactorizedComponent", 

29 "default_fista_parameterization", 

30 "default_adaprox_parameterization", 

31] 

32 

33from abc import ABC, abstractmethod 

34from functools import partial 

35from typing import TYPE_CHECKING, Any, Callable, cast 

36 

37import numpy as np 

38 

39from .bbox import Box 

40from .image import Image 

41from .operators import Monotonicity, prox_uncentered_symmetry 

42from .parameters import AdaproxParameter, FistaParameter, Parameter, parameter, relative_step 

43from .utils import convert_indices 

44 

45if TYPE_CHECKING: 

46 from .io import ScarletComponentBaseData, ScarletCubeComponentData 

47 

48import logging 

49 

50Logger = logging.getLogger(__name__) 

51 

52 

53class Component(ABC): 

54 """A base component in scarlet lite. 

55 

56 Parameters 

57 ---------- 

58 bands: 

59 The bands used when the component model is created. 

60 bbox: Box 

61 The bounding box for this component. 

62 """ 

63 

64 def __init__( 

65 self, 

66 bands: tuple, 

67 bbox: Box, 

68 ): 

69 self._bands = bands 

70 self._bbox = bbox 

71 

72 @property 

73 def bbox(self) -> Box: 

74 """The bounding box that contains the component in the full image""" 

75 return self._bbox 

76 

77 @property 

78 def bands(self) -> tuple: 

79 """The bands in the component model""" 

80 return self._bands 

81 

82 @abstractmethod 

83 def resize(self, model_box: Box) -> bool: 

84 """Test whether or not the component needs to be resized 

85 

86 This should be overriden in inherited classes and return `True` 

87 if the component needs to be resized. 

88 """ 

89 

90 @abstractmethod 

91 def update(self, it: int, input_grad: np.ndarray) -> None: 

92 """Update the component parameters from an input gradient 

93 

94 Parameters 

95 ---------- 

96 it: 

97 The current iteration of the optimizer. 

98 input_grad: 

99 Gradient of the likelihood wrt the component model 

100 """ 

101 

102 @abstractmethod 

103 def get_model(self) -> Image: 

104 """Generate a model for the component 

105 

106 This must be implemented in inherited classes. 

107 

108 Returns 

109 ------- 

110 model: Image 

111 The image of the component model. 

112 """ 

113 

114 @abstractmethod 

115 def parameterize(self, parameterization: Callable) -> None: 

116 """Convert the component parameter arrays into Parameter instances 

117 

118 Parameters 

119 ---------- 

120 parameterization: Callable 

121 A function to use to convert parameters of a given type into 

122 a `Parameter` in place. It should take a single argument that 

123 is the `Component` or `Source` that is to be parameterized. 

124 """ 

125 

126 @abstractmethod 

127 def to_data(self) -> ScarletComponentBaseData: 

128 """Convert the component to persistable ScarletComponentBaseData 

129 

130 Returns 

131 ------- 

132 component_data: ScarletComponentBaseData 

133 The data object containing the component information 

134 """ 

135 

136 @abstractmethod 

137 def __getitem__(self, indices: Any) -> Component: 

138 """Get a sub-component corresponding to the given indices. 

139 

140 Parameters 

141 ---------- 

142 indices: Any 

143 The indices to use to slice the component model. 

144 

145 Returns 

146 ------- 

147 sub_component: Component 

148 A new component that is a sub-component of this one. 

149 

150 Raises 

151 ------ 

152 IndexError : 

153 If the index includes a ``Box`` or spatial indices. 

154 """ 

155 

156 @abstractmethod 

157 def __copy__(self) -> Component: 

158 """Create a copy of this component. 

159 

160 Returns 

161 ------- 

162 component : Component 

163 A new component that is a copy of this one. 

164 """ 

165 

166 @abstractmethod 

167 def __deepcopy__(self, memo: dict[int, Any]) -> Component: 

168 """Create a deep copy of this component. 

169 

170 Returns 

171 ------- 

172 component : Component 

173 A new component that is a deep copy of this one. 

174 """ 

175 

176 def copy(self, deep: bool = False) -> Component: 

177 """Create a copy of this component. 

178 

179 Parameters 

180 ---------- 

181 deep : bool, optional 

182 If `True`, a deep copy is made. If `False`, a shallow copy is made. 

183 Default is `False`. 

184 

185 Returns 

186 ------- 

187 component : Component 

188 A new component that is a copy of this one. 

189 """ 

190 if deep: 

191 return self.__deepcopy__({}) 

192 return self.__copy__() 

193 

194 

195class FactorizedComponent(Component): 

196 """A component that can be factorized into spectrum and morphology 

197 parameters. 

198 

199 Parameters 

200 ---------- 

201 bands: 

202 The bands of the spectral dimension, in order. 

203 spectrum: 

204 The parameter to store and update the spectrum. 

205 morph: 

206 The parameter to store and update the morphology. 

207 peak: 

208 Location of the peak for the source. 

209 bbox: 

210 The `Box` in the `model_bbox` that contains the source. 

211 bg_rms: 

212 The RMS of the background used to threshold, grow, 

213 and shrink the component. 

214 bg_thresh: 

215 The threshold to use for the background RMS. 

216 If `None`, no background thresholding is applied, otherwise 

217 a sparsity constraint is applied to the morpholigy that 

218 requires flux in at least one band to be bg_thresh multiplied by 

219 `bg_rms` in that band. 

220 floor: 

221 Minimum value of the spectrum or center morphology pixel 

222 (depending on which is normalized). 

223 monotonicity: 

224 The monotonicity operator to use for making the source monotonic. 

225 If this parameter is `None`, the source will not be made monotonic. 

226 padding: 

227 The amount of padding to add to the component bounding box 

228 when resizing the component. 

229 is_symmetric: 

230 Whether the component is symmetric or not. 

231 If `True`, the morphology will be symmetrized using 

232 `prox_uncentered_symmetry`. 

233 If `False`, the morphology will not be symmetrized. 

234 """ 

235 

236 def __init__( 

237 self, 

238 bands: tuple, 

239 spectrum: Parameter | np.ndarray, 

240 morph: Parameter | np.ndarray, 

241 bbox: Box, 

242 peak: tuple[int, int] | None = None, 

243 bg_rms: np.ndarray | None = None, 

244 bg_thresh: float | None = 0.25, 

245 floor: float = 1e-20, 

246 monotonicity: Monotonicity | None = None, 

247 padding: int = 5, 

248 is_symmetric: bool = False, 

249 ): 

250 # Initialize all of the base attributes 

251 super().__init__( 

252 bands=bands, 

253 bbox=bbox, 

254 ) 

255 self._spectrum = parameter(spectrum) 

256 self._morph = parameter(morph) 

257 self._peak = peak 

258 self.bg_rms = bg_rms 

259 self.bg_thresh = bg_thresh 

260 

261 self.floor = floor 

262 self.monotonicity = monotonicity 

263 self.padding = padding 

264 self.is_symmetric = is_symmetric 

265 

266 @property 

267 def peak(self) -> tuple[int, int] | None: 

268 """The peak of the component 

269 

270 Returns 

271 ------- 

272 peak: 

273 The peak of the component 

274 """ 

275 return self._peak 

276 

277 @property 

278 def component_center(self) -> tuple[int, int] | None: 

279 """The center of the component in its bounding box 

280 

281 This is likely to be different than `Component.center`, 

282 since `Component.center` is the center of the component in the 

283 full model, whereas `component_center` is the center of the component 

284 inside its bounding box. 

285 

286 Returns 

287 ------- 

288 center: 

289 The center of the component in its bounding box 

290 """ 

291 _center = self.peak 

292 if _center is None: 

293 return None 

294 center = ( 

295 _center[0] - self.bbox.origin[-2], 

296 _center[1] - self.bbox.origin[-1], 

297 ) 

298 return center 

299 

300 @property 

301 def spectrum(self) -> np.ndarray: 

302 """The array of spectrum values""" 

303 return self._spectrum.x 

304 

305 @property 

306 def morph(self) -> np.ndarray: 

307 """The array of morphology values""" 

308 return self._morph.x 

309 

310 @property 

311 def shape(self) -> tuple: 

312 """Shape of the resulting model image""" 

313 return self.spectrum.shape + self.morph.shape 

314 

315 def get_model(self) -> Image: 

316 """Build the model from the spectrum and morphology""" 

317 # The spectrum and morph might be Parameters, 

318 # so cast them as arrays in the model. 

319 spectrum = self.spectrum 

320 morph = self.morph 

321 model = spectrum[:, None, None] * morph[None, :, :] 

322 return Image(model, bands=self.bands, yx0=cast(tuple[int, int], self.bbox.origin)) 

323 

324 def grad_spectrum(self, input_grad: np.ndarray, spectrum: np.ndarray, morph: np.ndarray): 

325 """Gradient of the spectrum wrt. the component model""" 

326 return np.einsum("...jk,jk", input_grad, morph) 

327 

328 def grad_morph(self, input_grad: np.ndarray, morph: np.ndarray, spectrum: np.ndarray): 

329 """Gradient of the morph wrt. the component model""" 

330 return np.einsum("i,i...", spectrum, input_grad) 

331 

332 def prox_spectrum(self, spectrum: np.ndarray) -> np.ndarray: 

333 """Apply a prox-like update to the spectrum""" 

334 # prevent divergent spectrum 

335 spectrum[spectrum < self.floor] = self.floor 

336 spectrum[~np.isfinite(spectrum)] = self.floor 

337 return spectrum 

338 

339 def prox_morph(self, morph: np.ndarray) -> np.ndarray: 

340 """Apply a prox-like update to the morphology""" 

341 # Get the peak position in the current bbox 

342 shape = morph.shape 

343 if self.peak is None: 

344 peak = (shape[0] // 2, shape[1] // 2) 

345 else: 

346 peak = ( 

347 self.peak[0] - self.bbox.origin[-2], 

348 self.peak[1] - self.bbox.origin[-1], 

349 ) 

350 

351 # monotonicity 

352 if self.monotonicity is not None: 

353 morph = self.monotonicity(morph, cast(tuple[int, int], self.component_center)) 

354 

355 # symmetry 

356 if self.is_symmetric: 

357 # Apply the symmetry operator 

358 morph = prox_uncentered_symmetry(morph, peak, fill=0.0) 

359 

360 if self.bg_thresh is not None and self.bg_rms is not None: 

361 bg_thresh = self.bg_rms * self.bg_thresh 

362 # Enforce background thresholding 

363 model = self.spectrum[:, None, None] * morph[None, :, :] 

364 morph[np.all(model < bg_thresh[:, None, None], axis=0)] = 0 

365 else: 

366 # enforce positivity 

367 morph[morph < 0] = 0 

368 

369 # prevent divergent morphology 

370 morph[peak] = np.max([morph[peak], self.floor]) 

371 

372 # Ensure that the morphology is finite 

373 morph[~np.isfinite(morph)] = 0 

374 

375 # Normalize the morphology 

376 max_value = np.max(morph) 

377 if max_value > 0: 

378 morph[:] = morph / max_value 

379 return morph 

380 

381 def resize(self, model_box: Box) -> bool: 

382 """Test whether or not the component needs to be resized""" 

383 # No need to resize if there is no size threshold. 

384 # To allow box sizing but no thresholding use `bg_thresh=0`. 

385 if self.bg_thresh is None or self.bg_rms is None: 

386 return False 

387 

388 model = self.spectrum[:, None, None] * self.morph[None, :, :] 

389 bg_thresh = self.bg_rms * self.bg_thresh 

390 significant = np.any(model >= bg_thresh[:, None, None], axis=0) 

391 if np.sum(significant) == 0: 

392 # There are no significant pixels, 

393 # so make a small box around the center 

394 center = self.peak 

395 if center is None: 

396 center = (0, 0) 

397 new_box = Box((1, 1), center).grow(self.padding) & model_box 

398 else: 

399 new_box = ( 

400 Box.from_data(significant, threshold=0).grow(self.padding) + self.bbox.origin # type: ignore 

401 ) & model_box 

402 if new_box == self.bbox: 

403 return False 

404 

405 old_box = self.bbox 

406 self._bbox = new_box 

407 self._morph.resize(old_box, new_box) 

408 return True 

409 

410 def update(self, it: int, input_grad: np.ndarray): 

411 """Update the spectrum and morphology parameters""" 

412 # Store the input spectrum so that the morphology can 

413 # have a consistent update 

414 spectrum = self.spectrum.copy() 

415 self._spectrum.update(it, input_grad, self.morph) 

416 self._morph.update(it, input_grad, spectrum) 

417 

418 def parameterize(self, parameterization: Callable) -> None: 

419 """Convert the component parameter arrays into Parameter instances 

420 

421 Parameters 

422 ---------- 

423 parameterization: Callable 

424 A function to use to convert parameters of a given type into 

425 a `Parameter` in place. It should take a single argument that 

426 is the `Component` or `Source` that is to be parameterized. 

427 """ 

428 # Update the spectrum and morph in place 

429 parameterization(self) 

430 # update the parameters 

431 self._spectrum.grad = self.grad_spectrum 

432 self._spectrum.prox = self.prox_spectrum 

433 self._morph.grad = self.grad_morph 

434 self._morph.prox = self.prox_morph 

435 

436 def to_data(self) -> ScarletComponentBaseData: 

437 """Convert the component to persistable ScarletComponentBaseData 

438 

439 Returns 

440 ------- 

441 component_data: ScarletComponentBaseData 

442 The data object containing the component information 

443 """ 

444 from .io import ScarletFactorizedComponentData 

445 

446 return ScarletFactorizedComponentData( 

447 origin=self.bbox.origin, # type: ignore 

448 peak=self.peak, # type: ignore 

449 spectrum=self.spectrum, 

450 morph=self.morph, 

451 ) 

452 

453 def __str__(self): 

454 result = ( 

455 f"FactorizedComponent<\n bands={self.bands},\n center={self.peak},\n " 

456 f"spectrum={self.spectrum},\n morph_shape={self.morph.shape}\n>" 

457 ) 

458 return result 

459 

460 def __repr__(self): 

461 return self.__str__() 

462 

463 def __getitem__(self, indices: Any) -> FactorizedComponent: 

464 """Get a sub-component corresponding to the given indices. 

465 

466 Parameters 

467 ---------- 

468 indices: Any 

469 The indices to use to slice the component model. 

470 

471 Returns 

472 ------- 

473 component: FactorizedComponent 

474 A new component that is a sub-component of this one. 

475 

476 Raises 

477 ------ 

478 IndexError : 

479 If the index includes a ``Box`` or spatial indices. 

480 """ 

481 # Convert the band indices into numerical indices 

482 band_indices = convert_indices(self.bands, indices) 

483 if isinstance(band_indices, slice): 

484 bands = self.bands[band_indices] 

485 else: 

486 bands = tuple(self.bands[i] for i in band_indices) 

487 

488 # Slice the spectrum 

489 spectrum = self._spectrum.x[band_indices,] 

490 

491 return FactorizedComponent( 

492 bands=bands, 

493 spectrum=spectrum, 

494 morph=self.morph, 

495 bbox=self.bbox, 

496 peak=self.peak, 

497 bg_rms=self.bg_rms, 

498 bg_thresh=self.bg_thresh, 

499 floor=self.floor, 

500 monotonicity=self.monotonicity, 

501 padding=self.padding, 

502 is_symmetric=self.is_symmetric, 

503 ) 

504 

505 def __deepcopy__(self, memo: dict[int, Any]) -> FactorizedComponent: 

506 """Create a deep copy of this component. 

507 

508 Parameters 

509 ---------- 

510 memo: dict[int, Any] 

511 The memoization dictionary used by `copy.deepcopy`. 

512 

513 Returns 

514 ------- 

515 component : FactorizedComponent 

516 A new component that is a deep copy of this one. 

517 """ 

518 # Check if already copied 

519 if id(self) in memo: 

520 return memo[id(self)] 

521 

522 # Create placeholder and add to memo FIRST 

523 component = FactorizedComponent.__new__(FactorizedComponent) 

524 memo[id(self)] = component 

525 

526 # Now safely initialize the placeholder with deepcopied arguments 

527 component.__init__( # type: ignore[misc] 

528 bands=deepcopy(self.bands, memo), 

529 spectrum=deepcopy(self.spectrum, memo), 

530 morph=deepcopy(self.morph, memo), 

531 bbox=deepcopy(self.bbox, memo), 

532 peak=deepcopy(self.peak, memo), 

533 bg_rms=deepcopy(self.bg_rms, memo), 

534 bg_thresh=self.bg_thresh, 

535 floor=self.floor, 

536 monotonicity=deepcopy(self.monotonicity, memo), 

537 padding=self.padding, 

538 is_symmetric=self.is_symmetric, 

539 ) 

540 return component 

541 

542 def __copy__(self) -> FactorizedComponent: 

543 """Create a copy of this component. 

544 

545 Returns 

546 ------- 

547 component : FactorizedComponent 

548 A new component that is a shallow copy of this one. 

549 """ 

550 return FactorizedComponent( 

551 bands=self.bands, 

552 spectrum=self.spectrum, 

553 morph=self.morph, 

554 bbox=self.bbox, 

555 peak=self.peak, 

556 bg_rms=self.bg_rms, 

557 bg_thresh=self.bg_thresh, 

558 floor=self.floor, 

559 monotonicity=self.monotonicity, 

560 padding=self.padding, 

561 is_symmetric=self.is_symmetric, 

562 ) 

563 

564 

565class CubeComponent(Component): 

566 """Dummy component for a component cube. 

567 

568 This is duck-typed to a `lsst.scarlet.lite.Component` in order to 

569 generate a model from the component but it is currently not functional 

570 in that it cannot be optimized, only persisted and loaded. 

571 

572 If scarlet lite ever implements a component as a data cube, 

573 this class can be removed. 

574 """ 

575 

576 def __init__(self, model: Image, peak: tuple[int, int]): 

577 """Initialization 

578 

579 Parameters 

580 ---------- 

581 bands : 

582 model : 

583 The 3D (bands, y, x) model of the component. 

584 peak : 

585 The `(y, x)` peak of the component. 

586 bbox : 

587 The bounding box of the component. 

588 """ 

589 super().__init__(model.bands, model.bbox) 

590 self._model = model 

591 self.peak = peak 

592 

593 def get_model(self) -> Image: 

594 """Generate the model for the source 

595 

596 Returns 

597 ------- 

598 model : 

599 The model as a 3D `(band, y, x)` array. 

600 """ 

601 return self._model 

602 

603 def resize(self, model_box: Box) -> bool: 

604 """Resize the component if needed and return whether it was resized""" 

605 Logger.warning("CubeComponent does not support resizing") 

606 return False 

607 

608 def update(self, it: int, input_grad: np.ndarray) -> None: 

609 """Implementation of unused abstract method""" 

610 Logger.warning("CubeComponent does not support updates") 

611 

612 def parameterize(self, parameterization: Callable) -> None: 

613 """Implementation of unused abstract method""" 

614 Logger.warning("CubeComponent does not support parameterization") 

615 

616 def to_data(self) -> ScarletCubeComponentData: 

617 """Convert the component to persistable ScarletComponentData 

618 

619 Returns 

620 ------- 

621 component_data: ScarletComponentData 

622 The data object containing the component information 

623 """ 

624 from .io import ScarletCubeComponentData 

625 

626 return ScarletCubeComponentData( 

627 origin=self.bbox.origin, # type: ignore 

628 peak=self.peak, # type: ignore 

629 model=self.get_model().data, 

630 ) 

631 

632 def __getitem__(self, indices: Any) -> CubeComponent: 

633 """Get a sub-component corresponding to the given indices. 

634 

635 Parameters 

636 ---------- 

637 indices : 

638 The indices to select. 

639 Returns 

640 ------- 

641 sub_component : 

642 A new component that is a sub-component of this one. 

643 """ 

644 band_indices = convert_indices(self.bands, indices) 

645 if isinstance(band_indices, slice): 

646 bands = self.bands[band_indices] 

647 else: 

648 bands = tuple(self.bands[i] for i in band_indices) 

649 

650 data = self.get_model()._data[band_indices,] 

651 model = Image(data=data, bands=bands, yx0=cast(tuple[int, int], self.bbox.origin)) 

652 return CubeComponent(model=model, peak=self.peak) 

653 

654 def __copy__(self) -> CubeComponent: 

655 """Create a copy of this component. 

656 

657 Returns 

658 ------- 

659 component : ComponentCube 

660 A new component that is a shallow copy of this one. 

661 """ 

662 return CubeComponent(model=self._model, peak=self.peak) 

663 

664 def __deepcopy__(self, memo: dict[int, Any]) -> CubeComponent: 

665 """Create a deep copy of this component. 

666 

667 Parameters 

668 ---------- 

669 memo: dict[int, Any] 

670 The memoization dictionary used by `copy.deepcopy`. 

671 

672 Returns 

673 ------- 

674 component : ComponentCube 

675 A new component that is a deep copy of this one. 

676 """ 

677 if id(self) in memo: 

678 return memo[id(self)] 

679 

680 # Create placeholder and add to memo FIRST 

681 component = CubeComponent.__new__(CubeComponent) 

682 memo[id(self)] = component 

683 

684 # Now safely initialize the placeholder with deepcopied arguments 

685 component.__init__( # type: ignore[misc] 

686 model=self._model.copy(), 

687 peak=self.peak, 

688 ) 

689 return component 

690 

691 

692def default_fista_parameterization(component: Component): 

693 """Initialize a factorized component to use FISTA PGM for optimization""" 

694 if isinstance(component, FactorizedComponent): 

695 component._spectrum = FistaParameter(component.spectrum, step=0.5) 

696 component._morph = FistaParameter(component.morph, step=0.5) 

697 else: 

698 raise NotImplementedError(f"Unrecognized component type {component}") 

699 

700 

701def default_adaprox_parameterization(component: Component, noise_rms: float | None = None): 

702 """Initialize a factorized component to use Proximal ADAM 

703 for optimization 

704 """ 

705 if noise_rms is None: 

706 noise_rms = 1e-16 

707 if isinstance(component, FactorizedComponent): 

708 component._spectrum = AdaproxParameter( 

709 component.spectrum, 

710 step=partial(relative_step, factor=1e-2, minimum=noise_rms), 

711 ) 

712 component._morph = AdaproxParameter( 

713 component.morph, 

714 step=1e-2, 

715 ) 

716 else: 

717 raise NotImplementedError(f"Unrecognized component type {component}")