Coverage for python / lsst / meas / extensions / scarlet / deconvolveExposureTask.py: 25%

120 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-28 09:00 +0000

1# This file is part of meas_extensions_scarlet. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import logging 

23 

24import lsst.afw.detection as afwDet 

25import lsst.afw.image as afwImage 

26import lsst.afw.table as afwTable 

27import lsst.pex.config as pexConfig 

28import lsst.pipe.base as pipeBase 

29import lsst.pipe.base.connectionTypes as cT 

30import lsst.scarlet.lite as scl 

31import numpy as np 

32 

33from . import utils 

34 

35log = logging.getLogger(__name__) 

36 

37__all__ = [ 

38 "DeconvolveExposureTask", 

39 "DeconvolveExposureConfig", 

40 "DeconvolveExposureConnections", 

41] 

42 

43 

44def calculate_update_step( 

45 observation: scl.Observation, 

46 min_scale: float = 0.01, 

47 default_scale: float = 0.1, 

48) -> float: 

49 """Calculate the scale factor for the update step in deconvolution. 

50 

51 For most images this will be 1.0 but for images with low SNR 

52 and/or high sparsity (for example LSST u-band images) the scale 

53 factor will be less than 1.0. 

54 

55 Parameters 

56 ---------- 

57 observation : 

58 Scarlet lite Observation. 

59 

60 min_scale : 

61 Minimum allowed scale factor. 

62 

63 default_scale : 

64 Default scale factor to return if noise level is non-finite. 

65 

66 Returns 

67 ------- 

68 scale : float 

69 Scale factor for the update step. 

70 """ 

71 # Calculate sparsity as fraction of pixels significantly above noise 

72 noise_level = observation.noise_rms[0] 

73 # Guard against non-finite or non-positive noise levels 

74 if noise_level <= 0 or not np.isfinite(noise_level): 

75 return default_scale 

76 signal_mask = observation.images.data > 3*noise_level 

77 signal_pixels = np.sum(signal_mask) 

78 sparsity = signal_pixels / observation.images.data.size 

79 

80 if np.any(signal_mask): 

81 median_signal = np.median(observation.images.data[signal_mask]) 

82 snr = median_signal / noise_level 

83 else: 

84 snr = 1.0 

85 

86 # Scale factor that decreases with sparsity and increases with SNR 

87 scale = min(1.0, (sparsity * np.sqrt(snr)) / 0.1) 

88 

89 return max(min_scale, scale) 

90 

91 

92class DeconvolveExposureConnections( 

93 pipeBase.PipelineTaskConnections, 

94 dimensions=("tract", "patch", "skymap", "band"), 

95 defaultTemplates={"inputCoaddName": "deep"}, 

96): 

97 """Connections for DeconvolveExposureTask""" 

98 

99 coadd = cT.Input( 

100 doc="Exposure to deconvolve", 

101 name="{inputCoaddName}Coadd_calexp", 

102 storageClass="ExposureF", 

103 dimensions=("tract", "patch", "band", "skymap"), 

104 ) 

105 

106 coadd_cell = cT.Input( 

107 doc="Exposure on which to run deblending", 

108 name="{inputCoaddName}CoaddCell", 

109 storageClass="MultipleCellCoadd", 

110 dimensions=("tract", "patch", "band", "skymap") 

111 ) 

112 

113 background = cT.Input( 

114 doc="Background model to subtract from the cell-based coadd", 

115 name="{inputCoaddName}Coadd_calexp_background", 

116 storageClass="Background", 

117 dimensions=("tract", "patch", "band", "skymap") 

118 ) 

119 

120 catalog = cT.Input( 

121 doc="Catalog of sources detected in the deconvolved image", 

122 name="{inputCoaddName}Coadd_mergeDet", 

123 storageClass="SourceCatalog", 

124 dimensions=("tract", "patch", "skymap"), 

125 ) 

126 

127 deconvolved = cT.Output( 

128 doc="Deconvolved exposure", 

129 name="deconvolved_{inputCoaddName}_coadd", 

130 storageClass="ExposureF", 

131 dimensions=("tract", "patch", "band", "skymap"), 

132 ) 

133 

134 def __init__(self, *, config=None): 

135 if not config.useFootprints: 

136 # Deconvolution will not use input catalog if 

137 # footprints are not used 

138 self.inputs.remove("catalog") 

139 

140 if config.useCellCoadds: 

141 del self.coadd 

142 else: 

143 del self.coadd_cell 

144 del self.background 

145 

146 

