Coverage for python/lsst/drp/tasks/assemble_cell_coadd.py: 28%

152 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-05-02 04:59 -0700

1# This file is part of drp_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ( 

23 "AssembleCellCoaddTask", 

24 "AssembleCellCoaddConfig", 

25 "ConvertMultipleCellCoaddToExposureTask", 

26) 

27 

28 

29import lsst.afw.image as afwImage 

30import lsst.afw.math as afwMath 

31import lsst.geom as geom 

32import numpy as np 

33from lsst.cell_coadds import ( 

34 CellIdentifiers, 

35 CoaddUnits, 

36 CommonComponents, 

37 GridContainer, 

38 MultipleCellCoadd, 

39 ObservationIdentifiers, 

40 OwnedImagePlanes, 

41 PatchIdentifiers, 

42 SingleCellCoadd, 

43 UniformGrid, 

44) 

45from lsst.meas.algorithms import AccumulatorMeanStack 

46from lsst.pex.config import ConfigField, ConfigurableField, Field, ListField, RangeField 

47from lsst.pipe.base import NoWorkFound, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct 

48from lsst.pipe.base.connectionTypes import Input, Output 

49from lsst.pipe.tasks.coaddBase import makeSkyInfo 

50from lsst.pipe.tasks.interpImage import InterpImageTask 

51from lsst.pipe.tasks.scaleZeroPoint import ScaleZeroPointTask 

52from lsst.skymap import BaseSkyMap 

53 

54 

55class AssembleCellCoaddConnections( 

56 PipelineTaskConnections, 

57 dimensions=("tract", "patch", "band", "skymap"), 

58 defaultTemplates={"inputWarpName": "deep", "outputCoaddSuffix": "Cell"}, 

59): 

60 inputWarps = Input( 

61 doc="Input warps", 

62 name="{inputWarpName}Coadd_directWarp", 

63 storageClass="ExposureF", 

64 dimensions=("tract", "patch", "skymap", "visit", "instrument"), 

65 deferLoad=True, 

66 multiple=True, 

67 ) 

68 

69 skyMap = Input( 

70 doc="Input definition of geometry/bbox and projection/wcs. This must be cell-based.", 

71 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

72 storageClass="SkyMap", 

73 dimensions=("skymap",), 

74 ) 

75 

76 multipleCellCoadd = Output( 

77 doc="Output multiple cell coadd", 

78 name="{inputWarpName}Coadd{outputCoaddSuffix}", 

79 storageClass="MultipleCellCoadd", 

80 dimensions=("tract", "patch", "band", "skymap"), 

81 ) 

82 

83 

84class AssembleCellCoaddConfig(PipelineTaskConfig, pipelineConnections=AssembleCellCoaddConnections): 

85 do_interpolate_coadd = Field[bool](doc="Interpolate over pixels with NO_DATA mask set?", default=False) 

86 interpolate_coadd = ConfigurableField( 

87 target=InterpImageTask, 

88 doc="Task to interpolate (and extrapolate) over pixels with NO_DATA mask on cell coadds", 

89 ) 

90 do_scale_zero_point = Field[bool]( 

91 doc="Scale warps to a common zero point? This is not needed if they have absolute flux calibration.", 

92 default=False, 

93 ) 

94 scale_zero_point = ConfigurableField( 

95 target=ScaleZeroPointTask, 

96 doc="Task to scale warps to a common zero point", 

97 ) 

98 bad_mask_planes = ListField[str]( 

99 doc="Mask planes that count towards the masked fraction within a cell.", 

100 default=("BAD", "NO_DATA", "SAT"), 

101 ) 

102 calc_error_from_input_variance = Field[bool]( 

103 doc="Calculate coadd variance from input variance by stacking " 

104 "statistic. Passed to AccumulatorMeanStack.", 

105 default=False, 

106 ) 

107 max_maskfrac = RangeField[float]( 

108 doc="Maximum fraction of masked pixels in a cell. This is currently " 

109 "just a placeholder and is not used now", 

110 default=0.99, 

111 min=0.0, 

112 max=1.0, 

113 inclusiveMin=True, 

114 inclusiveMax=False, 

115 ) 

116 psf_warper = ConfigField( 

117 doc="Configuration for the warper that warps the PSFs. It must have the same configuration used to " 

118 "warp the images.", 

119 dtype=afwMath.Warper.ConfigClass, 

120 ) 

