Coverage for python/lsst/drp/tasks/assemble_cell_coadd.py: 29%

148 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-04 04:29 -0700

1# This file is part of drp_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ( 

23 "AssembleCellCoaddTask", 

24 "AssembleCellCoaddConfig", 

25 "ConvertMultipleCellCoaddToExposureTask", 

26) 

27 

28 

29import lsst.afw.image as afwImage 

30import lsst.afw.math as afwMath 

31import numpy as np 

32from lsst.cell_coadds import ( 

33 CellIdentifiers, 

34 CoaddUnits, 

35 CommonComponents, 

36 GridContainer, 

37 MultipleCellCoadd, 

38 ObservationIdentifiers, 

39 OwnedImagePlanes, 

40 PatchIdentifiers, 

41 SingleCellCoadd, 

42 UniformGrid, 

43) 

44from lsst.meas.algorithms import AccumulatorMeanStack, CoaddPsf, CoaddPsfConfig 

45from lsst.pex.config import ConfigField, ConfigurableField, Field, ListField, RangeField 

46from lsst.pipe.base import NoWorkFound, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct 

47from lsst.pipe.base.connectionTypes import Input, Output 

48from lsst.pipe.tasks.coaddBase import makeSkyInfo 

49from lsst.pipe.tasks.coaddInputRecorder import CoaddInputRecorderTask 

50from lsst.pipe.tasks.interpImage import InterpImageTask 

51from lsst.pipe.tasks.scaleZeroPoint import ScaleZeroPointTask 

52from lsst.skymap import BaseSkyMap 

53 

54 

55class AssembleCellCoaddConnections( 

56 PipelineTaskConnections, 

57 dimensions=("tract", "patch", "band", "skymap"), 

58 defaultTemplates={"inputWarpName": "deep", "outputCoaddSuffix": "Cell"}, 

59): 

60 inputWarps = Input( 

61 doc="Input warps", 

62 name="{inputWarpName}Coadd_directWarp", 

63 storageClass="ExposureF", 

64 dimensions=("tract", "patch", "skymap", "visit", "instrument"), 

65 deferLoad=True, 

66 multiple=True, 

67 ) 

68 

69 skyMap = Input( 

70 doc="Input definition of geometry/bbox and projection/wcs. This must be cell-based.", 

71 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

72 storageClass="SkyMap", 

73 dimensions=("skymap",), 

74 ) 

75 

76 multipleCellCoadd = Output( 

77 doc="Output multiple cell coadd", 

78 name="{inputWarpName}Coadd{outputCoaddSuffix}", 

79 storageClass="MultipleCellCoadd", 

80 dimensions=("tract", "patch", "band", "skymap"), 

81 ) 

82 

83 

84class AssembleCellCoaddConfig(PipelineTaskConfig, pipelineConnections=AssembleCellCoaddConnections): 

85 do_interpolate_coadd = Field[bool](doc="Interpolate over pixels with NO_DATA mask set?", default=False) 

86 interpolate_coadd = ConfigurableField( 

87 target=InterpImageTask, 

88 doc="Task to interpolate (and extrapolate) over pixels with NO_DATA mask on cell coadds", 

89 ) 

90 do_scale_zero_point = Field[bool]( 

91 doc="Scale warps to a common zero point? This is not needed if they have absolute flux calibration.", 

92 default=False, 

93 ) 

94 scale_zero_point = ConfigurableField( 

95 target=ScaleZeroPointTask, 

96 doc="Task to scale warps to a common zero point", 

97 ) 

98 bad_mask_planes = ListField[str]( 

99 doc="Mask planes that count towards the masked fraction within a cell.", 

100 default=("BAD", "NO_DATA", "SAT"), 

101 ) 

102 calc_error_from_input_variance = Field[bool]( 

103 doc="Calculate coadd variance from input variance by stacking " 

104 "statistic. Passed to AccumulatorMeanStack.", 

105 default=False, 

106 ) 

107 max_maskfrac = RangeField[float]( 

108 doc="Maximum fraction of masked pixels in a cell. This is currently " 

109 "just a placeholder and is not used now", 

110 default=0.99, 

111 min=0.0, 

112 max=1.0, 

113 inclusiveMin=True, 

114 inclusiveMax=False, 

115 ) 

