Coverage for python/lsst/meas/extensions/scarlet/io.py: 9%

139 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-12 10:45 +0000

1# This file is part of meas_extensions_scarlet. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23import logging 

24 

25import numpy as np 

26 

27from lsst.afw.table import SourceCatalog 

28from lsst.afw.image import MaskedImage, Exposure 

29from lsst.afw.detection import Footprint as afwFootprint 

30from lsst.geom import Box2I, Extent2I, Point2I 

31import lsst.scarlet.lite as scl 

32from lsst.scarlet.lite import Blend, Source, Box, Component, FixedParameter, FactorizedComponent, Image 

33 

34from .metrics import setDeblenderMetrics 

35from .utils import scarletModelToHeavy 

36 

37 

38logger = logging.getLogger(__name__) 

39 

40 

41def monochromaticDataToScarlet( 

42 blendData: scl.io.ScarletBlendData, 

43 bandIndex: int, 

44 observation: scl.Observation, 

45): 

46 """Convert the storage data model into a scarlet lite blend 

47 

48 Parameters 

49 ---------- 

50 blendData: 

51 Persistable data for the entire blend. 

52 bandIndex: 

53 Index of model to extract. 

54 observation: 

55 Observation of region inside the bounding box. 

56 

57 Returns 

58 ------- 

59 blend : `scarlet.lite.LiteBlend` 

60 A scarlet blend model extracted from persisted data. 

61 """ 

62 sources = [] 

63 # Use a dummy band, since we are only extracting a monochromatic model 

64 # that will be turned into a HeavyFootprint. 

65 bands = ("dummy", ) 

66 for sourceId, sourceData in blendData.sources.items(): 

67 components: list[Component] = [] 

68 # There is no need to distinguish factorized components from regular 

69 # components, since there is only one band being used. 

70 for componentData in sourceData.components: 

71 bbox = Box(componentData.shape, origin=componentData.origin) 

72 model = scl.io.Image(componentData.model[bandIndex][None, :, :], yx0=bbox.origin, bands=bands) 

73 component = scl.io.ComponentCube( 

74 bands=bands, 

75 model=model, 

76 peak=tuple(componentData.peak[::-1]), 

77 bbox=bbox, 

78 ) 

79 components.append(component) 

80 

81 for factorizedData in sourceData.factorized_components: 

82 bbox = Box(factorizedData.shape, origin=factorizedData.origin) 

83 # Add dummy values for properties only needed for 

84 # model fitting. 

85 spectrum = FixedParameter(factorizedData.spectrum) 

86 totalBands = len(spectrum.x) 

87 morph = FixedParameter(factorizedData.morph) 

88 factorized = FactorizedComponent( 

89 bands=("dummy", ) * totalBands, 

90 spectrum=spectrum, 

91 morph=morph, 

92 peak=tuple(int(np.round(p)) for p in factorizedData.peak), # type: ignore 

93 bbox=bbox, 

94 bg_rms=np.full((totalBands,), np.nan), 

95 ) 

96 model = factorized.get_model().data[bandIndex][None, :, :] 

97 model = scl.io.Image(model, yx0=bbox.origin, bands=bands) 

98 component = scl.io.ComponentCube( 

99 bands=bands, 

100 model=model, 

101 peak=factorized.peak, 

102 bbox=factorized.bbox, 

103 ) 

104 components.append(component) 

105 

106 source = Source(components=components) 

107 source.record_id = sourceId 

108 source.peak_id = sourceData.peak_id 

109 sources.append(source) 

110 

111 return Blend(sources=sources, observation=observation) 

112 

113 

114def updateCatalogFootprints( 

115 modelData: scl.io.ScarletModelData, 

116 catalog: SourceCatalog, 

117 band: str, 

118 imageForRedistribution: MaskedImage | Exposure | None = None, 

119 removeScarletData: bool = True, 

120 updateFluxColumns: bool = True 

121): 

