Coverage for python / lsst / meas / extensions / scarlet / utils.py: 10%

134 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-24 08:32 +0000

1import lsst.geom as geom 

2import lsst.scarlet.lite as scl 

3import numpy as np 

4from scipy.signal import convolve 

5from lsst.afw.detection import InvalidPsfError, Footprint as afwFootprint 

6from lsst.afw.image import ( 

7 IncompleteDataError, 

8 MultibandExposure, 

9 MultibandImage, 

10 Exposure, 

11) 

12from lsst.afw.image.utils import projectImage 

13from lsst.afw.table import SourceCatalog 

14from lsst.geom import Box2I, Point2D, Point2I 

15from lsst.pipe.base import NoWorkFound 

16 

17defaultBadPixelMasks = ["BAD", "NO_DATA", "SAT", "SUSPECT", "EDGE"] 

18 

19 

20def scarletBoxToBBox(box: scl.Box, xy0: geom.Point2I = geom.Point2I()) -> geom.Box2I: 

21 """Convert a scarlet_lite Box into a Box2I. 

22 

23 Parameters 

24 ---------- 

25 box: 

26 The scarlet bounding box to convert. 

27 xy0: 

28 An additional offset to add to the scarlet box. 

29 This is common since scarlet sources have an origin of 

30 `(0,0)` at the lower left corner of the blend while 

31 the blend itself is likely to have an offset in the 

32 `Exposure`. 

33 

34 Returns 

35 ------- 

36 bbox: 

37 The converted bounding box. 

38 """ 

39 xy0 = geom.Point2I(box.origin[-1] + xy0.x, box.origin[-2] + xy0.y) 

40 extent = geom.Extent2I(box.shape[-1], box.shape[-2]) 

41 return geom.Box2I(xy0, extent) 

42 

43 

44def bboxToScarletBox(bbox: geom.Box2I, xy0: geom.Point2I = geom.Point2I()) -> scl.Box: 

45 """Convert a Box2I into a scarlet_lite Box. 

46 

47 Parameters 

48 ---------- 

49 bbox: 

50 The Box2I to convert into a scarlet `Box`. 

51 xy0: 

52 An overall offset to subtract from the `Box2I`. 

53 This is common in blends, where `xy0` is the minimum pixel 

54 location of the blend and `bbox` is the box containing 

55 a source in the blend. 

56 

57 Returns 

58 ------- 

59 box: 

60 A scarlet `Box` that is more useful for slicing image data 

61 as a numpy array. 

62 """ 

63 origin = (bbox.getMinY() - xy0.y, bbox.getMinX() - xy0.x) 

64 return scl.Box((bbox.getHeight(), bbox.getWidth()), origin) 

65 

66 

67def multiband_convolve(images: np.ndarray, psfs: np.ndarray) -> np.ndarray: 

68 """Convolve a multi-band image with the PSF in each band. 

69 

70 `images` and `psfs` should have dimensions `(bands, height, width)`. 

71 

72 Parameters 

73 ---------- 

74 images : 

75 The multi-band images to convolve. 

76 psfs : 

77 The PSF for each band. 

78 

79 Returns 

80 ------- 

81 result : 

82 The convolved images. 

83 """ 

84 result = np.zeros(images.shape, dtype=images.dtype) 

85 for bidx, (image, psf) in enumerate(zip(images, psfs, strict=True)): 

86 result[bidx] = convolve(image, psf, mode="same") 

87 return result 

88 

89 

90def computePsfKernelImage(mExposure, psfCenter, catalog=None): 

91 """Compute the PSF kernel image and update the multiband exposure 

92 if not all of the PSF images could be computed. 

93 

94 Parameters 

95 ---------- 

96 psfCenter : `tuple` or `Point2I` or `Point2D` 

97 The location `(x, y)` used as the center of the PSF. 

98 

99 Returns 

100 ------- 

101 psfModels : `np.ndarray` 

102 The multiband PSF image 

103 mExposure : `MultibandExposure` 

104 The exposure, updated to only use bands that 

105 successfully generated a PSF image. 

106 """ 

107 if not isinstance(psfCenter, geom.Point2D): 

108 psfCenter = geom.Point2D(*psfCenter) 

109 

110 try: 

111 psfModels = mExposure.computePsfKernelImage(psfCenter) 

112 except IncompleteDataError as e: 

113 psfModels = e.partialPsf 

114 if psfModels is None: 

115 return None, None 

116 # Use only the bands that successfully generated a PSF image. 

117 bands = psfModels.bands 

118 mExposure = mExposure[bands,] 

119 if len(bands) == 1: 

120 # Only a single band generated a PSF, so the MultibandExposure 

121 # became a single band ExposureF. 

122 # Convert the result back into a MultibandExposure. 

123 mExposure = MultibandExposure.fromExposures(bands, [mExposure]) 

124 return psfModels.array, mExposure 

125 

126 

127def computeNearestPsf( 

128 calexp: Exposure, 

129 catalog: SourceCatalog, 

130 band: str | None = None, 

131 psfCenter: Point2D | None = None, 

132) -> tuple[np.ndarray, Point2I, float]: 

