Coverage for python/lsst/meas/extensions/scarlet/io.py: 9%

151 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-29 03:37 -0700

1# This file is part of meas_extensions_scarlet. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23import logging 

24 

25import numpy as np 

26 

27from lsst.afw.table import SourceCatalog 

28from lsst.afw.image import MaskedImage, Exposure 

29from lsst.afw.detection import Footprint as afwFootprint, HeavyFootprintF 

30from lsst.afw.geom import SpanSet, Span 

31from lsst.geom import Box2I, Extent2I, Point2I 

32import lsst.scarlet.lite as scl 

33from lsst.scarlet.lite import Blend, Source, Box, Component, FixedParameter, FactorizedComponent, Image 

34 

35from .metrics import setDeblenderMetrics 

36from .utils import scarletModelToHeavy 

37 

38 

39logger = logging.getLogger(__name__) 

40 

41 

42def monochromaticDataToScarlet( 

43 blendData: scl.io.ScarletBlendData, 

44 bandIndex: int, 

45 observation: scl.Observation, 

46): 

47 """Convert the storage data model into a scarlet lite blend 

48 

49 Parameters 

50 ---------- 

51 blendData: 

52 Persistable data for the entire blend. 

53 bandIndex: 

54 Index of model to extract. 

55 observation: 

56 Observation of region inside the bounding box. 

57 

58 Returns 

59 ------- 

60 blend : `scarlet.lite.LiteBlend` 

61 A scarlet blend model extracted from persisted data. 

62 """ 

63 sources = [] 

64 # Use a dummy band, since we are only extracting a monochromatic model 

65 # that will be turned into a HeavyFootprint. 

66 bands = ("dummy", ) 

67 for sourceId, sourceData in blendData.sources.items(): 

68 components: list[Component] = [] 

69 # There is no need to distinguish factorized components from regular 

70 # components, since there is only one band being used. 

71 for componentData in sourceData.components: 

72 bbox = Box(componentData.shape, origin=componentData.origin) 

73 model = scl.io.Image(componentData.model[bandIndex][None, :, :], yx0=bbox.origin, bands=bands) 

74 component = scl.io.ComponentCube( 

75 bands=bands, 

76 model=model, 

77 peak=tuple(componentData.peak[::-1]), 

78 bbox=bbox, 

79 ) 

80 components.append(component) 

81 

82 for factorizedData in sourceData.factorized_components: 

83 bbox = Box(factorizedData.shape, origin=factorizedData.origin) 

84 # Add dummy values for properties only needed for 

85 # model fitting. 

86 spectrum = FixedParameter(factorizedData.spectrum) 

87 totalBands = len(spectrum.x) 

88 morph = FixedParameter(factorizedData.morph) 

89 factorized = FactorizedComponent( 

90 bands=("dummy", ) * totalBands, 

91 spectrum=spectrum, 

92 morph=morph, 

93 peak=tuple(int(np.round(p)) for p in factorizedData.peak), # type: ignore 

94 bbox=bbox, 

95 bg_rms=np.full((totalBands,), np.nan), 

96 ) 

97 model = factorized.get_model().data[bandIndex][None, :, :] 

98 model = scl.io.Image(model, yx0=bbox.origin, bands=bands) 

99 component = scl.io.ComponentCube( 

100 bands=bands, 

101 model=model, 

102 peak=factorized.peak, 

103 bbox=factorized.bbox, 

104 ) 

105 components.append(component) 

106 

107 source = Source(components=components) 

108 source.record_id = sourceId 

109 source.peak_id = sourceData.peak_id 

110 sources.append(source) 

111 

112 return Blend(sources=sources, observation=observation) 

113 

114 

115def updateCatalogFootprints( 

116 modelData: scl.io.ScarletModelData, 

117 catalog: SourceCatalog, 

118 band: str, 

119 imageForRedistribution: MaskedImage | Exposure | None = None, 

120 removeScarletData: bool = True, 

121 updateFluxColumns: bool = True 

122): 