116 # The following config options are specific to the CoaddPsf. 

117 coadd_psf = ConfigField( 

118 doc="Configuration for CoaddPsf", 

119 dtype=CoaddPsfConfig, 

120 ) 

121 input_recorder = ConfigurableField( 

122 doc="Subtask that helps fill CoaddInputs catalogs added to the final Exposure", 

123 target=CoaddInputRecorderTask, 

124 ) 

125 

126 

127class AssembleCellCoaddTask(PipelineTask): 

128 """Assemble a cell-based coadded image from a set of warps. 

129 

130 This task reads in the warp one at a time, and accumulates it in all the 

131 cells that it completely overlaps with. This is the optimal I/O pattern but 

132 this also implies that it is not possible to build one or only a few cells. 

133 

134 Each cell coadds is guaranteed to have a well-defined PSF. This is done by 

135 1) excluding warps that only partially overlap a cell from that cell coadd; 

136 2) interpolating bad pixels in the warps rather than excluding them; 

137 3) by computing the coadd as a weighted mean of the warps without clipping; 

138 4) by computing the coadd PSF as the weighted mean of the PSF of the warps 

139 with the same weights. 

140 

141 The cells are (and must be) defined in the skymap, and cannot be configured 

142 or redefined here. The cells are assumed to be small enough that the PSF is 

143 assumed to be spatially constant within a cell. 

144 

145 Raises 

146 ------ 

147 NoWorkFound 

148 Raised if no input warps are provided. 

149 RuntimeError 

150 Raised if the skymap is not cell-based. 

151 

152 Notes 

153 ----- 

154 This is not yet a part of the standard DRP pipeline. As such, the Task and 

155 especially its Config and Connections are experimental and subject to 

156 change any time without a formal RFC or standard deprecation procedures 

157 until it is included in the DRP pipeline. 

158 """ 

159 

160 ConfigClass = AssembleCellCoaddConfig 

161 _DefaultName = "assembleCellCoadd" 

162 

163 def __init__(self, *args, **kwargs): 

164 super().__init__(*args, **kwargs) 

165 self.makeSubtask("input_recorder") 

166 if self.config.do_interpolate_coadd: 

167 self.makeSubtask("interpolate_coadd") 

168 if self.config.do_scale_zero_point: 

169 self.makeSubtask("scale_zero_point") 

170 

171 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

172 # Docstring inherited. 

173 inputData = butlerQC.get(inputRefs) 

174 

175 if not inputData["inputWarps"]: 

176 raise NoWorkFound("No input warps provided for co-addition") 

177 self.log.info("Found %d input warps", len(inputData["inputWarps"])) 

178 

179 # Construct skyInfo expected by run 

180 # Do not remove skyMap from inputData in case _makeSupplementaryData 

181 # needs it 

182 skyMap = inputData["skyMap"] 

183 

184 if not skyMap.config.tractBuilder.name == "cells": 

185 raise RuntimeError("AssembleCellCoaddTask requires a cell-based skymap.") 

186 

187 outputDataId = butlerQC.quantum.dataId 

188 

189 inputData["skyInfo"] = makeSkyInfo( 

190 skyMap, tractId=outputDataId["tract"], patchId=outputDataId["patch"] 

191 ) 

192 

193 self.common = CommonComponents( 

194 units=CoaddUnits.legacy, # until the ScaleZeroPointTask can scale it to nJy. 

195 wcs=inputData["skyInfo"].patchInfo.wcs, 

196 band=outputDataId.get("band", None), 

197 identifiers=PatchIdentifiers.from_data_id(outputDataId), 

198 ) 

199 

200 returnStruct = self.run(**inputData) 

201 butlerQC.put(returnStruct, outputRefs) 

202 return returnStruct 

203 

204 @staticmethod 

205 def _compute_weight(maskedImage, statsCtrl): 

