Coverage for python/lsst/meas/extensions/scarlet/io.py: 9%

140 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-18 12:44 +0000

1# This file is part of meas_extensions_scarlet. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23import logging 

24 

25import numpy as np 

26 

27from lsst.afw.table import SourceCatalog 

28from lsst.afw.image import MaskedImage, Exposure 

29from lsst.afw.detection import Footprint as afwFootprint 

30from lsst.geom import Box2I, Extent2I, Point2I 

31import lsst.scarlet.lite as scl 

32from lsst.scarlet.lite import Blend, Source, Box, Component, FixedParameter, FactorizedComponent, Image 

33 

34from .metrics import setDeblenderMetrics 

35from .utils import scarletModelToHeavy 

36 

37 

38logger = logging.getLogger(__name__) 

39 

40 

41def monochromaticDataToScarlet( 

42 blendData: scl.io.ScarletBlendData, 

43 bandIndex: int, 

44 observation: scl.Observation, 

45): 

46 """Convert the storage data model into a scarlet lite blend 

47 

48 Parameters 

49 ---------- 

50 blendData: 

51 Persistable data for the entire blend. 

52 bandIndex: 

53 Index of model to extract. 

54 observation: 

55 Observation of region inside the bounding box. 

56 

57 Returns 

58 ------- 

59 blend : `scarlet.lite.LiteBlend` 

60 A scarlet blend model extracted from persisted data. 

61 """ 

62 sources = [] 

63 # Use a dummy band, since we are only extracting a monochromatic model 

64 # that will be turned into a HeavyFootprint. 

65 bands = ("dummy", ) 

66 for sourceId, sourceData in blendData.sources.items(): 

67 components: list[Component] = [] 

68 # There is no need to distinguish factorized components from regular 

69 # components, since there is only one band being used. 

70 for componentData in sourceData.components: 

71 bbox = Box(componentData.shape, origin=componentData.origin) 

72 model = scl.io.Image(componentData.model[bandIndex][None, :, :], yx0=bbox.origin, bands=bands) 

73 component = scl.io.ComponentCube( 

74 bands=bands, 

75 model=model, 

76 peak=tuple(componentData.peak[::-1]), 

77 bbox=bbox, 

78 ) 

79 components.append(component) 

80 

81 for factorizedData in sourceData.factorized_components: 

82 bbox = Box(factorizedData.shape, origin=factorizedData.origin) 

83 # Add dummy values for properties only needed for 

84 # model fitting. 

85 spectrum = FixedParameter(factorizedData.spectrum) 

86 totalBands = len(spectrum.x) 

87 morph = FixedParameter(factorizedData.morph) 

88 factorized = FactorizedComponent( 

89 bands=("dummy", ) * totalBands, 

90 spectrum=spectrum, 

91 morph=morph, 

92 peak=tuple(int(np.round(p)) for p in factorizedData.peak), # type: ignore 

93 bbox=bbox, 

94 bg_rms=np.full((totalBands,), np.nan), 

95 ) 

96 model = factorized.get_model().data[bandIndex][None, :, :] 

97 model = scl.io.Image(model, yx0=bbox.origin, bands=bands) 

98 component = scl.io.ComponentCube( 

99 bands=bands, 

100 model=model, 

101 peak=factorized.peak, 

102 bbox=factorized.bbox, 

103 ) 

104 components.append(component) 

105 

106 source = Source(components=components) 

107 source.record_id = sourceId 

108 source.peak_id = sourceData.peak_id 

109 sources.append(source) 

110 

111 return Blend(sources=sources, observation=observation) 

112 

113 

114def updateCatalogFootprints( 

115 modelData: scl.io.ScarletModelData, 

116 catalog: SourceCatalog, 

117 band: str, 

118 imageForRedistribution: MaskedImage | Exposure | None = None, 

119 removeScarletData: bool = True, 

120 updateFluxColumns: bool = True 

121): 

