Coverage for python/lsst/meas/extensions/scarlet/utils.py: 17%

69 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-12 10:45 +0000

1import numpy as np 

2 

3from lsst.afw.detection import Footprint as afwFootprint 

4from lsst.afw.detection import HeavyFootprintF, makeHeavyFootprint, PeakCatalog 

5from lsst.afw.detection.multiband import MultibandFootprint 

6from lsst.afw.geom import SpanSet 

7from lsst.afw.image import Mask, MaskedImage, Image as afwImage, MultibandExposure, MultibandImage 

8from lsst.afw.table import SourceCatalog 

9import lsst.geom as geom 

10import lsst.scarlet.lite as scl 

11 

12 

13defaultBadPixelMasks = ["BAD", "CR", "NO_DATA", "SAT", "SUSPECT", "EDGE"] 

14 

15 

16def footprintsToNumpy( 

17 catalog: SourceCatalog, 

18 shape: tuple[int, int], 

19 xy0: tuple[int, int] | None = None, 

20) -> np.ndarray: 

21 """Convert all of the footprints in a catalog into a boolean array. 

22 

23 Parameters 

24 ---------- 

25 catalog: 

26 The source catalog containing the footprints. 

27 This is typically a mergeDet catalog, or a full source catalog 

28 with the parents removed. 

29 shape: 

30 The final shape of the output array. 

31 xy0: 

32 The lower-left corner of the array that will contain the spans. 

33 

34 Returns 

35 ------- 

36 result: 

37 The array with pixels contained in `spans` marked as `True`. 

38 """ 

39 if xy0 is None: 

40 offset = (0, 0) 

41 else: 

42 offset = (-xy0[0], -xy0[1]) 

43 

44 result = np.zeros(shape, dtype=bool) 

45 for src in catalog: 

46 spans = src.getFootprint().spans 

47 yidx, xidx = spans.shiftedBy(*offset).indices() 

48 result[yidx, xidx] = 1 

49 return result 

50 

51 

52def scarletBoxToBBox(box: scl.Box, xy0: geom.Point2I = geom.Point2I()) -> geom.Box2I: 

53 """Convert a scarlet_lite Box into a Box2I. 

54 

55 Parameters 

56 ---------- 

57 box: 

58 The scarlet bounding box to convert. 

59 xy0: 

60 An additional offset to add to the scarlet box. 

61 This is common since scarlet sources have an origin of 

62 `(0,0)` at the lower left corner of the blend while 

63 the blend itself is likely to have an offset in the 

64 `Exposure`. 

65 

66 Returns 

67 ------- 

68 bbox: 

69 The converted bounding box. 

70 """ 

71 xy0 = geom.Point2I(box.origin[-1] + xy0.x, box.origin[-2] + xy0.y) 

72 extent = geom.Extent2I(box.shape[-1], box.shape[-2]) 

73 return geom.Box2I(xy0, extent) 

74 

75 

76def bboxToScarletBox(bbox: geom.Box2I, xy0: geom.Point2I = geom.Point2I()) -> scl.Box: 

77 """Convert a Box2I into a scarlet_lite Box. 

78 

79 Parameters 

80 ---------- 

81 bbox: 

82 The Box2I to convert into a scarlet `Box`. 

83 xy0: 

84 An overall offset to subtract from the `Box2I`. 

85 This is common in blends, where `xy0` is the minimum pixel 

86 location of the blend and `bbox` is the box containing 

87 a source in the blend. 

88 

89 Returns 

90 ------- 

91 box: 

92 A scarlet `Box` that is more useful for slicing image data 

93 as a numpy array. 

94 """ 

95 origin = (bbox.getMinY() - xy0.y, bbox.getMinX() - xy0.x) 

96 return scl.Box((bbox.getHeight(), bbox.getWidth()), origin) 

97 

98 

99def buildObservation( 

100 modelPsf: np.ndarray, 

101 psfCenter: tuple[int, int] | geom.Point2I | geom.Point2D, 

102 mExposure: MultibandExposure, 

103 badPixelMasks: list[str] | None = None, 

104 footprint: afwFootprint = None, 

105 useWeights: bool = True, 

106 convolutionType: str = "real", 

107) -> scl.Observation: 

108 """Generate an Observation from a set of arguments. 

109 

110 Make the generation and reconstruction of a scarlet model consistent 

111 by building an `Observation` from a set of arguments. 

112 

113 Parameters 

114 ---------- 

115 modelPsf: 

116 The 2D model of the PSF in the partially deconvolved space. 

117 psfCenter: 

118 The location `(x, y)` used as the center of the PSF. 

119 mExposure: 

120 The multi-band exposure that the model represents. 

121 If `mExposure` is `None` then no image, variance, or weights are 

122 attached to the observation. 

123 footprint: 

124 The footprint that is being fit. 

125 If `footprint` is `None` then the weights are not updated to mask 

126 out pixels not contained in the footprint. 

127 badPixelMasks: 

128 The keys from the bit mask plane used to mask out pixels 

129 during the fit. 

130 If `badPixelMasks` is `None` then the default values from 

131 `ScarletDeblendConfig.badMask` are used. 

132 useWeights: 

133 Whether or not fitting should use inverse variance weights to 

134 calculate the log-likelihood. 

135 convolutionType: 

136 The type of convolution to use (either "real" or "fft"). 

137 When reconstructing an image it is advised to use "real" to avoid 

138 polluting the footprint with artifacts from the fft. 

139 

140 Returns 

141 ------- 

142 observation: 

143 The observation constructed from the input parameters. 

144 """ 

