Coverage for python / lsst / drp / tasks / assemble_cell_coadd.py: 16%

386 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-21 10:50 +0000

1# This file is part of drp_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ( 

25 "AssembleCellCoaddTask", 

26 "AssembleCellCoaddConfig", 

27 "ConvertMultipleCellCoaddToExposureTask", 

28) 

29 

30import dataclasses 

31import itertools 

32import logging 

33 

34import numpy as np 

35 

36import lsst.afw.geom as afwGeom 

37import lsst.afw.image as afwImage 

38import lsst.afw.math as afwMath 

39import lsst.geom as geom 

40from lsst.afw.detection import InvalidPsfError 

41from lsst.afw.geom import SinglePolygonException, makeWcsPairTransform 

42from lsst.cell_coadds import ( 

43 CellIdentifiers, 

44 CoaddApCorrMapStacker, 

45 CoaddInputs, 

46 CoaddUnits, 

47 CommonComponents, 

48 GridContainer, 

49 MultipleCellCoadd, 

50 ObservationIdentifiers, 

51 OwnedImagePlanes, 

52 PatchIdentifiers, 

53 SingleCellCoadd, 

54 UniformGrid, 

55) 

56from lsst.daf.butler import DataCoordinate, DeferredDatasetHandle 

57from lsst.meas.algorithms import AccumulatorMeanStack 

58from lsst.pex.config import ConfigField, ConfigurableField, DictField, Field, ListField, RangeField 

59from lsst.pipe.base import ( 

60 InMemoryDatasetHandle, 

61 NoWorkFound, 

62 PipelineTask, 

63 PipelineTaskConfig, 

64 PipelineTaskConnections, 

65 Struct, 

66) 

67from lsst.pipe.base.connectionTypes import Input, Output 

68from lsst.pipe.tasks.coaddBase import makeSkyInfo, removeMaskPlanes, setRejectedMaskMapping 

69from lsst.pipe.tasks.healSparseMapping import HealSparseInputMapTask 

70from lsst.pipe.tasks.interpImage import InterpImageTask 

71from lsst.pipe.tasks.scaleZeroPoint import ScaleZeroPointTask 

72from lsst.skymap import BaseSkyMap 

73 

74 

75@dataclasses.dataclass 

76class WarpInputs: 

77 """Collection of associate inputs along with warps.""" 

78 

79 warp: DeferredDatasetHandle | InMemoryDatasetHandle 

80 """Handle for the warped exposure.""" 

81 

82 masked_fraction: DeferredDatasetHandle | InMemoryDatasetHandle | None = None 

83 """Handle for the masked fraction image.""" 

84 

85 artifact_mask: DeferredDatasetHandle | InMemoryDatasetHandle | None = None 

86 """Handle for the CompareWarp artifact mask.""" 

87 

88 noise_warps: list[DeferredDatasetHandle | InMemoryDatasetHandle] = dataclasses.field(default_factory=list) 

89 """List of handles for the noise warps""" 

90 

91 @property 

92 def dataId(self) -> DataCoordinate: 

93 """DataID corresponding to the warp. 

94 

95 Returns 

96 ------- 

97 data_id : `~lsst.daf.butler.DataCoordinate` 

98 DataID of the warp. 

99 """ 

100 return self.warp.dataId 

101 

102 

103class AssembleCellCoaddConnections( 

104 PipelineTaskConnections, 

105 dimensions=("tract", "patch", "band", "skymap"), 

106 defaultTemplates={"inputWarpName": "deep", "outputCoaddSuffix": "Cell"}, 

107): 

108 inputWarps = Input( 

109 doc="Input warps", 

110 name="{inputWarpName}Coadd_directWarp", 

111 storageClass="ExposureF", 

112 dimensions=("tract", "patch", "skymap", "visit", "instrument"), 

113 deferLoad=True, 

114 multiple=True, 

115 ) 

116 

117 maskedFractionWarps = Input( 

118 doc="Mask fraction warps", 

119 name="{inputWarpName}Coadd_directWarp_maskedFraction", 

120 storageClass="ImageF", 

121 dimensions=("tract", "patch", "skymap", "visit", "instrument"), 

122 deferLoad=True, 

123 multiple=True, 

124 ) 

125 

126 artifactMasks = Input( 

127 doc="Artifact masks to be applied to the input warps", 

128 name="compare_warp_artifact_mask", 

129 storageClass="Mask", 

130 dimensions=("tract", "patch", "skymap", "visit", "instrument"), 

131 deferLoad=True, 

132 multiple=True, 

133 ) 

134 

135 visitSummaryList = Input( 

136 doc="Input visit-summary catalogs with updated calibration objects. Mainly used for coadd weights.", 

137 name="finalVisitSummary", 

138 storageClass="ExposureCatalog", 

139 dimensions=("instrument", "visit"), 

140 deferLoad=True, 

141 multiple=True, 

142 ) 

143 

144 skyMap = Input( 

145 doc="Input definition of geometry/bbox and projection/wcs. This must be cell-based.", 

146 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

147 storageClass="SkyMap", 

148 dimensions=("skymap",), 

149 ) 

150 

151 multipleCellCoadd = Output( 

152 doc="Output multiple cell coadd", 

153 name="{inputWarpName}Coadd{outputCoaddSuffix}", 

154 storageClass="MultipleCellCoadd", 

155 dimensions=("tract", "patch", "band", "skymap"), 

156 ) 

157 

158 inputMap = Output( 

159 doc="Output healsparse map of input images", 

160 name="{inputWarpName}Coadd_inputMap", 

161 storageClass="HealSparseMap", 

162 dimensions=("tract", "patch", "band", "skymap"), 

163 ) 

164 

165 def __init__(self, *, config=None): 

166 super().__init__(config=config) 

167 

168 if not config: 

169 return 

170 

171 if config.do_calculate_weight_from_warp: 

172 del self.visitSummaryList 

173 

174 if not config.do_use_artifact_mask: 

175 del self.artifactMasks 

176 

177 if not config.do_input_map: 

178 del self.inputMap 