121 psf_dimensions = Field[int]( 121 ↛ exitline 121 didn't jump to the function exit

122 default=21, 

123 doc="Dimensions of the PSF image stamp size to be assigned to cells (must be odd).", 

124 check=lambda x: (x > 0) and (x % 2 == 1), 

125 ) 

126 

127 

128class AssembleCellCoaddTask(PipelineTask): 

129 """Assemble a cell-based coadded image from a set of warps. 

130 

131 This task reads in the warp one at a time, and accumulates it in all the 

132 cells that it completely overlaps with. This is the optimal I/O pattern but 

133 this also implies that it is not possible to build one or only a few cells. 

134 

135 Each cell coadds is guaranteed to have a well-defined PSF. This is done by 

136 1) excluding warps that only partially overlap a cell from that cell coadd; 

137 2) interpolating bad pixels in the warps rather than excluding them; 

138 3) by computing the coadd as a weighted mean of the warps without clipping; 

139 4) by computing the coadd PSF as the weighted mean of the PSF of the warps 

140 with the same weights. 

141 

142 The cells are (and must be) defined in the skymap, and cannot be configured 

143 or redefined here. The cells are assumed to be small enough that the PSF is 

144 assumed to be spatially constant within a cell. 

145 

146 Raises 

147 ------ 

148 NoWorkFound 

149 Raised if no input warps are provided. 

150 RuntimeError 

151 Raised if the skymap is not cell-based. 

152 

153 Notes 

154 ----- 

155 This is not yet a part of the standard DRP pipeline. As such, the Task and 

156 especially its Config and Connections are experimental and subject to 

157 change any time without a formal RFC or standard deprecation procedures 

158 until it is included in the DRP pipeline. 

159 """ 

160 

161 ConfigClass = AssembleCellCoaddConfig 

162 _DefaultName = "assembleCellCoadd" 

163 

164 def __init__(self, *args, **kwargs): 

165 super().__init__(*args, **kwargs) 

166 if self.config.do_interpolate_coadd: 

167 self.makeSubtask("interpolate_coadd") 

168 if self.config.do_scale_zero_point: 

169 self.makeSubtask("scale_zero_point") 

170 

171 self.psf_warper = afwMath.Warper.fromConfig(self.config.psf_warper) 

172 

173 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

174 # Docstring inherited. 

175 inputData = butlerQC.get(inputRefs) 

176 

177 if not inputData["inputWarps"]: 

178 raise NoWorkFound("No input warps provided for co-addition") 

179 self.log.info("Found %d input warps", len(inputData["inputWarps"])) 

180 

181 # Construct skyInfo expected by run 

182 # Do not remove skyMap from inputData in case _makeSupplementaryData 

183 # needs it 

184 skyMap = inputData["skyMap"] 

185 

186 if not skyMap.config.tractBuilder.name == "cells": 

187 raise RuntimeError("AssembleCellCoaddTask requires a cell-based skymap.") 

188 

189 outputDataId = butlerQC.quantum.dataId 

190 

191 inputData["skyInfo"] = makeSkyInfo( 

192 skyMap, tractId=outputDataId["tract"], patchId=outputDataId["patch"] 

193 ) 

194 

195 self.common = CommonComponents( 

196 units=CoaddUnits.legacy, # until the ScaleZeroPointTask can scale it to nJy. 

197 wcs=inputData["skyInfo"].patchInfo.wcs, 

198 band=outputDataId.get("band", None), 

199 identifiers=PatchIdentifiers.from_data_id(outputDataId), 

200 ) 

201 

202 returnStruct = self.run(**inputData) 

203 butlerQC.put(returnStruct, outputRefs) 

204 return returnStruct 

205 

206 @staticmethod 

207 def _compute_weight(maskedImage, statsCtrl): 

208 """Compute a weight for a masked image. 

209 

210 Parameters 

211 ---------- 

212 maskedImage : `~lsst.afw.image.MaskedImage` 

213 The masked image to compute the weight. 

214 statsCtrl : `~lsst.afw.math.StatisticsControl` 

215 A control (config-like) object for StatisticsStack. 

216 

217 Returns 

218 ------- 

219 weight : `float` 

220 Inverse of the clipped mean variance of the masked image. 

221 """ 