122 """Use the scarlet models to set HeavyFootprints for modeled sources 

123 

124 Parameters 

125 ---------- 

126 catalog: 

127 The catalog missing heavy footprints for deblended sources. 

128 band: 

129 The name of the band that the catalog data describes. 

130 imageForRedistribution: 

131 The image that is the source for flux re-distribution. 

132 If `imageForRedistribution` is `None` then flux re-distribution is 

133 not performed. 

134 removeScarletData: 

135 Whether or not to remove `ScarletBlendData` for each blend 

136 in order to save memory. 

137 updateFluxColumns: 

138 Whether or not to update the `deblend_*` columns in the catalog. 

139 This should only be true when the input catalog schema already 

140 contains those columns. 

141 """ 

142 # Iterate over the blends, since flux re-distribution must be done on 

143 # all of the children with the same parent 

144 parents = catalog[catalog["parent"] == 0] 

145 

146 for parentRecord in parents: 

147 parentId = parentRecord.getId() 

148 

149 try: 

150 blendModel = modelData.blends[parentId] 

151 except KeyError: 

152 # The parent was skipped in the deblender, so there are 

153 # no models for its sources. 

154 continue 

155 

156 parent = catalog.find(parentId) 

157 if updateFluxColumns and imageForRedistribution is not None: 

158 # Update the data coverage (1 - # of NO_DATA pixels/# of pixels) 

159 parentRecord["deblend_dataCoverage"] = calculateFootprintCoverage( 

160 parent.getFootprint(), 

161 imageForRedistribution.mask 

162 ) 

163 

164 if band not in blendModel.bands: 

165 peaks = parent.getFootprint().peaks 

166 # Set the footprint and coverage of the sources in this blend 

167 # to zero 

168 for sourceId, sourceData in blendModel.sources.items(): 

169 sourceRecord = catalog.find(sourceId) 

170 footprint = afwFootprint() 

171 peakIdx = np.where(peaks["id"] == sourceData.peak_id)[0][0] 

172 peak = peaks[peakIdx] 

173 footprint.addPeak(peak.getIx(), peak.getIy(), peak.getPeakValue()) 

174 sourceRecord.setFootprint(footprint) 

175 if updateFluxColumns: 

176 sourceRecord["deblend_dataCoverage"] = 0 

177 continue 

178 

179 # Get the index of the model for the given band 

180 bandIndex = blendModel.bands.index(band) 

181 

182 updateBlendRecords( 

183 blendData=blendModel, 

184 catalog=catalog, 

185 modelPsf=modelData.psf, 

186 imageForRedistribution=imageForRedistribution, 

187 bandIndex=bandIndex, 

188 parentFootprint=parentRecord.getFootprint(), 

189 updateFluxColumns=updateFluxColumns, 

190 ) 

191 

192 # Save memory by removing the data for the blend 

193 if removeScarletData: 

194 del modelData.blends[parentId] 

195 

196 

197def calculateFootprintCoverage(footprint, maskImage): 

198 """Calculate the fraction of pixels with no data in a Footprint 

199 Parameters 

200 ---------- 

201 footprint : `lsst.afw.detection.Footprint` 

202 The footprint to check for missing data. 

203 maskImage : `lsst.afw.image.MaskX` 

204 The mask image with the ``NO_DATA`` bit set. 

205 Returns 

206 ------- 

207 coverage : `float` 

208 The fraction of pixels in `footprint` where the ``NO_DATA`` bit is set. 

209 """ 

210 # Store the value of "NO_DATA" from the mask plane. 

211 noDataInt = 2**maskImage.getMaskPlaneDict()["NO_DATA"] 

212 

213 # Calculate the coverage in the footprint 

214 bbox = footprint.getBBox() 

215 if bbox.area == 0: 

216 # The source has no footprint, so it has no coverage 

217 return 0 

218 spans = footprint.spans.asArray() 

219 totalArea = footprint.getArea() 

220 mask = maskImage[bbox].array & noDataInt 

221 noData = (mask * spans) > 0 

222 coverage = 1 - np.sum(noData)/totalArea 

223 return coverage 

224 

225 

