Coverage for python/lsst/drp/tasks/assemble_cell_coadd.py: 33%

135 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-14 12:27 -0700

1# This file is part of drp_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ( 

23 "AssembleCellCoaddTask", 

24 "AssembleCellCoaddConfig", 

25 "ConvertMultipleCellCoaddToExposureTask", 

26) 

27 

28 

29import lsst.afw.image as afwImage 

30import lsst.afw.math as afwMath 

31import numpy as np 

32from lsst.cell_coadds import ( 

33 CellIdentifiers, 

34 CoaddUnits, 

35 CommonComponents, 

36 GridContainer, 

37 MultipleCellCoadd, 

38 OwnedImagePlanes, 

39 PatchIdentifiers, 

40 SingleCellCoadd, 

41 UniformGrid, 

42) 

43from lsst.meas.algorithms import AccumulatorMeanStack, CoaddPsf, CoaddPsfConfig 

44from lsst.pex.config import ConfigField, ConfigurableField, Field, ListField, RangeField 

45from lsst.pipe.base import NoWorkFound, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct 

46from lsst.pipe.base.connectionTypes import Input, Output 

47from lsst.pipe.tasks.coaddBase import makeSkyInfo 

48from lsst.pipe.tasks.coaddInputRecorder import CoaddInputRecorderTask 

49from lsst.pipe.tasks.interpImage import InterpImageTask 

50from lsst.pipe.tasks.scaleZeroPoint import ScaleZeroPointTask 

51from lsst.skymap import BaseSkyMap 

52 

53 

54class AssembleCellCoaddConnections( 

55 PipelineTaskConnections, 

56 dimensions=("tract", "patch", "band", "skymap"), 

57 defaultTemplates={"inputWarpName": "deep", "outputCoaddSuffix": "Cell"}, 

58): 

59 inputWarps = Input( 

60 doc="Input warps", 

61 name="{inputWarpName}Coadd_directWarp", 

62 storageClass="ExposureF", 

63 dimensions=("tract", "patch", "skymap", "visit", "instrument"), 

64 deferLoad=True, 

65 multiple=True, 

66 ) 

67 

68 skyMap = Input( 

69 doc="Input definition of geometry/bbox and projection/wcs. This must be cell-based.", 

70 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

71 storageClass="SkyMap", 

72 dimensions=("skymap",), 

73 ) 

74 

75 multipleCellCoadd = Output( 

76 doc="Output multiple cell coadd", 

77 name="{inputWarpName}Coadd{outputCoaddSuffix}", 

78 storageClass="MultipleCellCoadd", 

79 dimensions=("tract", "patch", "band", "skymap"), 

80 ) 

81 

82 

83class AssembleCellCoaddConfig(PipelineTaskConfig, pipelineConnections=AssembleCellCoaddConnections): 

84 do_interpolate_coadd = Field[bool](doc="Interpolate over pixels with NO_DATA mask set?", default=False) 

85 interpolate_coadd = ConfigurableField( 

86 target=InterpImageTask, 

87 doc="Task to interpolate (and extrapolate) over pixels with NO_DATA mask on cell coadds", 

88 ) 

89 scale_zero_point = ConfigurableField( 

90 target=ScaleZeroPointTask, 

91 doc="Task to scale warps to a common zero point", 

92 ) 

93 bad_mask_planes = ListField[str]( 

94 doc="Mask planes that count towards the masked fraction within a cell.", 

95 default=("BAD", "NO_DATA", "SAT"), 

96 ) 

97 calc_error_from_input_variance = Field[bool]( 

98 doc="Calculate coadd variance from input variance by stacking " 

99 "statistic. Passed to AccumulatorMeanStack.", 

100 default=False, 

101 ) 

102 max_maskfrac = RangeField[float]( 

103 doc="Maximum fraction of masked pixels in a cell. This is currently " 

104 "just a placeholder and is not used now", 

105 default=0.99, 

106 min=0.0, 

107 max=1.0, 

108 inclusiveMin=True, 

109 inclusiveMax=False, 

110 ) 

111 # The following config options are specific to the CoaddPsf. 

112 coadd_psf = ConfigField( 

113 doc="Configuration for CoaddPsf", 

114 dtype=CoaddPsfConfig, 

115 ) 

116 input_recorder = ConfigurableField( 

117 doc="Subtask that helps fill CoaddInputs catalogs added to the final Exposure", 

118 target=CoaddInputRecorderTask, 

119 ) 

120 

121 

122class AssembleCellCoaddTask(PipelineTask): 