222 statObj = afwMath.makeStatistics( 

223 maskedImage.getVariance(), maskedImage.getMask(), afwMath.MEANCLIP, statsCtrl 

224 ) 

225 meanVar, _ = statObj.getResult(afwMath.MEANCLIP) 

226 weight = 1.0 / float(meanVar) 

227 return weight 

228 

229 @staticmethod 

230 def _construct_grid(skyInfo): 

231 """Construct a UniformGrid object from a SkyInfo struct. 

232 

233 Parameters 

234 ---------- 

235 skyInfo : `~lsst.pipe.base.Struct` 

236 A Struct object 

237 

238 Returns 

239 ------- 

240 grid : `~lsst.cell_coadds.UniformGrid` 

241 A UniformGrid object. 

242 """ 

243 # grid has no notion about border or inner/outer boundaries. 

244 # So we have to clip the outermost border when constructing the grid. 

245 grid_bbox = skyInfo.patchInfo.outer_bbox.erodedBy(skyInfo.patchInfo.getCellBorder()) 

246 grid = UniformGrid.from_bbox_cell_size(grid_bbox, skyInfo.patchInfo.getCellInnerDimensions()) 

247 return grid 

248 

249 def _construct_grid_container(self, skyInfo): 

250 """Construct a grid of AccumulatorMeanStack instances. 

251 

252 Parameters 

253 ---------- 

254 skyInfo : `~lsst.pipe.base.Struct` 

255 A Struct object 

256 

257 Returns 

258 ------- 

259 gc : `~lsst.cell_coadds.GridContainer` 

260 A GridContainer object container one AccumulatorMeanStack per cell. 

261 """ 

262 grid = self._construct_grid(skyInfo) 

263 

264 # Initialize the grid container with AccumulatorMeanStacks 

265 gc = GridContainer[AccumulatorMeanStack](grid.shape) 

266 for cellInfo in skyInfo.patchInfo: 

267 stacker = AccumulatorMeanStack( 

268 # The shape is for the numpy arrays, hence transposed. 

269 shape=(cellInfo.outer_bbox.height, cellInfo.outer_bbox.width), 

270 bit_mask_value=0, 

271 calc_error_from_input_variance=self.config.calc_error_from_input_variance, 

272 compute_n_image=False, 

273 ) 

274 gc[cellInfo.index] = stacker 

275 

276 return gc 

277 

278 def _construct_stats_control(self): 

279 statsCtrl = afwMath.StatisticsControl() 

280 statsCtrl.setAndMask(afwImage.Mask.getPlaneBitMask(self.config.bad_mask_planes)) 

281 statsCtrl.setNanSafe(True) 

282 return statsCtrl 

283 

284 def run(self, inputWarps, skyInfo, **kwargs): 

285 statsCtrl = self._construct_stats_control() 

286 

287 gc = self._construct_grid_container(skyInfo) 

288 psf_gc = GridContainer[AccumulatorMeanStack](gc.shape) 

289 psf_bbox_gc = GridContainer[geom.Box2I](gc.shape) 

290 

291 # Make a container to hold the cell centers in sky coordinates now, 

292 # so we don't have to recompute them for each warp 

293 # (they share a common WCS). These are needed to find the various 

294 # warp + detector combinations that contributed to each cell, and later 

295 # get the corresponding PSFs as well. 

296 cell_centers_sky = GridContainer[geom.SpherePoint](gc.shape) 

297 # Make a container to hold the observation identifiers for each cell. 

298 observation_identifiers_gc = GridContainer[list](gc.shape) 

299 # Populate them. 

300 for cellInfo in skyInfo.patchInfo: 

301 # Make a list to hold the observation identifiers for each cell. 

302 observation_identifiers_gc[cellInfo.index] = [] 

303 cell_centers_sky[cellInfo.index] = skyInfo.wcs.pixelToSky(cellInfo.inner_bbox.getCenter()) 

304 psf_bbox_gc[cellInfo.index] = geom.Box2I.makeCenteredBox( 

305 geom.Point2D(cellInfo.inner_bbox.getCenter()), 

306 geom.Extent2I(self.config.psf_dimensions, self.config.psf_dimensions), 

307 ) 

308 psf_gc[cellInfo.index] = AccumulatorMeanStack( 

309 # The shape is for the numpy arrays, hence transposed. 

310 shape=(self.config.psf_dimensions, self.config.psf_dimensions), 

311 bit_mask_value=0, 

312 calc_error_from_input_variance=self.config.calc_error_from_input_variance, 

313 compute_n_image=False, 

314 ) 

