Coverage for python / lsst / drp / tasks / assemble_cell_coadd.py: 16%

384 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-22 09:04 +0000

1# This file is part of drp_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ( 

25 "AssembleCellCoaddTask", 

26 "AssembleCellCoaddConfig", 

27 "ConvertMultipleCellCoaddToExposureTask", 

28) 

29 

30import dataclasses 

31import itertools 

32 

33import numpy as np 

34 

35import lsst.afw.geom as afwGeom 

36import lsst.afw.image as afwImage 

37import lsst.afw.math as afwMath 

38import lsst.geom as geom 

39from lsst.afw.detection import InvalidPsfError 

40from lsst.afw.geom import SinglePolygonException, makeWcsPairTransform 

41from lsst.cell_coadds import ( 

42 CellIdentifiers, 

43 CoaddApCorrMapStacker, 

44 CoaddInputs, 

45 CoaddUnits, 

46 CommonComponents, 

47 GridContainer, 

48 MultipleCellCoadd, 

49 ObservationIdentifiers, 

50 OwnedImagePlanes, 

51 PatchIdentifiers, 

52 SingleCellCoadd, 

53 UniformGrid, 

54) 

55from lsst.daf.butler import DataCoordinate, DeferredDatasetHandle 

56from lsst.meas.algorithms import AccumulatorMeanStack 

57from lsst.pex.config import ConfigField, ConfigurableField, DictField, Field, ListField, RangeField 

58from lsst.pipe.base import ( 

59 InMemoryDatasetHandle, 

60 NoWorkFound, 

61 PipelineTask, 

62 PipelineTaskConfig, 

63 PipelineTaskConnections, 

64 Struct, 

65) 

66from lsst.pipe.base.connectionTypes import Input, Output 

67from lsst.pipe.tasks.coaddBase import makeSkyInfo, removeMaskPlanes, setRejectedMaskMapping 

68from lsst.pipe.tasks.healSparseMapping import HealSparseInputMapTask 

69from lsst.pipe.tasks.interpImage import InterpImageTask 

70from lsst.pipe.tasks.scaleZeroPoint import ScaleZeroPointTask 

71from lsst.skymap import BaseSkyMap 

72 

73 

74@dataclasses.dataclass 

75class WarpInputs: 

76 """Collection of associate inputs along with warps.""" 

77 

78 warp: DeferredDatasetHandle | InMemoryDatasetHandle 

79 """Handle for the warped exposure.""" 

80 

81 masked_fraction: DeferredDatasetHandle | InMemoryDatasetHandle | None = None 

82 """Handle for the masked fraction image.""" 

83 

84 artifact_mask: DeferredDatasetHandle | InMemoryDatasetHandle | None = None 

85 """Handle for the CompareWarp artifact mask.""" 

86 

87 noise_warps: list[DeferredDatasetHandle | InMemoryDatasetHandle] = dataclasses.field(default_factory=list) 

88 """List of handles for the noise warps""" 

89 

90 @property 

91 def dataId(self) -> DataCoordinate: 

92 """DataID corresponding to the warp. 

93 

94 Returns 

95 ------- 

96 data_id : `~lsst.daf.butler.DataCoordinate` 

97 DataID of the warp. 

98 """ 

99 return self.warp.dataId 

100 

101 

102class AssembleCellCoaddConnections( 

103 PipelineTaskConnections, 

104 dimensions=("tract", "patch", "band", "skymap"), 

105 defaultTemplates={"inputWarpName": "deep", "outputCoaddSuffix": "Cell"}, 

106): 

107 inputWarps = Input( 

108 doc="Input warps", 

109 name="{inputWarpName}Coadd_directWarp", 

110 storageClass="ExposureF", 

111 dimensions=("tract", "patch", "skymap", "visit", "instrument"), 

112 deferLoad=True, 

113 multiple=True, 

114 ) 

115 

116 maskedFractionWarps = Input( 

117 doc="Mask fraction warps", 

118 name="{inputWarpName}Coadd_directWarp_maskedFraction", 

119 storageClass="ImageF", 

120 dimensions=("tract", "patch", "skymap", "visit", "instrument"), 

121 deferLoad=True, 

122 multiple=True, 

123 ) 

124 

125 artifactMasks = Input( 

126 doc="Artifact masks to be applied to the input warps", 

127 name="compare_warp_artifact_mask", 

128 storageClass="Mask", 

129 dimensions=("tract", "patch", "skymap", "visit", "instrument"), 

130 deferLoad=True, 

131 multiple=True, 

132 ) 

133 

134 visitSummaryList = Input( 

135 doc="Input visit-summary catalogs with updated calibration objects. Mainly used for coadd weights.", 

136 name="finalVisitSummary", 

137 storageClass="ExposureCatalog", 

138 dimensions=("instrument", "visit"), 

139 deferLoad=True, 

140 multiple=True, 

141 ) 

142 

143 skyMap = Input( 

144 doc="Input definition of geometry/bbox and projection/wcs. This must be cell-based.", 

145 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

146 storageClass="SkyMap", 

147 dimensions=("skymap",), 

148 ) 

149 

150 multipleCellCoadd = Output( 

151 doc="Output multiple cell coadd", 

152 name="{inputWarpName}Coadd{outputCoaddSuffix}", 

153 storageClass="MultipleCellCoadd", 

154 dimensions=("tract", "patch", "band", "skymap"), 

155 ) 

156 

157 inputMap = Output( 

158 doc="Output healsparse map of input images", 

159 name="{inputWarpName}Coadd_inputMap", 

160 storageClass="HealSparseMap", 

161 dimensions=("tract", "patch", "band", "skymap"), 

162 ) 

163 

164 def __init__(self, *, config=None): 

165 super().__init__(config=config) 

166 

167 if not config: 

168 return 

169 

170 if config.do_calculate_weight_from_warp: 

171 del self.visitSummaryList 

172 

173 if not config.do_use_artifact_mask: 

174 del self.artifactMasks 

175 

176 if not config.do_input_map: 

177 del self.inputMap 