133 """Create a PSF image at the nearest valid location 

134 

135 Sometimes not all locations in an image can generate a PSF image so the 

136 source catalog is used to find the nearest valid location. 

137 

138 Parameters 

139 ---------- 

140 calexp : 

141 The exposure. 

142 catalog : 

143 The catalog. 

144 band : 

145 The band of the exposure used to filter the catalog by only 

146 selecting sources that have a 

147 If band is ``None`` then the full catalog is used. 

148 psfCenter : 

149 The location of the PSF image. 

150 If no location is provided, the center of the exposure is used. 

151 

152 Returns 

153 ------- 

154 psf : 

155 The PSF image. 

156 location : 

157 The location of the PSF image. 

158 diff : 

159 The difference between the requested location and the 

160 nearest valid location. 

161 """ 

162 if psfCenter is None: 

163 psfCenter = calexp.getBBox().getCenter() 

164 

165 if not isinstance(psfCenter, geom.Point2D): 

166 psfCenter = geom.Point2D(*psfCenter) 

167 

168 try: 

169 psf = calexp.getPsf().computeKernelImage(psfCenter) 

170 return psf, psfCenter, 0 

171 except InvalidPsfError: 

172 pass 

173 

174 xc, yc = psfCenter 

175 

176 # Only select records that have detections in this band 

177 if band is not None: 

178 sources = catalog[catalog[f'merge_footprint_{band}']] 

179 else: 

180 sources = catalog 

181 

182 # Get the peaks of all of the sources 

183 x = [] 

184 y = [] 

185 for src in sources: 

186 for peak in src.getFootprint().peaks: 

187 if band is None or peak[f'merge_peak_{band}']: 

188 x.append(peak['i_x']) 

189 y.append(peak['i_y']) 

190 x = np.array(x) 

191 y = np.array(y) 

192 

193 # Sort the peaks based on their distance to the location 

194 diff_x = x - xc 

195 diff_y = y - yc 

196 sorted_indices = np.argsort(diff_x**2 + diff_y**2) 

197 

198 # Iterate over sources until a location is found that can generate a PSF 

199 psf = None 

200 for ref_index in sorted_indices: 

201 try: 

202 psf = calexp.getPsf().computeKernelImage(Point2D(x[ref_index], y[ref_index])) 

203 break 

204 except InvalidPsfError: 

205 pass 

206 if psf is None: 

207 return None, None, None 

208 newLocation = Point2I(x[ref_index], y[ref_index]) 

209 diff = np.sqrt(diff_x[ref_index]**2 + diff_y[ref_index]**2) 

210 

211 return psf, newLocation, diff 

212 

213 

214def computeNearestPsfMultiBand( 

215 mExposure: MultibandExposure, 

216 psfCenter: tuple[int, int] | geom.Point2I | geom.Point2D, 

217 catalog: SourceCatalog, 

218) -> tuple[np.ndarray, MultibandExposure]: 

219 """Compute the image in each band at the location nearest to the PSF Center 

220 

221 If the PSF cannot be generated in all bands then `mExposure` is updated 

222 to use only the bands that successfully generated a PSF image. 

223 

224 Parameters 

225 ---------- 

226 mExposure : 

227 The multi-band exposure. 

228 psfCenter : 

229 The location `(x, y)` used as the center of the PSF. 

230 catalog : 

231 The source catalog. 

232 """ 

233 psfs = {} 

234 incomplete = False 

235 for band in mExposure.bands: 

236 psf, psfCenter, diff = computeNearestPsf( 

237 mExposure[band,], 

238 catalog, 

239 band, 

240 psfCenter, 

241 ) 

242 if psf is None: 

243 incomplete = True 

244 else: 

245 psfs[band] = psf 

246 

247 if len(psfs) == 0: 

248 return None, None 

249 

250 left = np.min([psf.getBBox().getMinX() for psf in psfs.values()]) 

251 bottom = np.min([psf.getBBox().getMinY() for psf in psfs.values()]) 

252 right = np.max([psf.getBBox().getMaxX() for psf in psfs.values()]) 

253 top = np.max([psf.getBBox().getMaxY() for psf in psfs.values()]) 

254 bbox = Box2I(Point2I(left, bottom), Point2I(right, top)) 

255 

256 psf_images = [projectImage(psf, bbox) for psf in psfs.values()] 

257 

258 mPsf = MultibandImage.fromImages(list(psfs.keys()), psf_images) 

259 

260 if incomplete: 

261 bands = mPsf.bands 

262 mExposure = mExposure[bands,] 

263 

264 if len(bands) == 1: 

265 # Only a single band generated a PSF, so the MultibandExposure 

266 # became a single band ExposureF. 

267 # Convert the result back into a MultibandExposure. 

268 mExposure = MultibandExposure.fromExposures(bands, [mExposure]) 

269 

270 return mPsf.array, mExposure 

271 

272 

