Coverage for python/lsst/meas/extensions/scarlet/utils.py: 16%

81 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-26 11:02 +0000

1import numpy as np 

2 

3from lsst.afw.detection import Footprint as afwFootprint 

4from lsst.afw.detection import HeavyFootprintF, makeHeavyFootprint, PeakCatalog 

5from lsst.afw.detection.multiband import MultibandFootprint 

6from lsst.afw.geom import SpanSet 

7from lsst.afw.image import ( 

8 Mask, 

9 MaskedImage, 

10 Image as afwImage, 

11 IncompleteDataError, 

12 MultibandExposure, 

13 MultibandImage 

14) 

15from lsst.afw.table import SourceCatalog 

16import lsst.geom as geom 

17import lsst.scarlet.lite as scl 

18 

19 

20defaultBadPixelMasks = ["BAD", "CR", "NO_DATA", "SAT", "SUSPECT", "EDGE"] 

21 

22 

23def footprintsToNumpy( 

24 catalog: SourceCatalog, 

25 shape: tuple[int, int], 

26 xy0: tuple[int, int] | None = None, 

27) -> np.ndarray: 

28 """Convert all of the footprints in a catalog into a boolean array. 

29 

30 Parameters 

31 ---------- 

32 catalog: 

33 The source catalog containing the footprints. 

34 This is typically a mergeDet catalog, or a full source catalog 

35 with the parents removed. 

36 shape: 

37 The final shape of the output array. 

38 xy0: 

39 The lower-left corner of the array that will contain the spans. 

40 

41 Returns 

42 ------- 

43 result: 

44 The array with pixels contained in `spans` marked as `True`. 

45 """ 

46 if xy0 is None: 

47 offset = (0, 0) 

48 else: 

49 offset = (-xy0[0], -xy0[1]) 

50 

51 result = np.zeros(shape, dtype=bool) 

52 for src in catalog: 

53 spans = src.getFootprint().spans 

54 yidx, xidx = spans.shiftedBy(*offset).indices() 

55 result[yidx, xidx] = 1 

56 return result 

57 

58 

59def scarletBoxToBBox(box: scl.Box, xy0: geom.Point2I = geom.Point2I()) -> geom.Box2I: 

60 """Convert a scarlet_lite Box into a Box2I. 

61 

62 Parameters 

63 ---------- 

64 box: 

65 The scarlet bounding box to convert. 

66 xy0: 

67 An additional offset to add to the scarlet box. 

68 This is common since scarlet sources have an origin of 

69 `(0,0)` at the lower left corner of the blend while 

70 the blend itself is likely to have an offset in the 

71 `Exposure`. 

72 

73 Returns 

74 ------- 

75 bbox: 

76 The converted bounding box. 

77 """ 

78 xy0 = geom.Point2I(box.origin[-1] + xy0.x, box.origin[-2] + xy0.y) 

79 extent = geom.Extent2I(box.shape[-1], box.shape[-2]) 

80 return geom.Box2I(xy0, extent) 

81 

82 

83def bboxToScarletBox(bbox: geom.Box2I, xy0: geom.Point2I = geom.Point2I()) -> scl.Box: 

84 """Convert a Box2I into a scarlet_lite Box. 

85 

86 Parameters 

87 ---------- 

88 bbox: 

89 The Box2I to convert into a scarlet `Box`. 

90 xy0: 

91 An overall offset to subtract from the `Box2I`. 

92 This is common in blends, where `xy0` is the minimum pixel 

93 location of the blend and `bbox` is the box containing 

94 a source in the blend. 

95 

96 Returns 

97 ------- 

98 box: 

99 A scarlet `Box` that is more useful for slicing image data 

100 as a numpy array. 

101 """ 

102 origin = (bbox.getMinY() - xy0.y, bbox.getMinX() - xy0.x) 

103 return scl.Box((bbox.getHeight(), bbox.getWidth()), origin) 

104 

105 

106def computePsfKernelImage(mExposure, psfCenter): 

107 """Compute the PSF kernel image and update the multiband exposure 

108 if not all of the PSF images could be computed. 

109 

110 Parameters 

111 ---------- 

112 psfCenter : `tuple` or `Point2I` or `Point2D` 

113 The location `(x, y)` used as the center of the PSF. 

114 

115 Returns 

116 ------- 

117 psfModels : `np.ndarray` 

118 The multiband PSF image 

119 mExposure : `MultibandExposure` 

120 The exposure, updated to only use bands that 

121 successfully generated a PSF image. 

122 """ 

123 if not isinstance(psfCenter, geom.Point2D): 

124 psfCenter = geom.Point2D(*psfCenter) 

125 

126 try: 

127 psfModels = mExposure.computePsfKernelImage(psfCenter) 

128 except IncompleteDataError as e: 

129 psfModels = e.partialPsf 

130 # Use only the bands that successfully generated a PSF image. 

131 bands = psfModels.filters 

132 mExposure = mExposure[bands,] 

133 if len(bands) == 1: 

134 # Only a single band generated a PSF, so the MultibandExposure 

135 # became a single band ExposureF. 

136 # Convert the result back into a MultibandExposure. 

137 mExposure = MultibandExposure.fromExposures(bands, [mExposure]) 

138 return psfModels.array, mExposure 

139 

140 

141def buildObservation( 

142 modelPsf: np.ndarray, 

143 psfCenter: tuple[int, int] | geom.Point2I | geom.Point2D, 

144 mExposure: MultibandExposure, 

145 badPixelMasks: list[str] | None = None, 

146 footprint: afwFootprint = None, 

147 useWeights: bool = True, 

148 convolutionType: str = "real", 

149) -> scl.Observation: 

