Coverage for python/lsst/scarlet/lite/initialization.py: 12%

194 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-02-07 11:26 +0000

1# This file is part of scarlet_lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import logging 

23from abc import ABC, abstractmethod 

24from typing import Sequence, cast 

25 

26import numpy as np 

27 

28from .bbox import Box 

29from .component import FactorizedComponent 

30from .detect import bounds_to_bbox, get_detect_wavelets 

31from .image import Image 

32from .measure import calculate_snr 

33from .observation import Observation 

34from .operators import Monotonicity, prox_monotonic_mask, prox_uncentered_symmetry 

35from .source import Source 

36 

37logger = logging.getLogger("scarlet.lite.initialization") 

38 

39 

40def trim_morphology( 

41 morph: np.ndarray, 

42 bg_thresh: float = 0, 

43 padding: int = 5, 

44) -> tuple[np.ndarray, Box]: 

45 """Trim the morphology up to pixels above a threshold 

46 

47 Parameters 

48 ---------- 

49 morph: 

50 The morphology to be trimmed. 

51 bg_thresh: 

52 The morphology is trimmed to pixels above the threshold. 

53 padding: 

54 The amount to pad each side to allow the source to grow. 

55 

56 Returns 

57 ------- 

58 morph: 

59 The trimmed morphology 

60 box: 

61 The box that contains the morphology. 

62 """ 

63 # trim morph to pixels above threshold 

64 mask = morph > bg_thresh 

65 morph[~mask] = 0 

66 bbox = Box.from_data(morph, threshold=0).grow(padding) 

67 return morph, bbox 

68 

69 

70def init_monotonic_morph( 

71 detect: np.ndarray, 

72 center: tuple[int, int], 

73 full_box: Box, 

74 padding: int = 5, 

75 normalize: bool = True, 

76 monotonicity: Monotonicity | None = None, 

77 thresh: float = 0, 

78) -> tuple[Box, np.ndarray | None]: 

79 """Initialize a morphology for a monotonic source 

80 

81 Parameters 

82 ---------- 

83 detect: 

84 The 2D detection image contained in `full_box`. 

85 center: 

86 The center of the monotonic source. 

87 full_box: 

88 The bounding box of `detect`. 

89 padding: 

90 The number of pixels to grow the morphology in each direction. 

91 This can be useful if initializing a source with a kernel that 

92 is known to be narrower than the expected value of the source. 

93 normalize: 

94 Whether or not to normalize the morphology. 

95 monotonicity: 

96 When `monotonicity` is `None`, 

97 the component is initialized with only the 

98 monotonic pixels, otherwise the monotonicity operator is used to 

99 project the morphology to a monotonic solution. 

100 thresh: 

101 The threshold (fraction above the background) to use for trimming the 

102 morphology. 

103 

104 Returns 

105 ------- 

106 bbox: 

107 The bounding box of the morphology. 

108 morph: 

109 The initialized morphology. 

110 """ 

111 center: tuple[int, int] = tuple(center[i] - full_box.origin[i] for i in range(2)) # type: ignore 

112 

113 if monotonicity is None: 

114 _, morph, bounds = prox_monotonic_mask(detect, center, max_iter=0) 

115 bbox = bounds_to_bbox(bounds) 

116 if bbox.shape == (1, 1) and morph[bbox.slices][0, 0] == 0: 

117 return Box((0, 0)), None 

118 

119 if thresh > 0: 

120 morph, bbox = trim_morphology(morph, bg_thresh=thresh, padding=padding) 

121 

122 # Shift the bounding box to account for the non-zero origin 

123 bbox += full_box.origin 

124 

125 else: 

126 morph = monotonicity(detect, center) 

127 

128 # truncate morph at thresh * bg_rms 

129 morph, bbox = trim_morphology(morph, bg_thresh=thresh, padding=padding) 

130 # Shift the bounding box to account for the non-zero origin 

131 bbox += full_box.origin 

132 

133 if np.max(morph) == 0: 

134 return Box((0, 0), origin=full_box.origin), None 

135 

136 if normalize: 

137 morph /= np.max(morph) 

138 

139 if padding is not None and padding > 0: 