123 """Assemble a cell-based coadded image from a set of warps. 

124 

125 This task reads in the warp one at a time, and accumulates it in all the 

126 cells that it completely overlaps with. This is the optimal I/O pattern but 

127 this also implies that it is not possible to build one or only a few cells. 

128 

129 Each cell coadds is guaranteed to have a well-defined PSF. This is done by 

130 1) excluding warps that only partially overlap a cell from that cell coadd; 

131 2) interpolating bad pixels in the warps rather than excluding them; 

132 3) by computing the coadd as a weighted mean of the warps without clipping; 

133 4) by computing the coadd PSF as the weighted mean of the PSF of the warps 

134 with the same weights. 

135 

136 The cells are (and must be) defined in the skymap, and cannot be configured 

137 or redefined here. The cells are assumed to be small enough that the PSF is 

138 assumed to be spatially constant within a cell. 

139 

140 Raises 

141 ------ 

142 NoWorkFound 

143 Raised if no input warps are provided. 

144 RuntimeError 

145 Raised if the skymap is not cell-based. 

146 

147 Notes 

148 ----- 

149 This is not yet a part of the standard DRP pipeline. As such, the Task and 

150 especially its Config and Connections are experimental and subject to 

151 change any time without a formal RFC or standard deprecation procedures 

152 until it is included in the DRP pipeline. 

153 """ 

154 

155 ConfigClass = AssembleCellCoaddConfig 

156 _DefaultName = "assembleCellCoadd" 

157 

158 def __init__(self, *args, **kwargs): 

159 super().__init__(*args, **kwargs) 

160 self.makeSubtask("input_recorder") 

161 self.makeSubtask("interpolate_coadd") 

162 self.makeSubtask("scale_zero_point") 

163 

164 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

165 # Docstring inherited. 

166 inputData = butlerQC.get(inputRefs) 

167 

168 if not inputData["inputWarps"]: 

169 raise NoWorkFound("No input warps provided for co-addition") 

170 self.log.info("Found %d input warps", len(inputData["inputWarps"])) 

171 

172 # Construct skyInfo expected by run 

173 # Do not remove skyMap from inputData in case _makeSupplementaryData 

174 # needs it 

175 skyMap = inputData["skyMap"] 

176 

177 if not skyMap.config.tractBuilder.name == "cells": 

178 raise RuntimeError("AssembleCellCoaddTask requires a cell-based skymap.") 

179 

180 outputDataId = butlerQC.quantum.dataId 

181 

182 inputData["skyInfo"] = makeSkyInfo( 

183 skyMap, tractId=outputDataId["tract"], patchId=outputDataId["patch"] 

184 ) 

185 

186 self.common = CommonComponents( 

187 units=CoaddUnits.legacy, # until the ScaleZeroPointTask can scale it to nJy. 

188 wcs=inputData["skyInfo"].patchInfo.wcs, 

189 band=outputDataId.get("band", None), 

190 identifiers=PatchIdentifiers.from_data_id(outputDataId), 

191 ) 

192 

193 returnStruct = self.run(**inputData) 

194 butlerQC.put(returnStruct, outputRefs) 

195 return returnStruct 

196 

197 @staticmethod 

198 def _compute_weight(maskedImage, statsCtrl): 

199 """Compute a weight for a masked image. 

200 

201 Parameters 

202 ---------- 

203 maskedImage : `~lsst.afw.image.MaskedImage` 

204 The masked image to compute the weight. 

205 statsCtrl : `~lsst.afw.math.StatisticsControl` 

206 A control (config-like) object for StatisticsStack. 

207 

208 Returns 

209 ------- 

210 weight : `float` 

211 Inverse of the clipped mean variance of the masked image. 

212 """ 

213 statObj = afwMath.makeStatistics( 

214 maskedImage.getVariance(), maskedImage.getMask(), afwMath.MEANCLIP, statsCtrl 

215 ) 

216 meanVar, _ = statObj.getResult(afwMath.MEANCLIP) 

217 weight = 1.0 / float(meanVar) 

218 return weight 

219 

220 @staticmethod 

221 def _construct_grid(skyInfo): 

222 """Construct a UniformGrid object from a SkyInfo struct. 

223 

224 Parameters 

225 ---------- 

226 skyInfo : `~lsst.pipe.base.Struct` 

227 A Struct object 

228 

229 Returns 

230 ------- 

231 grid : `~lsst.cell_coadds.UniformGrid` 

232 A UniformGrid object. 

233 """ 

234 # grid has no notion about border or inner/outer boundaries. 

235 # So we have to clip the outermost border when constructing the grid. 

236 grid_bbox = skyInfo.patchInfo.outer_bbox.erodedBy(skyInfo.patchInfo.getCellBorder()) 