147class DeconvolveExposureConfig( 

148 pipeBase.PipelineTaskConfig, 

149 pipelineConnections=DeconvolveExposureConnections, 

150): 

151 """Configuration for DeconvolveExposureTask""" 

152 

153 maxIter = pexConfig.Field[int]( 

154 doc="Maximum number of iterations", 

155 default=100, 

156 ) 

157 minIter = pexConfig.Field[int]( 

158 doc="Minimum number of iterations", 

159 default=10, 

160 ) 

161 eRel = pexConfig.Field[float]( 

162 doc="Relative error threshold", 

163 default=1e-3, 

164 ) 

165 backgroundThreshold = pexConfig.Field[float]( 

166 default=0, 

167 doc="Threshold for background subtraction. " 

168 "Pixels in the fit below this threshold will be set to zero", 

169 ) 

170 useFootprints = pexConfig.Field[bool]( 

171 default=True, 

172 doc="Use footprints to constrain the deconvolved model", 

173 ) 

174 useCellCoadds = pexConfig.Field[bool]( 

175 doc="Use cell-based coadd instead of regular coadd?", 

176 default=False, 

177 ) 

178 

179 

180class DeconvolveExposureTask(pipeBase.PipelineTask): 

181 """Deconvolve an Exposure using scarlet lite.""" 

182 

183 ConfigClass = DeconvolveExposureConfig 

184 _DefaultName = "deconvolveExposure" 

185 

186 def __init__(self, initInputs=None, **kwargs): 

187 if initInputs is None: 

188 initInputs = {} 

189 super().__init__(initInputs=initInputs, **kwargs) 

190 

191 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

192 inputs = butlerQC.get(inputRefs) 

193 

194 # Stitch together cell-based coadds (if necessary) 

195 if self.config.useCellCoadds: 

196 band = inputRefs.coadd_cell.dataId['band'] 

197 cellCoadd = inputs.pop('coadd_cell') 

198 background = inputs.pop('background') 

199 coadd = cellCoadd.stitch().asExposure() 

200 coadd.image -= background.getImage() 

201 else: 

202 coadd = inputs.pop("coadd") 

203 band = inputRefs.coadd.dataId['band'] 

204 

205 catalog = inputs.pop('catalog', None) 

206 

207 assert not inputs, "runQuantum got more inputs than expected." 

208 outputs = self.run( 

209 coadd=coadd, 

210 catalog=catalog, 

211 band=band, 

212 ) 

213 butlerQC.put(outputs, outputRefs) 

214 

215 def run( 

216 self, 

217 coadd: afwImage.Exposure, 

218 catalog: afwTable.SourceCatalog | None = None, 

219 band: str = 'dummy' 

220 ) -> pipeBase.Struct: 

221 """Deconvolve an Exposure 

222 

223 Parameters 

224 ---------- 

225 coadd : 

226 Coadd image to deconvolve 

227 

228 catalog : 

229 Catalog of sources detected in the merged catalog. 

230 This is used to supress noise in regions with no 

231 significant flux about the noise in the coadds. 

232 

233 band : 

234 Band of the coadd image. 

235 Since this is a single band task the band isn't really necessary 

236 but can be useful for debugging so we keep it as a parameter. 

237 

238 Returns 

239 ------- 

240 deconvolved : `pipeBase.Struct` 

241 Deconvolved exposure 

242 """ 

243 observation = self._buildObservation(coadd, catalog, band) 

244 self.bbox = coadd.getBBox() 

245 

246 # Deconvolve. 

247 # Store the loss history for debugging purposes. 

248 model, self.loss = self._deconvolve(observation, catalog) 

249 

250 # Store the model in an Exposure 

251 exposure = self._modelToExposure(model.data[0], coadd) 

252 return pipeBase.Struct(deconvolved=exposure) 

253 

254 def _buildObservation( 

255 self, 

256 coadd: afwImage.Exposure, 

257 catalog: afwTable.SourceCatalog | None = None, 

258 band: str = 'dummy' 

259 ) -> scl.Observation: 

260 """Build a scarlet lite Observation from an Exposure. 

261 

262 We don't actually use scarlet, but the optimized convolutions 

263 using scarlet data products are still useful. 

264 

265 Parameters 

266 ---------- 

267 coadd : 

268 Coadd image to deconvolve. 

269 catalog : 

270 Catalog of sources. 

271 This is used to find a location for the PSF if it cannot be 

272 generated at the center of the coadd. 

273 

274 band : 

275 Band of the coadd image. 

276 

277 """ 