140 # Pad the morphology to allow it to grow 

141 bbox = bbox.grow(padding) 

142 

143 # Ensure that the bounding box is inside the full box, 

144 # even after padding. 

145 bbox = bbox & full_box 

146 return bbox, morph 

147 

148 

149def multifit_spectra( 

150 observation: Observation, 

151 morphs: Sequence[Image], 

152 model: Image | None = None, 

153) -> np.ndarray: 

154 """Fit the spectra of multiple components simultaneously 

155 

156 Parameters 

157 ---------- 

158 observation: 

159 The class containing the observation data. 

160 morphs: 

161 The morphology of each component. 

162 model: 

163 An optional model for sources that are not factorized, 

164 and thus will not have their spectra fit. 

165 This model is subtracted from the data before fitting the other 

166 spectra. 

167 

168 Returns 

169 ------- 

170 spectra: 

171 The spectrum for each component, in the same order as `morphs`. 

172 """ 

173 _bands = observation.bands 

174 n_bands = len(_bands) 

175 dtype = observation.images.dtype 

176 

177 if model is not None: 

178 image = observation.images - model 

179 else: 

180 image = observation.images.copy() 

181 

182 morph_images = np.zeros((n_bands, len(morphs), image.data[0].size), dtype=dtype) 

183 for idx, morph in enumerate(morphs): 

184 _image = morph.repeat(observation.bands) 

185 _image = Image.from_box(image.bbox, bands=image.bands).insert(_image) 

186 morph_images[:, idx] = observation.convolve(_image).data.reshape(n_bands, -1) 

187 

188 spectra = np.zeros((len(morphs), n_bands), dtype=dtype) 

189 

190 for b in range(n_bands): 

191 a = np.vstack(morph_images[b]).T 

192 spectra[:, b] = np.linalg.lstsq(a, image[observation.bands[b]].data.flatten(), rcond=None)[0] 

193 spectra[spectra < 0] = 0 

194 return spectra 

195 

196 

197class FactorizedInitialization(ABC): 

198 """Common variables and methods for both Factorized Component schemes 

199 

200 Parameters 

201 ---------- 

202 observation: 

203 The observation containing the blend 

204 centers: 

205 The center of each source to initialize. 

206 min_snr: 

207 The minimum SNR required per component. 

208 So a 2-component source requires at least `2*min_snr` while sources 

209 with SNR < `min_snr` will be initialized with the PSF. 

210 monotonicity: 

211 When `monotonicity` is `None`, 

212 the component is initialized with only the 

213 monotonic pixels, otherwise the monotonicity operator is used to 

214 project the morphology to a monotonic solution. 

215 use_sparse_init: 

216 Use a monotonic mask to prevent initial source models from growing 

217 too large. 

218 """ 

219 

220 def __init__( 

221 self, 

222 observation: Observation, 

223 convolved: Image, 

224 centers: Sequence[tuple[int, int]], 

225 min_snr: float = 50, 

226 monotonicity: Monotonicity | None = None, 

227 use_sparse_init: bool = True, 

228 ): 

229 self.observation = observation 

230 self.convolved = convolved 

231 self.centers = centers 

232 self.min_snr = min_snr 

233 self.monotonicity = monotonicity 

234 self.use_sparse_init = use_sparse_init 

235 

236 # Get the model PSF 

237 # Convolve the PSF in order to set the spectrum 

238 # of a point source correctly. 

239 model_psf = Image(cast(np.ndarray, observation.model_psf)[0]) 

240 convolved = model_psf.repeat(observation.bands) 

241 self.convolved_psf = observation.convolve(convolved, mode="real").data 

242 # Get the "spectrum" of the PSF 

243 self.py = model_psf.shape[0] // 2 

244 self.px = model_psf.shape[1] // 2 

245 self.psf_spectrum = self.convolved_psf[:, self.py, self.px] 

246 

247 # Initalize all of the sources 

248 sources = [] 

249 for center in centers: 

250 source = self.init_source((int(center[0]), int(center[1]))) 

251 sources.append(source) 

252 self.sources = sources 

253 

254 def get_snr(self, center: tuple[int, int]) -> float: 