315 

316 # Read in one warp at a time, and accumulate it in all the cells that 

317 # it completely overlaps. 

318 for warpRef in inputWarps: 

319 warp = warpRef.get() 

320 

321 # Pre-process the warp before coadding. 

322 # Each Warp that goes into a coadd will typically have an 

323 # independent photometric zero-point. Therefore, we must scale each 

324 # Warp to set it to a common photometric zeropoint. 

325 if self.config.do_scale_zero_point: 

326 self.scale_zero_point.run(exposure=warp, dataRef=warpRef) 

327 

328 # Coadd the warp onto the cells it completely overlaps. 

329 edge = afwImage.Mask.getPlaneBitMask("EDGE") 

330 for cellInfo in skyInfo.patchInfo: 

331 bbox = cellInfo.outer_bbox 

332 mi = warp[bbox].getMaskedImage() 

333 

334 if (mi.getMask().array & edge).any(): 

335 self.log.debug( 

336 "Skipping %s in cell %s because it has an EDGE", warpRef.dataId, cellInfo.index 

337 ) 

338 continue 

339 

340 weight = self._compute_weight(mi, statsCtrl) 

341 if not np.isfinite(weight): 

342 # Log at the debug level, because this can be quite common. 

343 self.log.debug( 

344 "Non-finite weight for %s in cell %s: skipping", warpRef.dataId, cellInfo.index 

345 ) 

346 continue 

347 

348 ccd_table = ( 

349 warp.getInfo().getCoaddInputs().ccds.subsetContaining(cell_centers_sky[cellInfo.index]) 

350 ) 

351 assert len(ccd_table) > 0, "No CCD from a warp found within a cell." 

352 assert len(ccd_table) == 1, "More than one CCD from a warp found within a cell." 

353 ccd_row = ccd_table[0] 

354 

355 observation_identifier = ObservationIdentifiers.from_data_id( 

356 warpRef.dataId, 

357 backup_detector=ccd_row["ccd"], 

358 ) 

359 observation_identifiers_gc[cellInfo.index].append(observation_identifier) 

360 

361 stacker = gc[cellInfo.index] 

362 stacker.add_masked_image(mi, weight=weight) 

363 

364 calexp_point = ccd_row.getWcs().skyToPixel(cell_centers_sky[cellInfo.index]) 

365 undistorted_psf_im = ccd_row.getPsf().computeImage(calexp_point) 

366 

367 assert undistorted_psf_im.getBBox() == geom.Box2I.makeCenteredBox( 

368 calexp_point, 

369 undistorted_psf_im.getDimensions(), 

370 ), "PSF image does not share the coordinates of the 'calexp'" 

371 

372 # Convert the PSF image from Image to MaskedImage. 

373 undistorted_psf_maskedImage = afwImage.MaskedImageD(image=undistorted_psf_im) 

374 # TODO: In DM-43585, use the variance plane value from noise. 

375 undistorted_psf_maskedImage.variance += 1.0 # Set variance to 1 

376 

377 warped_psf_maskedImage = self.psf_warper.warpImage( 

378 destWcs=skyInfo.wcs, 

379 srcImage=undistorted_psf_maskedImage, 

380 srcWcs=ccd_row.getWcs(), 

381 destBBox=psf_bbox_gc[cellInfo.index], 

382 ) 

383 

384 # There may be NaNs in the PSF image. Set them to 0.0 

385 warped_psf_maskedImage.variance.array[np.isnan(warped_psf_maskedImage.image.array)] = 1.0 

386 warped_psf_maskedImage.image.array[np.isnan(warped_psf_maskedImage.image.array)] = 0.0 

387 

388 psf_stacker = psf_gc[cellInfo.index] 

389 psf_stacker.add_masked_image(warped_psf_maskedImage, weight=weight) 

390 

391 del warp 

392 

393 cells: list[SingleCellCoadd] = [] 

394 for cellInfo in skyInfo.patchInfo: 

395 if len(observation_identifiers_gc[cellInfo.index]) == 0: 

396 self.log.debug("Skipping cell %s because it has no input warps", cellInfo.index) 

397 continue 

398 

399 stacker = gc[cellInfo.index] 