237 grid = UniformGrid.from_bbox_cell_size(grid_bbox, skyInfo.patchInfo.getCellInnerDimensions()) 

238 return grid 

239 

240 def _construct_grid_container(self, skyInfo, statsCtrl): 

241 """Construct a grid of AccumulatorMeanStack instances. 

242 

243 Parameters 

244 ---------- 

245 skyInfo : `~lsst.pipe.base.Struct` 

246 A Struct object 

247 statsCtrl : `~lsst.afw.math.StatisticsControl` 

248 A control (config-like) object for StatisticsStack. 

249 

250 Returns 

251 ------- 

252 gc : `~lsst.cell_coadds.GridContainer` 

253 A GridContainer object container one AccumulatorMeanStack per cell. 

254 """ 

255 grid = self._construct_grid(skyInfo) 

256 

257 # Initialize the grid container with AccumulatorMeanStacks 

258 gc = GridContainer[AccumulatorMeanStack](grid.shape) 

259 for cellInfo in skyInfo.patchInfo: 

260 stacker = AccumulatorMeanStack( 

261 # The shape is for the numpy arrays, hence transposed. 

262 shape=(cellInfo.outer_bbox.height, cellInfo.outer_bbox.width), 

263 bit_mask_value=afwImage.Mask.getPlaneBitMask(self.config.bad_mask_planes), 

264 calc_error_from_input_variance=self.config.calc_error_from_input_variance, 

265 compute_n_image=False, 

266 ) 

267 gc[cellInfo.index] = stacker 

268 

269 return gc 

270 

271 def _construct_stats_control(self): 

272 statsCtrl = afwMath.StatisticsControl() 

273 statsCtrl.setAndMask(afwImage.Mask.getPlaneBitMask(self.config.bad_mask_planes)) 

274 statsCtrl.setNanSafe(True) 

275 return statsCtrl 

276 

277 def run(self, inputWarps, skyInfo, **kwargs): 

278 statsCtrl = self._construct_stats_control() 

279 

280 gc = self._construct_grid_container(skyInfo, statsCtrl) 

281 coadd_inputs_gc = GridContainer(gc.shape) 

282 for cellInfo in skyInfo.patchInfo: 

283 coadd_inputs = self.input_recorder.makeCoaddInputs() 

284 # Reserve the absolute maximum of how many ccds, visits 

285 # we could potentially have. 

286 coadd_inputs.ccds.reserve(len(inputWarps)) 

287 coadd_inputs.visits.reserve(len(inputWarps)) 

288 coadd_inputs_gc[cellInfo.index] = coadd_inputs 

289 # Read in one warp at a time, and accumulate it in all the cells that 

290 # it completely overlaps. 

291 

292 for warpRef in inputWarps: 

293 warp = warpRef.get() 

294 

295 # Pre-process the warp before coadding. 

296 # Each Warp that goes into a coadd will typically have an 

297 # independent photometric zero-point. Therefore, we must scale each 

298 # Warp to set it to a common photometric zeropoint. 

299 self.scale_zero_point.run(exposure=warp, dataRef=warpRef) 

300 

301 # Coadd the warp onto the cells it completely overlaps. 

302 edge = afwImage.Mask.getPlaneBitMask("EDGE") 

303 for cellInfo in skyInfo.patchInfo: 

304 bbox = cellInfo.outer_bbox 

305 stacker = gc[cellInfo.index] 

306 mi = warp[bbox].getMaskedImage() 

307 

308 if (mi.getMask().array & edge).any(): 

309 self.log.debug( 

310 "Skipping %s in cell %s because it has an EDGE", warpRef.dataId, cellInfo.index 

311 ) 

312 continue 

313 

314 weight = self._compute_weight(mi, statsCtrl) 

315 if not np.isfinite(weight): 

316 # Log at the debug level, because this can be quite common. 

317 self.log.debug( 

318 "Non-finite weight for %s in cell %s: skipping", warpRef.dataId, cellInfo.index 

319 ) 

320 continue 

321 

322 stacker.add_masked_image(mi, weight=weight) 

323 

324 coadd_inputs = coadd_inputs_gc[cellInfo.index] 

325 self.input_recorder.addVisitToCoadd(coadd_inputs, warp[bbox], weight) 

326 

327 del warp 

328 

329 cells: list[SingleCellCoadd] = [] 

330 for cellInfo in skyInfo.patchInfo: 

331 stacker = gc[cellInfo.index] 

332 cell_masked_image = afwImage.MaskedImageF(cellInfo.outer_bbox) 

333 stacker.fill_stacked_masked_image(cell_masked_image) 