179 

180 # Dynamically set input connections for noise images, depending on the 

181 # number of noise realizations specified in the config. 

182 for n in range(config.num_noise_realizations): 

183 noise_warps = Input( 

184 doc="Input noise warps", 

185 name=f"direct_warp_noise{n}", 

186 storageClass="MaskedImageF", 

187 dimensions=("tract", "patch", "skymap", "visit", "instrument"), 

188 deferLoad=True, 

189 multiple=True, 

190 ) 

191 setattr(self, f"noise{n}_warps", noise_warps) 

192 

193 

194class AssembleCellCoaddConfig(PipelineTaskConfig, pipelineConnections=AssembleCellCoaddConnections): 

195 do_interpolate_coadd = Field[bool](doc="Interpolate over pixels with NO_DATA mask set?", default=True) 

196 interpolate_coadd = ConfigurableField( 

197 target=InterpImageTask, 

198 doc="Task to interpolate (and extrapolate) over pixels with NO_DATA mask on cell coadds", 

199 ) 

200 do_scale_zero_point = Field[bool]( 

201 doc="Scale warps to a common zero point? This is not needed if they have absolute flux calibration.", 

202 default=False, 

203 deprecated="Now that visits are scaled to nJy it is no longer necessary or " 

204 "recommended to scale the zero point, so this will be removed " 

205 "after v29.", 

206 ) 

207 scale_zero_point = ConfigurableField( 

208 target=ScaleZeroPointTask, 

209 doc="Task to scale warps to a common zero point", 

210 deprecated="Now that visits are scaled to nJy it is no longer necessary or " 

211 "recommended to scale the zero point, so this will be removed " 

212 "after v29.", 

213 ) 

214 do_calculate_weight_from_warp = Field[bool]( 

215 doc="Calculate coadd weight from the input warp? Otherwise, the weight is obtained from the " 

216 "visitSummaryList connection. This is meant as a fallback when run outside the pipeline.", 

217 default=False, 

218 ) 

219 do_use_artifact_mask = Field[bool]( 

220 doc="Substitute the mask planes input warp with an alternative artifact mask?", 

221 default=True, 

222 ) 

223 do_coadd_inverse_aperture_corrections = Field[bool]( 

224 doc="Coadd the inverse aperture corrections for each cell? This is formally the more accurate way " 

225 "but may be turned off for parity with deepCoadd.", 

226 default=False, 

227 ) 

228 min_overlap_fraction = RangeField[float]( 

229 doc="The minimum overlap fraction required for a single (visit, detector) input to be included in a " 

230 "cell.", 

231 # A value of 1.0 corresponds to ideal, edge-free cells. 

232 # A value of 0.0 corresponds to the deep_coadd style coadds. 

233 # This has to be at least 0.5 to ensure that the an input overlaps the 

234 # cell center. Inputs will overlap fraction less than 0.25 will 

235 # definitely not overlap the cell center. 

236 default=1.0, 

237 min=0.0, 

238 max=1.0, 

239 inclusiveMin=True, 

240 inclusiveMax=True, 

241 ) 

242 bad_mask_planes = ListField[str]( 

243 doc="Mask planes that count towards the masked fraction within a cell.", 

244 default=("BAD", "NO_DATA", "SAT", "CLIPPED"), 

245 ) 

246 remove_mask_planes = ListField[str]( 

247 doc="Mask planes to remove before coadding", 

248 default=["EDGE", "NOT_DEBLENDED"], 

249 ) 

250 calc_error_from_input_variance = Field[bool]( 

251 doc="Calculate coadd variance from input variance by stacking " 

252 "statistic. Passed to AccumulatorMeanStack.", 

253 default=True, 

254 ) 

255 mask_propagation_thresholds = DictField[str, float]( 

256 doc=( 

257 "Threshold (in fractional weight) of rejection at which we " 

258 "propagate a mask plane to the coadd; that is, we set the mask " 

259 "bit on the coadd if the fraction the rejected frames " 

260 "would have contributed exceeds this value." 

261 ), 

262 default={"SAT": 0.1}, 

263 ) 

264 max_maskfrac = RangeField[float]( 

265 doc="Maximum fraction of masked pixels in a cell. This is currently " 

266 "just a placeholder and is not used now", 

267 default=0.99, 

268 min=0.0, 

269 max=1.0, 

270 inclusiveMin=True, 

271 inclusiveMax=False, 

272 ) 

273 num_noise_realizations = Field[int]( 

274 default=0, 

275 doc=( 

276 "Number of noise planes to include in the coadd. " 

277 "This should not exceed the corresponding config parameter " 

278 "specified in `MakeDirectWarpConfig`. " 

279 ), 

280 check=lambda x: x >= 0, 

281 ) 

282 psf_warper = ConfigField( 

283 doc="Configuration for the warper that warps the PSFs. It must have the same configuration used to " 

284 "warp the images.", 

285 dtype=afwMath.Warper.ConfigClass, 

286 ) 

287 psf_dimensions = Field[int]( 

288 default=35, 

289 doc="Dimensions of the PSF image stamp size to be assigned to cells (must be odd).", 

290 check=lambda x: (x > 0) and (x % 2 == 1), 

291 ) 

292 require_artifact_mask = Field[bool]( 

293 default=True, 

294 doc="Require presence of artifact mask for each warp? Use true if using artifact rejection outputs" 

295 " from CompareWarpTask", 

296 ) 

297 do_input_map = Field[bool]( 

298 default=False, 

299 doc="Create a bitwise map of coadd inputs.", 

300 ) 

301 input_mapper = ConfigurableField( 

302 target=HealSparseInputMapTask, 

303 doc="Input map creation subtask.", 

304 ) 

305 

306 

307class AssembleCellCoaddTask(PipelineTask): 