206 """Compute a weight for a masked image. 

207 

208 Parameters 

209 ---------- 

210 maskedImage : `~lsst.afw.image.MaskedImage` 

211 The masked image to compute the weight. 

212 statsCtrl : `~lsst.afw.math.StatisticsControl` 

213 A control (config-like) object for StatisticsStack. 

214 

215 Returns 

216 ------- 

217 weight : `float` 

218 Inverse of the clipped mean variance of the masked image. 

219 """ 

220 statObj = afwMath.makeStatistics( 

221 maskedImage.getVariance(), maskedImage.getMask(), afwMath.MEANCLIP, statsCtrl 

222 ) 

223 meanVar, _ = statObj.getResult(afwMath.MEANCLIP) 

224 weight = 1.0 / float(meanVar) 

225 return weight 

226 

227 @staticmethod 

228 def _construct_grid(skyInfo): 

229 """Construct a UniformGrid object from a SkyInfo struct. 

230 

231 Parameters 

232 ---------- 

233 skyInfo : `~lsst.pipe.base.Struct` 

234 A Struct object 

235 

236 Returns 

237 ------- 

238 grid : `~lsst.cell_coadds.UniformGrid` 

239 A UniformGrid object. 

240 """ 

241 # grid has no notion about border or inner/outer boundaries. 

242 # So we have to clip the outermost border when constructing the grid. 

243 grid_bbox = skyInfo.patchInfo.outer_bbox.erodedBy(skyInfo.patchInfo.getCellBorder()) 

244 grid = UniformGrid.from_bbox_cell_size(grid_bbox, skyInfo.patchInfo.getCellInnerDimensions()) 

245 return grid 

246 

247 def _construct_grid_container(self, skyInfo, statsCtrl): 

248 """Construct a grid of AccumulatorMeanStack instances. 

249 

250 Parameters 

251 ---------- 

252 skyInfo : `~lsst.pipe.base.Struct` 

253 A Struct object 

254 statsCtrl : `~lsst.afw.math.StatisticsControl` 

255 A control (config-like) object for StatisticsStack. 

256 

257 Returns 

258 ------- 

259 gc : `~lsst.cell_coadds.GridContainer` 

260 A GridContainer object container one AccumulatorMeanStack per cell. 

261 """ 

262 grid = self._construct_grid(skyInfo) 

263 

264 # Initialize the grid container with AccumulatorMeanStacks 

265 gc = GridContainer[AccumulatorMeanStack](grid.shape) 

266 for cellInfo in skyInfo.patchInfo: 

267 stacker = AccumulatorMeanStack( 

268 # The shape is for the numpy arrays, hence transposed. 

269 shape=(cellInfo.outer_bbox.height, cellInfo.outer_bbox.width), 

270 bit_mask_value=afwImage.Mask.getPlaneBitMask(self.config.bad_mask_planes), 

271 calc_error_from_input_variance=self.config.calc_error_from_input_variance, 

272 compute_n_image=False, 

273 ) 

274 gc[cellInfo.index] = stacker 

275 

276 return gc 

277 

278 def _construct_stats_control(self): 

279 statsCtrl = afwMath.StatisticsControl() 

280 statsCtrl.setAndMask(afwImage.Mask.getPlaneBitMask(self.config.bad_mask_planes)) 

281 statsCtrl.setNanSafe(True) 

282 return statsCtrl 

283 

284 def run(self, inputWarps, skyInfo, **kwargs): 

285 statsCtrl = self._construct_stats_control() 

286 

287 gc = self._construct_grid_container(skyInfo, statsCtrl) 

288 coadd_inputs_gc = GridContainer(gc.shape) 

289 

290 # Make a container to hold the cell centers in sky coordinates now, 

291 # so we don't have to recompute them for each warp 

292 # (they share a common WCS). These are needed to find the various 

293 # warp + detector combinations that contributed to each cell, and later 

294 # get the corresponding PSFs as well. 

295 cell_centers_sky = GridContainer(gc.shape) 

296 # Make a container to hold the observation identifiers for each cell. 

297 observation_identifiers_gc = GridContainer(gc.shape) 

298 # Populate them. 

299 for cellInfo in skyInfo.patchInfo: 