123 """Use the scarlet models to set HeavyFootprints for modeled sources 

124 

125 Parameters 

126 ---------- 

127 catalog: 

128 The catalog missing heavy footprints for deblended sources. 

129 band: 

130 The name of the band that the catalog data describes. 

131 imageForRedistribution: 

132 The image that is the source for flux re-distribution. 

133 If `imageForRedistribution` is `None` then flux re-distribution is 

134 not performed. 

135 removeScarletData: 

136 Whether or not to remove `ScarletBlendData` for each blend 

137 in order to save memory. 

138 updateFluxColumns: 

139 Whether or not to update the `deblend_*` columns in the catalog. 

140 This should only be true when the input catalog schema already 

141 contains those columns. 

142 """ 

143 # Iterate over the blends, since flux re-distribution must be done on 

144 # all of the children with the same parent 

145 parents = catalog[catalog["parent"] == 0] 

146 

147 for parentRecord in parents: 

148 parentId = parentRecord.getId() 

149 

150 try: 

151 blendModel = modelData.blends[parentId] 

152 except KeyError: 

153 # The parent was skipped in the deblender, so there are 

154 # no models for its sources. 

155 continue 

156 

157 parent = catalog.find(parentId) 

158 if updateFluxColumns and imageForRedistribution is not None: 

159 # Update the data coverage (1 - # of NO_DATA pixels/# of pixels) 

160 parentRecord["deblend_dataCoverage"] = calculateFootprintCoverage( 

161 parent.getFootprint(), 

162 imageForRedistribution.mask 

163 ) 

164 

165 if band not in blendModel.bands: 

166 peaks = parent.getFootprint().peaks 

167 # Set the footprint and coverage of the sources in this blend 

168 # to zero 

169 for sourceId, sourceData in blendModel.sources.items(): 

170 sourceRecord = catalog.find(sourceId) 

171 footprint = afwFootprint() 

172 peakIdx = np.where(peaks["id"] == sourceData.peak_id)[0][0] 

173 peak = peaks[peakIdx] 

174 footprint.addPeak(peak.getIx(), peak.getIy(), peak.getPeakValue()) 

175 sourceRecord.setFootprint(footprint) 

176 if updateFluxColumns: 

177 sourceRecord["deblend_dataCoverage"] = 0 

178 continue 

179 

180 # Get the index of the model for the given band 

181 bandIndex = blendModel.bands.index(band) 

182 

183 updateBlendRecords( 

184 blendData=blendModel, 

185 catalog=catalog, 

186 modelPsf=modelData.psf, 

187 imageForRedistribution=imageForRedistribution, 

188 bandIndex=bandIndex, 

189 parentFootprint=parentRecord.getFootprint(), 

190 updateFluxColumns=updateFluxColumns, 

191 ) 

192 

193 # Save memory by removing the data for the blend 

194 if removeScarletData: 

195 del modelData.blends[parentId] 

196 

197 

198def calculateFootprintCoverage(footprint, maskImage): 

199 """Calculate the fraction of pixels with no data in a Footprint 

200 Parameters 

201 ---------- 

202 footprint : `lsst.afw.detection.Footprint` 

203 The footprint to check for missing data. 

204 maskImage : `lsst.afw.image.MaskX` 

205 The mask image with the ``NO_DATA`` bit set. 

206 Returns 

207 ------- 

208 coverage : `float` 

209 The fraction of pixels in `footprint` where the ``NO_DATA`` bit is set. 

210 """ 

211 # Store the value of "NO_DATA" from the mask plane. 

212 noDataInt = 2**maskImage.getMaskPlaneDict()["NO_DATA"] 

213 

214 # Calculate the coverage in the footprint 

215 bbox = footprint.getBBox() 

216 if bbox.area == 0: 

217 # The source has no footprint, so it has no coverage 

218 return 0 

219 spans = footprint.spans.asArray() 

220 totalArea = footprint.getArea() 

221 mask = maskImage[bbox].array & noDataInt 

222 noData = (mask * spans) > 0 

223 coverage = 1 - np.sum(noData)/totalArea 

224 return coverage 

225 

226 

227def updateBlendRecords( 

228 blendData: scl.io.ScarletBlendData, 

229 catalog: SourceCatalog, 

230 modelPsf: np.ndarray, 

231 imageForRedistribution: MaskedImage | Exposure | None, 

232 bandIndex: int, 

233 parentFootprint: afwFootprint, 

234 updateFluxColumns: bool, 

235): 

