Coverage for python / lsst / scarlet / lite / detect.py: 31%

67 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-26 08:40 +0000

1# This file is part of scarlet_lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24import logging 

25from typing import Sequence 

26 

27import numpy as np 

28from lsst.scarlet.lite.detect_pybind11 import Footprint, get_footprints # type: ignore 

29 

30from .bbox import Box, overlapped_slices 

31from .image import Image 

32from .utils import continue_class 

33from .wavelet import ( 

34 get_multiresolution_support, 

35 get_starlet_scales, 

36 multiband_starlet_reconstruction, 

37 starlet_transform, 

38) 

39 

40logger = logging.getLogger("scarlet.detect") 

41 

42 

43def bounds_to_bbox(bounds: tuple[int, int, int, int]) -> Box: 

44 """Convert the bounds of a Footprint into a Box 

45 

46 Notes 

47 ----- 

48 Unlike slices, the bounds are _inclusive_ of the end points. 

49 

50 Parameters 

51 ---------- 

52 bounds: 

53 The bounds of the `Footprint` as a `tuple` of 

54 ``(bottom, top, left, right)``. 

55 Returns 

56 ------- 

57 result: 

58 The `Box` created from the bounds 

59 """ 

60 return Box( 

61 (bounds[1] + 1 - bounds[0], bounds[3] + 1 - bounds[2]), 

62 origin=(bounds[0], bounds[2]), 

63 ) 

64 

65 

66def bbox_to_bounds(bbox: Box) -> tuple[int, int, int, int]: 

67 """Convert a Box into the bounds of a Footprint 

68 

69 Parameters 

70 ---------- 

71 bbox: 

72 The `Box` to convert into bounds. 

73 

74 Returns 

75 ------- 

76 result: 

77 The bounds of the `Footprint` as a `tuple` of 

78 ``(bottom, top, left, right)``. 

79 

80 Notes 

81 ----- 

82 Unlike slices, the bounds are _inclusive_ of the end points. 

83 """ 

84 bounds = ( 

85 bbox.origin[0], 

86 bbox.origin[0] + bbox.shape[0] - 1, 

87 bbox.origin[1], 

88 bbox.origin[1] + bbox.shape[1] - 1, 

89 ) 

90 return bounds 

91 

92 

93@continue_class 

94class Footprint: # type: ignore # noqa 

95 @property 

96 def bbox(self) -> Box: 

97 """Bounding box for the Footprint 

98 

99 Returns 

100 ------- 

101 bbox: 

102 The minimal `Box` that contains the entire `Footprint`. 

103 """ 

104 return bounds_to_bbox(self.bounds) # type: ignore 

105 

106 @property 

107 def yx0(self) -> tuple[int, int]: 

108 """Origin in y, x of the lower left corner of the footprint""" 

109 return self.bounds[0], self.bounds[2] # type: ignore 

110 

111 def intersection(self, other: Footprint) -> Image | None: 

112 """The intersection of two footprints 

113 

114 Parameters 

115 ---------- 

116 other: 

117 The other footprint to compare. 

118 

119 Returns 

120 ------- 

121 intersection: 

122 The intersection of two footprints. 

123 """ 

124 footprint1 = Image(self.data, yx0=self.yx0) # type: ignore 

125 footprint2 = Image(other.data, yx0=other.yx0) # type: ignore # noqa 

126 return footprint1 & footprint2 

127 

128 def union(self, other: Footprint) -> Image | None: 

129 """The intersection of two footprints 

130 

131 Parameters 

132 ---------- 

133 other: 

134 The other footprint to compare. 

135 

136 Returns 

137 ------- 

138 union: 

139 The union of two footprints. 

140 """ 

141 footprint1 = Image(self.data, yx0=self.yx0) # type: ignore 

142 footprint2 = Image(other.data, yx0=other.yx0) 

143 return footprint1 | footprint2 

144 

145 

146def footprints_to_image(footprints: Sequence[Footprint], bbox: Box) -> Image: 

147 """Convert a set of scarlet footprints to a pixelized image. 

148 

149 Parameters 

150 ---------- 

151 footprints: 

152 The footprints to convert into an image. 

153 box: 

154 The full box of the image that will contain the footprints. 

155 

156 Returns 

157 ------- 

158 result: 

159 The image created from the footprints. 

160 """ 

161 result = Image.from_box(bbox, dtype=int) 

162 for k, footprint in enumerate(footprints): 

163 slices = overlapped_slices(result.bbox, footprint.bbox) 

164 result.data[slices[0]] += footprint.data[slices[1]] * (k + 1) 

165 return result 

166 

167 

168def get_wavelets( 

169 images: np.ndarray, 

170 variance: np.ndarray, 

171 scales: int | None = None, 

172 generation: int = 2, 

173) -> np.ndarray: 

174 """Calculate wavelet coefficents given a set of images and their variances 

175 

176 Parameters 

177 ---------- 

178 images: 

179 The array of images with shape `(bands, Ny, Nx)` for which to 

180 calculate wavelet coefficients. 

181 variance: 

182 An array of variances with the same shape as `images`. 

183 scales: 

184 The maximum number of wavelet scales to use. 

185 

186 Returns 

187 ------- 

188 coeffs: 

189 The array of coefficents with shape `(scales+1, bands, Ny, Nx)`. 

190 Note that the result has `scales+1` total arrays, 

191 since the last set of coefficients is the image of all 

192 flux with frequency greater than the last wavelet scale. 

193 """ 