178 

179 # Dynamically set input connections for noise images, depending on the 

180 # number of noise realizations specified in the config. 

181 for n in range(config.num_noise_realizations): 

182 noise_warps = Input( 

183 doc="Input noise warps", 

184 name=f"direct_warp_noise{n}", 

185 storageClass="MaskedImageF", 

186 dimensions=("tract", "patch", "skymap", "visit", "instrument"), 

187 deferLoad=True, 

188 multiple=True, 

189 ) 

190 setattr(self, f"noise{n}_warps", noise_warps) 

191 

192 

193class AssembleCellCoaddConfig(PipelineTaskConfig, pipelineConnections=AssembleCellCoaddConnections): 

194 do_interpolate_coadd = Field[bool](doc="Interpolate over pixels with NO_DATA mask set?", default=True) 

195 interpolate_coadd = ConfigurableField( 

196 target=InterpImageTask, 

197 doc="Task to interpolate (and extrapolate) over pixels with NO_DATA mask on cell coadds", 

198 ) 

199 do_scale_zero_point = Field[bool]( 

200 doc="Scale warps to a common zero point? This is not needed if they have absolute flux calibration.", 

201 default=False, 

202 deprecated="Now that visits are scaled to nJy it is no longer necessary or " 

203 "recommended to scale the zero point, so this will be removed " 

204 "after v29.", 

205 ) 

206 scale_zero_point = ConfigurableField( 

207 target=ScaleZeroPointTask, 

208 doc="Task to scale warps to a common zero point", 

209 deprecated="Now that visits are scaled to nJy it is no longer necessary or " 

210 "recommended to scale the zero point, so this will be removed " 

211 "after v29.", 

212 ) 

213 do_calculate_weight_from_warp = Field[bool]( 

214 doc="Calculate coadd weight from the input warp? Otherwise, the weight is obtained from the " 

215 "visitSummaryList connection. This is meant as a fallback when run outside the pipeline.", 

216 default=False, 

217 ) 

218 do_use_artifact_mask = Field[bool]( 

219 doc="Substitute the mask planes input warp with an alternative artifact mask?", 

220 default=True, 

221 ) 

222 do_coadd_inverse_aperture_corrections = Field[bool]( 

223 doc="Coadd the inverse aperture corrections for each cell? This is formally the more accurate way " 

224 "but may be turned off for parity with deepCoadd.", 

225 default=False, 

226 ) 

227 min_overlap_fraction = RangeField[float]( 

228 doc="The minimum overlap fraction required for a single (visit, detector) input to be included in a " 

229 "cell.", 

230 # A value of 1.0 corresponds to ideal, edge-free cells. 

231 # A value of 0.0 corresponds to the deep_coadd style coadds. 

232 # This has to be at least 0.5 to ensure that the an input overlaps the 

233 # cell center. Inputs will overlap fraction less than 0.25 will 

234 # definitely not overlap the cell center. 

235 default=1.0, 

236 min=0.0, 

237 max=1.0, 

238 inclusiveMin=True, 

239 inclusiveMax=True, 

240 ) 

241 bad_mask_planes = ListField[str]( 

242 doc="Mask planes that count towards the masked fraction within a cell.", 

243 default=("BAD", "NO_DATA", "SAT", "CLIPPED"), 

244 ) 

245 remove_mask_planes = ListField[str]( 

246 doc="Mask planes to remove before coadding", 

247 default=["EDGE", "NOT_DEBLENDED"], 

248 ) 

249 calc_error_from_input_variance = Field[bool]( 

250 doc="Calculate coadd variance from input variance by stacking " 

251 "statistic. Passed to AccumulatorMeanStack.", 

252 default=True, 

253 ) 

254 mask_propagation_thresholds = DictField[str, float]( 

255 doc=( 

256 "Threshold (in fractional weight) of rejection at which we " 

257 "propagate a mask plane to the coadd; that is, we set the mask " 

258 "bit on the coadd if the fraction the rejected frames " 

259 "would have contributed exceeds this value." 

260 ), 

261 default={"SAT": 0.1}, 

262 ) 

263 max_maskfrac = RangeField[float]( 

264 doc="Maximum fraction of masked pixels in a cell. This is currently " 

265 "just a placeholder and is not used now", 

266 default=0.99, 

267 min=0.0, 

268 max=1.0, 

269 inclusiveMin=True, 

270 inclusiveMax=False, 

271 ) 

272 num_noise_realizations = Field[int]( 

273 default=0, 

274 doc=( 

275 "Number of noise planes to include in the coadd. " 

276 "This should not exceed the corresponding config parameter " 

277 "specified in `MakeDirectWarpConfig`. " 

278 ), 

279 check=lambda x: x >= 0, 

280 ) 

281 psf_warper = ConfigField( 

282 doc="Configuration for the warper that warps the PSFs. It must have the same configuration used to " 

283 "warp the images.", 

284 dtype=afwMath.Warper.ConfigClass, 

285 ) 

286 psf_dimensions = Field[int]( 

287 default=35, 

288 doc="Dimensions of the PSF image stamp size to be assigned to cells (must be odd).", 

289 check=lambda x: (x > 0) and (x % 2 == 1), 

290 ) 

291 require_artifact_mask = Field[bool]( 

292 default=True, 

293 doc="Require presence of artifact mask for each warp? Use true if using artifact rejection outputs" 

294 " from CompareWarpTask", 

295 ) 

296 do_input_map = Field[bool]( 

297 default=False, 

298 doc="Create a bitwise map of coadd inputs.", 

299 ) 

300 input_mapper = ConfigurableField( 

301 target=HealSparseInputMapTask, 

302 doc="Input map creation subtask.", 

303 ) 

304 

305 

306class AssembleCellCoaddTask(PipelineTask): 

