Coverage for python / lsst / drp / tasks / assemble_cell_coadd.py: 16%
384 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 18:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 18:49 +0000
1# This file is part of drp_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = (
25 "AssembleCellCoaddTask",
26 "AssembleCellCoaddConfig",
27 "ConvertMultipleCellCoaddToExposureTask",
28)
30import dataclasses
31import itertools
33import numpy as np
35import lsst.afw.geom as afwGeom
36import lsst.afw.image as afwImage
37import lsst.afw.math as afwMath
38import lsst.geom as geom
39from lsst.afw.detection import InvalidPsfError
40from lsst.afw.geom import SinglePolygonException, makeWcsPairTransform
41from lsst.cell_coadds import (
42 CellIdentifiers,
43 CoaddApCorrMapStacker,
44 CoaddInputs,
45 CoaddUnits,
46 CommonComponents,
47 GridContainer,
48 MultipleCellCoadd,
49 ObservationIdentifiers,
50 OwnedImagePlanes,
51 PatchIdentifiers,
52 SingleCellCoadd,
53 UniformGrid,
54)
55from lsst.daf.butler import DataCoordinate, DeferredDatasetHandle
56from lsst.meas.algorithms import AccumulatorMeanStack
57from lsst.pex.config import ConfigField, ConfigurableField, DictField, Field, ListField, RangeField
58from lsst.pipe.base import (
59 InMemoryDatasetHandle,
60 NoWorkFound,
61 PipelineTask,
62 PipelineTaskConfig,
63 PipelineTaskConnections,
64 Struct,
65)
66from lsst.pipe.base.connectionTypes import Input, Output
67from lsst.pipe.tasks.coaddBase import makeSkyInfo, removeMaskPlanes, setRejectedMaskMapping
68from lsst.pipe.tasks.healSparseMapping import HealSparseInputMapTask
69from lsst.pipe.tasks.interpImage import InterpImageTask
70from lsst.pipe.tasks.scaleZeroPoint import ScaleZeroPointTask
71from lsst.skymap import BaseSkyMap
74@dataclasses.dataclass
75class WarpInputs:
76 """Collection of associate inputs along with warps."""
78 warp: DeferredDatasetHandle | InMemoryDatasetHandle
79 """Handle for the warped exposure."""
81 masked_fraction: DeferredDatasetHandle | InMemoryDatasetHandle | None = None
82 """Handle for the masked fraction image."""
84 artifact_mask: DeferredDatasetHandle | InMemoryDatasetHandle | None = None
85 """Handle for the CompareWarp artifact mask."""
87 noise_warps: list[DeferredDatasetHandle | InMemoryDatasetHandle] = dataclasses.field(default_factory=list)
88 """List of handles for the noise warps"""
90 @property
91 def dataId(self) -> DataCoordinate:
92 """DataID corresponding to the warp.
94 Returns
95 -------
96 data_id : `~lsst.daf.butler.DataCoordinate`
97 DataID of the warp.
98 """
99 return self.warp.dataId
102class AssembleCellCoaddConnections(
103 PipelineTaskConnections,
104 dimensions=("tract", "patch", "band", "skymap"),
105 defaultTemplates={"inputWarpName": "deep", "outputCoaddSuffix": "Cell"},
106):
107 inputWarps = Input(
108 doc="Input warps",
109 name="{inputWarpName}Coadd_directWarp",
110 storageClass="ExposureF",
111 dimensions=("tract", "patch", "skymap", "visit", "instrument"),
112 deferLoad=True,
113 multiple=True,
114 )
116 maskedFractionWarps = Input(
117 doc="Mask fraction warps",
118 name="{inputWarpName}Coadd_directWarp_maskedFraction",
119 storageClass="ImageF",
120 dimensions=("tract", "patch", "skymap", "visit", "instrument"),
121 deferLoad=True,
122 multiple=True,
123 )
125 artifactMasks = Input(
126 doc="Artifact masks to be applied to the input warps",
127 name="compare_warp_artifact_mask",
128 storageClass="Mask",
129 dimensions=("tract", "patch", "skymap", "visit", "instrument"),
130 deferLoad=True,
131 multiple=True,
132 )
134 visitSummaryList = Input(
135 doc="Input visit-summary catalogs with updated calibration objects. Mainly used for coadd weights.",
136 name="finalVisitSummary",
137 storageClass="ExposureCatalog",
138 dimensions=("instrument", "visit"),
139 deferLoad=True,
140 multiple=True,
141 )
143 skyMap = Input(
144 doc="Input definition of geometry/bbox and projection/wcs. This must be cell-based.",
145 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
146 storageClass="SkyMap",
147 dimensions=("skymap",),
148 )
150 multipleCellCoadd = Output(
151 doc="Output multiple cell coadd",
152 name="{inputWarpName}Coadd{outputCoaddSuffix}",
153 storageClass="MultipleCellCoadd",
154 dimensions=("tract", "patch", "band", "skymap"),
155 )
157 inputMap = Output(
158 doc="Output healsparse map of input images",
159 name="{inputWarpName}Coadd_inputMap",
160 storageClass="HealSparseMap",
161 dimensions=("tract", "patch", "band", "skymap"),
162 )
164 def __init__(self, *, config=None):
165 super().__init__(config=config)
167 if not config:
168 return
170 if config.do_calculate_weight_from_warp:
171 del self.visitSummaryList
173 if not config.do_use_artifact_mask:
174 del self.artifactMasks
176 if not config.do_input_map:
177 del self.inputMap
179 # Dynamically set input connections for noise images, depending on the
180 # number of noise realizations specified in the config.
181 for n in range(config.num_noise_realizations):
182 noise_warps = Input(
183 doc="Input noise warps",
184 name=f"direct_warp_noise{n}",
185 storageClass="MaskedImageF",
186 dimensions=("tract", "patch", "skymap", "visit", "instrument"),
187 deferLoad=True,
188 multiple=True,
189 )
190 setattr(self, f"noise{n}_warps", noise_warps)
193class AssembleCellCoaddConfig(PipelineTaskConfig, pipelineConnections=AssembleCellCoaddConnections):
194 do_interpolate_coadd = Field[bool](doc="Interpolate over pixels with NO_DATA mask set?", default=True)
195 interpolate_coadd = ConfigurableField(
196 target=InterpImageTask,
197 doc="Task to interpolate (and extrapolate) over pixels with NO_DATA mask on cell coadds",
198 )
199 do_scale_zero_point = Field[bool](
200 doc="Scale warps to a common zero point? This is not needed if they have absolute flux calibration.",
201 default=False,
202 deprecated="Now that visits are scaled to nJy it is no longer necessary or "
203 "recommended to scale the zero point, so this will be removed "
204 "after v29.",
205 )
206 scale_zero_point = ConfigurableField(
207 target=ScaleZeroPointTask,
208 doc="Task to scale warps to a common zero point",
209 deprecated="Now that visits are scaled to nJy it is no longer necessary or "
210 "recommended to scale the zero point, so this will be removed "
211 "after v29.",
212 )
213 do_calculate_weight_from_warp = Field[bool](
214 doc="Calculate coadd weight from the input warp? Otherwise, the weight is obtained from the "
215 "visitSummaryList connection. This is meant as a fallback when run outside the pipeline.",
216 default=False,
217 )
218 do_use_artifact_mask = Field[bool](
219 doc="Substitute the mask planes input warp with an alternative artifact mask?",
220 default=True,
221 )
222 do_coadd_inverse_aperture_corrections = Field[bool](
223 doc="Coadd the inverse aperture corrections for each cell? This is formally the more accurate way "
224 "but may be turned off for parity with deepCoadd.",
225 default=False,
226 )
227 min_overlap_fraction = RangeField[float](
228 doc="The minimum overlap fraction required for a single (visit, detector) input to be included in a "
229 "cell.",
230 # A value of 1.0 corresponds to ideal, edge-free cells.
231 # A value of 0.0 corresponds to the deep_coadd style coadds.
232 # This has to be at least 0.5 to ensure that the an input overlaps the
233 # cell center. Inputs will overlap fraction less than 0.25 will
234 # definitely not overlap the cell center.
235 default=1.0,
236 min=0.0,
237 max=1.0,
238 inclusiveMin=True,
239 inclusiveMax=True,
240 )
241 bad_mask_planes = ListField[str](
242 doc="Mask planes that count towards the masked fraction within a cell.",
243 default=("BAD", "NO_DATA", "SAT", "CLIPPED"),
244 )
245 remove_mask_planes = ListField[str](
246 doc="Mask planes to remove before coadding",
247 default=["EDGE", "NOT_DEBLENDED"],
248 )
249 calc_error_from_input_variance = Field[bool](
250 doc="Calculate coadd variance from input variance by stacking "
251 "statistic. Passed to AccumulatorMeanStack.",
252 default=True,
253 )
254 mask_propagation_thresholds = DictField[str, float](
255 doc=(
256 "Threshold (in fractional weight) of rejection at which we "
257 "propagate a mask plane to the coadd; that is, we set the mask "
258 "bit on the coadd if the fraction the rejected frames "
259 "would have contributed exceeds this value."
260 ),
261 default={"SAT": 0.1},
262 )
263 max_maskfrac = RangeField[float](
264 doc="Maximum fraction of masked pixels in a cell. This is currently "
265 "just a placeholder and is not used now",
266 default=0.99,
267 min=0.0,
268 max=1.0,
269 inclusiveMin=True,
270 inclusiveMax=False,
271 )
272 num_noise_realizations = Field[int](
273 default=0,
274 doc=(
275 "Number of noise planes to include in the coadd. "
276 "This should not exceed the corresponding config parameter "
277 "specified in `MakeDirectWarpConfig`. "
278 ),
279 check=lambda x: x >= 0,
280 )
281 psf_warper = ConfigField(
282 doc="Configuration for the warper that warps the PSFs. It must have the same configuration used to "
283 "warp the images.",
284 dtype=afwMath.Warper.ConfigClass,
285 )
286 psf_dimensions = Field[int](
287 default=35,
288 doc="Dimensions of the PSF image stamp size to be assigned to cells (must be odd).",
289 check=lambda x: (x > 0) and (x % 2 == 1),
290 )
291 require_artifact_mask = Field[bool](
292 default=True,
293 doc="Require presence of artifact mask for each warp? Use true if using artifact rejection outputs"
294 " from CompareWarpTask",
295 )
296 do_input_map = Field[bool](
297 default=False,
298 doc="Create a bitwise map of coadd inputs.",
299 )
300 input_mapper = ConfigurableField(
301 target=HealSparseInputMapTask,
302 doc="Input map creation subtask.",
303 )
306class AssembleCellCoaddTask(PipelineTask):
307 """Assemble a cell-based coadded image from a set of warps.
309 This task reads in the warp one at a time, and accumulates it in all the
310 cells that it completely overlaps with. This is the optimal I/O pattern but
311 this also implies that it is not possible to build one or only a few cells.
313 Each cell coadds is guaranteed to have a well-defined PSF. This is done by
314 1) excluding warps that only partially overlap a cell from that cell coadd;
315 2) interpolating bad pixels in the warps rather than excluding them;
316 3) by computing the coadd as a weighted mean of the warps without clipping;
317 4) by computing the coadd PSF as the weighted mean of the PSF of the warps
318 with the same weights.
320 The cells are (and must be) defined in the skymap, and cannot be configured
321 or redefined here. The cells are assumed to be small enough that the PSF is
322 assumed to be spatially constant within a cell.
324 Raises
325 ------
326 NoWorkFound
327 Raised if no input warps are provided, or no cells could be populated.
328 RuntimeError
329 Raised if the skymap is not cell-based.
331 Notes
332 -----
333 This is not yet a part of the standard DRP pipeline. As such, the Task and
334 especially its Config and Connections are experimental and subject to
335 change any time without a formal RFC or standard deprecation procedures
336 until it is included in the DRP pipeline.
337 """
339 ConfigClass = AssembleCellCoaddConfig
340 _DefaultName = "assembleCellCoadd"
342 def __init__(self, *args, **kwargs):
343 super().__init__(*args, **kwargs)
344 if self.config.do_interpolate_coadd:
345 self.makeSubtask("interpolate_coadd")
346 if self.config.do_scale_zero_point:
347 self.makeSubtask("scale_zero_point")
348 if self.config.do_input_map:
349 self.makeSubtask("input_mapper")
351 self.psf_warper = afwMath.Warper.fromConfig(self.config.psf_warper)
352 if (warping_kernel_name := self.config.psf_warper.warpingKernelName.lower()).startswith("lanczos"):
353 psf_padding = 2 * int(warping_kernel_name.lstrip("lanczos")) - 1
354 self.log.debug(
355 "Padding PSF image by %d pixels since the warping kernel is %s.",
356 psf_padding,
357 self.config.psf_warper.warpingKernelName,
358 )
359 else:
360 psf_padding = 10
361 self.log.info(
362 "Padding PSF image by %d pixels since the warping kernel is not Lanczos.",
363 psf_padding,
364 )
365 self.psf_padding = psf_padding
367 def runQuantum(self, butlerQC, inputRefs, outputRefs):
368 # Docstring inherited.
369 if not inputRefs.inputWarps:
370 raise NoWorkFound("No input warps provided for co-addition")
371 self.log.info("Found %d input warps", len(inputRefs.inputWarps))
373 # Construct skyInfo expected by run
374 # Do not remove skyMap from inputData in case _makeSupplementaryData
375 # needs it
376 skyMap = butlerQC.get(inputRefs.skyMap)
378 if not skyMap.config.tractBuilder.name == "cells":
379 raise RuntimeError("AssembleCellCoaddTask requires a cell-based skymap.")
381 outputDataId = butlerQC.quantum.dataId
383 skyInfo = makeSkyInfo(skyMap, tractId=outputDataId["tract"], patchId=outputDataId["patch"])
384 visitSummaryList = butlerQC.get(getattr(inputRefs, "visitSummaryList", []))
386 units = CoaddUnits.legacy if self.config.do_scale_zero_point else CoaddUnits.nJy
387 self.common = CommonComponents(
388 units=units,
389 wcs=skyInfo.patchInfo.wcs,
390 band=outputDataId.get("band", None),
391 identifiers=PatchIdentifiers.from_data_id(outputDataId),
392 )
394 inputs: dict[DataCoordinate, WarpInputs] = {}
395 for handle in butlerQC.get(inputRefs.inputWarps):
396 inputs[handle.dataId] = WarpInputs(warp=handle, noise_warps=[])
398 for ref in getattr(inputRefs, "artifactMasks", []):
399 inputs[ref.dataId].artifact_mask = butlerQC.get(ref)
400 for ref in getattr(inputRefs, "maskedFractionWarps", []):
401 inputs[ref.dataId].masked_fraction = butlerQC.get(ref)
402 for n in range(self.config.num_noise_realizations):
403 for ref in getattr(inputRefs, f"noise{n}_warps"):
404 inputs[ref.dataId].noise_warps.append(butlerQC.get(ref))
406 returnStruct = self.run(inputs=inputs, skyInfo=skyInfo, visitSummaryList=visitSummaryList)
407 butlerQC.put(returnStruct, outputRefs)
408 return returnStruct
410 @staticmethod
411 def _compute_weight(maskedImage, statsCtrl):
412 """Compute a weight for a masked image.
414 Parameters
415 ----------
416 maskedImage : `~lsst.afw.image.MaskedImage`
417 The masked image to compute the weight.
418 statsCtrl : `~lsst.afw.math.StatisticsControl`
419 A control (config-like) object for StatisticsStack.
421 Returns
422 -------
423 weight : `float`
424 Inverse of the clipped mean variance of the masked image.
425 """
426 statObj = afwMath.makeStatistics(
427 maskedImage.getVariance(), maskedImage.getMask(), afwMath.MEANCLIP, statsCtrl
428 )
429 meanVar, _ = statObj.getResult(afwMath.MEANCLIP)
430 weight = 1.0 / float(meanVar)
431 return weight
433 @staticmethod
434 def _construct_grid(skyInfo):
435 """Construct a UniformGrid object from a SkyInfo struct.
437 Parameters
438 ----------
439 skyInfo : `~lsst.pipe.base.Struct`
440 A Struct object
442 Returns
443 -------
444 grid : `~lsst.cell_coadds.UniformGrid`
445 A UniformGrid object.
446 """
447 padding = skyInfo.patchInfo.getCellBorder()
448 grid_bbox = skyInfo.patchInfo.outer_bbox.erodedBy(padding)
449 grid = UniformGrid.from_bbox_cell_size(
450 grid_bbox,
451 skyInfo.patchInfo.getCellInnerDimensions(),
452 padding=padding,
453 )
454 return grid
456 def _construct_grid_container(self, skyInfo, statsCtrl):
457 """Construct a grid of AccumulatorMeanStack instances.
459 Parameters
460 ----------
461 skyInfo : `~lsst.pipe.base.Struct`
462 A Struct object
463 statsCtrl : `~lsst.afw.math.StatisticsControl`
464 A control (config-like) object for StatisticsStack.
466 Returns
467 -------
468 gc : `~lsst.cell_coadds.GridContainer`
469 A GridContainer object container one AccumulatorMeanStack per cell.
470 """
471 grid = self._construct_grid(skyInfo)
473 maskMap = setRejectedMaskMapping(statsCtrl)
474 self.log.debug("Obtained maskMap = %s for %s", maskMap, skyInfo.patchInfo)
475 thresholdDict = AccumulatorMeanStack.stats_ctrl_to_threshold_dict(statsCtrl)
477 # Initialize the grid container with AccumulatorMeanStacks
478 gc = GridContainer[AccumulatorMeanStack](grid.shape)
479 for cellInfo in skyInfo.patchInfo:
480 stacker = AccumulatorMeanStack(
481 # The shape is for the numpy arrays, hence transposed.
482 shape=(cellInfo.outer_bbox.height, cellInfo.outer_bbox.width),
483 bit_mask_value=statsCtrl.getAndMask(),
484 mask_threshold_dict=thresholdDict,
485 calc_error_from_input_variance=self.config.calc_error_from_input_variance,
486 compute_n_image=False,
487 mask_map=maskMap,
488 no_good_pixels_mask=statsCtrl.getNoGoodPixelsMask(),
489 )
490 gc[cellInfo.index] = stacker
492 return gc
494 def _construct_stats_control(self):
495 """Construct a StatisticsControl object for coadd.
497 Unlike AssembleCoaddTask or CompareWarpAssembleCoaddTask, there is
498 very little to be configured apart from setting the mask planes and
499 optionally mask propagation thresholds.
501 Returns
502 -------
503 statsCtrl : `~lsst.afw.math.StatisticsControl`
504 A control object for StatisticsStack.
505 """
506 statsCtrl = afwMath.StatisticsControl()
507 # Hardcode the numIter parameter to the default config value set in
508 # CompareWarpAssembleCoaddTask to get consistent weights. This is NOT
509 # exposed as a config parameter, since this is only meant to be a
510 # fallback option that is not recommended for production.
511 statsCtrl.setNumIter(2)
512 statsCtrl.setAndMask(afwImage.Mask.getPlaneBitMask(self.config.bad_mask_planes))
513 statsCtrl.setNanSafe(True)
514 for plane, threshold in self.config.mask_propagation_thresholds.items():
515 bit = afwImage.Mask.getMaskPlane(plane)
516 statsCtrl.setMaskPropagationThreshold(bit, threshold)
517 return statsCtrl
519 def _construct_ap_corr_grid_container(self, skyInfo):
520 """Construct a grid of CoaddApCorrMapStacker instances.
522 Parameters
523 ----------
524 skyInfo : `~lsst.pipe.base.Struct`
525 A Struct object
527 Returns
528 -------
529 gc : `~lsst.cell_coadds.GridContainer`
530 A GridContainer object container one CoaddApCorrMapStacker per
531 cell.
532 """
533 grid = self._construct_grid(skyInfo)
535 # Initialize the grid container with CoaddApCorrMapStacker.
536 gc = GridContainer[CoaddApCorrMapStacker](grid.shape)
537 for cellInfo in skyInfo.patchInfo:
538 stacker = CoaddApCorrMapStacker(
539 evaluation_point=cellInfo.inner_bbox.getCenter(),
540 do_coadd_inverse_ap_corr=self.config.do_coadd_inverse_aperture_corrections,
541 )
542 gc[cellInfo.index] = stacker
544 return gc
546 def run(
547 self,
548 *,
549 inputs: dict[DataCoordinate, WarpInputs],
550 skyInfo,
551 visitSummaryList: list | None = None,
552 ):
553 for mask_plane in self.config.bad_mask_planes:
554 afwImage.Mask.addMaskPlane(mask_plane)
555 for mask_plane in self.config.mask_propagation_thresholds:
556 afwImage.Mask.addMaskPlane(mask_plane)
558 statsCtrl = self._construct_stats_control()
560 warp_stacker_gc = self._construct_grid_container(skyInfo, statsCtrl)
561 maskfrac_stacker_gc = self._construct_grid_container(skyInfo, statsCtrl)
562 noise_stacker_gc_list = [
563 self._construct_grid_container(skyInfo, statsCtrl)
564 for n in range(self.config.num_noise_realizations)
565 ]
566 psf_stacker_gc = GridContainer[AccumulatorMeanStack](warp_stacker_gc.shape)
567 psf_bbox_gc = GridContainer[geom.Box2I](warp_stacker_gc.shape)
568 ap_corr_stacker_gc = self._construct_ap_corr_grid_container(skyInfo)
570 # Make a container to hold the cell centers in sky coordinates now,
571 # so we don't have to recompute them for each warp
572 # (they share a common WCS). These are needed to find the various
573 # warp + detector combinations that contributed to each cell, and later
574 # get the corresponding PSFs as well.
575 cell_centers_sky = GridContainer[geom.SpherePoint](warp_stacker_gc.shape)
576 # Make a container to hold the observation identifiers for each cell.
577 observation_identifiers_gc = GridContainer[dict](warp_stacker_gc.shape)
579 if self.config.do_input_map:
580 # We need to know all the visit + detector pairs in the inputs.
581 warp_input_list = [warp_ref.warp.get(component="coaddInputs") for warp_ref in inputs.values()]
582 visit_detectors = []
583 for warp_input in warp_input_list:
584 for row in warp_input.ccds:
585 visit_detectors.append((int(row["visit"]), int(row["ccd"])))
587 self.input_mapper.initialize_cell_input_map(
588 skyInfo.patchInfo.getOuterBBox(),
589 skyInfo.patchInfo.wcs,
590 visit_detectors,
591 )
593 # Populate them.
594 for cellInfo in skyInfo.patchInfo:
595 # Make a list to hold the observation identifiers for each cell.
596 observation_identifiers_gc[cellInfo.index] = {}
597 cell_center_pixel = geom.Point2D(geom.Point2I(cellInfo.inner_bbox.getCenter()))
598 cell_centers_sky[cellInfo.index] = skyInfo.wcs.pixelToSky(cell_center_pixel)
599 psf_bbox_gc[cellInfo.index] = geom.Box2I.makeCenteredBox(
600 cell_center_pixel,
601 geom.Extent2I(self.config.psf_dimensions, self.config.psf_dimensions),
602 )
603 psf_stacker_gc[cellInfo.index] = AccumulatorMeanStack(
604 # The shape is for the numpy arrays, hence transposed.
605 shape=(self.config.psf_dimensions, self.config.psf_dimensions),
606 bit_mask_value=0,
607 calc_error_from_input_variance=self.config.calc_error_from_input_variance,
608 compute_n_image=False,
609 )
611 if self.config.do_input_map:
612 self.input_mapper.build_cell_input_map(cellInfo)
614 # visit_summary do not have (tract, patch, band, skymap) dimensions.
615 if not visitSummaryList:
616 visitSummaryList = []
617 visitSummaryRefDict = {
618 visitSummaryRef.dataId["visit"]: visitSummaryRef for visitSummaryRef in visitSummaryList
619 }
621 # Keep track of the polygons corresponding to each (visit, detector).
622 visit_polygons: dict[ObservationIdentifiers, afwGeom.Polygon] = {}
624 # Read in one warp at a time, and accumulate it in all the cells that
625 # it completely overlaps.
626 for _, warp_input in inputs.items():
627 # warps that have been excluded from CompareWarp via visit
628 # selection from SelectVisitsTasks will not have artifact masks.
629 # Exclude them from the cell coadds too.
630 if self.config.require_artifact_mask and warp_input.artifact_mask is None:
631 self.log.info(
632 "Excluding warp %s from cell coadds because it has no artifact mask",
633 warp_input.dataId["visit"],
634 )
635 continue
637 warp = warp_input.warp.get(parameters={"bbox": skyInfo.bbox})
638 masked_fraction_image = (
639 warp_input.masked_fraction.get(parameters={"bbox": skyInfo.bbox})
640 if warp_input.masked_fraction
641 else None
642 )
644 # Pre-process the warp before coadding.
645 # TODO: Can we get these mask names from artifactMask?
646 warp.mask.addMaskPlane("CLIPPED")
647 warp.mask.addMaskPlane("REJECTED")
648 warp.mask.addMaskPlane("SENSOR_EDGE")
649 warp.mask.addMaskPlane("INEXACT_PSF")
651 if artifact_mask_ref := warp_input.artifact_mask:
652 # Apply the artifact mask to the warp.
653 artifact_mask = artifact_mask_ref.get()
654 assert (
655 warp.mask.getMaskPlaneDict() == artifact_mask.getMaskPlaneDict()
656 ), "Mask dicts do not agree."
657 warp.mask.array = artifact_mask.array
658 del artifact_mask
660 if self.config.do_scale_zero_point:
661 # Each Warp that goes into a coadd will typically have an
662 # independent photometric zero-point. Therefore, we must scale
663 # each Warp to set it to a common photometric zeropoint.
664 imageScaler = self.scale_zero_point.run(exposure=warp, dataRef=warp_input.warp).imageScaler
665 zero_point_scale_factor = imageScaler.scale
666 self.log.debug(
667 "Scaled the warp %s by %f to match zero points",
668 warp_input.dataId,
669 zero_point_scale_factor,
670 )
671 else:
672 zero_point_scale_factor = 1.0
673 if "BUNIT" not in warp.metadata:
674 raise ValueError(f"Warp {warp_input.dataId} has no BUNIT metadata")
675 if warp.metadata["BUNIT"] != "nJy":
676 raise ValueError(
677 f"Warp {warp_input.dataId} has BUNIT {warp.metadata['BUNIT']}, expected nJy"
678 )
680 # Only try to remove maks planes that have been registered.
681 to_remove = []
682 for plane in self.config.remove_mask_planes:
683 if plane in warp.mask.getMaskPlaneDict():
684 to_remove.append(plane)
685 removeMaskPlanes(warp.mask, to_remove, self.log)
686 # Instead of using self.config.bad_mask_planes, we explicitly
687 # ask statsCtrl which pixels are going to be ignored/rejected.
688 rejected = afwImage.Mask.getPlaneBitMask(
689 ["CLIPPED", "REJECTED"] + afwImage.Mask.interpret(statsCtrl.getAndMask()).split(",")
690 )
692 # Compute the weight for each CCD in the warp from the visitSummary
693 # or from the warp itself, if not provided. Computing the weight
694 # from the warp is not recommended, and in that case we compute one
695 # weight per warp and not bother with per-detector weights.
696 full_ccd_table = warp.getInfo().getCoaddInputs().ccds
697 weights: dict[int, float] = dict.fromkeys(
698 full_ccd_table["ccd"].tolist(),
699 0.0,
700 ) # Mapping from detector to weight.
702 if visitSummaryRef := visitSummaryRefDict.get(warp_input.dataId["visit"]):
703 visitSummary = visitSummaryRef.get()
704 for detector in full_ccd_table["ccd"].tolist():
705 visitSummaryRow = visitSummary.find(detector)
706 mean_variance = visitSummaryRow["meanVar"]
707 mean_variance *= zero_point_scale_factor**2
708 if warp.metadata.get("BUNIT", None) == "nJy":
709 mean_variance *= visitSummaryRow.photoCalib.getCalibrationMean() ** 2
710 weights[detector] = 1.0 / mean_variance
711 del visitSummary
712 else:
713 self.log.debug("No visit summary found for %s; using warp-based weights", warp_input.dataId)
714 weight = self._compute_weight(warp, statsCtrl)
715 if not np.isfinite(weight):
716 self.log.warn("Non-finite weight for %s: skipping", warp_input.dataId)
717 continue
719 for detector in weights:
720 weights[detector] = weight
722 noise_warps = [ref.get(parameters={"bbox": skyInfo.bbox}) for ref in warp_input.noise_warps]
724 # Create an image where each pixel value corresponds to the
725 # detector ID that pixel comes from.
726 detector_map = afwImage.ImageI(bbox=warp.getBBox(), initialValue=-1)
727 for row in full_ccd_table:
728 transform = makeWcsPairTransform(row.wcs, warp.wcs)
729 if (src_polygon := row.validPolygon) is None:
730 src_polygon = afwGeom.Polygon(geom.Box2D(row.getBBox()))
731 try:
732 dest_polygon = src_polygon.transform(transform).intersectionSingle(
733 geom.Box2D(warp.getBBox())
734 )
735 except SinglePolygonException:
736 continue
738 observation_identifier = ObservationIdentifiers.from_data_id(
739 warp_input.dataId,
740 backup_detector=row["ccd"],
741 )
742 visit_polygons[observation_identifier] = dest_polygon
744 detector_map_slice = dest_polygon.createImage(detector_map.getBBox()).array > 0
745 if not (detector_map.array[detector_map_slice] < 0).all():
746 self.log.warning("Multiple detectors from visit %s are overlapping", warp_input.dataId)
747 detector_map.array[detector_map_slice] = row["ccd"]
749 if (detector_map.array < 0).all():
750 self.log.warning("Unable to split the warp %s into single-detector warps.", warp_input.dataId)
751 detector_map.array[:, :] = 0
753 for cellInfo, ccd_row in itertools.product(skyInfo.patchInfo, full_ccd_table):
754 bbox = cellInfo.outer_bbox
755 inner_bbox = cellInfo.inner_bbox
757 overlap_fraction = (detector_map[inner_bbox].array == ccd_row["ccd"]).mean()
758 assert -1e-4 < overlap_fraction < 1.0001, "Overlap fraction is not within [0, 1]."
759 if (overlap_fraction < self.config.min_overlap_fraction) or (overlap_fraction <= 0.0):
760 self.log.debug(
761 "Skipping %s in cell %s because it had only %.3f < %.3f fractional overlap.",
762 warp_input.dataId,
763 cellInfo.index,
764 overlap_fraction,
765 self.config.min_overlap_fraction,
766 )
767 continue
769 weight = weights[int(ccd_row["ccd"])]
770 if not np.isfinite(weight):
771 self.log.warn(
772 "Non-finite weight for %s in cell %s: skipping", warp_input.dataId, cellInfo.index
773 )
774 continue
776 if weight == 0:
777 self.log.info(
778 "Zero weight for %s in cell %s: skipping", warp_input.dataId, cellInfo.index
779 )
780 continue
782 # Decide if a deep copy is necessary to apply the single
783 # detector cuts since it involves modifying the image in-place.
784 # If within the inner cell, there are three or more different
785 # values that detector map takes, then there are definitely
786 # multiple detectors (one for chip gaps, two for two detectors)
787 deep_copy = len(set(detector_map[inner_bbox].array.ravel())) >= 3
788 if deep_copy:
789 single_detector_mask_array = detector_map[bbox].array != ccd_row["ccd"]
791 mi = afwImage.MaskedImageF(warp[bbox].maskedImage, deep=deep_copy)
792 if deep_copy:
793 mi.image.array[single_detector_mask_array] = 0.0
794 mi.variance.array[single_detector_mask_array] = np.inf
795 nodata_or_mask = (single_detector_mask_array) * afwImage.Mask.getPlaneBitMask("NO_DATA")
796 mi.mask[bbox].array |= nodata_or_mask
797 warp_stacker_gc[cellInfo.index].add_masked_image(mi, weight=weight)
799 if masked_fraction_image:
800 mi = afwImage.ImageF(masked_fraction_image[bbox], deep=deep_copy)
801 if deep_copy:
802 mi.array[single_detector_mask_array] = 0.0
803 maskfrac_stacker_gc[cellInfo.index].add_image(masked_fraction_image[bbox], weight=weight)
805 for n in range(self.config.num_noise_realizations):
806 mi = afwImage.MaskedImageF(noise_warps[n][bbox], deep=deep_copy)
807 if deep_copy:
808 mi.image.array[single_detector_mask_array] = 0.0
809 mi.variance.array[single_detector_mask_array] = np.inf
810 mi.mask[bbox].array |= nodata_or_mask
811 noise_stacker_gc_list[n][cellInfo.index].add_masked_image(mi, weight=weight)
813 # Set the defaults for PSF shape quantities.
814 psf_shape = afwGeom.Quadrupole()
815 psf_shape_flag = True
816 psf_eval_point = None
817 try:
818 # The `if` branch is buggy. `dest_polygon` is technically
819 # out of scope, but Python does not raise an error.
820 # TODO: Fix this properly in DM-53479, but sweep it under
821 # the rug for now.
822 if overlap_fraction < 0.5:
823 psf_eval_point = dest_polygon.intersectionSingle(
824 geom.Box2D(inner_bbox)
825 ).calculateCenter()
826 else:
827 psf_eval_point = geom.Point2D(geom.Point2I(inner_bbox.getCenter()))
828 psf_shape = warp.psf.computeShape(psf_eval_point)
829 psf_shape_flag = False
830 except SinglePolygonException:
831 self.log.info(
832 "Unable to find the overlapping polygon between %d detector in %s and cell %s",
833 ccd_row["ccd"],
834 warp_input.dataId,
835 cellInfo.index,
836 )
837 except InvalidPsfError:
838 self.log.info(
839 "Unable to compute PSF shape from %d detector in %s at %s",
840 ccd_row["ccd"],
841 warp_input.dataId,
842 psf_eval_point,
843 )
845 overlaps_center = detector_map[geom.Point2I(bbox.getCenter())] == ccd_row["ccd"]
847 observation_identifier = ObservationIdentifiers.from_data_id(
848 warp_input.dataId,
849 backup_detector=int(ccd_row["ccd"]),
850 )
851 observation_identifiers_gc[cellInfo.index][observation_identifier] = CoaddInputs(
852 overlaps_center=overlaps_center,
853 overlap_fraction=overlap_fraction,
854 weight=weight,
855 psf_shape=psf_shape,
856 psf_shape_flag=psf_shape_flag,
857 )
858 if overlaps_center is False:
859 self.log.debug(
860 "%s does not overlap with the center of the cell %s",
861 warp_input.dataId,
862 cellInfo.index,
863 )
864 continue
866 # Everything below this has to do with the center of the cell
867 calexp_point = ccd_row.getWcs().skyToPixel(cell_centers_sky[cellInfo.index])
868 undistorted_psf_im = ccd_row.getPsf().computeImage(calexp_point)
870 assert undistorted_psf_im.getBBox() == geom.Box2I.makeCenteredBox(
871 calexp_point,
872 undistorted_psf_im.getDimensions(),
873 ), "PSF image does not share the coordinates of the 'calexp'"
875 # Convert the PSF image from Image to MaskedImage and
876 # zero-pad the image.
877 undistorted_psf_bbox = undistorted_psf_im.getBBox()
878 undistorted_psf_maskedImage = afwImage.MaskedImageD(
879 undistorted_psf_bbox.dilatedBy(self.psf_padding)
880 )
881 undistorted_psf_maskedImage.image[undistorted_psf_bbox].array[:, :] = undistorted_psf_im.array
882 # TODO: In DM-43585, use the variance plane value from noise.
883 undistorted_psf_maskedImage.variance += 1.0 # Set variance to 1
885 warped_psf_maskedImage = self.psf_warper.warpImage(
886 destWcs=skyInfo.wcs,
887 srcImage=undistorted_psf_maskedImage,
888 srcWcs=ccd_row.getWcs(),
889 destBBox=psf_bbox_gc[cellInfo.index],
890 )
892 # There may be NaNs in the PSF image. Set them to 0.0
893 warped_psf_maskedImage.variance.array[np.isnan(warped_psf_maskedImage.image.array)] = 1.0
894 warped_psf_maskedImage.image.array[np.isnan(warped_psf_maskedImage.image.array)] = 0.0
896 psf_stacker = psf_stacker_gc[cellInfo.index]
897 psf_stacker.add_masked_image(warped_psf_maskedImage, weight=weight)
899 if not (0.995 < (psf_normalization := warped_psf_maskedImage.image.array.sum()) < 1.005):
900 self.log.warning(
901 "PSF image for %s in %s is not normalized to 1.0, but instead %f",
902 warp_input.dataId,
903 cellInfo.index,
904 psf_normalization,
905 )
907 if (ap_corr_map := warp.getInfo().getApCorrMap()) is not None:
908 ap_corr_stacker_gc[cellInfo.index].add(ap_corr_map, weight=weight)
910 if self.config.do_input_map:
911 self.input_mapper.add_warp_to_cell_input_map(
912 ccd_row,
913 weight,
914 cellInfo,
915 )
917 del warp
919 if self.config.do_input_map:
920 inputMap = self.input_mapper.cell_input_map
921 else:
922 inputMap = None
924 # Update common with the visit polygons.
925 self.common = dataclasses.replace(
926 self.common,
927 visit_polygons=visit_polygons,
928 )
930 cells: list[SingleCellCoadd] = []
931 for cellInfo in skyInfo.patchInfo:
932 if len(observation_identifiers_gc[cellInfo.index]) == 0:
933 self.log.debug("Skipping cell %s because it has no input warps", cellInfo.index)
934 continue
936 cell_masked_image = afwImage.MaskedImageF(cellInfo.outer_bbox)
937 cell_maskfrac_image = afwImage.ImageF(cellInfo.outer_bbox)
938 cell_noise_images = [
939 afwImage.MaskedImageF(cellInfo.outer_bbox) for n in range(self.config.num_noise_realizations)
940 ]
941 psf_masked_image = afwImage.MaskedImageF(psf_bbox_gc[cellInfo.index])
943 warp_stacker_gc[cellInfo.index].fill_stacked_masked_image(cell_masked_image)
944 maskfrac_stacker_gc[cellInfo.index].fill_stacked_image(cell_maskfrac_image)
945 for n in range(self.config.num_noise_realizations):
946 noise_stacker_gc_list[n][cellInfo.index].fill_stacked_masked_image(cell_noise_images[n])
947 psf_stacker_gc[cellInfo.index].fill_stacked_masked_image(psf_masked_image)
949 if ap_corr_stacker_gc[cellInfo.index].ap_corr_names:
950 ap_corr_map = ap_corr_stacker_gc[cellInfo.index].final_ap_corr_map
951 else:
952 ap_corr_map = None
954 # Post-process the coadd before converting to new data structures.
955 if np.isnan(cell_masked_image.image.array).all():
956 cell_masked_image.image.array[:, :] = 0.0
957 cell_masked_image.variance.array[:, :] = np.inf
958 elif self.config.do_interpolate_coadd:
959 self.interpolate_coadd.run(cell_masked_image, planeName="NO_DATA")
960 for noise_image in cell_noise_images:
961 self.interpolate_coadd.run(noise_image, planeName="NO_DATA")
962 # The variance must be positive; work around for DM-3201.
963 varArray = cell_masked_image.variance.array
964 with np.errstate(invalid="ignore"):
965 varArray[:] = np.where(varArray > 0, varArray, np.inf)
967 afwImage.Mask.addMaskPlane("INEXACT_PSF")
968 cell_masked_image.mask.array[
969 (cell_masked_image.mask.array & rejected) > 0
970 ] |= cell_masked_image.mask.getPlaneBitMask("INEXACT_PSF")
972 image_planes = OwnedImagePlanes.from_masked_image(
973 masked_image=cell_masked_image,
974 mask_fractions=cell_maskfrac_image,
975 noise_realizations=[noise_image.image for noise_image in cell_noise_images],
976 )
977 identifiers = CellIdentifiers(
978 cell=cellInfo.index,
979 skymap=self.common.identifiers.skymap,
980 tract=self.common.identifiers.tract,
981 patch=self.common.identifiers.patch,
982 band=self.common.identifiers.band,
983 )
985 singleCellCoadd = SingleCellCoadd(
986 outer=image_planes,
987 psf=psf_masked_image.image,
988 inner_bbox=cellInfo.inner_bbox,
989 inputs=observation_identifiers_gc[cellInfo.index],
990 common=self.common,
991 identifiers=identifiers,
992 aperture_correction_map=ap_corr_map,
993 )
994 # TODO: Attach transmission curve when they become available.
995 cells.append(singleCellCoadd)
997 if not cells:
998 raise NoWorkFound("No cells could be populated for the cell coadd.")
1000 grid = self._construct_grid(skyInfo)
1001 multipleCellCoadd = MultipleCellCoadd(
1002 cells,
1003 grid=grid,
1004 outer_cell_size=cellInfo.outer_bbox.getDimensions(),
1005 inner_bbox=None,
1006 common=self.common,
1007 psf_image_size=cells[0].psf_image.getDimensions(),
1008 )
1010 return Struct(
1011 multipleCellCoadd=multipleCellCoadd,
1012 inputMap=inputMap,
1013 )
1016class ConvertMultipleCellCoaddToExposureConnections(
1017 PipelineTaskConnections,
1018 dimensions=("tract", "patch", "band", "skymap"),
1019 defaultTemplates={"inputCoaddName": "deep", "inputCoaddSuffix": "Cell"},
1020):
1021 cellCoaddExposure = Input(
1022 doc="Output coadded exposure, produced by stacking input warps",
1023 name="{inputCoaddName}Coadd{inputCoaddSuffix}",
1024 storageClass="MultipleCellCoadd",
1025 dimensions=("tract", "patch", "skymap", "band"),
1026 )
1028 stitchedCoaddExposure = Output(
1029 doc="Output stitched coadded exposure, produced by stacking input warps",
1030 name="{inputCoaddName}Coadd{inputCoaddSuffix}_stitched",
1031 storageClass="ExposureF",
1032 dimensions=("tract", "patch", "skymap", "band"),
1033 )
1036class ConvertMultipleCellCoaddToExposureConfig(
1037 PipelineTaskConfig, pipelineConnections=ConvertMultipleCellCoaddToExposureConnections
1038):
1039 """A trivial PipelineTaskConfig class for
1040 ConvertMultipleCellCoaddToExposureTask.
1041 """
1044class ConvertMultipleCellCoaddToExposureTask(PipelineTask):
1045 """An after burner PipelineTask that converts a cell-based coadd from
1046 `MultipleCellCoadd` format to `ExposureF` format.
1048 The run method stitches the cell-based coadd into contiguous exposure and
1049 returns it in as an `Exposure` object. This is lossy as it preserves only
1050 the pixels in the inner bounding box of the cells and discards the values
1051 in the buffer region.
1053 Notes
1054 -----
1055 This task has no configurable parameters.
1056 """
1058 ConfigClass = ConvertMultipleCellCoaddToExposureConfig
1059 _DefaultName = "convertMultipleCellCoaddToExposure"
1061 def run(self, cellCoaddExposure):
1062 return Struct(
1063 stitchedCoaddExposure=cellCoaddExposure.stitch().asExposure(),
1064 )