308 """Assemble a cell-based coadded image from a set of warps. 

309 

310 This task reads in the warp one at a time, and accumulates it in all the 

311 cells that it completely overlaps with. This is the optimal I/O pattern but 

312 this also implies that it is not possible to build one or only a few cells. 

313 

314 Each cell coadds is guaranteed to have a well-defined PSF. This is done by 

315 1) excluding warps that only partially overlap a cell from that cell coadd; 

316 2) interpolating bad pixels in the warps rather than excluding them; 

317 3) by computing the coadd as a weighted mean of the warps without clipping; 

318 4) by computing the coadd PSF as the weighted mean of the PSF of the warps 

319 with the same weights. 

320 

321 The cells are (and must be) defined in the skymap, and cannot be configured 

322 or redefined here. The cells are assumed to be small enough that the PSF is 

323 assumed to be spatially constant within a cell. 

324 

325 Raises 

326 ------ 

327 NoWorkFound 

328 Raised if no input warps are provided, or no cells could be populated. 

329 RuntimeError 

330 Raised if the skymap is not cell-based. 

331 

332 Notes 

333 ----- 

334 This is not yet a part of the standard DRP pipeline. As such, the Task and 

335 especially its Config and Connections are experimental and subject to 

336 change any time without a formal RFC or standard deprecation procedures 

337 until it is included in the DRP pipeline. 

338 """ 

339 

340 ConfigClass = AssembleCellCoaddConfig 

341 _DefaultName = "assembleCellCoadd" 

342 

343 def __init__(self, *args, **kwargs): 

344 super().__init__(*args, **kwargs) 

345 if self.config.do_interpolate_coadd: 

346 self.makeSubtask("interpolate_coadd") 

347 # Suppress the warning message about fallback. 

348 self.interpolate_coadd.log.setLevel(logging.ERROR) 

349 if self.config.do_scale_zero_point: 

350 self.makeSubtask("scale_zero_point") 

351 if self.config.do_input_map: 

352 self.makeSubtask("input_mapper") 

353 

354 self.psf_warper = afwMath.Warper.fromConfig(self.config.psf_warper) 

355 if (warping_kernel_name := self.config.psf_warper.warpingKernelName.lower()).startswith("lanczos"): 

356 psf_padding = 2 * int(warping_kernel_name.lstrip("lanczos")) - 1 

357 self.log.debug( 

358 "Padding PSF image by %d pixels since the warping kernel is %s.", 

359 psf_padding, 

360 self.config.psf_warper.warpingKernelName, 

361 ) 

362 else: 

363 psf_padding = 10 

364 self.log.info( 

365 "Padding PSF image by %d pixels since the warping kernel is not Lanczos.", 

366 psf_padding, 

367 ) 

368 self.psf_padding = psf_padding 

369 

370 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

371 # Docstring inherited. 

372 if not inputRefs.inputWarps: 

373 raise NoWorkFound("No input warps provided for co-addition") 

374 self.log.info("Found %d input warps", len(inputRefs.inputWarps)) 

375 

376 # Construct skyInfo expected by run 

377 # Do not remove skyMap from inputData in case _makeSupplementaryData 

378 # needs it 

379 skyMap = butlerQC.get(inputRefs.skyMap) 

380 

381 if not skyMap.config.tractBuilder.name == "cells": 

382 raise RuntimeError("AssembleCellCoaddTask requires a cell-based skymap.") 

383 

384 outputDataId = butlerQC.quantum.dataId 

385 

386 skyInfo = makeSkyInfo(skyMap, tractId=outputDataId["tract"], patchId=outputDataId["patch"]) 

387 visitSummaryList = butlerQC.get(getattr(inputRefs, "visitSummaryList", [])) 

388 

389 units = CoaddUnits.legacy if self.config.do_scale_zero_point else CoaddUnits.nJy 

390 self.common = CommonComponents( 

391 units=units, 

392 wcs=skyInfo.patchInfo.wcs, 

393 band=outputDataId.get("band", None), 

394 identifiers=PatchIdentifiers.from_data_id(outputDataId), 

395 ) 

396 

397 inputs: dict[DataCoordinate, WarpInputs] = {} 

398 for handle in butlerQC.get(inputRefs.inputWarps): 

399 inputs[handle.dataId] = WarpInputs(warp=handle, noise_warps=[]) 

400 

401 for ref in getattr(inputRefs, "artifactMasks", []): 

402 inputs[ref.dataId].artifact_mask = butlerQC.get(ref) 

403 for ref in getattr(inputRefs, "maskedFractionWarps", []): 

404 inputs[ref.dataId].masked_fraction = butlerQC.get(ref) 

405 for n in range(self.config.num_noise_realizations): 

406 for ref in getattr(inputRefs, f"noise{n}_warps"): 

407 inputs[ref.dataId].noise_warps.append(butlerQC.get(ref)) 

408 

409 returnStruct = self.run(inputs=inputs, skyInfo=skyInfo, visitSummaryList=visitSummaryList) 

410 butlerQC.put(returnStruct, outputRefs) 

411 return returnStruct 

412 

413 @staticmethod 

414 def _compute_weight(maskedImage, statsCtrl): 

415 """Compute a weight for a masked image. 

416 

417 Parameters 

418 ---------- 

419 maskedImage : `~lsst.afw.image.MaskedImage` 

420 The masked image to compute the weight. 

421 statsCtrl : `~lsst.afw.math.StatisticsControl` 

422 A control (config-like) object for StatisticsStack. 

423 

424 Returns 

425 ------- 

426 weight : `float` 

427 Inverse of the clipped mean variance of the masked image. 

428 """ 

429 statObj = afwMath.makeStatistics( 

430 maskedImage.getVariance(), maskedImage.getMask(), afwMath.MEANCLIP, statsCtrl 

431 ) 

432 meanVar, _ = statObj.getResult(afwMath.MEANCLIP) 

433 weight = 1.0 / float(meanVar) 

434 return weight 

435 

436 @staticmethod 

437 def _construct_grid(skyInfo): 

438 """Construct a UniformGrid object from a SkyInfo struct. 

439 

440 Parameters 

441 ---------- 

442 skyInfo : `~lsst.pipe.base.Struct` 

443 A Struct object 

444 

445 Returns 

446 ------- 

447 grid : `~lsst.cell_coadds.UniformGrid` 

448 A UniformGrid object. 

449 """ 