145 # Initialize the observed PSFs 

146 if not isinstance(psfCenter, geom.Point2D): 

147 psfCenter = geom.Point2D(*psfCenter) 

148 psfModels = mExposure.computePsfKernelImage(psfCenter) 

149 

150 # Use the inverse variance as the weights 

151 if useWeights: 

152 weights = 1/mExposure.variance.array 

153 else: 

154 weights = np.ones_like(mExposure.image.array) 

155 

156 # Mask out bad pixels 

157 if badPixelMasks is None: 

158 badPixelMasks = defaultBadPixelMasks 

159 badPixels = mExposure.mask.getPlaneBitMask(badPixelMasks) 

160 mask = mExposure.mask.array & badPixels 

161 weights[mask > 0] = 0 

162 

163 if footprint is not None: 

164 # Mask out the pixels outside the footprint 

165 weights *= footprint.spans.asArray() 

166 

167 return scl.Observation( 

168 images=mExposure.image.array, 

169 variance=mExposure.variance.array, 

170 weights=weights, 

171 psfs=psfModels.array, 

172 model_psf=modelPsf[None, :, :], 

173 convolution_mode=convolutionType, 

174 bands=mExposure.filters, 

175 bbox=bboxToScarletBox(mExposure.getBBox()), 

176 ) 

177 

178 

179def scarletModelToHeavy( 

180 source: scl.Source, blend: scl.Blend, useFlux=False, 

181) -> HeavyFootprintF | MultibandFootprint: 

182 """Convert a scarlet_lite model to a `HeavyFootprintF` 

183 or `MultibandFootprint`. 

184 

185 Parameters 

186 ---------- 

187 source: 

188 The source to convert to a `HeavyFootprint`. 

189 blend: 

190 The `Blend` object that contains information about 

191 the observation, PSF, etc, used to convolve the 

192 scarlet model to the observed seeing in each band. 

193 useFlux: 

194 Whether or not to re-distribute the flux from the image 

195 to conserve flux. 

196 

197 Returns 

198 ------- 

199 heavy: 

200 The footprint (possibly multiband) containing the model for the source. 

201 """ 

202 # We want to convolve the model with the observed PSF, 

203 # which means we need to grow the model box by the PSF to 

204 # account for all of the flux after convolution. 

205 

206 # Get the PSF size and radii to grow the box 

207 py, px = blend.observation.psfs.shape[1:] 

208 dh = py // 2 

209 dw = px // 2 

210 

211 if useFlux: 

212 bbox = source.flux_weighted_image.bbox 

213 else: 

214 bbox = source.bbox.grow((dh, dw)) 

215 # Only use the portion of the convolved model that fits in the image 

216 overlap = bbox & blend.observation.bbox 

217 # Load the full multiband model in the larger box 

218 if useFlux: 

219 # The flux weighted model is already convolved, so we just load it 

220 model = source.get_model(use_flux=True).project(bbox=overlap) 

221 else: 

222 model = source.get_model().project(bbox=overlap) 

223 # Convolve the model with the PSF in each band 

224 # Always use a real space convolution to limit artifacts 

225 model = blend.observation.convolve(model, mode="real") 

226 

227 # Update xy0 with the origin of the sources box 

228 xy0 = geom.Point2I(model.yx0[-1], model.yx0[-2]) 

229 # Create the spans for the footprint 

230 valid = np.max(model.data, axis=0) != 0 

231 valid = Mask(valid.astype(np.int32), xy0=xy0) 

232 spans = SpanSet.fromMask(valid) 

233 

234 # Add the location of the source to the peak catalog 

235 peakCat = PeakCatalog(source.detectedPeak.table) 

236 peakCat.append(source.detectedPeak) 

237 

238 # Create the MultibandHeavyFootprint 

239 foot = afwFootprint(spans) 

240 foot.setPeakCatalog(peakCat) 

241 if model.n_bands == 1: 

242 image = afwImage( 

243 array=model.data[0], 

244 xy0=valid.getBBox().getMin(), 

245 dtype=model.dtype 

246 ) 

247 maskedImage = MaskedImage(image, dtype=model.dtype) 

248 heavy = makeHeavyFootprint(foot, maskedImage) 

249 else: 

250 model = MultibandImage(blend.bands, model.data, valid.getBBox()) 

251 heavy = MultibandFootprint.fromImages(blend.bands, model, footprint=foot) 

252 return heavy