Coverage for python / lsst / scarlet / lite / blend.py: 20%

172 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-26 08:40 +0000

1# This file is part of scarlet_lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["Blend"] 

25 

26from abc import ABC, abstractmethod 

27from copy import deepcopy 

28from typing import TYPE_CHECKING, Any, Callable, Self, Sequence, cast 

29 

30import numpy as np 

31 

32from .bbox import Box 

33from .component import Component, FactorizedComponent 

34from .image import Image 

35from .observation import Observation 

36from .source import Source, SourceBase 

37 

38if TYPE_CHECKING: 

39 from .io import ScarletBlendData, ScarletSourceBaseData 

40 

41 

42class BlendBase(ABC): 

43 """A base class for blends that can be extended to add additional 

44 functionality. 

45 

46 This class holds all of the sources and observation that are to be fit, 

47 as well as performing fitting and joint initialization of the 

48 spectral components (when applicable). 

49 

50 Parameters 

51 ---------- 

52 sources: 

53 The sources to fit. 

54 observation: 

55 The observation that contains the images, 

56 PSF, etc. that are being fit. 

57 metadata: 

58 Additional metadata to store with the blend. 

59 """ 

60 

61 sources: Sequence[SourceBase] 

62 observation: Observation 

63 metadata: dict | None 

64 

65 @property 

66 def shape(self) -> tuple[int, int, int]: 

67 """Shape of the model for the entire `Blend`.""" 

68 return self.observation.shape 

69 

70 @property 

71 def bbox(self) -> Box: 

72 """The bounding box of the entire blend.""" 

73 return self.observation.bbox 

74 

75 @property 

76 def components(self) -> list[Component]: 

77 """The list of all components in the blend. 

78 

79 Since the list of sources might change, 

80 this is always built on the fly. 

81 """ 

82 return [c for src in self.sources for c in src.components] 

83 

84 @abstractmethod 

85 def __getitem__(self, indices: Any) -> Self: 

86 """Get a sub-blend corresponding to the given indices. 

87 

88 Parameters 

89 ---------- 

90 indices : 

91 The indices to use to slice the blend. 

92 

93 Returns 

94 ------- 

95 sub_blend : 

96 A new `BlendBase` instance containing only data from the 

97 specified bands in the specified order. 

98 

99 Raises 

100 ------ 

101 IndexError : 

102 If the indices contain bands not included in the original 

103 blend or any spatial indices are given. 

104 """ 

105 

106 @abstractmethod 

107 def __copy__(self) -> Self: 

108 """Create a copy of this blend. 

109 

110 Returns 

111 ------- 

112 blend : BlendBase 

113 A new blend that is a copy of this one. 

114 """ 

115 

116 @abstractmethod 

117 def __deepcopy__(self, memo: dict[int, Any]) -> Self: 

118 """Create a deep copy of this blend. 

119 

120 Parameters 

121 ---------- 

122 memo : dict[int, Any] 

123 A memoization dictionary used by `copy.deepcopy`. 

124 

125 Returns 

126 ------- 

127 blend : BlendBase 

128 A new blend that is a deep copy of this one. 

129 """ 

130 

131 def copy(self, deep: bool = False) -> Self: 

132 """Create a copy of this blend. 

133 

134 Parameters 

135 ---------- 

136 deep : 

137 If `True`, a deep copy is made. If `False`, a shallow copy is made. 

138 Default is `False`. 

139 

140 Returns 

141 ------- 

142 blend : Self 

143 A new blend that is a copy of this one. 

144 """ 

145 if deep: 

146 return self.__deepcopy__({}) 

147 else: 

148 return self.__copy__() 

149 

150 @abstractmethod 

151 def get_model(self, convolve: bool = False, use_flux: bool = False) -> Image: 

152 """Generate a model of the entire blend. 

153 

154 Parameters 

155 ---------- 

156 convolve: 

157 Whether to convolve the model with the observed PSF in each band. 

158 use_flux: 

159 Whether to use the re-distributed flux associated with the sources 

160 instead of the component models. 

161 

162 Returns 

163 ------- 

164 model: 

165 The model created by combining all of the source models. 

166 """ 

167 

168 @abstractmethod 

169 def to_data(self) -> ScarletBlendData: 

170 """Convert the blend into a serializable dictionary format. 

171 

172 Returns 

173 ------- 

174 data: 

175 A dictionary containing all of the information needed to 

176 reconstruct the blend. 

177 """ 

178 