255 """Get the SNR at the center of a component 

256 

257 Parameters 

258 ---------- 

259 center: 

260 The location of the center of the source. 

261 

262 Returns 

263 ------- 

264 result: 

265 The SNR at the center of the component. 

266 """ 

267 snr = np.floor( 

268 calculate_snr( 

269 self.observation.images, 

270 self.observation.variance, 

271 self.observation.psfs, 

272 center, 

273 ) 

274 ) 

275 return snr / self.min_snr 

276 

277 def get_psf_component(self, center: tuple[int, int]) -> FactorizedComponent: 

278 """Create a factorized component with a PSF morphology 

279 

280 Parameters 

281 ---------- 

282 center: 

283 The center of the component. 

284 

285 Returns 

286 ------- 

287 component: 

288 A `FactorizedComponent` with a PSF-like morphology. 

289 """ 

290 local_center = ( 

291 center[0] - self.observation.bbox.origin[0], 

292 center[1] - self.observation.bbox.origin[1], 

293 ) 

294 # There wasn't sufficient flux for an extended source, 

295 # so create a PSF source. 

296 spectrum_center = (slice(None), local_center[0], local_center[1]) 

297 spectrum = self.observation.images.data[spectrum_center] / self.psf_spectrum 

298 spectrum[spectrum < 0] = 0 

299 

300 psf = cast(np.ndarray, self.observation.model_psf)[0].copy() 

301 py = psf.shape[0] // 2 

302 px = psf.shape[1] // 2 

303 bbox = Box(psf.shape, origin=(-py + center[0], -px + center[1])) 

304 bbox = self.observation.bbox & bbox 

305 morph = Image(psf, yx0=cast(tuple[int, int], bbox.origin))[bbox].data 

306 component = FactorizedComponent( 

307 self.observation.bands, 

308 spectrum, 

309 morph, 

310 bbox, 

311 center, 

312 self.observation.noise_rms, 

313 monotonicity=self.monotonicity, 

314 ) 

315 return component 

316 

317 def get_single_component( 

318 self, 

319 center: tuple[int, int], 

320 detect: np.ndarray, 

321 thresh: float, 

322 padding: int, 

323 ) -> FactorizedComponent | None: 

324 """Initialize parameters for a `FactorizedComponent` 

325 

326 Parameters 

327 ---------- 

328 center: 

329 The location of the center of the source to detect in the 

330 full image. 

331 detect: 

332 The image used for detection of the morphology. 

333 thresh: 

334 The lower cutoff threshold to use for the morphology. 

335 padding: 

336 The amount to pad the morphology to allow for extra flux 

337 in the first few iterations before resizing. 

338 

339 Returns 

340 ------- 

341 component: 

342 A `FactorizedComponent` created from the detection image. 

343 

344 """ 

345 local_center = ( 

346 center[0] - self.observation.bbox.origin[0], 

347 center[1] - self.observation.bbox.origin[1], 

348 ) 

349 

350 if self.use_sparse_init: 

351 monotonicity = None 

352 else: 

353 monotonicity = self.monotonicity 

354 bbox, morph = init_monotonic_morph( 

355 detect, 

356 center, 

357 self.observation.bbox, 

358 padding=padding, 

359 normalize=False, 

360 monotonicity=monotonicity, 

361 thresh=thresh, 

362 ) 

363 

364 if morph is None: 

365 return None 

366 morph = morph[(bbox - self.observation.bbox.origin).slices] 

367 

368 spectrum_center = (slice(None), local_center[0], local_center[1]) 

369 images = self.observation.images 

370 

371 convolved = self.convolved 

372 spectrum = images.data[spectrum_center] / convolved.data[spectrum_center] 

373 spectrum[spectrum < 0] = 0 

374 morph_max = np.max(morph) 

375 spectrum *= morph_max 

376 morph /= morph_max 

377 

378 return FactorizedComponent( 

379 self.observation.bands, 

380 spectrum, 

381 morph, 

382 bbox, 

383 center, 

384 self.observation.noise_rms, 

385 monotonicity=self.monotonicity, 

386 ) 

387 

388 @abstractmethod 