194 sigma = np.median(np.sqrt(variance), axis=(1, 2)) 

195 # Create the wavelet coefficients for the significant pixels 

196 scales = get_starlet_scales(images[0].shape, scales) 

197 coeffs = np.empty((scales + 1,) + images.shape, dtype=images.dtype) 

198 for b, image in enumerate(images): 

199 _coeffs = starlet_transform(image, scales=scales, generation=generation) 

200 support = get_multiresolution_support( 

201 image=image, 

202 starlets=_coeffs, 

203 sigma=sigma[b], 

204 sigma_scaling=3, 

205 epsilon=1e-1, 

206 max_iter=20, 

207 ) 

208 coeffs[:, b] = (support.support * _coeffs).astype(images.dtype) 

209 return coeffs 

210 

211 

212def get_detect_wavelets(images: np.ndarray, variance: np.ndarray, scales: int = 3) -> np.ndarray: 

213 """Get an array of wavelet coefficents to use for detection 

214 

215 Parameters 

216 ---------- 

217 images: 

218 The array of images with shape `(bands, Ny, Nx)` for which to 

219 calculate wavelet coefficients. 

220 variance: 

221 An array of variances with the same shape as `images`. 

222 scales: 

223 The maximum number of wavelet scales to use. 

224 Note that the result will have `scales+1` total arrays, 

225 where the last set of coefficients is the image of all 

226 flux with frequency greater than the last wavelet scale. 

227 

228 Returns 

229 ------- 

230 starlets: 

231 The array of wavelet coefficients for pixels with siignificant 

232 amplitude in each scale. 

233 """ 

234 sigma = np.median(np.sqrt(variance)) 

235 # Create the wavelet coefficients for the significant pixels 

236 detect = np.sum(images, axis=0) 

237 _coeffs = starlet_transform(detect, scales=scales) 

238 support = get_multiresolution_support( 

239 image=detect, 

240 starlets=_coeffs, 

241 sigma=sigma, # type: ignore 

242 sigma_scaling=3, 

243 epsilon=1e-1, 

244 max_iter=20, 

245 ) 

246 return (support.support * _coeffs).astype(images.dtype) 

247 

248 

249def detect_footprints( 

250 images: np.ndarray, 

251 variance: np.ndarray, 

252 scales: int = 1, 

253 generation: int = 2, 

254 origin: tuple[int, int] | None = None, 

255 min_separation: float = 4, 

256 min_area: int = 4, 

257 peak_thresh: float = 5, 

258 footprint_thresh: float = 5, 

259 find_peaks: bool = True, 

260 remove_high_freq: bool = True, 

261 min_pixel_detect: int = 1, 

262) -> list[Footprint]: 

263 """Detect footprints in an image 

264 

265 Parameters 

266 ---------- 

267 images: 

268 The array of images with shape `(bands, Ny, Nx)` for which to 

269 calculate wavelet coefficients. 

270 variance: 

271 An array of variances with the same shape as `images`. 

272 scales: 

273 The maximum number of wavelet scales to use. 

274 If `remove_high_freq` is `False`, then this argument is ignored. 

275 generation: 

276 The generation of the starlet transform to use. 

277 If `remove_high_freq` is `False`, then this argument is ignored. 

278 origin: 

279 The location (y, x) of the lower corner of the image. 

280 min_separation: 

281 The minimum separation between peaks in pixels. 

282 min_area: 

283 The minimum area of a footprint in pixels. 

284 peak_thresh: 

285 The threshold for peak detection. 

286 footprint_thresh: 

287 The threshold for footprint detection. 

288 find_peaks: 

289 If `True`, then detect peaks in the detection image, 

290 otherwise only the footprints are returned. 

291 remove_high_freq: 

292 If `True`, then remove high frequency wavelet coefficients 

293 before detecting peaks. 

294 min_pixel_detect: 

295 The minimum number of bands that must be above the 

296 detection threshold for a pixel to be included in a footprint. 

297 """ 

298 

299 if origin is None: 

300 origin = (0, 0) 

301 if remove_high_freq: 

302 # Build the wavelet coefficients 

303 wavelets = get_wavelets( 

304 images, 

305 variance, 

306 scales=scales, 

307 generation=generation, 

308 ) 

309 # Remove the high frequency wavelets. 

310 # This has the effect of preventing high frequency noise 

311 # from interfering with the detection of peak positions. 

312 wavelets[0] = 0 

313 # Reconstruct the image from the remaining wavelet coefficients 

314 _images = multiband_starlet_reconstruction( 

315 wavelets, 

316 generation=generation, 

317 ) 

318 else: 

319 _images = images 

320 # Build a SNR weighted detection image 

321 sigma = np.median(np.sqrt(variance), axis=(1, 2)) / 2 

322 detection = np.sum(_images / sigma[:, None, None], axis=0) 

323 if min_pixel_detect > 1: 

324 mask = np.sum(images > 0, axis=0) >= min_pixel_detect 

325 detection[~mask] = 0 

326 # Detect peaks on the detection image 

327 footprints = get_footprints( 

328 detection, 

329 min_separation, 

330 min_area, 

331 peak_thresh, 

332 footprint_thresh, 

333 find_peaks, 

334 origin[0], 

335 origin[1], 

336 ) 

337 

338 return footprints