307 """Assemble a cell-based coadded image from a set of warps. 

308 

309 This task reads in the warp one at a time, and accumulates it in all the 

310 cells that it completely overlaps with. This is the optimal I/O pattern but 

311 this also implies that it is not possible to build one or only a few cells. 

312 

313 Each cell coadds is guaranteed to have a well-defined PSF. This is done by 

314 1) excluding warps that only partially overlap a cell from that cell coadd; 

315 2) interpolating bad pixels in the warps rather than excluding them; 

316 3) by computing the coadd as a weighted mean of the warps without clipping; 

317 4) by computing the coadd PSF as the weighted mean of the PSF of the warps 

318 with the same weights. 

319 

320 The cells are (and must be) defined in the skymap, and cannot be configured 

321 or redefined here. The cells are assumed to be small enough that the PSF is 

322 assumed to be spatially constant within a cell. 

323 

324 Raises 

325 ------ 

326 NoWorkFound 

327 Raised if no input warps are provided, or no cells could be populated. 

328 RuntimeError 

329 Raised if the skymap is not cell-based. 

330 

331 Notes 

332 ----- 

333 This is not yet a part of the standard DRP pipeline. As such, the Task and 

334 especially its Config and Connections are experimental and subject to 

335 change any time without a formal RFC or standard deprecation procedures 

336 until it is included in the DRP pipeline. 

337 """ 

338 

339 ConfigClass = AssembleCellCoaddConfig 

340 _DefaultName = "assembleCellCoadd" 

341 

342 def __init__(self, *args, **kwargs): 

343 super().__init__(*args, **kwargs) 

344 if self.config.do_interpolate_coadd: 

345 self.makeSubtask("interpolate_coadd") 

346 if self.config.do_scale_zero_point: 

347 self.makeSubtask("scale_zero_point") 

348 if self.config.do_input_map: 

349 self.makeSubtask("input_mapper") 

350 

351 self.psf_warper = afwMath.Warper.fromConfig(self.config.psf_warper) 

352 if (warping_kernel_name := self.config.psf_warper.warpingKernelName.lower()).startswith("lanczos"): 

353 psf_padding = 2 * int(warping_kernel_name.lstrip("lanczos")) - 1 

354 self.log.debug( 

355 "Padding PSF image by %d pixels since the warping kernel is %s.", 

356 psf_padding, 

357 self.config.psf_warper.warpingKernelName, 

358 ) 

359 else: 

360 psf_padding = 10 

361 self.log.info( 

362 "Padding PSF image by %d pixels since the warping kernel is not Lanczos.", 

363 psf_padding, 

364 ) 

365 self.psf_padding = psf_padding 

366 

367 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

368 # Docstring inherited. 

369 if not inputRefs.inputWarps: 

370 raise NoWorkFound("No input warps provided for co-addition") 

371 self.log.info("Found %d input warps", len(inputRefs.inputWarps)) 

372 

373 # Construct skyInfo expected by run 

374 # Do not remove skyMap from inputData in case _makeSupplementaryData 

375 # needs it 

376 skyMap = butlerQC.get(inputRefs.skyMap) 

377 

378 if not skyMap.config.tractBuilder.name == "cells": 

379 raise RuntimeError("AssembleCellCoaddTask requires a cell-based skymap.") 

380 

381 outputDataId = butlerQC.quantum.dataId 

382 

383 skyInfo = makeSkyInfo(skyMap, tractId=outputDataId["tract"], patchId=outputDataId["patch"]) 

384 visitSummaryList = butlerQC.get(getattr(inputRefs, "visitSummaryList", [])) 

385 

386 units = CoaddUnits.legacy if self.config.do_scale_zero_point else CoaddUnits.nJy 

387 self.common = CommonComponents( 

388 units=units, 

389 wcs=skyInfo.patchInfo.wcs, 

390 band=outputDataId.get("band", None), 

391 identifiers=PatchIdentifiers.from_data_id(outputDataId), 

392 ) 

393 

394 inputs: dict[DataCoordinate, WarpInputs] = {} 

395 for handle in butlerQC.get(inputRefs.inputWarps): 

396 inputs[handle.dataId] = WarpInputs(warp=handle, noise_warps=[]) 

397 

398 for ref in getattr(inputRefs, "artifactMasks", []): 

399 inputs[ref.dataId].artifact_mask = butlerQC.get(ref) 

400 for ref in getattr(inputRefs, "maskedFractionWarps", []): 

401 inputs[ref.dataId].masked_fraction = butlerQC.get(ref) 

402 for n in range(self.config.num_noise_realizations): 

403 for ref in getattr(inputRefs, f"noise{n}_warps"): 

404 inputs[ref.dataId].noise_warps.append(butlerQC.get(ref)) 

405 

406 returnStruct = self.run(inputs=inputs, skyInfo=skyInfo, visitSummaryList=visitSummaryList) 

407 butlerQC.put(returnStruct, outputRefs) 

408 return returnStruct 

409 

410 @staticmethod 

411 def _compute_weight(maskedImage, statsCtrl): 

412 """Compute a weight for a masked image. 

413 

414 Parameters 

415 ---------- 

416 maskedImage : `~lsst.afw.image.MaskedImage` 

417 The masked image to compute the weight. 

418 statsCtrl : `~lsst.afw.math.StatisticsControl` 

419 A control (config-like) object for StatisticsStack. 

420 

421 Returns 

422 ------- 

423 weight : `float` 

424 Inverse of the clipped mean variance of the masked image. 

425 """ 

426 statObj = afwMath.makeStatistics( 

427 maskedImage.getVariance(), maskedImage.getMask(), afwMath.MEANCLIP, statsCtrl 

428 ) 

429 meanVar, _ = statObj.getResult(afwMath.MEANCLIP) 

430 weight = 1.0 / float(meanVar) 

431 return weight 

432 

433 @staticmethod 

434 def _construct_grid(skyInfo): 

435 """Construct a UniformGrid object from a SkyInfo struct. 

436 

437 Parameters 

438 ---------- 

439 skyInfo : `~lsst.pipe.base.Struct` 

440 A Struct object 

441 

442 Returns 

443 ------- 

444 grid : `~lsst.cell_coadds.UniformGrid` 

445 A UniformGrid object. 

446 """ 