236 """Create footprints and update band-dependent columns in the catalog 

237 

238 Parameters 

239 ---------- 

240 blendData: 

241 Persistable data for the entire blend. 

242 catalog: 

243 The catalog that is being updated. 

244 modelPsf: 

245 The 2D model of the PSF. 

246 observedPsf: 

247 The observed PSF model for the catalog. 

248 imageForRedistribution: 

249 The image that is the source for flux re-distribution. 

250 If `imageForRedistribution` is `None` then flux re-distribution is 

251 not performed. 

252 bandIndex: 

253 The number of the band to extract. 

254 parentFootprint: 

255 The footprint of the parent, used for masking out the model 

256 when re-distributing flux. 

257 updateFluxColumns: 

258 Whether or not to update the `deblend_*` columns in the catalog. 

259 This should only be true when the input catalog schema already 

260 contains those columns. 

261 """ 

262 useFlux = imageForRedistribution is not None 

263 bands = ("dummy",) 

264 # Only use the PSF for the current image 

265 psfs = np.array([blendData.psf[bandIndex]]) 

266 

267 if useFlux: 

268 # Extract the image array to re-distribute its flux 

269 xy0 = Point2I(*blendData.origin[::-1]) 

270 extent = Extent2I(*blendData.shape[::-1]) 

271 bbox = Box2I(xy0, extent) 

272 

273 images = Image( 

274 imageForRedistribution[bbox].image.array[None, :, :], 

275 yx0=blendData.origin, 

276 bands=bands, 

277 ) 

278 

279 variance = Image( 

280 imageForRedistribution[bbox].variance.array[None, :, :], 

281 yx0=blendData.origin, 

282 bands=bands, 

283 ) 

284 

285 weights = Image( 

286 parentFootprint.spans.asArray()[None, :, :], 

287 yx0=blendData.origin, 

288 bands=bands, 

289 ) 

290 

291 observation = scl.io.Observation( 

292 images=images, 

293 variance=variance, 

294 weights=weights, 

295 psfs=psfs, 

296 model_psf=modelPsf[None, :, :], 

297 ) 

298 else: 

299 observation = scl.io.Observation.empty( 

300 bands=bands, 

301 psfs=psfs, 

302 model_psf=modelPsf[None, :, :], 

303 bbox=Box(blendData.shape, blendData.origin), 

304 dtype=np.float32, 

305 ) 

306 

307 blend = monochromaticDataToScarlet( 

308 blendData=blendData, 

309 bandIndex=bandIndex, 

310 observation=observation, 

311 ) 

312 

313 if useFlux: 

314 # Re-distribute the flux in the images 

315 blend.conserve_flux() 

316 

317 # Set the metrics for the blend. 

318 # TODO: remove this once DM-34558 runs all deblender metrics 

319 # in a separate task. 

320 if updateFluxColumns: 

321 setDeblenderMetrics(blend) 

322 

323 # Update the HeavyFootprints for deblended sources 

324 # and update the band-dependent catalog columns. 

325 for source in blend.sources: 

326 sourceRecord = catalog.find(source.record_id) 

327 parent = catalog.find(sourceRecord["parent"]) 

328 peaks = parent.getFootprint().peaks 

329 peakIdx = np.where(peaks["id"] == source.peak_id)[0][0] 

330 source.detectedPeak = peaks[peakIdx] 

331 # Set the Footprint 

332 heavy = scarletModelToHeavy( 

333 source=source, 

334 blend=blend, 

335 useFlux=useFlux, 

336 ) 

337 

338 if updateFluxColumns: 

339 if heavy.getArea() == 0: 

340 # The source has no flux after being weighted with the PSF 

341 # in this particular band (it might have flux in others). 

342 sourceRecord.set("deblend_zeroFlux", True) 

343 # Create a Footprint with a single pixel, set to zero, 

344 # to avoid breakage in measurement algorithms. 

345 center = Point2I(heavy.peaks[0]["i_x"], heavy.peaks[0]["i_y"]) 

346 spanList = [Span(center.y, center.x, center.x)] 

347 footprint = afwFootprint(SpanSet(spanList)) 

348 footprint.setPeakCatalog(heavy.peaks) 

349 heavy = HeavyFootprintF(footprint) 

350 heavy.getImageArray()[0] = 0.0 

351 else: 