389 def init_source(self, center: tuple[int, int]) -> Source | None: 

390 """Initialize a source 

391 

392 Parameters 

393 ---------- 

394 center: 

395 The center of the source. 

396 """ 

397 

398 

399class FactorizedChi2Initialization(FactorizedInitialization): 

400 """Initialize all sources with chi^2 detections 

401 

402 There are a large number of parameters that are universal for all of the 

403 sources being initialized from the same set of observed images. 

404 To simplify the API those parameters are all initialized by this class 

405 and passed to `init_main_source` for each source. 

406 It also creates temporary objects that only need to be created once for 

407 all of the sources in a blend. 

408 

409 Parameters 

410 ---------- 

411 observation: 

412 The observation containing the blend 

413 centers: 

414 The center of each source to initialize. 

415 detect: 

416 The array that contains a 2D image used for detection. 

417 min_snr: 

418 The minimum SNR required per component. 

419 So a 2-component source requires at least `2*min_snr` while sources 

420 with SNR < `min_snr` will be initialized with the PSF. 

421 monotonicity: 

422 When `monotonicity` is `None`, 

423 the component is initialized with only the 

424 monotonic pixels, otherwise the monotonicity operator is used to 

425 project the morphology to a monotonic solution. 

426 disk_percentile: 

427 The percentage of the overall flux to attribute to the disk. 

428 thresh: 

429 The threshold used to trim the morphology, 

430 so all pixels below `thresh * bg_rms` are set to zero. 

431 padding: 

432 The amount to pad the morphology to allow for extra flux 

433 in the first few iterations before resizing. 

434 """ 

435 

436 def __init__( 

437 self, 

438 observation: Observation, 

439 centers: Sequence[tuple[int, int]], 

440 detect: np.ndarray | None = None, 

441 min_snr: float = 50, 

442 monotonicity: Monotonicity | None = None, 

443 disk_percentile: float = 25, 

444 thresh: float = 0.5, 

445 padding: int = 2, 

446 ): 

447 if detect is None: 

448 # Build the morphology detection image 

449 detect = np.sum( 

450 observation.images.data / (observation.noise_rms**2)[:, None, None], 

451 axis=0, 

452 ) 

453 self.detect = detect 

454 _detect = Image(detect) 

455 # Convolve the detection image. 

456 # This may seem counter-intuitive, 

457 # since this is effectively growing the model, 

458 # but this is exactly what convolution will do to the model 

459 # in each iteration. 

460 # So we create the convolved model in order 

461 # to correctly set the spectrum. 

462 convolved = observation.convolve(_detect.repeat(observation.bands), mode="real") 

463 

464 # Set the input parameters 

465 self.disk_percentile = disk_percentile 

466 self.thresh = thresh 

467 self.padding = padding 

468 

469 # Initialize the sources 

470 super().__init__(observation, convolved, centers, min_snr, monotonicity) 

471 

472 def init_source(self, center: tuple[int, int]) -> Source | None: 

473 """Initialize a source from a chi^2 detection. 

474 

475 Parameter 

476 --------- 

477 center: 

478 The center of the source. 

479 init: 

480 The initialization parameters common to all of the sources. 

481 max_components: 

482 The maximum number of components in the source. 

483 """ 

484 # Some operators need the local center, not center in the full image 

485 local_center = ( 

486 center[0] - self.observation.bbox.origin[0], 

487 center[1] - self.observation.bbox.origin[1], 

488 ) 

489 

490 # Calculate the signal to noise at the center of this source 

491 component_snr = self.get_snr(center) 

492 

493 # Initialize the bbox, morph, and spectrum 

494 # for a single component source 

495 detect = prox_uncentered_symmetry(self.detect.copy(), local_center, fill=0) 

496 thresh = np.mean(self.observation.noise_rms) * self.thresh 

497 component = self.get_single_component(center, detect, thresh, self.padding) 

498 

499 if component is None: 

500 components = [self.get_psf_component(center)] 

501 elif component_snr < 2: 

502 components = [component] 

503 else: 

504 # There was enough flux for a 2-component source, 

505 # so split the single component model into two components, 