447 padding = skyInfo.patchInfo.getCellBorder() 

448 grid_bbox = skyInfo.patchInfo.outer_bbox.erodedBy(padding) 

449 grid = UniformGrid.from_bbox_cell_size( 

450 grid_bbox, 

451 skyInfo.patchInfo.getCellInnerDimensions(), 

452 padding=padding, 

453 ) 

454 return grid 

455 

456 def _construct_grid_container(self, skyInfo, statsCtrl): 

457 """Construct a grid of AccumulatorMeanStack instances. 

458 

459 Parameters 

460 ---------- 

461 skyInfo : `~lsst.pipe.base.Struct` 

462 A Struct object 

463 statsCtrl : `~lsst.afw.math.StatisticsControl` 

464 A control (config-like) object for StatisticsStack. 

465 

466 Returns 

467 ------- 

468 gc : `~lsst.cell_coadds.GridContainer` 

469 A GridContainer object container one AccumulatorMeanStack per cell. 

470 """ 

471 grid = self._construct_grid(skyInfo) 

472 

473 maskMap = setRejectedMaskMapping(statsCtrl) 

474 self.log.debug("Obtained maskMap = %s for %s", maskMap, skyInfo.patchInfo) 

475 thresholdDict = AccumulatorMeanStack.stats_ctrl_to_threshold_dict(statsCtrl) 

476 

477 # Initialize the grid container with AccumulatorMeanStacks 

478 gc = GridContainer[AccumulatorMeanStack](grid.shape) 

479 for cellInfo in skyInfo.patchInfo: 

480 stacker = AccumulatorMeanStack( 

481 # The shape is for the numpy arrays, hence transposed. 

482 shape=(cellInfo.outer_bbox.height, cellInfo.outer_bbox.width), 

483 bit_mask_value=statsCtrl.getAndMask(), 

484 mask_threshold_dict=thresholdDict, 

485 calc_error_from_input_variance=self.config.calc_error_from_input_variance, 

486 compute_n_image=False, 

487 mask_map=maskMap, 

488 no_good_pixels_mask=statsCtrl.getNoGoodPixelsMask(), 

489 ) 

490 gc[cellInfo.index] = stacker 

491 

492 return gc 

493 

494 def _construct_stats_control(self): 

495 """Construct a StatisticsControl object for coadd. 

496 

497 Unlike AssembleCoaddTask or CompareWarpAssembleCoaddTask, there is 

498 very little to be configured apart from setting the mask planes and 

499 optionally mask propagation thresholds. 

500 

501 Returns 

502 ------- 

503 statsCtrl : `~lsst.afw.math.StatisticsControl` 

504 A control object for StatisticsStack. 

505 """ 

506 statsCtrl = afwMath.StatisticsControl() 

507 # Hardcode the numIter parameter to the default config value set in 

508 # CompareWarpAssembleCoaddTask to get consistent weights. This is NOT 

509 # exposed as a config parameter, since this is only meant to be a 

510 # fallback option that is not recommended for production. 

511 statsCtrl.setNumIter(2) 

512 statsCtrl.setAndMask(afwImage.Mask.getPlaneBitMask(self.config.bad_mask_planes)) 

513 statsCtrl.setNanSafe(True) 

514 for plane, threshold in self.config.mask_propagation_thresholds.items(): 

515 bit = afwImage.Mask.getMaskPlane(plane) 

516 statsCtrl.setMaskPropagationThreshold(bit, threshold) 

517 return statsCtrl 

518 

519 def _construct_ap_corr_grid_container(self, skyInfo): 

520 """Construct a grid of CoaddApCorrMapStacker instances. 

521 

522 Parameters 

523 ---------- 

524 skyInfo : `~lsst.pipe.base.Struct` 

525 A Struct object 

526 

527 Returns 

528 ------- 

529 gc : `~lsst.cell_coadds.GridContainer` 

530 A GridContainer object container one CoaddApCorrMapStacker per 

531 cell. 

532 """ 

533 grid = self._construct_grid(skyInfo) 

534 

535 # Initialize the grid container with CoaddApCorrMapStacker. 

536 gc = GridContainer[CoaddApCorrMapStacker](grid.shape) 

537 for cellInfo in skyInfo.patchInfo: 

538 stacker = CoaddApCorrMapStacker( 

539 evaluation_point=cellInfo.inner_bbox.getCenter(), 

540 do_coadd_inverse_ap_corr=self.config.do_coadd_inverse_aperture_corrections, 

541 ) 

542 gc[cellInfo.index] = stacker 

543 

544 return gc 

545 

546 def run( 

547 self, 

548 *, 

549 inputs: dict[DataCoordinate, WarpInputs], 

550 skyInfo, 

551 visitSummaryList: list | None = None, 

552 ): 

553 for mask_plane in self.config.bad_mask_planes: 

554 afwImage.Mask.addMaskPlane(mask_plane) 

555 for mask_plane in self.config.mask_propagation_thresholds: 

556 afwImage.Mask.addMaskPlane(mask_plane) 

557 

558 statsCtrl = self._construct_stats_control() 

559 

560 warp_stacker_gc = self._construct_grid_container(skyInfo, statsCtrl) 

561 maskfrac_stacker_gc = self._construct_grid_container(skyInfo, statsCtrl) 

562 noise_stacker_gc_list = [ 

563 self._construct_grid_container(skyInfo, statsCtrl) 

564 for n in range(self.config.num_noise_realizations) 

565 ] 

566 psf_stacker_gc = GridContainer[AccumulatorMeanStack](warp_stacker_gc.shape) 

567 psf_bbox_gc = GridContainer[geom.Box2I](warp_stacker_gc.shape) 

568 ap_corr_stacker_gc = self._construct_ap_corr_grid_container(skyInfo) 

569 

570 # Make a container to hold the cell centers in sky coordinates now, 

571 # so we don't have to recompute them for each warp 

572 # (they share a common WCS). These are needed to find the various 

573 # warp + detector combinations that contributed to each cell, and later 

574 # get the corresponding PSFs as well. 