450 padding = skyInfo.patchInfo.getCellBorder() 

451 grid_bbox = skyInfo.patchInfo.outer_bbox.erodedBy(padding) 

452 grid = UniformGrid.from_bbox_cell_size( 

453 grid_bbox, 

454 skyInfo.patchInfo.getCellInnerDimensions(), 

455 padding=padding, 

456 ) 

457 return grid 

458 

459 def _construct_grid_container(self, skyInfo, statsCtrl): 

460 """Construct a grid of AccumulatorMeanStack instances. 

461 

462 Parameters 

463 ---------- 

464 skyInfo : `~lsst.pipe.base.Struct` 

465 A Struct object 

466 statsCtrl : `~lsst.afw.math.StatisticsControl` 

467 A control (config-like) object for StatisticsStack. 

468 

469 Returns 

470 ------- 

471 gc : `~lsst.cell_coadds.GridContainer` 

472 A GridContainer object container one AccumulatorMeanStack per cell. 

473 """ 

474 grid = self._construct_grid(skyInfo) 

475 

476 maskMap = setRejectedMaskMapping(statsCtrl) 

477 self.log.debug("Obtained maskMap = %s for %s", maskMap, skyInfo.patchInfo) 

478 thresholdDict = AccumulatorMeanStack.stats_ctrl_to_threshold_dict(statsCtrl) 

479 

480 # Initialize the grid container with AccumulatorMeanStacks 

481 gc = GridContainer[AccumulatorMeanStack](grid.shape) 

482 for cellInfo in skyInfo.patchInfo: 

483 stacker = AccumulatorMeanStack( 

484 # The shape is for the numpy arrays, hence transposed. 

485 shape=(cellInfo.outer_bbox.height, cellInfo.outer_bbox.width), 

486 bit_mask_value=statsCtrl.getAndMask(), 

487 mask_threshold_dict=thresholdDict, 

488 calc_error_from_input_variance=self.config.calc_error_from_input_variance, 

489 compute_n_image=False, 

490 mask_map=maskMap, 

491 no_good_pixels_mask=statsCtrl.getNoGoodPixelsMask(), 

492 ) 

493 gc[cellInfo.index] = stacker 

494 

495 return gc 

496 

497 def _construct_stats_control(self): 

498 """Construct a StatisticsControl object for coadd. 

499 

500 Unlike AssembleCoaddTask or CompareWarpAssembleCoaddTask, there is 

501 very little to be configured apart from setting the mask planes and 

502 optionally mask propagation thresholds. 

503 

504 Returns 

505 ------- 

506 statsCtrl : `~lsst.afw.math.StatisticsControl` 

507 A control object for StatisticsStack. 

508 """ 

509 statsCtrl = afwMath.StatisticsControl() 

510 # Hardcode the numIter parameter to the default config value set in 

511 # CompareWarpAssembleCoaddTask to get consistent weights. This is NOT 

512 # exposed as a config parameter, since this is only meant to be a 

513 # fallback option that is not recommended for production. 

514 statsCtrl.setNumIter(2) 

515 statsCtrl.setAndMask(afwImage.Mask.getPlaneBitMask(self.config.bad_mask_planes)) 

516 statsCtrl.setNanSafe(True) 

517 for plane, threshold in self.config.mask_propagation_thresholds.items(): 

518 bit = afwImage.Mask.getMaskPlane(plane) 

519 statsCtrl.setMaskPropagationThreshold(bit, threshold) 

520 return statsCtrl 

521 

522 def _construct_ap_corr_grid_container(self, skyInfo): 

523 """Construct a grid of CoaddApCorrMapStacker instances. 

524 

525 Parameters 

526 ---------- 

527 skyInfo : `~lsst.pipe.base.Struct` 

528 A Struct object 

529 

530 Returns 

531 ------- 

532 gc : `~lsst.cell_coadds.GridContainer` 

533 A GridContainer object container one CoaddApCorrMapStacker per 

534 cell. 

535 """ 

536 grid = self._construct_grid(skyInfo) 

537 

538 # Initialize the grid container with CoaddApCorrMapStacker. 

539 gc = GridContainer[CoaddApCorrMapStacker](grid.shape) 

540 for cellInfo in skyInfo.patchInfo: 

541 stacker = CoaddApCorrMapStacker( 

542 evaluation_point=cellInfo.inner_bbox.getCenter(), 

543 do_coadd_inverse_ap_corr=self.config.do_coadd_inverse_aperture_corrections, 

544 ) 

545 gc[cellInfo.index] = stacker 

546 

547 return gc 

548 

549 def run( 

550 self, 

551 *, 

552 inputs: dict[DataCoordinate, WarpInputs], 

553 skyInfo, 

554 visitSummaryList: list | None = None, 

555 ): 

556 for mask_plane in self.config.bad_mask_planes: 

557 afwImage.Mask.addMaskPlane(mask_plane) 

558 for mask_plane in self.config.mask_propagation_thresholds: 

559 afwImage.Mask.addMaskPlane(mask_plane) 

560 

561 statsCtrl = self._construct_stats_control() 

562 

563 warp_stacker_gc = self._construct_grid_container(skyInfo, statsCtrl) 

564 maskfrac_stacker_gc = self._construct_grid_container(skyInfo, statsCtrl) 

565 noise_stacker_gc_list = [ 

566 self._construct_grid_container(skyInfo, statsCtrl) 

567 for n in range(self.config.num_noise_realizations) 

568 ] 

569 psf_stacker_gc = GridContainer[AccumulatorMeanStack](warp_stacker_gc.shape) 

570 psf_bbox_gc = GridContainer[geom.Box2I](warp_stacker_gc.shape) 

571 ap_corr_stacker_gc = self._construct_ap_corr_grid_container(skyInfo) 

572 

573 # Make a container to hold the cell centers in sky coordinates now, 

574 # so we don't have to recompute them for each warp 

575 # (they share a common WCS). These are needed to find the various 

576 # warp + detector combinations that contributed to each cell, and later 