179 

180class Blend(BlendBase): 

181 """A single blend. 

182 

183 This class holds all of the sources and observation that are to be fit, 

184 as well as performing fitting and joint initialization of the 

185 spectral components (when applicable). 

186 

187 Parameters 

188 ---------- 

189 sources: 

190 The sources to fit. 

191 observation: 

192 The observation that contains the images, 

193 PSF, etc. that are being fit. 

194 metadata: 

195 Additional metadata to store with the blend. 

196 """ 

197 

198 sources: list[Source] 

199 

200 def __init__(self, sources: Sequence[Source], observation: Observation, metadata: dict | None = None): 

201 self.sources = list(sources) 

202 self.observation = observation 

203 if metadata is not None and len(metadata) == 0: 

204 metadata = None 

205 self.metadata = metadata 

206 

207 # Initialize the iteration count and loss function 

208 self.it = 0 

209 self.loss: list[float] = [] 

210 

211 def get_model(self, convolve: bool = False, use_flux: bool = False) -> Image: 

212 """Generate a model of the entire blend. 

213 

214 Parameters 

215 ---------- 

216 convolve: 

217 Whether to convolve the model with the observed PSF in each band. 

218 use_flux: 

219 Whether to use the re-distributed flux associated with the sources 

220 instead of the component models. 

221 

222 Returns 

223 ------- 

224 model: 

225 The model created by combining all of the source models. 

226 """ 

227 model = Image( 

228 np.zeros(self.shape, dtype=self.observation.images.dtype), 

229 bands=self.observation.bands, 

230 yx0=cast(tuple[int, int], self.observation.bbox.origin[-2:]), 

231 ) 

232 

233 if use_flux: 

234 for src in self.sources: 

235 if src.flux_weighted_image is None: 

236 raise ValueError( 

237 "Some sources do not have 'flux' attribute set. Run measure.conserve_flux" 

238 ) 

239 src.flux_weighted_image.insert_into(model) 

240 else: 

241 for component in self.components: 

242 component.get_model().insert_into(model) 

243 if convolve: 

244 return self.observation.convolve(model) 

245 return model 

246 

247 def _grad_log_likelihood(self) -> tuple[Image, np.ndarray]: 

248 """Gradient of the likelihood wrt the unconvolved model 

249 

250 Returns 

251 ------- 

252 result: 

253 The gradient of the likelihood wrt the model 

254 model_data: 

255 The convol model data used to calculate the gradient. 

256 This can be useful for debugging but is not used in 

257 production. 

258 """ 

259 model = self.get_model(convolve=True) 

260 # Update the loss 

261 self.loss.append(self.observation.log_likelihood(model)) 

262 # Calculate the gradient wrt the model d(logL)/d(model) 

263 result = self.observation.weights * (model - self.observation.images) 

264 result = self.observation.convolve(result, grad=True) 

265 return result, model.data 

266 

267 @property 

268 def log_likelihood(self) -> float: 

269 """The current log-likelihood 

270 

271 This is calculated on the fly to ensure that it is always up to date 

272 with the current model parameters. 

273 """ 

274 return self.observation.log_likelihood(self.get_model(convolve=True)) 

275 

276 def fit_spectra(self, clip: bool = False) -> Blend: 

277 """Fit all of the spectra given their current morphologies with a 

278 linear least squares algorithm. 

279 

280 Parameters 

281 ---------- 

282 clip: 

283 Whether or not to clip components that were not 

284 assigned any flux during the fit. 

285 

286 Returns 

287 ------- 

288 blend: 

289 The blend with updated components is returned. 

290 """ 

291 from .initialization import multifit_spectra 

292 

293 morphs = [] 

294 spectra = [] 

295 factorized_indices = [] 

296 model = Image.from_box( 

297 self.observation.bbox, 

298 bands=self.observation.bands, 

299 dtype=self.observation.dtype, 

300 ) 

301 components = self.components 

302 for idx, component in enumerate(components): 

303 if hasattr(component, "morph") and hasattr(component, "spectrum"): 

304 component = cast(FactorizedComponent, component) 

305 morphs.append(component.morph) 

306 spectra.append(component.spectrum) 

307 factorized_indices.append(idx) 

308 else: 

309 model.insert(component.get_model()) 

310 model = self.observation.convolve(model, mode="real") 

311 

312 boxes = [c.bbox for c in components] 