575 cell_centers_sky = GridContainer[geom.SpherePoint](warp_stacker_gc.shape) 

576 # Make a container to hold the observation identifiers for each cell. 

577 observation_identifiers_gc = GridContainer[dict](warp_stacker_gc.shape) 

578 

579 if self.config.do_input_map: 

580 # We need to know all the visit + detector pairs in the inputs. 

581 warp_input_list = [warp_ref.warp.get(component="coaddInputs") for warp_ref in inputs.values()] 

582 visit_detectors = [] 

583 for warp_input in warp_input_list: 

584 for row in warp_input.ccds: 

585 visit_detectors.append((int(row["visit"]), int(row["ccd"]))) 

586 

587 self.input_mapper.initialize_cell_input_map( 

588 skyInfo.patchInfo.getOuterBBox(), 

589 skyInfo.patchInfo.wcs, 

590 visit_detectors, 

591 ) 

592 

593 # Populate them. 

594 for cellInfo in skyInfo.patchInfo: 

595 # Make a list to hold the observation identifiers for each cell. 

596 observation_identifiers_gc[cellInfo.index] = {} 

597 cell_center_pixel = geom.Point2D(geom.Point2I(cellInfo.inner_bbox.getCenter())) 

598 cell_centers_sky[cellInfo.index] = skyInfo.wcs.pixelToSky(cell_center_pixel) 

599 psf_bbox_gc[cellInfo.index] = geom.Box2I.makeCenteredBox( 

600 cell_center_pixel, 

601 geom.Extent2I(self.config.psf_dimensions, self.config.psf_dimensions), 

602 ) 

603 psf_stacker_gc[cellInfo.index] = AccumulatorMeanStack( 

604 # The shape is for the numpy arrays, hence transposed. 

605 shape=(self.config.psf_dimensions, self.config.psf_dimensions), 

606 bit_mask_value=0, 

607 calc_error_from_input_variance=self.config.calc_error_from_input_variance, 

608 compute_n_image=False, 

609 ) 

610 

611 if self.config.do_input_map: 

612 self.input_mapper.build_cell_input_map(cellInfo) 

613 

614 # visit_summary do not have (tract, patch, band, skymap) dimensions. 

615 if not visitSummaryList: 

616 visitSummaryList = [] 

617 visitSummaryRefDict = { 

618 visitSummaryRef.dataId["visit"]: visitSummaryRef for visitSummaryRef in visitSummaryList 

619 } 

620 

621 # Keep track of the polygons corresponding to each (visit, detector). 

622 visit_polygons: dict[ObservationIdentifiers, afwGeom.Polygon] = {} 

623 

624 # Read in one warp at a time, and accumulate it in all the cells that 

625 # it completely overlaps. 

626 for _, warp_input in inputs.items(): 

627 # warps that have been excluded from CompareWarp via visit 

628 # selection from SelectVisitsTasks will not have artifact masks. 

629 # Exclude them from the cell coadds too. 

630 if self.config.require_artifact_mask and warp_input.artifact_mask is None: 

631 self.log.info( 

632 "Excluding warp %s from cell coadds because it has no artifact mask", 

633 warp_input.dataId["visit"], 

634 ) 

635 continue 

636 

637 warp = warp_input.warp.get(parameters={"bbox": skyInfo.bbox}) 

638 masked_fraction_image = ( 

639 warp_input.masked_fraction.get(parameters={"bbox": skyInfo.bbox}) 

640 if warp_input.masked_fraction 

641 else None 

642 ) 

643 

644 # Pre-process the warp before coadding. 

645 # TODO: Can we get these mask names from artifactMask? 

646 warp.mask.addMaskPlane("CLIPPED") 

647 warp.mask.addMaskPlane("REJECTED") 

648 warp.mask.addMaskPlane("SENSOR_EDGE") 

649 warp.mask.addMaskPlane("INEXACT_PSF") 

650 

651 if artifact_mask_ref := warp_input.artifact_mask: 

652 # Apply the artifact mask to the warp. 

653 artifact_mask = artifact_mask_ref.get() 

654 assert ( 

655 warp.mask.getMaskPlaneDict() == artifact_mask.getMaskPlaneDict() 

656 ), "Mask dicts do not agree." 

657 warp.mask.array = artifact_mask.array 

658 del artifact_mask 

659 

660 if self.config.do_scale_zero_point: 

661 # Each Warp that goes into a coadd will typically have an 

662 # independent photometric zero-point. Therefore, we must scale 

663 # each Warp to set it to a common photometric zeropoint. 

664 imageScaler = self.scale_zero_point.run(exposure=warp, dataRef=warp_input.warp).imageScaler 

665 zero_point_scale_factor = imageScaler.scale 

666 self.log.debug( 

667 "Scaled the warp %s by %f to match zero points", 

668 warp_input.dataId, 

669 zero_point_scale_factor, 

670 ) 

671 else: 

672 zero_point_scale_factor = 1.0 

673 if "BUNIT" not in warp.metadata: 

674 raise ValueError(f"Warp {warp_input.dataId} has no BUNIT metadata") 

675 if warp.metadata["BUNIT"] != "nJy": 

676 raise ValueError( 

677 f"Warp {warp_input.dataId} has BUNIT {warp.metadata['BUNIT']}, expected nJy" 

678 ) 

679 

680 # Only try to remove maks planes that have been registered. 

681 to_remove = [] 

682 for plane in self.config.remove_mask_planes: 

683 if plane in warp.mask.getMaskPlaneDict(): 

684 to_remove.append(plane) 

685 removeMaskPlanes(warp.mask, to_remove, self.log) 

686 # Instead of using self.config.bad_mask_planes, we explicitly 

687 # ask statsCtrl which pixels are going to be ignored/rejected. 

688 rejected = afwImage.Mask.getPlaneBitMask( 

689 ["CLIPPED", "REJECTED"] + afwImage.Mask.interpret(statsCtrl.getAndMask()).split(",") 

690 ) 

691 

692 # Compute the weight for each CCD in the warp from the visitSummary 