577 # get the corresponding PSFs as well. 

578 cell_centers_sky = GridContainer[geom.SpherePoint](warp_stacker_gc.shape) 

579 # Make a container to hold the observation identifiers for each cell. 

580 observation_identifiers_gc = GridContainer[dict](warp_stacker_gc.shape) 

581 

582 if self.config.do_input_map: 

583 # We need to know all the visit + detector pairs in the inputs. 

584 warp_input_list = [warp_ref.warp.get(component="coaddInputs") for warp_ref in inputs.values()] 

585 visit_detectors = [] 

586 for warp_input in warp_input_list: 

587 for row in warp_input.ccds: 

588 visit_detectors.append((int(row["visit"]), int(row["ccd"]))) 

589 

590 self.input_mapper.initialize_cell_input_map( 

591 skyInfo.patchInfo.getOuterBBox(), 

592 skyInfo.patchInfo.wcs, 

593 visit_detectors, 

594 ) 

595 

596 # Populate them. 

597 for cellInfo in skyInfo.patchInfo: 

598 # Make a list to hold the observation identifiers for each cell. 

599 observation_identifiers_gc[cellInfo.index] = {} 

600 cell_center_pixel = geom.Point2D(geom.Point2I(cellInfo.inner_bbox.getCenter())) 

601 cell_centers_sky[cellInfo.index] = skyInfo.wcs.pixelToSky(cell_center_pixel) 

602 psf_bbox_gc[cellInfo.index] = geom.Box2I.makeCenteredBox( 

603 cell_center_pixel, 

604 geom.Extent2I(self.config.psf_dimensions, self.config.psf_dimensions), 

605 ) 

606 psf_stacker_gc[cellInfo.index] = AccumulatorMeanStack( 

607 # The shape is for the numpy arrays, hence transposed. 

608 shape=(self.config.psf_dimensions, self.config.psf_dimensions), 

609 bit_mask_value=0, 

610 calc_error_from_input_variance=self.config.calc_error_from_input_variance, 

611 compute_n_image=False, 

612 ) 

613 

614 if self.config.do_input_map: 

615 self.input_mapper.build_cell_input_map(cellInfo) 

616 

617 # visit_summary do not have (tract, patch, band, skymap) dimensions. 

618 if not visitSummaryList: 

619 visitSummaryList = [] 

620 visitSummaryRefDict = { 

621 visitSummaryRef.dataId["visit"]: visitSummaryRef for visitSummaryRef in visitSummaryList 

622 } 

623 

624 # Keep track of the polygons corresponding to each (visit, detector). 

625 visit_polygons: dict[ObservationIdentifiers, afwGeom.Polygon] = {} 

626 

627 # Read in one warp at a time, and accumulate it in all the cells that 

628 # it completely overlaps. 

629 for _, warp_input in inputs.items(): 

630 # warps that have been excluded from CompareWarp via visit 

631 # selection from SelectVisitsTasks will not have artifact masks. 

632 # Exclude them from the cell coadds too. 

633 if self.config.require_artifact_mask and warp_input.artifact_mask is None: 

634 self.log.info( 

635 "Excluding warp %s from cell coadds because it has no artifact mask", 

636 warp_input.dataId["visit"], 

637 ) 

638 continue 

639 

640 warp = warp_input.warp.get(parameters={"bbox": skyInfo.bbox}) 

641 masked_fraction_image = ( 

642 warp_input.masked_fraction.get(parameters={"bbox": skyInfo.bbox}) 

643 if warp_input.masked_fraction 

644 else None 

645 ) 

646 

647 # Pre-process the warp before coadding. 

648 # TODO: Can we get these mask names from artifactMask? 

649 warp.mask.addMaskPlane("CLIPPED") 

650 warp.mask.addMaskPlane("REJECTED") 

651 warp.mask.addMaskPlane("SENSOR_EDGE") 

652 warp.mask.addMaskPlane("INEXACT_PSF") 

653 

654 if artifact_mask_ref := warp_input.artifact_mask: 

655 # Apply the artifact mask to the warp. 

656 artifact_mask = artifact_mask_ref.get() 

657 assert ( 

658 warp.mask.getMaskPlaneDict() == artifact_mask.getMaskPlaneDict() 

659 ), "Mask dicts do not agree." 

660 warp.mask.array = artifact_mask.array 

661 del artifact_mask 

662 

663 if self.config.do_scale_zero_point: 

664 # Each Warp that goes into a coadd will typically have an 

665 # independent photometric zero-point. Therefore, we must scale 

666 # each Warp to set it to a common photometric zeropoint. 

667 imageScaler = self.scale_zero_point.run(exposure=warp, dataRef=warp_input.warp).imageScaler 

668 zero_point_scale_factor = imageScaler.scale 

669 self.log.debug( 

670 "Scaled the warp %s by %f to match zero points", 

671 warp_input.dataId, 

672 zero_point_scale_factor, 

673 ) 

674 else: 

675 zero_point_scale_factor = 1.0 

676 if "BUNIT" not in warp.metadata: 

677 raise ValueError(f"Warp {warp_input.dataId} has no BUNIT metadata") 

678 if warp.metadata["BUNIT"] != "nJy": 

679 raise ValueError( 

680 f"Warp {warp_input.dataId} has BUNIT {warp.metadata['BUNIT']}, expected nJy" 

681 ) 

682 

683 # Only try to remove maks planes that have been registered. 

684 to_remove = [] 

685 for plane in self.config.remove_mask_planes: 

686 if plane in warp.mask.getMaskPlaneDict(): 

687 to_remove.append(plane) 

688 removeMaskPlanes(warp.mask, to_remove, self.log) 

689 # Instead of using self.config.bad_mask_planes, we explicitly 

690 # ask statsCtrl which pixels are going to be ignored/rejected. 

691 rejected = afwImage.Mask.getPlaneBitMask( 

692 ["CLIPPED", "REJECTED"] + afwImage.Mask.interpret(statsCtrl.getAndMask()).split(",") 

693 ) 

694 