122 """Use the scarlet models to set HeavyFootprints for modeled sources 

123 

124 Parameters 

125 ---------- 

126 catalog: 

127 The catalog missing heavy footprints for deblended sources. 

128 band: 

129 The name of the band that the catalog data describes. 

130 imageForRedistribution: 

131 The image that is the source for flux re-distribution. 

132 If `imageForRedistribution` is `None` then flux re-distribution is 

133 not performed. 

134 removeScarletData: 

135 Whether or not to remove `ScarletBlendData` for each blend 

136 in order to save memory. 

137 updateFluxColumns: 

138 Whether or not to update the `deblend_*` columns in the catalog. 

139 This should only be true when the input catalog schema already 

140 contains those columns. 

141 """ 

142 # Iterate over the blends, since flux re-distribution must be done on 

143 # all of the children with the same parent 

144 parents = catalog[catalog["parent"] == 0] 

145 

146 for parentRecord in parents: 

147 parentId = parentRecord.getId() 

148 

149 try: 

150 blendModel = modelData.blends[parentId] 

151 except KeyError: 

152 # The parent was skipped in the deblender, so there are 

153 # no models for its sources. 

154 continue 

155 

156 parent = catalog.find(parentId) 

157 if updateFluxColumns and imageForRedistribution is not None: 

158 # Update the data coverage (1 - # of NO_DATA pixels/# of pixels) 

159 parentRecord["deblend_dataCoverage"] = calculateFootprintCoverage( 

160 parent.getFootprint(), 

161 imageForRedistribution.mask 

162 ) 

163 

164 if band not in blendModel.bands: 

165 peaks = parent.getFootprint().peaks 

166 # Set the footprint and coverage of the sources in this blend 

167 # to zero 

168 for sourceId, sourceData in blendModel.sources.items(): 

169 sourceRecord = catalog.find(sourceId) 

170 footprint = afwFootprint() 

171 peakIdx = np.where(peaks["id"] == sourceData.peakId)[0][0] 

172 peak = peaks[peakIdx] 

173 footprint.addPeak(peak.getIx(), peak.getIy(), peak.getPeakValue()) 

174 sourceRecord.setFootprint(footprint) 

175 if updateFluxColumns: 

176 sourceRecord["deblend_dataCoverage"] = 0 

177 continue 

178 

179 # Get the index of the model for the given band 

180 bandIndex = blendModel.bands.index(band) 

181 

182 updateBlendRecords( 

183 blendData=blendModel, 

184 catalog=catalog, 

185 modelPsf=modelData.psf, 

186 imageForRedistribution=imageForRedistribution, 

187 bandIndex=bandIndex, 

188 parentFootprint=parentRecord.getFootprint(), 

189 updateFluxColumns=updateFluxColumns, 

190 ) 

191 

192 # Save memory by removing the data for the blend 

193 if removeScarletData: 

194 del modelData.blends[parentId] 

195 

196 

197def calculateFootprintCoverage(footprint, maskImage): 

198 """Calculate the fraction of pixels with no data in a Footprint 

199 Parameters 

200 ---------- 

201 footprint : `lsst.afw.detection.Footprint` 

202 The footprint to check for missing data. 

203 maskImage : `lsst.afw.image.MaskX` 

204 The mask image with the ``NO_DATA`` bit set. 

205 Returns 

206 ------- 

207 coverage : `float` 

208 The fraction of pixels in `footprint` where the ``NO_DATA`` bit is set. 

209 """ 

210 # Store the value of "NO_DATA" from the mask plane. 

211 noDataInt = 2**maskImage.getMaskPlaneDict()["NO_DATA"] 

212 

213 # Calculate the coverage in the footprint 

214 bbox = footprint.getBBox() 

215 if bbox.area == 0: 

216 # The source has no footprint, so it has no coverage 

217 return 0 

218 spans = footprint.spans.asArray() 

219 totalArea = footprint.getArea() 

220 mask = maskImage[bbox].array & noDataInt 

221 noData = (mask * spans) > 0 

222 coverage = 1 - np.sum(noData)/totalArea 

223 return coverage 