693 # or from the warp itself, if not provided. Computing the weight 

694 # from the warp is not recommended, and in that case we compute one 

695 # weight per warp and not bother with per-detector weights. 

696 full_ccd_table = warp.getInfo().getCoaddInputs().ccds 

697 weights: dict[int, float] = dict.fromkeys( 

698 full_ccd_table["ccd"].tolist(), 

699 0.0, 

700 ) # Mapping from detector to weight. 

701 

702 if visitSummaryRef := visitSummaryRefDict.get(warp_input.dataId["visit"]): 

703 visitSummary = visitSummaryRef.get() 

704 for detector in full_ccd_table["ccd"].tolist(): 

705 visitSummaryRow = visitSummary.find(detector) 

706 mean_variance = visitSummaryRow["meanVar"] 

707 mean_variance *= zero_point_scale_factor**2 

708 if warp.metadata.get("BUNIT", None) == "nJy": 

709 mean_variance *= visitSummaryRow.photoCalib.getCalibrationMean() ** 2 

710 weights[detector] = 1.0 / mean_variance 

711 del visitSummary 

712 else: 

713 self.log.debug("No visit summary found for %s; using warp-based weights", warp_input.dataId) 

714 weight = self._compute_weight(warp, statsCtrl) 

715 if not np.isfinite(weight): 

716 self.log.warn("Non-finite weight for %s: skipping", warp_input.dataId) 

717 continue 

718 

719 for detector in weights: 

720 weights[detector] = weight 

721 

722 noise_warps = [ref.get(parameters={"bbox": skyInfo.bbox}) for ref in warp_input.noise_warps] 

723 

724 # Create an image where each pixel value corresponds to the 

725 # detector ID that pixel comes from. 

726 detector_map = afwImage.ImageI(bbox=warp.getBBox(), initialValue=-1) 

727 for row in full_ccd_table: 

728 transform = makeWcsPairTransform(row.wcs, warp.wcs) 

729 if (src_polygon := row.validPolygon) is None: 

730 src_polygon = afwGeom.Polygon(geom.Box2D(row.getBBox())) 

731 try: 

732 dest_polygon = src_polygon.transform(transform).intersectionSingle( 

733 geom.Box2D(warp.getBBox()) 

734 ) 

735 except SinglePolygonException: 

736 continue 

737 

738 observation_identifier = ObservationIdentifiers.from_data_id( 

739 warp_input.dataId, 

740 backup_detector=row["ccd"], 

741 ) 

742 visit_polygons[observation_identifier] = dest_polygon 

743 

744 detector_map_slice = dest_polygon.createImage(detector_map.getBBox()).array > 0 

745 if not (detector_map.array[detector_map_slice] < 0).all(): 

746 self.log.warning("Multiple detectors from visit %s are overlapping", warp_input.dataId) 

747 detector_map.array[detector_map_slice] = row["ccd"] 

748 

749 if (detector_map.array < 0).all(): 

750 self.log.warning("Unable to split the warp %s into single-detector warps.", warp_input.dataId) 

751 detector_map.array[:, :] = 0 

752 

753 for cellInfo, ccd_row in itertools.product(skyInfo.patchInfo, full_ccd_table): 

754 bbox = cellInfo.outer_bbox 

755 inner_bbox = cellInfo.inner_bbox 

756 

757 overlap_fraction = (detector_map[inner_bbox].array == ccd_row["ccd"]).mean() 

758 assert -1e-4 < overlap_fraction < 1.0001, "Overlap fraction is not within [0, 1]." 

759 if (overlap_fraction < self.config.min_overlap_fraction) or (overlap_fraction <= 0.0): 

760 self.log.debug( 

761 "Skipping %s in cell %s because it had only %.3f < %.3f fractional overlap.", 

762 warp_input.dataId, 

763 cellInfo.index, 

764 overlap_fraction, 

765 self.config.min_overlap_fraction, 

766 ) 

767 continue 

768 

769 weight = weights[int(ccd_row["ccd"])] 

770 if not np.isfinite(weight): 

771 self.log.warn( 

772 "Non-finite weight for %s in cell %s: skipping", warp_input.dataId, cellInfo.index 

773 ) 

774 continue 

775 

776 if weight == 0: 

777 self.log.info( 

778 "Zero weight for %s in cell %s: skipping", warp_input.dataId, cellInfo.index 

779 ) 

780 continue 

781 

782 # Decide if a deep copy is necessary to apply the single 

783 # detector cuts since it involves modifying the image in-place. 

784 # If within the inner cell, there are three or more different 

785 # values that detector map takes, then there are definitely 

786 # multiple detectors (one for chip gaps, two for two detectors) 

787 deep_copy = len(set(detector_map[inner_bbox].array.ravel())) >= 3 

788 if deep_copy: 

789 single_detector_mask_array = detector_map[bbox].array != ccd_row["ccd"] 

790 

791 mi = afwImage.MaskedImageF(warp[bbox].maskedImage, deep=deep_copy) 

792 if deep_copy: 

793 mi.image.array[single_detector_mask_array] = 0.0 

794 mi.variance.array[single_detector_mask_array] = np.inf 

795 nodata_or_mask = (single_detector_mask_array) * afwImage.Mask.getPlaneBitMask("NO_DATA") 

796 mi.mask[bbox].array |= nodata_or_mask 

797 warp_stacker_gc[cellInfo.index].add_masked_image(mi, weight=weight) 

798 

799 if masked_fraction_image: 

800 mi = afwImage.ImageF(masked_fraction_image[bbox], deep=deep_copy) 

801 if deep_copy: 

802 mi.array[single_detector_mask_array] = 0.0 

803 maskfrac_stacker_gc[cellInfo.index].add_image(masked_fraction_image[bbox], weight=weight) 

804 

805 for n in range(self.config.num_noise_realizations): 

806 mi = afwImage.MaskedImageF(noise_warps[n][bbox], deep=deep_copy) 

807 if deep_copy: 

808 mi.image.array[single_detector_mask_array] = 0.0 