150 """Generate an Observation from a set of arguments. 

151 

152 Make the generation and reconstruction of a scarlet model consistent 

153 by building an `Observation` from a set of arguments. 

154 

155 Parameters 

156 ---------- 

157 modelPsf: 

158 The 2D model of the PSF in the partially deconvolved space. 

159 psfCenter: 

160 The location `(x, y)` used as the center of the PSF. 

161 mExposure: 

162 The multi-band exposure that the model represents. 

163 If `mExposure` is `None` then no image, variance, or weights are 

164 attached to the observation. 

165 footprint: 

166 The footprint that is being fit. 

167 If `footprint` is `None` then the weights are not updated to mask 

168 out pixels not contained in the footprint. 

169 badPixelMasks: 

170 The keys from the bit mask plane used to mask out pixels 

171 during the fit. 

172 If `badPixelMasks` is `None` then the default values from 

173 `ScarletDeblendConfig.badMask` are used. 

174 useWeights: 

175 Whether or not fitting should use inverse variance weights to 

176 calculate the log-likelihood. 

177 convolutionType: 

178 The type of convolution to use (either "real" or "fft"). 

179 When reconstructing an image it is advised to use "real" to avoid 

180 polluting the footprint with artifacts from the fft. 

181 

182 Returns 

183 ------- 

184 observation: 

185 The observation constructed from the input parameters. 

186 """ 

187 # Initialize the observed PSFs 

188 if not isinstance(psfCenter, geom.Point2D): 

189 psfCenter = geom.Point2D(*psfCenter) 

190 psfModels, mExposure = computePsfKernelImage(mExposure, psfCenter) 

191 

192 # Use the inverse variance as the weights 

193 if useWeights: 

194 weights = 1/mExposure.variance.array 

195 else: 

196 weights = np.ones_like(mExposure.image.array) 

197 

198 # Mask out bad pixels 

199 if badPixelMasks is None: 

200 badPixelMasks = defaultBadPixelMasks 

201 badPixels = mExposure.mask.getPlaneBitMask(badPixelMasks) 

202 mask = mExposure.mask.array & badPixels 

203 weights[mask > 0] = 0 

204 

205 if footprint is not None: 

206 # Mask out the pixels outside the footprint 

207 weights *= footprint.spans.asArray() 

208 

209 return scl.Observation( 

210 images=mExposure.image.array, 

211 variance=mExposure.variance.array, 

212 weights=weights, 

213 psfs=psfModels, 

214 model_psf=modelPsf[None, :, :], 

215 convolution_mode=convolutionType, 

216 bands=mExposure.filters, 

217 bbox=bboxToScarletBox(mExposure.getBBox()), 

218 ) 

219 

220 

221def scarletModelToHeavy( 

222 source: scl.Source, blend: scl.Blend, useFlux=False, 

223) -> HeavyFootprintF | MultibandFootprint: 

224 """Convert a scarlet_lite model to a `HeavyFootprintF` 

225 or `MultibandFootprint`. 

226 

227 Parameters 

228 ---------- 

229 source: 

230 The source to convert to a `HeavyFootprint`. 

231 blend: 

232 The `Blend` object that contains information about 

233 the observation, PSF, etc, used to convolve the 

234 scarlet model to the observed seeing in each band. 

235 useFlux: 

236 Whether or not to re-distribute the flux from the image 

237 to conserve flux. 

238 

239 Returns 

240 ------- 

241 heavy: 

242 The footprint (possibly multiband) containing the model for the source. 

243 """ 

244 # We want to convolve the model with the observed PSF, 

245 # which means we need to grow the model box by the PSF to 

246 # account for all of the flux after convolution. 

247 

248 # Get the PSF size and radii to grow the box 

249 py, px = blend.observation.psfs.shape[1:] 

250 dh = py // 2 

251 dw = px // 2 

252 

253 if useFlux: 

254 bbox = source.flux_weighted_image.bbox 

255 else: 

256 bbox = source.bbox.grow((dh, dw)) 

257 # Only use the portion of the convolved model that fits in the image 

258 overlap = bbox & blend.observation.bbox 

259 # Load the full multiband model in the larger box 

260 if useFlux: 

261 # The flux weighted model is already convolved, so we just load it 

262 model = source.get_model(use_flux=True).project(bbox=overlap) 

263 else: 

264 model = source.get_model().project(bbox=overlap) 

265 # Convolve the model with the PSF in each band 

266 # Always use a real space convolution to limit artifacts 

267 model = blend.observation.convolve(model, mode="real") 

268 

269 # Update xy0 with the origin of the sources box 

270 xy0 = geom.Point2I(model.yx0[-1], model.yx0[-2]) 

271 # Create the spans for the footprint 

272 valid = np.max(model.data, axis=0) != 0 

273 valid = Mask(valid.astype(np.int32), xy0=xy0) 

274 spans = SpanSet.fromMask(valid) 

275 

276 # Add the location of the source to the peak catalog 

277 peakCat = PeakCatalog(source.detectedPeak.table) 

278 peakCat.append(source.detectedPeak) 

279 

280 # Create the MultibandHeavyFootprint 

281 foot = afwFootprint(spans) 

282 foot.setPeakCatalog(peakCat) 

283 if model.n_bands == 1: 

284 image = afwImage( 

285 array=model.data[0], 

286 xy0=valid.getBBox().getMin(), 

287 dtype=model.dtype 

288 ) 

289 maskedImage = MaskedImage(image, dtype=model.dtype) 

290 heavy = makeHeavyFootprint(foot, maskedImage) 

291 else: 

292 model = MultibandImage(blend.bands, model.data, valid.getBBox()) 

293 heavy = MultibandFootprint.fromImages(blend.bands, model, footprint=foot) 

294 return heavy