273def buildObservation( 

274 modelPsf: np.ndarray, 

275 psfCenter: tuple[int, int] | geom.Point2I | geom.Point2D, 

276 mExposure: MultibandExposure, 

277 badPixelMasks: list[str] | None = None, 

278 footprint: afwFootprint = None, 

279 useWeights: bool = True, 

280 convolutionType: str = "real", 

281 catalog: SourceCatalog | None = None, 

282) -> scl.Observation: 

283 """Generate an Observation from a set of arguments. 

284 

285 Make the generation and reconstruction of a scarlet model consistent 

286 by building an `Observation` from a set of arguments. 

287 

288 Parameters 

289 ---------- 

290 modelPsf : 

291 The 2D model of the PSF in the partially deconvolved space. 

292 psfCenter : 

293 The location `(x, y)` used as the center of the PSF. 

294 mExposure : 

295 The multi-band exposure that the model represents. 

296 If `mExposure` is `None` then no image, variance, or weights are 

297 attached to the observation. 

298 footprint : 

299 The footprint that is being fit. 

300 If `footprint` is `None` then the weights are not updated to mask 

301 out pixels not contained in the footprint. 

302 badPixelMasks : 

303 The keys from the bit mask plane used to mask out pixels 

304 during the fit. 

305 If `badPixelMasks` is `None` then the default values from 

306 `ScarletDeblendConfig.badMask` are used. 

307 useWeights : 

308 Whether or not fitting should use inverse variance weights to 

309 calculate the log-likelihood. 

310 convolutionType : 

311 The type of convolution to use (either "real" or "fft"). 

312 When reconstructing an image it is advised to use "real" to avoid 

313 polluting the footprint with artifacts from the fft. 

314 catalog : 

315 A source catalog to use for PSFs that cannot be determined at 

316 the center of the image. 

317 

318 Returns 

319 ------- 

320 observation: 

321 The observation constructed from the input parameters. 

322 """ 

323 # Initialize the observed PSFs 

324 if not isinstance(psfCenter, geom.Point2D): 

325 psfCenter = geom.Point2D(*psfCenter) 

326 if catalog is None: 

327 psfModels, mExposure = computePsfKernelImage(mExposure, psfCenter) 

328 else: 

329 psfModels, mExposure = computeNearestPsfMultiBand(mExposure, psfCenter, catalog) 

330 

331 if psfModels is None: 

332 raise NoWorkFound("No valid PSF could be obtained for building the observation") 

333 

334 # Use the inverse variance as the weights 

335 if useWeights: 

336 weights = 1 / mExposure.variance.array 

337 weights[~np.isfinite(weights)] = 0 

338 else: 

339 weights = np.ones_like(mExposure.image.array) 

340 

341 # Mask out bad pixels 

342 if badPixelMasks is None: 

343 badPixelMasks = defaultBadPixelMasks 

344 badPixels = mExposure.mask.getPlaneBitMask(badPixelMasks) 

345 mask = mExposure.mask.array & badPixels 

346 weights[mask > 0] = 0 

347 

348 if footprint is not None: 

349 # Mask out the pixels outside the footprint 

350 weights *= footprint.spans.asArray() 

351 

352 # Mask out non-finite pixels 

353 image = mExposure.image.array.copy() 

354 weights[~np.isfinite(image)] = 0 

355 image[~np.isfinite(image)] = 0 

356 

357 return scl.Observation( 

358 images=image, 

359 variance=mExposure.variance.array, 

360 weights=weights, 

361 psfs=psfModels, 

362 model_psf=modelPsf[None, :, :], 

363 convolution_mode=convolutionType, 

364 bands=mExposure.bands, 

365 bbox=bboxToScarletBox(mExposure.getBBox()), 

366 ) 

367 

368 

369def calcChi2( 

370 model: scl.Image, 

371 observation: scl.Observation, 

372 footprint: np.ndarray | None = None, 

373 doConvolve: bool = True, 

374) -> scl.Image: 

375 """Calculate the chi2 image for a model. 

376 

377 Parameters 

378 ---------- 

379 model : 

380 The model used to calculate the chi2. 

381 observation : 

382 The observation used to calculate the chi2. 

383 footprint : 

384 The footprint to use when calculating the chi2. 

385 If `footprint` is `None` then the footprint is calculated 

386 to be the pixels where the model is greater than 0. 

387 doConvolve : 

388 Whether or not to convolve the model with the PSF. 

389 

390 Returns 

391 ------- 

392 chi2 : 

393 The chi2/pixel image for the model. 

394 """ 

395 if doConvolve: 

396 model = observation.convolve(model) 

397 if footprint is None: 

398 footprint = model.data > 0 

399 bbox = model.bbox 

400 nBands = len(observation.images.bands) 

401 residual = (observation.images[:, bbox].data - model.data) * footprint 

402 cuts = observation.variance[:, bbox].data != 0 

403 chi2Data = np.zeros(residual.shape, dtype=residual.dtype) 

404 chi2Data[cuts] = residual[cuts]**2 / observation.variance[:, bbox].data[cuts] / nBands 

405 chi2 = scl.Image( 

406 chi2Data, 

407 bands=model.bands, 

408 yx0=model.yx0, 

409 ) 

410 return chi2