334 

335 # Post-process the coadd before converting to new data structures. 

336 if self.config.do_interpolate_coadd: 

337 self.interpolate_coadd.run(cell_masked_image, planeName="NO_DATA") 

338 # The variance must be positive; work around for DM-3201. 

339 varArray = cell_masked_image.variance.array 

340 with np.errstate(invalid="ignore"): 

341 varArray[:] = np.where(varArray > 0, varArray, np.inf) 

342 

343 # Finalize the PSF on the cell coadds. 

344 coadd_inputs = coadd_inputs_gc[cellInfo.index] 

345 coadd_inputs.ccds.sort() 

346 coadd_inputs.visits.sort() 

347 cell_coadd_psf = CoaddPsf(coadd_inputs.ccds, skyInfo.wcs, self.config.coadd_psf.makeControl()) 

348 

349 image_planes = OwnedImagePlanes.from_masked_image(cell_masked_image) 

350 identifiers = CellIdentifiers( 

351 cell=cellInfo.index, 

352 skymap=self.common.identifiers.skymap, 

353 tract=self.common.identifiers.tract, 

354 patch=self.common.identifiers.patch, 

355 band=self.common.identifiers.band, 

356 ) 

357 

358 singleCellCoadd = SingleCellCoadd( 

359 outer=image_planes, 

360 psf=cell_coadd_psf.computeKernelImage(cell_coadd_psf.getAveragePosition()), 

361 inner_bbox=cellInfo.inner_bbox, 

362 inputs=None, # TODO 

363 common=self.common, 

364 identifiers=identifiers, 

365 ) 

366 # TODO: Attach transmission curve when they become available. 

367 cells.append(singleCellCoadd) 

368 

369 grid = self._construct_grid(skyInfo) 

370 multipleCellCoadd = MultipleCellCoadd( 

371 cells, 

372 grid=grid, 

373 outer_cell_size=cellInfo.outer_bbox.getDimensions(), 

374 inner_bbox=None, 

375 common=self.common, 

376 psf_image_size=cells[0].psf_image.getDimensions(), 

377 ) 

378 

379 return Struct( 

380 multipleCellCoadd=multipleCellCoadd, 

381 ) 

382 

383 

384class ConvertMulipleCellCoaddToExposureConnections( 

385 PipelineTaskConnections, 

386 dimensions=("tract", "patch", "band", "skymap"), 

387 defaultTemplates={"inputCoaddName": "deep", "inputCoaddSuffix": "Cell"}, 

388): 

389 cellCoaddExposure = Input( 

390 doc="Output coadded exposure, produced by stacking input warps", 

391 name="{inputCoaddName}Coadd{inputCoaddSuffix}", 

392 storageClass="MultipleCellCoadd", 

393 dimensions=("tract", "patch", "skymap", "band"), 

394 ) 

395 

396 stitchedCoaddExposure = Output( 

397 doc="Output stitched coadded exposure, produced by stacking input warps", 

398 name="{inputCoaddName}Coadd{inputCoaddSuffix}_stitched", 

399 storageClass="ExposureF", 

400 dimensions=("tract", "patch", "skymap", "band"), 

401 ) 

402 

403 

404class ConvertMultipleCellCoaddToExposureConfig( 

405 PipelineTaskConfig, pipelineConnections=ConvertMulipleCellCoaddToExposureConnections 

406): 

407 """A trivial PipelineTaskConfig class for 

408 ConvertMultipleCellCoaddToExposureTask. 

409 """ 

410 

411 pass 

412 

413 

414class ConvertMultipleCellCoaddToExposureTask(PipelineTask): 

415 """An after burner PipelineTask that converts a cell-based coadd from 

416 `MultipleCellCoadd` format to `ExposureF` format. 

417 

418 The run method stitches the cell-based coadd into contiguous exposure and 

419 returns it in as an `Exposure` object. This is lossy as it preserves only 

420 the pixels in the inner bounding box of the cells and discards the values 

421 in the buffer region. 

422 

423 Notes 

424 ----- 

425 This task has no configurable parameters. 

426 """ 

427 

428 ConfigClass = ConvertMultipleCellCoaddToExposureConfig 

429 _DefaultName = "convertMultipleCellCoaddToExposure" 

430 

431 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

432 inputData = butlerQC.get(inputRefs) 

433 returnStruct = self.run(**inputData) 

434 butlerQC.put(returnStruct, outputRefs) 

435 

436 def run(self, cellCoaddExposure): 

437 return Struct( 

438 stitchedCoaddExposure=cellCoaddExposure.stitch().asExposure(), 

439 )