506 # using the same algorithm as scarlet main. 

507 bulge_morph = component.morph.copy() 

508 disk_morph = component.morph 

509 # Set the threshold for the bulge. 

510 # Since the morphology is monotonic, this selects the inner 

511 # of the single component morphology and assigns it to the bulge. 

512 flux_thresh = self.disk_percentile / 100 

513 mask = disk_morph > flux_thresh 

514 # Remove the flux above the threshold so that the disk will have 

515 # a flat center. 

516 disk_morph[mask] = flux_thresh 

517 # Subtract off the thresholded flux (since we're normalizing the 

518 # morphology anyway) so that it does not have a sharp 

519 # discontinuity at the edge. 

520 bulge_morph -= flux_thresh 

521 bulge_morph[bulge_morph < 0] = 0 

522 

523 bulge_morph /= np.max(bulge_morph) 

524 disk_morph /= np.max(disk_morph) 

525 

526 # Fit the spectra assuming that all of the flux in the image 

527 # is due to both components. This is not true, but for the 

528 # vast majority of sources this is a good approximation. 

529 bulge_spectrum, disk_spectrum = multifit_spectra( 

530 self.observation, 

531 [ 

532 Image(bulge_morph, yx0=cast(tuple[int, int], component.bbox.origin)), 

533 Image(disk_morph, yx0=cast(tuple[int, int], component.bbox.origin)), 

534 ], 

535 ) 

536 

537 components = [ 

538 FactorizedComponent( 

539 self.observation.bands, 

540 bulge_spectrum, 

541 bulge_morph, 

542 component.bbox.copy(), 

543 center, 

544 self.observation.noise_rms, 

545 monotonicity=self.monotonicity, 

546 ), 

547 FactorizedComponent( 

548 self.observation.bands, 

549 disk_spectrum, 

550 disk_morph, 

551 component.bbox.copy(), 

552 center, 

553 self.observation.noise_rms, 

554 monotonicity=self.monotonicity, 

555 ), 

556 ] 

557 

558 return Source(components) # type: ignore 

559 

560 

561class FactorizedWaveletInitialization(FactorizedInitialization): 

562 """Parameters used to initialize all sources with wavelet detections 

563 

564 There are a large number of parameters that are universal for all of the 

565 sources being initialized from the same set of wavelet coefficients. 

566 To simplify the API those parameters are all initialized by this class 

567 and passed to `init_wavelet_source` for each source. 

568 

569 Parameters 

570 ---------- 

571 observation: 

572 The multiband observation of the blend. 

573 centers: 

574 The center of each source to initialize. 

575 bulge_slice, disk_slice: 

576 The slice used to select the wavelet scales used for the 

577 bulge/disk. 

578 bulge_padding, disk_padding: 

579 The number of pixels to grow the bounding box of the bulge/disk 

580 to leave extra room for growth in the first few iterations. 

581 use_psf: 

582 Whether or not to use the PSF for single component sources. 

583 If `use_psf` is `False` then only sources with low signal 

584 at all scales are initialized with the PSF morphology. 

585 scales: 

586 Number of wavelet scales to use. 

587 wavelets: 

588 The array of wavelet coefficients `(scale, y, x)` 

589 used for detection. 

590 monotonicity: 

591 When `monotonicity` is `None`, 

592 the component is initialized with only the 

593 monotonic pixels, otherwise the monotonicity operator is used to 

594 project the morphology to a monotonic solution. 

595 min_snr: 

596 The minimum SNR required per component. 

597 So a 2-component source requires at least `2*min_snr` while sources 

598 with SNR < `min_snr` will be initialized with the PSF. 

599 """ 

600 

601 def __init__( 

602 self, 

603 observation: Observation, 

604 centers: Sequence[tuple[int, int]], 

605 bulge_slice: slice = slice(None, 2), 

606 disk_slice: slice = slice(2, -1), 

607 bulge_padding: int = 5, 

608 disk_padding: int = 5, 

609 use_psf: bool = True, 

610 scales: int = 5, 

611 wavelets: np.ndarray | None = None, 

612 monotonicity: Monotonicity | None = None, 

613 min_snr: float = 50, 

614 ): 