226def updateBlendRecords( 

227 blendData: scl.io.ScarletBlendData, 

228 catalog: SourceCatalog, 

229 modelPsf: np.ndarray, 

230 imageForRedistribution: MaskedImage | Exposure | None, 

231 bandIndex: int, 

232 parentFootprint: afwFootprint, 

233 updateFluxColumns: bool, 

234): 

235 """Create footprints and update band-dependent columns in the catalog 

236 

237 Parameters 

238 ---------- 

239 blendData: 

240 Persistable data for the entire blend. 

241 catalog: 

242 The catalog that is being updated. 

243 modelPsf: 

244 The 2D model of the PSF. 

245 observedPsf: 

246 The observed PSF model for the catalog. 

247 imageForRedistribution: 

248 The image that is the source for flux re-distribution. 

249 If `imageForRedistribution` is `None` then flux re-distribution is 

250 not performed. 

251 bandIndex: 

252 The number of the band to extract. 

253 parentFootprint: 

254 The footprint of the parent, used for masking out the model 

255 when re-distributing flux. 

256 updateFluxColumns: 

257 Whether or not to update the `deblend_*` columns in the catalog. 

258 This should only be true when the input catalog schema already 

259 contains those columns. 

260 """ 

261 useFlux = imageForRedistribution is not None 

262 bands = ("dummy",) 

263 # Only use the PSF for the current image 

264 psfs = np.array([blendData.psf[bandIndex]]) 

265 

266 if useFlux: 

267 # Extract the image array to re-distribute its flux 

268 xy0 = Point2I(*blendData.origin[::-1]) 

269 extent = Extent2I(*blendData.shape[::-1]) 

270 bbox = Box2I(xy0, extent) 

271 

272 images = Image( 

273 imageForRedistribution[bbox].image.array[None, :, :], 

274 yx0=blendData.origin, 

275 bands=bands, 

276 ) 

277 

278 variance = Image( 

279 imageForRedistribution[bbox].variance.array[None, :, :], 

280 yx0=blendData.origin, 

281 bands=bands, 

282 ) 

283 

284 weights = Image( 

285 parentFootprint.spans.asArray()[None, :, :], 

286 yx0=blendData.origin, 

287 bands=bands, 

288 ) 

289 

290 observation = scl.io.Observation( 

291 images=images, 

292 variance=variance, 

293 weights=weights, 

294 psfs=psfs, 

295 model_psf=modelPsf[None, :, :], 

296 ) 

297 else: 

298 observation = scl.io.Observation.empty( 

299 bands=bands, 

300 psfs=psfs, 

301 model_psf=modelPsf[None, :, :], 

302 bbox=Box(blendData.shape, blendData.origin), 

303 dtype=np.float32, 

304 ) 

305 

306 blend = monochromaticDataToScarlet( 

307 blendData=blendData, 

308 bandIndex=bandIndex, 

309 observation=observation, 

310 ) 

311 

312 if useFlux: 

313 # Re-distribute the flux in the images 

314 blend.conserve_flux() 

315 

316 # Set the metrics for the blend. 

317 # TODO: remove this once DM-34558 runs all deblender metrics 

318 # in a separate task. 

319 if updateFluxColumns: 

320 setDeblenderMetrics(blend) 

321 

322 # Update the HeavyFootprints for deblended sources 

323 # and update the band-dependent catalog columns. 

324 for source in blend.sources: 

325 sourceRecord = catalog.find(source.record_id) 

326 parent = catalog.find(sourceRecord["parent"]) 

327 peaks = parent.getFootprint().peaks 

328 peakIdx = np.where(peaks["id"] == source.peak_id)[0][0] 

329 source.detectedPeak = peaks[peakIdx] 

330 # Set the Footprint 

331 heavy = scarletModelToHeavy( 

332 source=source, 

333 blend=blend, 

334 useFlux=useFlux, 

335 ) 

336 sourceRecord.setFootprint(heavy) 

337 

338 if updateFluxColumns: 

339 if useFlux: 

340 # Set the fraction of pixels with valid data. 