224 

225 

226def updateBlendRecords( 

227 blendData: scl.io.ScarletBlendData, 

228 catalog: SourceCatalog, 

229 modelPsf: np.ndarray, 

230 imageForRedistribution: MaskedImage | Exposure | None, 

231 bandIndex: int, 

232 parentFootprint: afwFootprint, 

233 updateFluxColumns: bool, 

234): 

235 """Create footprints and update band-dependent columns in the catalog 

236 

237 Parameters 

238 ---------- 

239 blendData: 

240 Persistable data for the entire blend. 

241 catalog: 

242 The catalog that is being updated. 

243 modelPsf: 

244 The 2D model of the PSF. 

245 observedPsf: 

246 The observed PSF model for the catalog. 

247 imageForRedistribution: 

248 The image that is the source for flux re-distribution. 

249 If `imageForRedistribution` is `None` then flux re-distribution is 

250 not performed. 

251 bandIndex: 

252 The number of the band to extract. 

253 parentFootprint: 

254 The footprint of the parent, used for masking out the model 

255 when re-distributing flux. 

256 updateFluxColumns: 

257 Whether or not to update the `deblend_*` columns in the catalog. 

258 This should only be true when the input catalog schema already 

259 contains those columns. 

260 """ 

261 useFlux = imageForRedistribution is not None 

262 bands = ("dummy",) 

263 

264 if useFlux: 

265 # Extract the image array to re-distribute its flux 

266 xy0 = Point2I(*blendData.origin[::-1]) 

267 extent = Extent2I(*blendData.shape[::-1]) 

268 bbox = Box2I(xy0, extent) 

269 

270 images = Image( 

271 imageForRedistribution[bbox].image.array[None, :, :], 

272 yx0=blendData.origin, 

273 bands=bands, 

274 ) 

275 

276 variance = Image( 

277 imageForRedistribution[bbox].variance.array[None, :, :], 

278 yx0=blendData.origin, 

279 bands=bands, 

280 ) 

281 

282 weights = Image( 

283 parentFootprint.spans.asArray()[None, :, :], 

284 yx0=blendData.origin, 

285 bands=bands, 

286 ) 

287 

288 observation = scl.io.Observation( 

289 images=images, 

290 variance=variance, 

291 weights=weights, 

292 psfs=blendData.psf, 

293 model_psf=modelPsf[None, :, :], 

294 ) 

295 else: 

296 observation = scl.io.Observation.empty( 

297 bands=bands, 

298 psfs=blendData.psf, 

299 model_psf=modelPsf[None, :, :], 

300 bbox=Box(blendData.shape, blendData.origin), 

301 dtype=np.float32, 

302 ) 

303 

304 blend = monochromaticDataToScarlet( 

305 blendData=blendData, 

306 bandIndex=bandIndex, 

307 observation=observation, 

308 ) 

309 

310 if useFlux: 

311 # Re-distribute the flux in the images 

312 blend.conserve_flux() 

313 

314 # Set the metrics for the blend. 

315 # TODO: remove this once DM-34558 runs all deblender metrics 

316 # in a separate task. 

317 if updateFluxColumns: 

318 setDeblenderMetrics(blend) 

319 

320 # Update the HeavyFootprints for deblended sources 

321 # and update the band-dependent catalog columns. 

322 for source in blend.sources: 

323 sourceRecord = catalog.find(source.record_id) 

324 parent = catalog.find(sourceRecord["parent"]) 

325 peaks = parent.getFootprint().peaks 

326 peakIdx = np.where(peaks["id"] == source.peak_id)[0][0] 

327 source.detectedPeak = peaks[peakIdx] 

328 # Set the Footprint 

329 heavy = scarletModelToHeavy( 

330 source=source, 

331 blend=blend, 

332 useFlux=useFlux, 

333 ) 

334 sourceRecord.setFootprint(heavy) 

335 

336 if updateFluxColumns: 

337 if useFlux: 

338 # Set the fraction of pixels with valid data. 