400 cell_masked_image = afwImage.MaskedImageF(cellInfo.outer_bbox) 

401 psf_masked_image = afwImage.MaskedImageF(psf_bbox_gc[cellInfo.index]) 

402 gc[cellInfo.index].fill_stacked_masked_image(cell_masked_image) 

403 psf_gc[cellInfo.index].fill_stacked_masked_image(psf_masked_image) 

404 

405 # Post-process the coadd before converting to new data structures. 

406 if self.config.do_interpolate_coadd: 

407 self.interpolate_coadd.run(cell_masked_image, planeName="NO_DATA") 

408 # The variance must be positive; work around for DM-3201. 

409 varArray = cell_masked_image.variance.array 

410 with np.errstate(invalid="ignore"): 

411 varArray[:] = np.where(varArray > 0, varArray, np.inf) 

412 

413 image_planes = OwnedImagePlanes.from_masked_image(cell_masked_image) 

414 identifiers = CellIdentifiers( 

415 cell=cellInfo.index, 

416 skymap=self.common.identifiers.skymap, 

417 tract=self.common.identifiers.tract, 

418 patch=self.common.identifiers.patch, 

419 band=self.common.identifiers.band, 

420 ) 

421 

422 singleCellCoadd = SingleCellCoadd( 

423 outer=image_planes, 

424 psf=psf_masked_image.image, 

425 inner_bbox=cellInfo.inner_bbox, 

426 inputs=observation_identifiers_gc[cellInfo.index], 

427 common=self.common, 

428 identifiers=identifiers, 

429 ) 

430 # TODO: Attach transmission curve when they become available. 

431 cells.append(singleCellCoadd) 

432 

433 grid = self._construct_grid(skyInfo) 

434 multipleCellCoadd = MultipleCellCoadd( 

435 cells, 

436 grid=grid, 

437 outer_cell_size=cellInfo.outer_bbox.getDimensions(), 

438 inner_bbox=None, 

439 common=self.common, 

440 psf_image_size=cells[0].psf_image.getDimensions(), 

441 ) 

442 

443 return Struct( 

444 multipleCellCoadd=multipleCellCoadd, 

445 ) 

446 

447 

448class ConvertMultipleCellCoaddToExposureConnections( 

449 PipelineTaskConnections, 

450 dimensions=("tract", "patch", "band", "skymap"), 

451 defaultTemplates={"inputCoaddName": "deep", "inputCoaddSuffix": "Cell"}, 

452): 

453 cellCoaddExposure = Input( 

454 doc="Output coadded exposure, produced by stacking input warps", 

455 name="{inputCoaddName}Coadd{inputCoaddSuffix}", 

456 storageClass="MultipleCellCoadd", 

457 dimensions=("tract", "patch", "skymap", "band"), 

458 ) 

459 

460 stitchedCoaddExposure = Output( 

461 doc="Output stitched coadded exposure, produced by stacking input warps", 

462 name="{inputCoaddName}Coadd{inputCoaddSuffix}_stitched", 

463 storageClass="ExposureF", 

464 dimensions=("tract", "patch", "skymap", "band"), 

465 ) 

466 

467 

468class ConvertMultipleCellCoaddToExposureConfig( 

469 PipelineTaskConfig, pipelineConnections=ConvertMultipleCellCoaddToExposureConnections 

470): 

471 """A trivial PipelineTaskConfig class for 

472 ConvertMultipleCellCoaddToExposureTask. 

473 """ 

474 

475 

476class ConvertMultipleCellCoaddToExposureTask(PipelineTask): 

477 """An after burner PipelineTask that converts a cell-based coadd from 

478 `MultipleCellCoadd` format to `ExposureF` format. 

479 

480 The run method stitches the cell-based coadd into contiguous exposure and 

481 returns it in as an `Exposure` object. This is lossy as it preserves only 

482 the pixels in the inner bounding box of the cells and discards the values 

483 in the buffer region. 

484 

485 Notes 

486 ----- 

487 This task has no configurable parameters. 

488 """ 

489 

490 ConfigClass = ConvertMultipleCellCoaddToExposureConfig 

491 _DefaultName = "convertMultipleCellCoaddToExposure" 

492 

493 def run(self, cellCoaddExposure): 

494 return Struct( 

495 stitchedCoaddExposure=cellCoaddExposure.stitch().asExposure(), 

496 )