695 # Compute the weight for each CCD in the warp from the visitSummary 

696 # or from the warp itself, if not provided. Computing the weight 

697 # from the warp is not recommended, and in that case we compute one 

698 # weight per warp and not bother with per-detector weights. 

699 full_ccd_table = warp.getInfo().getCoaddInputs().ccds 

700 weights: dict[int, float] = dict.fromkeys( 

701 full_ccd_table["ccd"].tolist(), 

702 0.0, 

703 ) # Mapping from detector to weight. 

704 

705 if visitSummaryRef := visitSummaryRefDict.get(warp_input.dataId["visit"]): 

706 visitSummary = visitSummaryRef.get() 

707 for detector in full_ccd_table["ccd"].tolist(): 

708 visitSummaryRow = visitSummary.find(detector) 

709 mean_variance = visitSummaryRow["meanVar"] 

710 mean_variance *= zero_point_scale_factor**2 

711 if warp.metadata.get("BUNIT", None) == "nJy": 

712 mean_variance *= visitSummaryRow.photoCalib.getCalibrationMean() ** 2 

713 weights[detector] = 1.0 / mean_variance 

714 del visitSummary 

715 else: 

716 self.log.debug("No visit summary found for %s; using warp-based weights", warp_input.dataId) 

717 weight = self._compute_weight(warp, statsCtrl) 

718 if not np.isfinite(weight): 

719 self.log.warn("Non-finite weight for %s: skipping", warp_input.dataId) 

720 continue 

721 

722 for detector in weights: 

723 weights[detector] = weight 

724 

725 noise_warps = [ref.get(parameters={"bbox": skyInfo.bbox}) for ref in warp_input.noise_warps] 

726 

727 # Create an image where each pixel value corresponds to the 

728 # detector ID that pixel comes from. 

729 detector_map = afwImage.ImageI(bbox=warp.getBBox(), initialValue=-1) 

730 for row in full_ccd_table: 

731 transform = makeWcsPairTransform(row.wcs, warp.wcs) 

732 if (src_polygon := row.validPolygon) is None: 

733 src_polygon = afwGeom.Polygon(geom.Box2D(row.getBBox())) 

734 try: 

735 dest_polygon = src_polygon.transform(transform).intersectionSingle( 

736 geom.Box2D(warp.getBBox()) 

737 ) 

738 except SinglePolygonException: 

739 continue 

740 

741 observation_identifier = ObservationIdentifiers.from_data_id( 

742 warp_input.dataId, 

743 backup_detector=row["ccd"], 

744 ) 

745 visit_polygons[observation_identifier] = dest_polygon 

746 

747 detector_map_slice = dest_polygon.createImage(detector_map.getBBox()).array > 0 

748 if not (detector_map.array[detector_map_slice] < 0).all(): 

749 self.log.warning("Multiple detectors from visit %s are overlapping", warp_input.dataId) 

750 detector_map.array[detector_map_slice] = row["ccd"] 

751 

752 if (detector_map.array < 0).all(): 

753 self.log.warning("Unable to split the warp %s into single-detector warps.", warp_input.dataId) 

754 detector_map.array[:, :] = 0 

755 

756 for cellInfo, ccd_row in itertools.product(skyInfo.patchInfo, full_ccd_table): 

757 bbox = cellInfo.outer_bbox 

758 inner_bbox = cellInfo.inner_bbox 

759 

760 overlap_fraction = (detector_map[inner_bbox].array == ccd_row["ccd"]).mean() 

761 assert -1e-4 < overlap_fraction < 1.0001, "Overlap fraction is not within [0, 1]." 

762 if (overlap_fraction < self.config.min_overlap_fraction) or (overlap_fraction <= 0.0): 

763 self.log.debug( 

764 "Skipping %s in cell %s because it had only %.3f < %.3f fractional overlap.", 

765 warp_input.dataId, 

766 cellInfo.index, 

767 overlap_fraction, 

768 self.config.min_overlap_fraction, 

769 ) 

770 continue 

771 

772 weight = weights[int(ccd_row["ccd"])] 

773 if not np.isfinite(weight): 

774 self.log.warn( 

775 "Non-finite weight for %s in cell %s: skipping", warp_input.dataId, cellInfo.index 

776 ) 

777 continue 

778 

779 if weight == 0: 

780 self.log.info( 

781 "Zero weight for %s in cell %s: skipping", warp_input.dataId, cellInfo.index 

782 ) 

783 continue 

784 

785 # Decide if a deep copy is necessary to apply the single 

786 # detector cuts since it involves modifying the image in-place. 

787 # If within the inner cell, there are three or more different 

788 # values that detector map takes, then there are definitely 

789 # multiple detectors (one for chip gaps, two for two detectors) 

790 deep_copy = len(set(detector_map[inner_bbox].array.ravel())) >= 3 

791 if deep_copy: 

792 single_detector_mask_array = detector_map[bbox].array != ccd_row["ccd"] 

793 

794 mi = afwImage.MaskedImageF(warp[bbox].maskedImage, deep=deep_copy) 

795 if deep_copy: 

796 mi.image.array[single_detector_mask_array] = 0.0 

797 mi.variance.array[single_detector_mask_array] = np.inf 

798 nodata_or_mask = (single_detector_mask_array) * afwImage.Mask.getPlaneBitMask("NO_DATA") 

799 mi.mask[bbox].array |= nodata_or_mask 

800 warp_stacker_gc[cellInfo.index].add_masked_image(mi, weight=weight) 

801 

802 if masked_fraction_image: 

803 mi = afwImage.ImageF(masked_fraction_image[bbox], deep=deep_copy) 

804 if deep_copy: 

805 mi.array[single_detector_mask_array] = 0.0 

806 maskfrac_stacker_gc[cellInfo.index].add_image(masked_fraction_image[bbox], weight=weight) 

807 

808 for n in range(self.config.num_noise_realizations): 

809 mi = afwImage.MaskedImageF(noise_warps[n][bbox], deep=deep_copy) 

810 if deep_copy: 

811 mi.image.array[single_detector_mask_array] = 0.0 