278 bands = (band,) 

279 model_psf = scl.utils.integrated_circular_gaussian(sigma=0.8) 

280 

281 # Give zero weight to non-finite pixels 

282 weights = np.ones_like(coadd.image.array) 

283 weights[~np.isfinite(coadd.image.array)] = 0 

284 

285 image = coadd.image.array.copy() 

286 # Set non-finite pixels to zero 

287 image[~np.isfinite(image)] = 0.0 

288 psfCenter = coadd.getBBox().getCenter() 

289 if catalog is not None: 

290 psf, _, _ = utils.computeNearestPsf(coadd, catalog, band, psfCenter) 

291 if psf is None: 

292 # There were no valid locations from 

293 # which a PSF could be obtained 

294 raise pipeBase.NoWorkFound("No valid PSF could be obtained for deconvolution") 

295 psf = psf.array 

296 else: 

297 psf = coadd.getPsf().computeKernelImage(psfCenter).array 

298 

299 badPixelMasks = utils.defaultBadPixelMasks 

300 badPixels = coadd.mask.getPlaneBitMask(badPixelMasks) 

301 mask = coadd.mask.array & badPixels 

302 weights[mask > 0] = 0 

303 

304 observation = scl.Observation( 

305 images=image[None], 

306 variance=coadd.variance.array.copy()[None], 

307 weights=weights[None], 

308 psfs=psf[None], 

309 model_psf=model_psf[None], 

310 convolution_mode="fft", 

311 bands=bands, 

312 bbox=utils.bboxToScarletBox(coadd.getBBox()), 

313 ) 

314 return observation 

315 

316 def _deconvolve( 

317 self, 

318 observation: scl.Observation, 

319 catalog: afwTable.SourceCatalog | None = None, 

320 ) -> tuple[scl.Image, list[float]]: 

321 """Deconvolve the observed image. 

322 

323 Parameters 

324 ---------- 

325 observation : 

326 Scarlet lite Observation. 

327 catalog : 

328 Catalog of sources detected in the deconvolved image. 

329 This is used to mask the deconvolved image so that 

330 the deconvolved footprints detected downstream will always 

331 fit inside of the original footprints. 

332 """ 

333 model = observation.images.copy() 

334 loss = [] 

335 step = calculate_update_step(observation) 

336 if catalog is not None: 

337 width, height = self.bbox.getDimensions() 

338 x0, y0 = self.bbox.getMin() 

339 footprintImage = afwDet.footprintsToNumpy(catalog, shape=(height, width), xy0=(x0, y0)) 

340 for n in range(self.config.maxIter): 

341 residual = observation.images - observation.convolve(model) 

342 loss.append(-0.5 * np.sum(residual.data**2)) 

343 update = observation.convolve(residual, grad=True) 

344 update.data[:] *= step 

345 model += update 

346 model.data[(model.data < 0) | ~np.isfinite(model.data)] = 0 

347 if catalog is not None: 

348 # Ensure that the deconvolved model footprints fit 

349 # inside of the original footprints by setting regions 

350 # outside of the original footprints to zero. 

351 model.data[:] *= footprintImage 

352 

353 # Check for a diverging model 

354 if len(loss) > 1 and loss[-1] < loss[-2]: 

355 step = step / 2 

356 self.log.warning(f"Loss increased at iteration {n}, decreasing scale to {step}") 

357 

358 # Check for convergence 

359 if n > self.config.minIter and np.abs(loss[-1] - loss[-2]) < self.config.eRel * np.abs(loss[-1]): 

360 break 

361 

362 return model, loss 

363 

364 def _modelToExposure(self, model: np.ndarray, coadd: afwImage.Exposure) -> afwImage.Exposure: 

365 """Convert a scarlet lite Image to an Exposure. 

366 

367 Parameters 

368 ---------- 

369 image : 

370 Scarlet lite Image. 

371 """ 

372 image = afwImage.Image( 

373 array=model, 

374 xy0=coadd.getBBox().getMin(), 

375 deep=False, 

376 dtype=coadd.image.array.dtype, 

377 ) 

378 maskedImage = afwImage.MaskedImage( 

379 image=image, 

380 mask=coadd.mask, 

381 variance=coadd.variance, 

382 dtype=coadd.image.array.dtype, 

383 ) 

384 exposure = afwImage.Exposure( 

385 maskedImage=maskedImage, 

386 exposureInfo=coadd.getInfo(), 

387 dtype=coadd.image.array.dtype, 

388 ) 

389 return exposure