352 sourceRecord.set("deblend_zeroFlux", False) 

353 sourceRecord.setFootprint(heavy) 

354 

355 if useFlux: 

356 # Set the fraction of pixels with valid data. 

357 coverage = calculateFootprintCoverage(heavy, imageForRedistribution.mask) 

358 sourceRecord.set("deblend_dataCoverage", coverage) 

359 

360 # Set the flux of the scarlet model 

361 # TODO: this field should probably be deprecated, 

362 # since DM-33710 gives users access to the scarlet models. 

363 model = source.get_model().data[0] 

364 sourceRecord.set("deblend_scarletFlux", np.sum(model)) 

365 

366 # Set the flux at the center of the model 

367 peak = heavy.peaks[0] 

368 

369 img = heavy.extractImage(fill=0.0) 

370 try: 

371 sourceRecord.set("deblend_peak_instFlux", img[Point2I(peak["i_x"], peak["i_y"])]) 

372 except Exception: 

373 srcId = sourceRecord.getId() 

374 x = peak["i_x"] 

375 y = peak["i_y"] 

376 logger.warning( 

377 f"Source {srcId} at {x},{y} could not set the peak flux with error:", 

378 exc_info=1 

379 ) 

380 sourceRecord.set("deblend_peak_instFlux", np.nan) 

381 

382 # Set the metrics columns. 

383 # TODO: remove this once DM-34558 runs all deblender metrics 

384 # in a separate task. 

385 sourceRecord.set("deblend_maxOverlap", source.metrics.maxOverlap[0]) 

386 sourceRecord.set("deblend_fluxOverlap", source.metrics.fluxOverlap[0]) 

387 sourceRecord.set("deblend_fluxOverlapFraction", source.metrics.fluxOverlapFraction[0]) 

388 sourceRecord.set("deblend_blendedness", source.metrics.blendedness[0]) 

389 else: 

390 sourceRecord.setFootprint(heavy) 

391 

392 

393def oldScarletToData(blend: Blend, psfCenter: tuple[int, int], xy0: Point2I): 

394 """Convert a scarlet.lite blend into a persistable data object 

395 

396 Note: This converts a blend from the old version of scarlet.lite, 

397 which is deprecated, to the persistable data format used in the 

398 new scarlet lite package. 

399 This is kept to compare the two scarlet versions, 

400 and can be removed once the new lsst.scarlet.lite package is 

401 used in production. 

402 

403 Parameters 

404 ---------- 

405 blend: 

406 The blend that is being persisted. 

407 psfCenter: 

408 The center of the PSF. 

409 xy0: 

410 The lower coordinate of the entire blend. 

411 Returns 

412 ------- 

413 blendData : `ScarletBlendDataModel` 

414 The data model for a single blend. 

415 """ 

416 from scarlet import lite 

417 yx0 = (xy0.y, xy0.x) 

418 

419 sources = {} 

420 for source in blend.sources: 

421 components = [] 

422 factorizedComponents = [] 

423 for component in source.components: 

424 origin = tuple(component.bbox.origin[i+1] + yx0[i] for i in range(2)) 

425 peak = tuple(component.center[i] + yx0[i] for i in range(2)) 

426 

427 if isinstance(component, lite.LiteFactorizedComponent): 

428 componentData = scl.io.ScarletFactorizedComponentData( 

429 origin=origin, 

430 peak=peak, 

431 spectrum=component.sed, 

432 morph=component.morph, 

433 ) 

434 factorizedComponents.append(componentData) 

435 else: 

436 componentData = scl.io.ScarletComponentData( 

437 origin=origin, 

438 peak=peak, 

439 model=component.get_model(), 

440 ) 

441 components.append(componentData) 

442 sourceData = scl.io.ScarletSourceData( 

443 components=components, 

444 factorized_components=factorizedComponents, 

445 peak_id=source.peak_id, 

446 ) 

447 sources[source.record_id] = sourceData 

448 

449 blendData = scl.io.ScarletBlendData( 

450 origin=(xy0.y, xy0.x), 

451 shape=blend.observation.bbox.shape[-2:], 

452 sources=sources, 

453 psf_center=psfCenter, 

454 psf=blend.observation.psfs, 

455 bands=blend.observation.bands, 

456 ) 

457 

458 return blendData