812 mi.variance.array[single_detector_mask_array] = np.inf 

813 mi.mask[bbox].array |= nodata_or_mask 

814 noise_stacker_gc_list[n][cellInfo.index].add_masked_image(mi, weight=weight) 

815 

816 # Set the defaults for PSF shape quantities. 

817 psf_shape = afwGeom.Quadrupole() 

818 psf_shape_flag = True 

819 psf_eval_point = None 

820 try: 

821 # The `if` branch is buggy. `dest_polygon` is technically 

822 # out of scope, but Python does not raise an error. 

823 # TODO: Fix this properly in DM-53479, but sweep it under 

824 # the rug for now. 

825 if overlap_fraction < 0.5: 

826 psf_eval_point = dest_polygon.intersectionSingle( 

827 geom.Box2D(inner_bbox) 

828 ).calculateCenter() 

829 else: 

830 psf_eval_point = geom.Point2D(geom.Point2I(inner_bbox.getCenter())) 

831 psf_shape = warp.psf.computeShape(psf_eval_point) 

832 psf_shape_flag = False 

833 except SinglePolygonException: 

834 self.log.info( 

835 "Unable to find the overlapping polygon between %d detector in %s and cell %s", 

836 ccd_row["ccd"], 

837 warp_input.dataId, 

838 cellInfo.index, 

839 ) 

840 except InvalidPsfError: 

841 self.log.info( 

842 "Unable to compute PSF shape from %d detector in %s at %s", 

843 ccd_row["ccd"], 

844 warp_input.dataId, 

845 psf_eval_point, 

846 ) 

847 

848 overlaps_center = detector_map[geom.Point2I(bbox.getCenter())] == ccd_row["ccd"] 

849 

850 observation_identifier = ObservationIdentifiers.from_data_id( 

851 warp_input.dataId, 

852 backup_detector=int(ccd_row["ccd"]), 

853 ) 

854 observation_identifiers_gc[cellInfo.index][observation_identifier] = CoaddInputs( 

855 overlaps_center=overlaps_center, 

856 overlap_fraction=overlap_fraction, 

857 weight=weight, 

858 psf_shape=psf_shape, 

859 psf_shape_flag=psf_shape_flag, 

860 ) 

861 if overlaps_center is False: 

862 self.log.debug( 

863 "%s does not overlap with the center of the cell %s", 

864 warp_input.dataId, 

865 cellInfo.index, 

866 ) 

867 continue 

868 

869 # Everything below this has to do with the center of the cell 

870 calexp_point = ccd_row.getWcs().skyToPixel(cell_centers_sky[cellInfo.index]) 

871 undistorted_psf_im = ccd_row.getPsf().computeImage(calexp_point) 

872 

873 assert undistorted_psf_im.getBBox() == geom.Box2I.makeCenteredBox( 

874 calexp_point, 

875 undistorted_psf_im.getDimensions(), 

876 ), "PSF image does not share the coordinates of the 'calexp'" 

877 

878 # Convert the PSF image from Image to MaskedImage and 

879 # zero-pad the image. 

880 undistorted_psf_bbox = undistorted_psf_im.getBBox() 

881 undistorted_psf_maskedImage = afwImage.MaskedImageD( 

882 undistorted_psf_bbox.dilatedBy(self.psf_padding) 

883 ) 

884 undistorted_psf_maskedImage.image[undistorted_psf_bbox].array[:, :] = undistorted_psf_im.array 

885 # TODO: In DM-43585, use the variance plane value from noise. 

886 undistorted_psf_maskedImage.variance += 1.0 # Set variance to 1 

887 

888 warped_psf_maskedImage = self.psf_warper.warpImage( 

889 destWcs=skyInfo.wcs, 

890 srcImage=undistorted_psf_maskedImage, 

891 srcWcs=ccd_row.getWcs(), 

892 destBBox=psf_bbox_gc[cellInfo.index], 

893 ) 

894 

895 # There may be NaNs in the PSF image. Set them to 0.0 

896 warped_psf_maskedImage.variance.array[np.isnan(warped_psf_maskedImage.image.array)] = 1.0 

897 warped_psf_maskedImage.image.array[np.isnan(warped_psf_maskedImage.image.array)] = 0.0 

898 

899 psf_stacker = psf_stacker_gc[cellInfo.index] 

900 psf_stacker.add_masked_image(warped_psf_maskedImage, weight=weight) 

901 

902 if not (0.995 < (psf_normalization := warped_psf_maskedImage.image.array.sum()) < 1.005): 

903 self.log.warning( 

904 "PSF image for %s in %s is not normalized to 1.0, but instead %f", 

905 warp_input.dataId, 

906 cellInfo.index, 

907 psf_normalization, 

908 ) 

909 

910 if (ap_corr_map := warp.getInfo().getApCorrMap()) is not None: 

911 ap_corr_stacker_gc[cellInfo.index].add(ap_corr_map, weight=weight) 

912 

913 if self.config.do_input_map: 

914 self.input_mapper.add_warp_to_cell_input_map( 

915 ccd_row, 

916 weight, 

917 cellInfo, 

918 ) 

919 

920 del warp 

921 

922 if self.config.do_input_map: 

923 inputMap = self.input_mapper.cell_input_map 

924 else: 

925 inputMap = None 

926 

927 # Update common with the visit polygons. 

928 self.common = dataclasses.replace( 

929 self.common, 

930 visit_polygons=visit_polygons, 

931 ) 

932 

933 cells: list[SingleCellCoadd] = [] 

934 for cellInfo in skyInfo.patchInfo: 

935 if len(observation_identifiers_gc[cellInfo.index]) == 0: 

936 self.log.debug("Skipping cell %s because it has no input warps", cellInfo.index) 

937 continue 

938 

939 cell_masked_image = afwImage.MaskedImageF(cellInfo.outer_bbox) 

940 cell_maskfrac_image = afwImage.ImageF(cellInfo.outer_bbox) 