341 coverage = calculateFootprintCoverage(heavy, imageForRedistribution.mask) 

342 sourceRecord.set("deblend_dataCoverage", coverage) 

343 

344 # Set the flux of the scarlet model 

345 # TODO: this field should probably be deprecated, 

346 # since DM-33710 gives users access to the scarlet models. 

347 model = source.get_model().data[0] 

348 sourceRecord.set("deblend_scarletFlux", np.sum(model)) 

349 

350 # Set the flux at the center of the model 

351 peak = heavy.peaks[0] 

352 

353 img = heavy.extractImage(fill=0.0) 

354 try: 

355 sourceRecord.set("deblend_peak_instFlux", img[Point2I(peak["i_x"], peak["i_y"])]) 

356 except Exception: 

357 srcId = sourceRecord.getId() 

358 x = peak["i_x"] 

359 y = peak["i_y"] 

360 logger.warning( 

361 f"Source {srcId} at {x},{y} could not set the peak flux with error:", 

362 exc_info=1 

363 ) 

364 sourceRecord.set("deblend_peak_instFlux", np.nan) 

365 

366 # Set the metrics columns. 

367 # TODO: remove this once DM-34558 runs all deblender metrics 

368 # in a separate task. 

369 sourceRecord.set("deblend_maxOverlap", source.metrics.maxOverlap[0]) 

370 sourceRecord.set("deblend_fluxOverlap", source.metrics.fluxOverlap[0]) 

371 sourceRecord.set("deblend_fluxOverlapFraction", source.metrics.fluxOverlapFraction[0]) 

372 sourceRecord.set("deblend_blendedness", source.metrics.blendedness[0]) 

373 

374 

375def oldScarletToData(blend: Blend, psfCenter: tuple[int, int], xy0: Point2I): 

376 """Convert a scarlet.lite blend into a persistable data object 

377 

378 Note: This converts a blend from the old version of scarlet.lite, 

379 which is deprecated, to the persistable data format used in the 

380 new scarlet lite package. 

381 This is kept to compare the two scarlet versions, 

382 and can be removed once the new lsst.scarlet.lite package is 

383 used in production. 

384 

385 Parameters 

386 ---------- 

387 blend: 

388 The blend that is being persisted. 

389 psfCenter: 

390 The center of the PSF. 

391 xy0: 

392 The lower coordinate of the entire blend. 

393 Returns 

394 ------- 

395 blendData : `ScarletBlendDataModel` 

396 The data model for a single blend. 

397 """ 

398 from scarlet import lite 

399 yx0 = (xy0.y, xy0.x) 

400 

401 sources = {} 

402 for source in blend.sources: 

403 components = [] 

404 factorizedComponents = [] 

405 for component in source.components: 

406 origin = tuple(component.bbox.origin[i+1] + yx0[i] for i in range(2)) 

407 peak = tuple(component.center[i] + yx0[i] for i in range(2)) 

408 

409 if isinstance(component, lite.LiteFactorizedComponent): 

410 componentData = scl.io.ScarletFactorizedComponentData( 

411 origin=origin, 

412 peak=peak, 

413 spectrum=component.sed, 

414 morph=component.morph, 

415 ) 

416 factorizedComponents.append(componentData) 

417 else: 

418 componentData = scl.io.ScarletComponentData( 

419 origin=origin, 

420 peak=peak, 

421 model=component.get_model(), 

422 ) 

423 components.append(componentData) 

424 sourceData = scl.io.ScarletSourceData( 

425 components=components, 

426 factorized_components=factorizedComponents, 

427 peak_id=source.peak_id, 

428 ) 

429 sources[source.record_id] = sourceData 

430 

431 blendData = scl.io.ScarletBlendData( 

432 origin=(xy0.y, xy0.x), 

433 shape=blend.observation.bbox.shape[-2:], 

434 sources=sources, 

435 psf_center=psfCenter, 

436 psf=blend.observation.psfs, 

437 bands=blend.observation.bands, 

438 ) 

439 

440 return blendData