313 fit_spectra = multifit_spectra( 

314 self.observation, 

315 [Image(morph, yx0=cast(tuple[int, int], bbox.origin)) for morph, bbox in zip(morphs, boxes)], 

316 model, 

317 ) 

318 for idx in range(len(morphs)): 

319 component = cast(FactorizedComponent, components[factorized_indices[idx]]) 

320 component.spectrum[:] = fit_spectra[idx] 

321 component.spectrum[component.spectrum < 0] = 0 

322 

323 # Run the proxes for all of the components to make sure that the 

324 # spectra are consistent with the constraints. 

325 # In practice this usually means making sure that they are 

326 # non-negative. 

327 for src in self.sources: 

328 for component in src.components: 

329 if ( 

330 hasattr(component, "spectrum") 

331 and hasattr(component, "prox_spectrum") 

332 and component.prox_spectrum is not None # type: ignore 

333 ): 

334 component.prox_spectrum(component.spectrum) # type: ignore 

335 

336 if clip: 

337 # Remove components with no positive flux 

338 for src in self.sources: 

339 _components = [] 

340 for component in src.components: 

341 component_model = component.get_model() 

342 component_model.data[component_model.data < 0] = 0 

343 if np.sum(component_model.data) > 0: 

344 _components.append(component) 

345 src.components = _components 

346 

347 return self 

348 

349 def fit( 

350 self, 

351 max_iter: int, 

352 e_rel: float = 1e-4, 

353 min_iter: int = 15, 

354 resize: int = 10, 

355 ) -> tuple[int, float]: 

356 """Fit all of the parameters 

357 

358 Parameters 

359 ---------- 

360 max_iter: 

361 The maximum number of iterations 

362 e_rel: 

363 The relative error to use for determining convergence. 

364 min_iter: 

365 The minimum number of iterations. 

366 resize: 

367 Number of iterations before attempting to resize the 

368 resizable components. If `resize` is `None` then 

369 no resizing is ever attempted. 

370 

371 Returns 

372 ------- 

373 it: 

374 Number of iterations. 

375 loss: 

376 Loss for the last solution 

377 """ 

378 while self.it < max_iter: 

379 # Calculate the gradient wrt the on-convolved model 

380 grad_log_likelihood = self._grad_log_likelihood() 

381 if resize is not None and self.it > 0 and self.it % resize == 0: 

382 do_resize = True 

383 else: 

384 do_resize = False 

385 # Update each component given the current gradient 

386 for component in self.components: 

387 overlap = component.bbox & self.bbox 

388 component.update(self.it, grad_log_likelihood[0][overlap].data) 

389 # Check to see if any components need to be resized 

390 if do_resize: 

391 component.resize(self.bbox) 

392 # Stopping criteria 

393 self.it += 1 

394 if self.it > min_iter and np.abs(self.loss[-1] - self.loss[-2]) < e_rel * np.abs(self.loss[-1]): 

395 break 

396 return self.it, self.loss[-1] 

397 

398 def parameterize(self, parameterization: Callable): 

399 """Convert the component parameter arrays into Parameter instances 

400 

401 Parameters 

402 ---------- 

403 parameterization: 

404 A function to use to convert parameters of a given type into 

405 a `Parameter` in place. It should take a single argument that 

406 is the `Component` or `Source` that is to be parameterized. 

407 """ 

408 for source in self.sources: 

409 source.parameterize(parameterization) 

410 

411 def conserve_flux(self, mask_footprint: bool = True, weight_image: Image | None = None) -> None: 

412 """Use the source models as templates to re-distribute flux 

413 from the data 

414 

415 The source models are used as approximations to the data, 

416 which redistribute the flux in the data according to the 

417 ratio of the models for each source. 

418 There is no return value for this function, 

419 instead it adds (or modifies) a ``flux_weighted_image`` 

420 attribute to each the sources with the flux attributed to 

421 that source. 

422 

423 Parameters 

424 ---------- 

425 blend: 

426 The blend that is being fit 

427 mask_footprint: 

428 Whether or not to apply a mask for pixels with zero weight. 

429 weight_image: 

430 The weight image to use for the redistribution. 

431 If `None` then the observation image is used. 

432 """ 

433 observation = self.observation 

434 py = observation.psfs.shape[-2] // 2 

435 px = observation.psfs.shape[-1] // 2 

436 

437 images = observation.images.copy() 

438 if mask_footprint: 

439 images.data[observation.weights.data == 0] = 0 

440 

441 if weight_image is None: 

442 weight_image = self.get_model() 