615 if wavelets is None: 

616 wavelets = get_detect_wavelets( 

617 observation.images.data, 

618 observation.variance.data, 

619 scales=scales, 

620 ) 

621 wavelets[wavelets < 0] = 0 

622 # The detection coadd for single component sources 

623 detectlets = np.sum(wavelets[:-1], axis=0) 

624 # The detection coadd for the bulge 

625 bulgelets = np.sum(wavelets[bulge_slice], axis=0) 

626 # The detection coadd for the disk 

627 disklets = np.sum(wavelets[disk_slice], axis=0) 

628 

629 # The convolved image, used to initialize the spectrum 

630 detect = Image(detectlets) 

631 convolved = observation.convolve(detect.repeat(observation.bands), mode="real") 

632 

633 self.detectlets = detectlets 

634 self.bulgelets = bulgelets 

635 self.disklets = disklets 

636 self.bulge_grow = bulge_padding 

637 self.disk_grow = disk_padding 

638 self.use_psf = use_psf 

639 

640 # Initialize the sources 

641 super().__init__(observation, convolved, centers, min_snr, monotonicity) 

642 

643 def init_source(self, center: tuple[int, int]) -> Source | None: 

644 """Initialize a source from a chi^2 detection. 

645 

646 Parameter 

647 --------- 

648 center: 

649 The center of the source. 

650 """ 

651 local_center = ( 

652 center[0] - self.observation.bbox.origin[0], 

653 center[1] - self.observation.bbox.origin[1], 

654 ) 

655 nbr_components = self.get_snr(center) 

656 observation = self.observation 

657 

658 if (nbr_components < 1 and self.use_psf) or self.detectlets[local_center[0], local_center[1]] <= 0: 

659 # Initialize the source as an PSF source 

660 components = [self.get_psf_component(center)] 

661 elif nbr_components < 2: 

662 # Inititialize with a single component 

663 component = self.get_single_component(center, self.detectlets, 0, self.disk_grow) 

664 if component is not None: 

665 components = [component] 

666 else: 

667 # Initialize with a 2 component model 

668 bulge_box, bulge_morph = init_monotonic_morph( 

669 self.bulgelets, center, observation.bbox, self.bulge_grow 

670 ) 

671 disk_box, disk_morph = init_monotonic_morph( 

672 self.disklets, center, observation.bbox, self.disk_grow 

673 ) 

674 if bulge_morph is None or disk_morph is None: 

675 if bulge_morph is None: 

676 if disk_morph is None: 

677 return None 

678 # One of the components was null, 

679 # so initialize as a single component 

680 component = self.get_single_component(center, self.detectlets, 0, self.disk_grow) 

681 if component is not None: 

682 components = [component] 

683 else: 

684 local_bulge_box = bulge_box - self.observation.bbox.origin 

685 local_disk_box = disk_box - self.observation.bbox.origin 

686 bulge_morph = bulge_morph[local_bulge_box.slices] 

687 disk_morph = disk_morph[local_disk_box.slices] 

688 

689 bulge_spectrum, disk_spectrum = multifit_spectra( 

690 observation, 

691 [ 

692 Image(bulge_morph, yx0=cast(tuple[int, int], bulge_box.origin)), 

693 Image(disk_morph, yx0=cast(tuple[int, int], disk_box.origin)), 

694 ], 

695 ) 

696 

697 components = [] 

698 if np.sum(bulge_spectrum != 0): 

699 components.append( 

700 FactorizedComponent( 

701 observation.bands, 

702 bulge_spectrum, 

703 bulge_morph, 

704 bulge_box, 

705 center, 

706 monotonicity=self.monotonicity, 

707 ) 

708 ) 

709 else: 

710 logger.debug("cut bulge") 

711 if np.sum(disk_spectrum) != 0: 

712 components.append( 

713 FactorizedComponent( 

714 observation.bands, 

715 disk_spectrum, 

716 disk_morph, 

717 disk_box, 

718 center, 

719 monotonicity=self.monotonicity, 

720 ) 

721 ) 

722 else: 

723 logger.debug("cut disk") 

724 return Source(components) # type: ignore