809 mi.variance.array[single_detector_mask_array] = np.inf 

810 mi.mask[bbox].array |= nodata_or_mask 

811 noise_stacker_gc_list[n][cellInfo.index].add_masked_image(mi, weight=weight) 

812 

813 # Set the defaults for PSF shape quantities. 

814 psf_shape = afwGeom.Quadrupole() 

815 psf_shape_flag = True 

816 psf_eval_point = None 

817 try: 

818 # The `if` branch is buggy. `dest_polygon` is technically 

819 # out of scope, but Python does not raise an error. 

820 # TODO: Fix this properly in DM-53479, but sweep it under 

821 # the rug for now. 

822 if overlap_fraction < 0.5: 

823 psf_eval_point = dest_polygon.intersectionSingle( 

824 geom.Box2D(inner_bbox) 

825 ).calculateCenter() 

826 else: 

827 psf_eval_point = geom.Point2D(geom.Point2I(inner_bbox.getCenter())) 

828 psf_shape = warp.psf.computeShape(psf_eval_point) 

829 psf_shape_flag = False 

830 except SinglePolygonException: 

831 self.log.info( 

832 "Unable to find the overlapping polygon between %d detector in %s and cell %s", 

833 ccd_row["ccd"], 

834 warp_input.dataId, 

835 cellInfo.index, 

836 ) 

837 except InvalidPsfError: 

838 self.log.info( 

839 "Unable to compute PSF shape from %d detector in %s at %s", 

840 ccd_row["ccd"], 

841 warp_input.dataId, 

842 psf_eval_point, 

843 ) 

844 

845 overlaps_center = detector_map[geom.Point2I(bbox.getCenter())] == ccd_row["ccd"] 

846 

847 observation_identifier = ObservationIdentifiers.from_data_id( 

848 warp_input.dataId, 

849 backup_detector=int(ccd_row["ccd"]), 

850 ) 

851 observation_identifiers_gc[cellInfo.index][observation_identifier] = CoaddInputs( 

852 overlaps_center=overlaps_center, 

853 overlap_fraction=overlap_fraction, 

854 weight=weight, 

855 psf_shape=psf_shape, 

856 psf_shape_flag=psf_shape_flag, 

857 ) 

858 if overlaps_center is False: 

859 self.log.debug( 

860 "%s does not overlap with the center of the cell %s", 

861 warp_input.dataId, 

862 cellInfo.index, 

863 ) 

864 continue 

865 

866 # Everything below this has to do with the center of the cell 

867 calexp_point = ccd_row.getWcs().skyToPixel(cell_centers_sky[cellInfo.index]) 

868 undistorted_psf_im = ccd_row.getPsf().computeImage(calexp_point) 

869 

870 assert undistorted_psf_im.getBBox() == geom.Box2I.makeCenteredBox( 

871 calexp_point, 

872 undistorted_psf_im.getDimensions(), 

873 ), "PSF image does not share the coordinates of the 'calexp'" 

874 

875 # Convert the PSF image from Image to MaskedImage and 

876 # zero-pad the image. 

877 undistorted_psf_bbox = undistorted_psf_im.getBBox() 

878 undistorted_psf_maskedImage = afwImage.MaskedImageD( 

879 undistorted_psf_bbox.dilatedBy(self.psf_padding) 

880 ) 

881 undistorted_psf_maskedImage.image[undistorted_psf_bbox].array[:, :] = undistorted_psf_im.array 

882 # TODO: In DM-43585, use the variance plane value from noise. 

883 undistorted_psf_maskedImage.variance += 1.0 # Set variance to 1 

884 

885 warped_psf_maskedImage = self.psf_warper.warpImage( 

886 destWcs=skyInfo.wcs, 

887 srcImage=undistorted_psf_maskedImage, 

888 srcWcs=ccd_row.getWcs(), 

889 destBBox=psf_bbox_gc[cellInfo.index], 

890 ) 

891 

892 # There may be NaNs in the PSF image. Set them to 0.0 

893 warped_psf_maskedImage.variance.array[np.isnan(warped_psf_maskedImage.image.array)] = 1.0 

894 warped_psf_maskedImage.image.array[np.isnan(warped_psf_maskedImage.image.array)] = 0.0 

895 

896 psf_stacker = psf_stacker_gc[cellInfo.index] 

897 psf_stacker.add_masked_image(warped_psf_maskedImage, weight=weight) 

898 

899 if not (0.995 < (psf_normalization := warped_psf_maskedImage.image.array.sum()) < 1.005): 

900 self.log.warning( 

901 "PSF image for %s in %s is not normalized to 1.0, but instead %f", 

902 warp_input.dataId, 

903 cellInfo.index, 

904 psf_normalization, 

905 ) 

906 

907 if (ap_corr_map := warp.getInfo().getApCorrMap()) is not None: 

908 ap_corr_stacker_gc[cellInfo.index].add(ap_corr_map, weight=weight) 

909 

910 if self.config.do_input_map: 

911 self.input_mapper.add_warp_to_cell_input_map( 

912 ccd_row, 

913 weight, 

914 cellInfo, 

915 ) 

916 

917 del warp 

918 

919 if self.config.do_input_map: 

920 inputMap = self.input_mapper.cell_input_map 

921 else: 

922 inputMap = None 

923 

924 # Update common with the visit polygons. 

925 self.common = dataclasses.replace( 

926 self.common, 

927 visit_polygons=visit_polygons, 

928 ) 

929 

930 cells: list[SingleCellCoadd] = [] 

931 for cellInfo in skyInfo.patchInfo: 

932 if len(observation_identifiers_gc[cellInfo.index]) == 0: 

933 self.log.debug("Skipping cell %s because it has no input warps", cellInfo.index) 

934 continue 

935 

936 cell_masked_image = afwImage.MaskedImageF(cellInfo.outer_bbox) 

937 cell_maskfrac_image = afwImage.ImageF(cellInfo.outer_bbox) 

938 cell_noise_images = [ 

939 afwImage.MaskedImageF(cellInfo.outer_bbox) for n in range(self.config.num_noise_realizations) 

940 ] 