941 cell_noise_images = [ 

942 afwImage.MaskedImageF(cellInfo.outer_bbox) for n in range(self.config.num_noise_realizations) 

943 ] 

944 psf_masked_image = afwImage.MaskedImageF(psf_bbox_gc[cellInfo.index]) 

945 

946 warp_stacker_gc[cellInfo.index].fill_stacked_masked_image(cell_masked_image) 

947 maskfrac_stacker_gc[cellInfo.index].fill_stacked_image(cell_maskfrac_image) 

948 for n in range(self.config.num_noise_realizations): 

949 noise_stacker_gc_list[n][cellInfo.index].fill_stacked_masked_image(cell_noise_images[n]) 

950 psf_stacker_gc[cellInfo.index].fill_stacked_masked_image(psf_masked_image) 

951 

952 if ap_corr_stacker_gc[cellInfo.index].ap_corr_names: 

953 ap_corr_map = ap_corr_stacker_gc[cellInfo.index].final_ap_corr_map 

954 else: 

955 ap_corr_map = None 

956 

957 # Post-process the coadd before converting to new data structures. 

958 if np.isnan(cell_masked_image.image.array).all(): 

959 cell_masked_image.image.array[:, :] = 0.0 

960 cell_masked_image.variance.array[:, :] = np.inf 

961 elif self.config.do_interpolate_coadd: 

962 self.interpolate_coadd.run(cell_masked_image, planeName="NO_DATA") 

963 for noise_image in cell_noise_images: 

964 self.interpolate_coadd.run(noise_image, planeName="NO_DATA") 

965 # The variance must be positive; work around for DM-3201. 

966 varArray = cell_masked_image.variance.array 

967 with np.errstate(invalid="ignore"): 

968 varArray[:] = np.where(varArray > 0, varArray, np.inf) 

969 

970 afwImage.Mask.addMaskPlane("INEXACT_PSF") 

971 cell_masked_image.mask.array[ 

972 (cell_masked_image.mask.array & rejected) > 0 

973 ] |= cell_masked_image.mask.getPlaneBitMask("INEXACT_PSF") 

974 

975 image_planes = OwnedImagePlanes.from_masked_image( 

976 masked_image=cell_masked_image, 

977 mask_fractions=cell_maskfrac_image, 

978 noise_realizations=[noise_image.image for noise_image in cell_noise_images], 

979 ) 

980 identifiers = CellIdentifiers( 

981 cell=cellInfo.index, 

982 skymap=self.common.identifiers.skymap, 

983 tract=self.common.identifiers.tract, 

984 patch=self.common.identifiers.patch, 

985 band=self.common.identifiers.band, 

986 ) 

987 

988 singleCellCoadd = SingleCellCoadd( 

989 outer=image_planes, 

990 psf=psf_masked_image.image, 

991 inner_bbox=cellInfo.inner_bbox, 

992 inputs=observation_identifiers_gc[cellInfo.index], 

993 common=self.common, 

994 identifiers=identifiers, 

995 aperture_correction_map=ap_corr_map, 

996 ) 

997 # TODO: Attach transmission curve when they become available. 

998 cells.append(singleCellCoadd) 

999 

1000 if not cells: 

1001 raise NoWorkFound("No cells could be populated for the cell coadd.") 

1002 

1003 grid = self._construct_grid(skyInfo) 

1004 multipleCellCoadd = MultipleCellCoadd( 

1005 cells, 

1006 grid=grid, 

1007 outer_cell_size=cellInfo.outer_bbox.getDimensions(), 

1008 inner_bbox=None, 

1009 common=self.common, 

1010 psf_image_size=cells[0].psf_image.getDimensions(), 

1011 ) 

1012 

1013 return Struct( 

1014 multipleCellCoadd=multipleCellCoadd, 

1015 inputMap=inputMap, 

1016 ) 

1017 

1018 

1019class ConvertMultipleCellCoaddToExposureConnections( 

1020 PipelineTaskConnections, 

1021 dimensions=("tract", "patch", "band", "skymap"), 

1022 defaultTemplates={"inputCoaddName": "deep", "inputCoaddSuffix": "Cell"}, 

1023): 

1024 cellCoaddExposure = Input( 

1025 doc="Output coadded exposure, produced by stacking input warps", 

1026 name="{inputCoaddName}Coadd{inputCoaddSuffix}", 

1027 storageClass="MultipleCellCoadd", 

1028 dimensions=("tract", "patch", "skymap", "band"), 

1029 ) 

1030 

1031 stitchedCoaddExposure = Output( 

1032 doc="Output stitched coadded exposure, produced by stacking input warps", 

1033 name="{inputCoaddName}Coadd{inputCoaddSuffix}_stitched", 

1034 storageClass="ExposureF", 

1035 dimensions=("tract", "patch", "skymap", "band"), 

1036 ) 

1037 

1038 

1039class ConvertMultipleCellCoaddToExposureConfig( 

1040 PipelineTaskConfig, pipelineConnections=ConvertMultipleCellCoaddToExposureConnections 

1041): 

1042 """A trivial PipelineTaskConfig class for 

1043 ConvertMultipleCellCoaddToExposureTask. 

1044 """ 

1045 

1046 

1047class ConvertMultipleCellCoaddToExposureTask(PipelineTask): 

1048 """An after burner PipelineTask that converts a cell-based coadd from 

1049 `MultipleCellCoadd` format to `ExposureF` format. 

1050 

1051 The run method stitches the cell-based coadd into contiguous exposure and 

1052 returns it in as an `Exposure` object. This is lossy as it preserves only 

1053 the pixels in the inner bounding box of the cells and discards the values 

1054 in the buffer region. 

1055 

1056 Notes 

1057 ----- 

1058 This task has no configurable parameters. 

1059 """ 

1060 

1061 ConfigClass = ConvertMultipleCellCoaddToExposureConfig 

1062 _DefaultName = "convertMultipleCellCoaddToExposure" 

1063 

1064 def run(self, cellCoaddExposure): 

1065 return Struct( 

1066 stitchedCoaddExposure=cellCoaddExposure.stitch().asExposure(), 

1067 )