443 # Always convolve in real space to avoid FFT artifacts 

444 weight_image = observation.convolve(weight_image, mode="real") 

445 

446 # Due to ringing in the PSF, the convolved model can have 

447 # negative values. We take the absolute value to avoid 

448 # negative fluxes in the flux weighted images. 

449 weight_image.data[:] = np.abs(weight_image.data) 

450 

451 for src in self.sources: 

452 if src.is_null: 

453 src.flux_weighted_image = Image.from_box(Box((0, 0)), bands=observation.bands) # type: ignore 

454 continue 

455 src_model = src.get_model() 

456 

457 # Grow the model to include the wings of the PSF 

458 src_box = src.bbox.grow((py, px)) 

459 overlap = observation.bbox & src_box 

460 src_model = src_model.project(bbox=overlap) 

461 src_model = observation.convolve(src_model, mode="real") 

462 src_model.data[:] = np.abs(src_model.data) 

463 numerator = src_model.data 

464 denominator = weight_image[overlap].data 

465 cuts = denominator != 0 

466 ratio = np.zeros(numerator.shape, dtype=numerator.dtype) 

467 ratio[cuts] = numerator[cuts] / denominator[cuts] 

468 ratio[denominator == 0] = 0 

469 # sometimes numerical errors can cause a hot pixel to have a 

470 # slightly higher ratio than 1 

471 ratio[ratio > 1] = 1 

472 src.flux_weighted_image = src_model.copy_with(data=ratio) * images[overlap] 

473 

474 def to_data(self) -> ScarletBlendData: 

475 """Convert the Blend into a persistable data object 

476 

477 Parameters 

478 ---------- 

479 blend : 

480 The blend that is being persisted. 

481 

482 Returns 

483 ------- 

484 blend_data : 

485 The data model for a single blend. 

486 """ 

487 from .io import ScarletBlendData 

488 

489 sources: dict[Any, ScarletSourceBaseData] = {} 

490 for sidx, source in enumerate(self.sources): 

491 metadata = source.metadata or {} 

492 if "id" in metadata: 

493 sources[metadata["id"]] = source.to_data() 

494 else: 

495 sources[sidx] = source.to_data() 

496 

497 blend_data = ScarletBlendData( 

498 origin=self.bbox.origin, # type: ignore 

499 shape=self.bbox.shape, # type: ignore 

500 sources=sources, 

501 metadata=self.metadata, 

502 ) 

503 

504 return blend_data 

505 

506 def __getitem__(self, indices: Any) -> Blend: 

507 """Get a sub-blend corresponding to the given indices. 

508 

509 Parameters 

510 ---------- 

511 indices : 

512 The indices to use to slice the blend. 

513 

514 Returns 

515 ------- 

516 blend : 

517 A new `Blend` instance containing only data from the 

518 specified bands in the specified order. 

519 

520 Raises 

521 ------ 

522 IndexError : 

523 If the indices contain bands not included in the original 

524 blend or a bounding box is given. 

525 """ 

526 return Blend( 

527 sources=[src[indices] for src in self.sources], 

528 observation=self.observation[indices], 

529 metadata=self.metadata, 

530 ) 

531 

532 def __copy__(self) -> Blend: 

533 """Create a copy of this blend. 

534 

535 Returns 

536 ------- 

537 blend : Blend 

538 A new blend that is a copy of this one. 

539 """ 

540 return Blend(sources=self.sources, observation=self.observation, metadata=self.metadata) 

541 

542 def __deepcopy__(self, memo: dict[int, Any]) -> Blend: 

543 """Create a deep copy of this blend. 

544 

545 Parameters 

546 ---------- 

547 memo : dict[int, Any] 

548 A memoization dictionary used by `copy.deepcopy`. 

549 

550 Returns 

551 ------- 

552 blend : Blend 

553 A new blend that is a deep copy of this one. 

554 """ 

555 # Check if already copied 

556 if id(self) in memo: 

557 return memo[id(self)] 

558 

559 # Create placeholder and add to memo FIRST 

560 blend = Blend.__new__(Blend) 

561 memo[id(self)] = blend 

562 

563 # Now safely initialize the placeholder with deepcopied arguments 

564 blend.__init__( # type: ignore[misc] 

565 sources=[deepcopy(src, memo) for src in self.sources], 

566 observation=deepcopy(self.observation, memo), 

567 metadata=deepcopy(self.metadata, memo), 

568 ) 

569 

570 return blend