300 coadd_inputs = self.input_recorder.makeCoaddInputs() 

301 # Reserve the absolute maximum of how many ccds, visits 

302 # we could potentially have. 

303 coadd_inputs.ccds.reserve(len(inputWarps)) 

304 coadd_inputs.visits.reserve(len(inputWarps)) 

305 coadd_inputs_gc[cellInfo.index] = coadd_inputs 

306 # Make a list to hold the observation identifiers for each cell. 

307 observation_identifiers_gc[cellInfo.index] = [] 

308 cell_centers_sky[cellInfo.index] = skyInfo.wcs.pixelToSky(cellInfo.inner_bbox.getCenter()) 

309 

310 # Read in one warp at a time, and accumulate it in all the cells that 

311 # it completely overlaps. 

312 for warpRef in inputWarps: 

313 warp = warpRef.get() 

314 

315 # Pre-process the warp before coadding. 

316 # Each Warp that goes into a coadd will typically have an 

317 # independent photometric zero-point. Therefore, we must scale each 

318 # Warp to set it to a common photometric zeropoint. 

319 if self.config.do_scale_zero_point: 

320 self.scale_zero_point.run(exposure=warp, dataRef=warpRef) 

321 

322 # Coadd the warp onto the cells it completely overlaps. 

323 edge = afwImage.Mask.getPlaneBitMask("EDGE") 

324 for cellInfo in skyInfo.patchInfo: 

325 bbox = cellInfo.outer_bbox 

326 stacker = gc[cellInfo.index] 

327 mi = warp[bbox].getMaskedImage() 

328 

329 if (mi.getMask().array & edge).any(): 

330 self.log.debug( 

331 "Skipping %s in cell %s because it has an EDGE", warpRef.dataId, cellInfo.index 

332 ) 

333 continue 

334 

335 weight = self._compute_weight(mi, statsCtrl) 

336 if not np.isfinite(weight): 

337 # Log at the debug level, because this can be quite common. 

338 self.log.debug( 

339 "Non-finite weight for %s in cell %s: skipping", warpRef.dataId, cellInfo.index 

340 ) 

341 continue 

342 

343 stacker.add_masked_image(mi, weight=weight) 

344 

345 coadd_inputs = coadd_inputs_gc[cellInfo.index] 

346 self.input_recorder.addVisitToCoadd(coadd_inputs, warp[bbox], weight) 

347 if True: 

348 ccd_table = ( 

349 warp.getInfo() 

350 .getCoaddInputs() 

351 .ccds.subsetContaining(cell_centers_sky[cellInfo.index]) 

352 ) 

353 assert len(ccd_table) > 0, "No CCD from a warp found within a cell." 

354 assert len(ccd_table) == 1, "More than one CCD from a warp found within a cell." 

355 ccd_row = ccd_table[0] 

356 else: 

357 for ccd_row in warp.getInfo().getCoaddInputs().ccds: 

358 if ccd_row.contains(cell_centers_sky[cellInfo.index]): 

359 break 

360 

361 observation_identifier = ObservationIdentifiers.from_data_id( 

362 warpRef.dataId, 

363 backup_detector=ccd_row["ccd"], 

364 ) 

365 observation_identifiers_gc[cellInfo.index].append(observation_identifier) 

366 

367 del warp 

368 

369 cells: list[SingleCellCoadd] = [] 

370 for cellInfo in skyInfo.patchInfo: 

371 if len(observation_identifiers_gc[cellInfo.index]) == 0: 

372 self.log.debug("Skipping cell %s because it has no input warps", cellInfo.index) 

373 continue 

374 

375 stacker = gc[cellInfo.index] 

376 cell_masked_image = afwImage.MaskedImageF(cellInfo.outer_bbox) 

377 stacker.fill_stacked_masked_image(cell_masked_image) 

378 

379 # Post-process the coadd before converting to new data structures. 

380 if self.config.do_interpolate_coadd: 

381 self.interpolate_coadd.run(cell_masked_image, planeName="NO_DATA") 

382 # The variance must be positive; work around for DM-3201. 

383 varArray = cell_masked_image.variance.array 