339 coverage = calculateFootprintCoverage(heavy, imageForRedistribution.mask) 

340 sourceRecord.set("deblend_dataCoverage", coverage) 

341 

342 # Set the flux of the scarlet model 

343 # TODO: this field should probably be deprecated, 

344 # since DM-33710 gives users access to the scarlet models. 

345 model = source.get_model().data[0] 

346 sourceRecord.set("deblend_scarletFlux", np.sum(model)) 

347 

348 # Set the flux at the center of the model 

349 peak = heavy.peaks[0] 

350 

351 img = heavy.extractImage(fill=0.0) 

352 try: 

353 sourceRecord.set("deblend_peak_instFlux", img[Point2I(peak["i_x"], peak["i_y"])]) 

354 except Exception: 

355 srcId = sourceRecord.getId() 

356 x = peak["i_x"] 

357 y = peak["i_y"] 

358 logger.warning( 

359 f"Source {srcId} at {x},{y} could not set the peak flux with error:", 

360 exc_info=1 

361 ) 

362 sourceRecord.set("deblend_peak_instFlux", np.nan) 

363 

364 # Set the metrics columns. 

365 # TODO: remove this once DM-34558 runs all deblender metrics 

366 # in a separate task. 

367 sourceRecord.set("deblend_maxOverlap", source.metrics.maxOverlap[0]) 

368 sourceRecord.set("deblend_fluxOverlap", source.metrics.fluxOverlap[0]) 

369 sourceRecord.set("deblend_fluxOverlapFraction", source.metrics.fluxOverlapFraction[0]) 

370 sourceRecord.set("deblend_blendedness", source.metrics.blendedness[0]) 

371 

372 

373def oldScarletToData(blend: Blend, psfCenter: tuple[int, int], xy0: Point2I): 

374 """Convert a scarlet.lite blend into a persistable data object 

375 

376 Note: This converts a blend from the old version of scarlet.lite, 

377 which is deprecated, to the persistable data format used in the 

378 new scarlet lite package. 

379 This is kept to compare the two scarlet versions, 

380 and can be removed once the new lsst.scarlet.lite package is 

381 used in production. 

382 

383 Parameters 

384 ---------- 

385 blend: 

386 The blend that is being persisted. 

387 psfCenter: 

388 The center of the PSF. 

389 xy0: 

390 The lower coordinate of the entire blend. 

391 Returns 

392 ------- 

393 blendData : `ScarletBlendDataModel` 

394 The data model for a single blend. 

395 """ 

396 from scarlet import lite 

397 yx0 = (xy0.y, xy0.x) 

398 

399 sources = {} 

400 for source in blend.sources: 

401 components = [] 

402 factorizedComponents = [] 

403 for component in source.components: 

404 origin = tuple(component.bbox.origin[i+1] + yx0[i] for i in range(2)) 

405 peak = tuple(component.center[i] + yx0[i] for i in range(2)) 

406 

407 if isinstance(component, lite.LiteFactorizedComponent): 

408 componentData = scl.io.ScarletFactorizedComponentData( 

409 origin=origin, 

410 peak=peak, 

411 spectrum=component.sed, 

412 morph=component.morph, 

413 ) 

414 factorizedComponents.append(componentData) 

415 else: 

416 componentData = scl.io.ScarletComponentData( 

417 origin=origin, 

418 peak=peak, 

419 model=component.get_model(), 

420 ) 

421 components.append(componentData) 

422 sourceData = scl.io.ScarletSourceData( 

423 components=components, 

424 factorized_components=factorizedComponents, 

425 peak_id=source.peak_id, 

426 ) 

427 sources[source.record_id] = sourceData 

428 

429 blendData = scl.io.ScarletBlendData( 

430 origin=(xy0.y, xy0.x), 

431 shape=blend.observation.bbox.shape[-2:], 

432 sources=sources, 

433 psf_center=psfCenter, 

434 psf=blend.observation.psfs, 

435 bands=blend.observation.bands, 

436 ) 

437 

438 return blendData