941 psf_masked_image = afwImage.MaskedImageF(psf_bbox_gc[cellInfo.index]) 

942 

943 warp_stacker_gc[cellInfo.index].fill_stacked_masked_image(cell_masked_image) 

944 maskfrac_stacker_gc[cellInfo.index].fill_stacked_image(cell_maskfrac_image) 

945 for n in range(self.config.num_noise_realizations): 

946 noise_stacker_gc_list[n][cellInfo.index].fill_stacked_masked_image(cell_noise_images[n]) 

947 psf_stacker_gc[cellInfo.index].fill_stacked_masked_image(psf_masked_image) 

948 

949 if ap_corr_stacker_gc[cellInfo.index].ap_corr_names: 

950 ap_corr_map = ap_corr_stacker_gc[cellInfo.index].final_ap_corr_map 

951 else: 

952 ap_corr_map = None 

953 

954 # Post-process the coadd before converting to new data structures. 

955 if np.isnan(cell_masked_image.image.array).all(): 

956 cell_masked_image.image.array[:, :] = 0.0 

957 cell_masked_image.variance.array[:, :] = np.inf 

958 elif self.config.do_interpolate_coadd: 

959 self.interpolate_coadd.run(cell_masked_image, planeName="NO_DATA") 

960 for noise_image in cell_noise_images: 

961 self.interpolate_coadd.run(noise_image, planeName="NO_DATA") 

962 # The variance must be positive; work around for DM-3201. 

963 varArray = cell_masked_image.variance.array 

964 with np.errstate(invalid="ignore"): 

965 varArray[:] = np.where(varArray > 0, varArray, np.inf) 

966 

967 afwImage.Mask.addMaskPlane("INEXACT_PSF") 

968 cell_masked_image.mask.array[ 

969 (cell_masked_image.mask.array & rejected) > 0 

970 ] |= cell_masked_image.mask.getPlaneBitMask("INEXACT_PSF") 

971 

972 image_planes = OwnedImagePlanes.from_masked_image( 

973 masked_image=cell_masked_image, 

974 mask_fractions=cell_maskfrac_image, 

975 noise_realizations=[noise_image.image for noise_image in cell_noise_images], 

976 ) 

977 identifiers = CellIdentifiers( 

978 cell=cellInfo.index, 

979 skymap=self.common.identifiers.skymap, 

980 tract=self.common.identifiers.tract, 

981 patch=self.common.identifiers.patch, 

982 band=self.common.identifiers.band, 

983 ) 

984 

985 singleCellCoadd = SingleCellCoadd( 

986 outer=image_planes, 

987 psf=psf_masked_image.image, 

988 inner_bbox=cellInfo.inner_bbox, 

989 inputs=observation_identifiers_gc[cellInfo.index], 

990 common=self.common, 

991 identifiers=identifiers, 

992 aperture_correction_map=ap_corr_map, 

993 ) 

994 # TODO: Attach transmission curve when they become available. 

995 cells.append(singleCellCoadd) 

996 

997 if not cells: 

998 raise NoWorkFound("No cells could be populated for the cell coadd.") 

999 

1000 grid = self._construct_grid(skyInfo) 

1001 multipleCellCoadd = MultipleCellCoadd( 

1002 cells, 

1003 grid=grid, 

1004 outer_cell_size=cellInfo.outer_bbox.getDimensions(), 

1005 inner_bbox=None, 

1006 common=self.common, 

1007 psf_image_size=cells[0].psf_image.getDimensions(), 

1008 ) 

1009 

1010 return Struct( 

1011 multipleCellCoadd=multipleCellCoadd, 

1012 inputMap=inputMap, 

1013 ) 

1014 

1015 

1016class ConvertMultipleCellCoaddToExposureConnections( 

1017 PipelineTaskConnections, 

1018 dimensions=("tract", "patch", "band", "skymap"), 

1019 defaultTemplates={"inputCoaddName": "deep", "inputCoaddSuffix": "Cell"}, 

1020): 

1021 cellCoaddExposure = Input( 

1022 doc="Output coadded exposure, produced by stacking input warps", 

1023 name="{inputCoaddName}Coadd{inputCoaddSuffix}", 

1024 storageClass="MultipleCellCoadd", 

1025 dimensions=("tract", "patch", "skymap", "band"), 

1026 ) 

1027 

1028 stitchedCoaddExposure = Output( 

1029 doc="Output stitched coadded exposure, produced by stacking input warps", 

1030 name="{inputCoaddName}Coadd{inputCoaddSuffix}_stitched", 

1031 storageClass="ExposureF", 

1032 dimensions=("tract", "patch", "skymap", "band"), 

1033 ) 

1034 

1035 

1036class ConvertMultipleCellCoaddToExposureConfig( 

1037 PipelineTaskConfig, pipelineConnections=ConvertMultipleCellCoaddToExposureConnections 

1038): 

1039 """A trivial PipelineTaskConfig class for 

1040 ConvertMultipleCellCoaddToExposureTask. 

1041 """ 

1042 

1043 

1044class ConvertMultipleCellCoaddToExposureTask(PipelineTask): 

1045 """An after burner PipelineTask that converts a cell-based coadd from 

1046 `MultipleCellCoadd` format to `ExposureF` format. 

1047 

1048 The run method stitches the cell-based coadd into contiguous exposure and 

1049 returns it in as an `Exposure` object. This is lossy as it preserves only 

1050 the pixels in the inner bounding box of the cells and discards the values 

1051 in the buffer region. 

1052 

1053 Notes 

1054 ----- 

1055 This task has no configurable parameters. 

1056 """ 

1057 

1058 ConfigClass = ConvertMultipleCellCoaddToExposureConfig 

1059 _DefaultName = "convertMultipleCellCoaddToExposure" 

1060 

1061 def run(self, cellCoaddExposure): 

1062 return Struct( 

1063 stitchedCoaddExposure=cellCoaddExposure.stitch().asExposure(), 

1064 )