384 with np.errstate(invalid="ignore"): 

385 varArray[:] = np.where(varArray > 0, varArray, np.inf) 

386 

387 # Finalize the PSF on the cell coadds. 

388 coadd_inputs = coadd_inputs_gc[cellInfo.index] 

389 coadd_inputs.ccds.sort() 

390 coadd_inputs.visits.sort() 

391 cell_coadd_psf = CoaddPsf(coadd_inputs.ccds, skyInfo.wcs, self.config.coadd_psf.makeControl()) 

392 

393 image_planes = OwnedImagePlanes.from_masked_image(cell_masked_image) 

394 identifiers = CellIdentifiers( 

395 cell=cellInfo.index, 

396 skymap=self.common.identifiers.skymap, 

397 tract=self.common.identifiers.tract, 

398 patch=self.common.identifiers.patch, 

399 band=self.common.identifiers.band, 

400 ) 

401 

402 singleCellCoadd = SingleCellCoadd( 

403 outer=image_planes, 

404 psf=cell_coadd_psf.computeKernelImage(cell_coadd_psf.getAveragePosition()), 

405 inner_bbox=cellInfo.inner_bbox, 

406 inputs=frozenset(observation_identifiers_gc[cellInfo.index]), 

407 common=self.common, 

408 identifiers=identifiers, 

409 ) 

410 # TODO: Attach transmission curve when they become available. 

411 cells.append(singleCellCoadd) 

412 

413 grid = self._construct_grid(skyInfo) 

414 multipleCellCoadd = MultipleCellCoadd( 

415 cells, 

416 grid=grid, 

417 outer_cell_size=cellInfo.outer_bbox.getDimensions(), 

418 inner_bbox=None, 

419 common=self.common, 

420 psf_image_size=cells[0].psf_image.getDimensions(), 

421 ) 

422 

423 return Struct( 

424 multipleCellCoadd=multipleCellCoadd, 

425 ) 

426 

427 

428class ConvertMultipleCellCoaddToExposureConnections( 

429 PipelineTaskConnections, 

430 dimensions=("tract", "patch", "band", "skymap"), 

431 defaultTemplates={"inputCoaddName": "deep", "inputCoaddSuffix": "Cell"}, 

432): 

433 cellCoaddExposure = Input( 

434 doc="Output coadded exposure, produced by stacking input warps", 

435 name="{inputCoaddName}Coadd{inputCoaddSuffix}", 

436 storageClass="MultipleCellCoadd", 

437 dimensions=("tract", "patch", "skymap", "band"), 

438 ) 

439 

440 stitchedCoaddExposure = Output( 

441 doc="Output stitched coadded exposure, produced by stacking input warps", 

442 name="{inputCoaddName}Coadd{inputCoaddSuffix}_stitched", 

443 storageClass="ExposureF", 

444 dimensions=("tract", "patch", "skymap", "band"), 

445 ) 

446 

447 

448class ConvertMultipleCellCoaddToExposureConfig( 

449 PipelineTaskConfig, pipelineConnections=ConvertMultipleCellCoaddToExposureConnections 

450): 

451 """A trivial PipelineTaskConfig class for 

452 ConvertMultipleCellCoaddToExposureTask. 

453 """ 

454 

455 

456class ConvertMultipleCellCoaddToExposureTask(PipelineTask): 

457 """An after burner PipelineTask that converts a cell-based coadd from 

458 `MultipleCellCoadd` format to `ExposureF` format. 

459 

460 The run method stitches the cell-based coadd into contiguous exposure and 

461 returns it in as an `Exposure` object. This is lossy as it preserves only 

462 the pixels in the inner bounding box of the cells and discards the values 

463 in the buffer region. 

464 

465 Notes 

466 ----- 

467 This task has no configurable parameters. 

468 """ 

469 

470 ConfigClass = ConvertMultipleCellCoaddToExposureConfig 

471 _DefaultName = "convertMultipleCellCoaddToExposure" 

472 

473 def run(self, cellCoaddExposure): 

474 return Struct( 

475 stitchedCoaddExposure=cellCoaddExposure.stitch().asExposure(), 

476 )