Coverage for python / lsst / drp / tasks / assemble_cell_coadd.py: 16%
386 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 09:32 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 09:32 +0000
1# This file is part of drp_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = (
25 "AssembleCellCoaddTask",
26 "AssembleCellCoaddConfig",
27 "ConvertMultipleCellCoaddToExposureTask",
28)
30import dataclasses
31import itertools
32import logging
34import numpy as np
36import lsst.afw.geom as afwGeom
37import lsst.afw.image as afwImage
38import lsst.afw.math as afwMath
39import lsst.geom as geom
40from lsst.afw.detection import InvalidPsfError
41from lsst.afw.geom import SinglePolygonException, makeWcsPairTransform
42from lsst.cell_coadds import (
43 CellIdentifiers,
44 CoaddApCorrMapStacker,
45 CoaddInputs,
46 CoaddUnits,
47 CommonComponents,
48 GridContainer,
49 MultipleCellCoadd,
50 ObservationIdentifiers,
51 OwnedImagePlanes,
52 PatchIdentifiers,
53 SingleCellCoadd,
54 UniformGrid,
55)
56from lsst.daf.butler import DataCoordinate, DeferredDatasetHandle
57from lsst.meas.algorithms import AccumulatorMeanStack
58from lsst.pex.config import ConfigField, ConfigurableField, DictField, Field, ListField, RangeField
59from lsst.pipe.base import (
60 InMemoryDatasetHandle,
61 NoWorkFound,
62 PipelineTask,
63 PipelineTaskConfig,
64 PipelineTaskConnections,
65 Struct,
66)
67from lsst.pipe.base.connectionTypes import Input, Output
68from lsst.pipe.tasks.coaddBase import makeSkyInfo, removeMaskPlanes, setRejectedMaskMapping
69from lsst.pipe.tasks.healSparseMapping import HealSparseInputMapTask
70from lsst.pipe.tasks.interpImage import InterpImageTask
71from lsst.pipe.tasks.scaleZeroPoint import ScaleZeroPointTask
72from lsst.skymap import BaseSkyMap
75@dataclasses.dataclass
76class WarpInputs:
77 """Collection of associate inputs along with warps."""
79 warp: DeferredDatasetHandle | InMemoryDatasetHandle
80 """Handle for the warped exposure."""
82 masked_fraction: DeferredDatasetHandle | InMemoryDatasetHandle | None = None
83 """Handle for the masked fraction image."""
85 artifact_mask: DeferredDatasetHandle | InMemoryDatasetHandle | None = None
86 """Handle for the CompareWarp artifact mask."""
88 noise_warps: list[DeferredDatasetHandle | InMemoryDatasetHandle] = dataclasses.field(default_factory=list)
89 """List of handles for the noise warps"""
91 @property
92 def dataId(self) -> DataCoordinate:
93 """DataID corresponding to the warp.
95 Returns
96 -------
97 data_id : `~lsst.daf.butler.DataCoordinate`
98 DataID of the warp.
99 """
100 return self.warp.dataId
103class AssembleCellCoaddConnections(
104 PipelineTaskConnections,
105 dimensions=("tract", "patch", "band", "skymap"),
106 defaultTemplates={"inputWarpName": "deep", "outputCoaddSuffix": "Cell"},
107):
108 inputWarps = Input(
109 doc="Input warps",
110 name="{inputWarpName}Coadd_directWarp",
111 storageClass="ExposureF",
112 dimensions=("tract", "patch", "skymap", "visit", "instrument"),
113 deferLoad=True,
114 multiple=True,
115 )
117 maskedFractionWarps = Input(
118 doc="Mask fraction warps",
119 name="{inputWarpName}Coadd_directWarp_maskedFraction",
120 storageClass="ImageF",
121 dimensions=("tract", "patch", "skymap", "visit", "instrument"),
122 deferLoad=True,
123 multiple=True,
124 )
126 artifactMasks = Input(
127 doc="Artifact masks to be applied to the input warps",
128 name="compare_warp_artifact_mask",
129 storageClass="Mask",
130 dimensions=("tract", "patch", "skymap", "visit", "instrument"),
131 deferLoad=True,
132 multiple=True,
133 )
135 visitSummaryList = Input(
136 doc="Input visit-summary catalogs with updated calibration objects. Mainly used for coadd weights.",
137 name="finalVisitSummary",
138 storageClass="ExposureCatalog",
139 dimensions=("instrument", "visit"),
140 deferLoad=True,
141 multiple=True,
142 )
144 skyMap = Input(
145 doc="Input definition of geometry/bbox and projection/wcs. This must be cell-based.",
146 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
147 storageClass="SkyMap",
148 dimensions=("skymap",),
149 )
151 multipleCellCoadd = Output(
152 doc="Output multiple cell coadd",
153 name="{inputWarpName}Coadd{outputCoaddSuffix}",
154 storageClass="MultipleCellCoadd",
155 dimensions=("tract", "patch", "band", "skymap"),
156 )
158 inputMap = Output(
159 doc="Output healsparse map of input images",
160 name="{inputWarpName}Coadd_inputMap",
161 storageClass="HealSparseMap",
162 dimensions=("tract", "patch", "band", "skymap"),
163 )
165 def __init__(self, *, config=None):
166 super().__init__(config=config)
168 if not config:
169 return
171 if config.do_calculate_weight_from_warp:
172 del self.visitSummaryList
174 if not config.do_use_artifact_mask:
175 del self.artifactMasks
177 if not config.do_input_map:
178 del self.inputMap
180 # Dynamically set input connections for noise images, depending on the
181 # number of noise realizations specified in the config.
182 for n in range(config.num_noise_realizations):
183 noise_warps = Input(
184 doc="Input noise warps",
185 name=f"direct_warp_noise{n}",
186 storageClass="MaskedImageF",
187 dimensions=("tract", "patch", "skymap", "visit", "instrument"),
188 deferLoad=True,
189 multiple=True,
190 )
191 setattr(self, f"noise{n}_warps", noise_warps)
194class AssembleCellCoaddConfig(PipelineTaskConfig, pipelineConnections=AssembleCellCoaddConnections):
195 do_interpolate_coadd = Field[bool](doc="Interpolate over pixels with NO_DATA mask set?", default=True)
196 interpolate_coadd = ConfigurableField(
197 target=InterpImageTask,
198 doc="Task to interpolate (and extrapolate) over pixels with NO_DATA mask on cell coadds",
199 )
200 do_scale_zero_point = Field[bool](
201 doc="Scale warps to a common zero point? This is not needed if they have absolute flux calibration.",
202 default=False,
203 deprecated="Now that visits are scaled to nJy it is no longer necessary or "
204 "recommended to scale the zero point, so this will be removed "
205 "after v29.",
206 )
207 scale_zero_point = ConfigurableField(
208 target=ScaleZeroPointTask,
209 doc="Task to scale warps to a common zero point",
210 deprecated="Now that visits are scaled to nJy it is no longer necessary or "
211 "recommended to scale the zero point, so this will be removed "
212 "after v29.",
213 )
214 do_calculate_weight_from_warp = Field[bool](
215 doc="Calculate coadd weight from the input warp? Otherwise, the weight is obtained from the "
216 "visitSummaryList connection. This is meant as a fallback when run outside the pipeline.",
217 default=False,
218 )
219 do_use_artifact_mask = Field[bool](
220 doc="Substitute the mask planes input warp with an alternative artifact mask?",
221 default=True,
222 )
223 do_coadd_inverse_aperture_corrections = Field[bool](
224 doc="Coadd the inverse aperture corrections for each cell? This is formally the more accurate way "
225 "but may be turned off for parity with deepCoadd.",
226 default=False,
227 )
228 min_overlap_fraction = RangeField[float](
229 doc="The minimum overlap fraction required for a single (visit, detector) input to be included in a "
230 "cell.",
231 # A value of 1.0 corresponds to ideal, edge-free cells.
232 # A value of 0.0 corresponds to the deep_coadd style coadds.
233 # This has to be at least 0.5 to ensure that the an input overlaps the
234 # cell center. Inputs will overlap fraction less than 0.25 will
235 # definitely not overlap the cell center.
236 default=1.0,
237 min=0.0,
238 max=1.0,
239 inclusiveMin=True,
240 inclusiveMax=True,
241 )
242 bad_mask_planes = ListField[str](
243 doc="Mask planes that count towards the masked fraction within a cell.",
244 default=("BAD", "NO_DATA", "SAT", "CLIPPED"),
245 )
246 remove_mask_planes = ListField[str](
247 doc="Mask planes to remove before coadding",
248 default=["EDGE", "NOT_DEBLENDED"],
249 )
250 calc_error_from_input_variance = Field[bool](
251 doc="Calculate coadd variance from input variance by stacking "
252 "statistic. Passed to AccumulatorMeanStack.",
253 default=True,
254 )
255 mask_propagation_thresholds = DictField[str, float](
256 doc=(
257 "Threshold (in fractional weight) of rejection at which we "
258 "propagate a mask plane to the coadd; that is, we set the mask "
259 "bit on the coadd if the fraction the rejected frames "
260 "would have contributed exceeds this value."
261 ),
262 default={"SAT": 0.1},
263 )
264 max_maskfrac = RangeField[float](
265 doc="Maximum fraction of masked pixels in a cell. This is currently "
266 "just a placeholder and is not used now",
267 default=0.99,
268 min=0.0,
269 max=1.0,
270 inclusiveMin=True,
271 inclusiveMax=False,
272 )
273 num_noise_realizations = Field[int](
274 default=0,
275 doc=(
276 "Number of noise planes to include in the coadd. "
277 "This should not exceed the corresponding config parameter "
278 "specified in `MakeDirectWarpConfig`. "
279 ),
280 check=lambda x: x >= 0,
281 )
282 psf_warper = ConfigField(
283 doc="Configuration for the warper that warps the PSFs. It must have the same configuration used to "
284 "warp the images.",
285 dtype=afwMath.Warper.ConfigClass,
286 )
287 psf_dimensions = Field[int](
288 default=35,
289 doc="Dimensions of the PSF image stamp size to be assigned to cells (must be odd).",
290 check=lambda x: (x > 0) and (x % 2 == 1),
291 )
292 require_artifact_mask = Field[bool](
293 default=True,
294 doc="Require presence of artifact mask for each warp? Use true if using artifact rejection outputs"
295 " from CompareWarpTask",
296 )
297 do_input_map = Field[bool](
298 default=False,
299 doc="Create a bitwise map of coadd inputs.",
300 )
301 input_mapper = ConfigurableField(
302 target=HealSparseInputMapTask,
303 doc="Input map creation subtask.",
304 )
307class AssembleCellCoaddTask(PipelineTask):
308 """Assemble a cell-based coadded image from a set of warps.
310 This task reads in the warp one at a time, and accumulates it in all the
311 cells that it completely overlaps with. This is the optimal I/O pattern but
312 this also implies that it is not possible to build one or only a few cells.
314 Each cell coadds is guaranteed to have a well-defined PSF. This is done by
315 1) excluding warps that only partially overlap a cell from that cell coadd;
316 2) interpolating bad pixels in the warps rather than excluding them;
317 3) by computing the coadd as a weighted mean of the warps without clipping;
318 4) by computing the coadd PSF as the weighted mean of the PSF of the warps
319 with the same weights.
321 The cells are (and must be) defined in the skymap, and cannot be configured
322 or redefined here. The cells are assumed to be small enough that the PSF is
323 assumed to be spatially constant within a cell.
325 Raises
326 ------
327 NoWorkFound
328 Raised if no input warps are provided, or no cells could be populated.
329 RuntimeError
330 Raised if the skymap is not cell-based.
332 Notes
333 -----
334 This is not yet a part of the standard DRP pipeline. As such, the Task and
335 especially its Config and Connections are experimental and subject to
336 change any time without a formal RFC or standard deprecation procedures
337 until it is included in the DRP pipeline.
338 """
340 ConfigClass = AssembleCellCoaddConfig
341 _DefaultName = "assembleCellCoadd"
343 def __init__(self, *args, **kwargs):
344 super().__init__(*args, **kwargs)
345 if self.config.do_interpolate_coadd:
346 self.makeSubtask("interpolate_coadd")
347 # Suppress the warning message about fallback.
348 self.interpolate_coadd.log.setLevel(logging.ERROR)
349 if self.config.do_scale_zero_point:
350 self.makeSubtask("scale_zero_point")
351 if self.config.do_input_map:
352 self.makeSubtask("input_mapper")
354 self.psf_warper = afwMath.Warper.fromConfig(self.config.psf_warper)
355 if (warping_kernel_name := self.config.psf_warper.warpingKernelName.lower()).startswith("lanczos"):
356 psf_padding = 2 * int(warping_kernel_name.lstrip("lanczos")) - 1
357 self.log.debug(
358 "Padding PSF image by %d pixels since the warping kernel is %s.",
359 psf_padding,
360 self.config.psf_warper.warpingKernelName,
361 )
362 else:
363 psf_padding = 10
364 self.log.info(
365 "Padding PSF image by %d pixels since the warping kernel is not Lanczos.",
366 psf_padding,
367 )
368 self.psf_padding = psf_padding
370 def runQuantum(self, butlerQC, inputRefs, outputRefs):
371 # Docstring inherited.
372 if not inputRefs.inputWarps:
373 raise NoWorkFound("No input warps provided for co-addition")
374 self.log.info("Found %d input warps", len(inputRefs.inputWarps))
376 # Construct skyInfo expected by run
377 # Do not remove skyMap from inputData in case _makeSupplementaryData
378 # needs it
379 skyMap = butlerQC.get(inputRefs.skyMap)
381 if not skyMap.config.tractBuilder.name == "cells":
382 raise RuntimeError("AssembleCellCoaddTask requires a cell-based skymap.")
384 outputDataId = butlerQC.quantum.dataId
386 skyInfo = makeSkyInfo(skyMap, tractId=outputDataId["tract"], patchId=outputDataId["patch"])
387 visitSummaryList = butlerQC.get(getattr(inputRefs, "visitSummaryList", []))
389 units = CoaddUnits.legacy if self.config.do_scale_zero_point else CoaddUnits.nJy
390 self.common = CommonComponents(
391 units=units,
392 wcs=skyInfo.patchInfo.wcs,
393 band=outputDataId.get("band", None),
394 identifiers=PatchIdentifiers.from_data_id(outputDataId),
395 )
397 inputs: dict[DataCoordinate, WarpInputs] = {}
398 for handle in butlerQC.get(inputRefs.inputWarps):
399 inputs[handle.dataId] = WarpInputs(warp=handle, noise_warps=[])
401 for ref in getattr(inputRefs, "artifactMasks", []):
402 inputs[ref.dataId].artifact_mask = butlerQC.get(ref)
403 for ref in getattr(inputRefs, "maskedFractionWarps", []):
404 inputs[ref.dataId].masked_fraction = butlerQC.get(ref)
405 for n in range(self.config.num_noise_realizations):
406 for ref in getattr(inputRefs, f"noise{n}_warps"):
407 inputs[ref.dataId].noise_warps.append(butlerQC.get(ref))
409 returnStruct = self.run(inputs=inputs, skyInfo=skyInfo, visitSummaryList=visitSummaryList)
410 butlerQC.put(returnStruct, outputRefs)
411 return returnStruct
413 @staticmethod
414 def _compute_weight(maskedImage, statsCtrl):
415 """Compute a weight for a masked image.
417 Parameters
418 ----------
419 maskedImage : `~lsst.afw.image.MaskedImage`
420 The masked image to compute the weight.
421 statsCtrl : `~lsst.afw.math.StatisticsControl`
422 A control (config-like) object for StatisticsStack.
424 Returns
425 -------
426 weight : `float`
427 Inverse of the clipped mean variance of the masked image.
428 """
429 statObj = afwMath.makeStatistics(
430 maskedImage.getVariance(), maskedImage.getMask(), afwMath.MEANCLIP, statsCtrl
431 )
432 meanVar, _ = statObj.getResult(afwMath.MEANCLIP)
433 weight = 1.0 / float(meanVar)
434 return weight
436 @staticmethod
437 def _construct_grid(skyInfo):
438 """Construct a UniformGrid object from a SkyInfo struct.
440 Parameters
441 ----------
442 skyInfo : `~lsst.pipe.base.Struct`
443 A Struct object
445 Returns
446 -------
447 grid : `~lsst.cell_coadds.UniformGrid`
448 A UniformGrid object.
449 """
450 padding = skyInfo.patchInfo.getCellBorder()
451 grid_bbox = skyInfo.patchInfo.outer_bbox.erodedBy(padding)
452 grid = UniformGrid.from_bbox_cell_size(
453 grid_bbox,
454 skyInfo.patchInfo.getCellInnerDimensions(),
455 padding=padding,
456 )
457 return grid
459 def _construct_grid_container(self, skyInfo, statsCtrl):
460 """Construct a grid of AccumulatorMeanStack instances.
462 Parameters
463 ----------
464 skyInfo : `~lsst.pipe.base.Struct`
465 A Struct object
466 statsCtrl : `~lsst.afw.math.StatisticsControl`
467 A control (config-like) object for StatisticsStack.
469 Returns
470 -------
471 gc : `~lsst.cell_coadds.GridContainer`
472 A GridContainer object container one AccumulatorMeanStack per cell.
473 """
474 grid = self._construct_grid(skyInfo)
476 maskMap = setRejectedMaskMapping(statsCtrl)
477 self.log.debug("Obtained maskMap = %s for %s", maskMap, skyInfo.patchInfo)
478 thresholdDict = AccumulatorMeanStack.stats_ctrl_to_threshold_dict(statsCtrl)
480 # Initialize the grid container with AccumulatorMeanStacks
481 gc = GridContainer[AccumulatorMeanStack](grid.shape)
482 for cellInfo in skyInfo.patchInfo:
483 stacker = AccumulatorMeanStack(
484 # The shape is for the numpy arrays, hence transposed.
485 shape=(cellInfo.outer_bbox.height, cellInfo.outer_bbox.width),
486 bit_mask_value=statsCtrl.getAndMask(),
487 mask_threshold_dict=thresholdDict,
488 calc_error_from_input_variance=self.config.calc_error_from_input_variance,
489 compute_n_image=False,
490 mask_map=maskMap,
491 no_good_pixels_mask=statsCtrl.getNoGoodPixelsMask(),
492 )
493 gc[cellInfo.index] = stacker
495 return gc
497 def _construct_stats_control(self):
498 """Construct a StatisticsControl object for coadd.
500 Unlike AssembleCoaddTask or CompareWarpAssembleCoaddTask, there is
501 very little to be configured apart from setting the mask planes and
502 optionally mask propagation thresholds.
504 Returns
505 -------
506 statsCtrl : `~lsst.afw.math.StatisticsControl`
507 A control object for StatisticsStack.
508 """
509 statsCtrl = afwMath.StatisticsControl()
510 # Hardcode the numIter parameter to the default config value set in
511 # CompareWarpAssembleCoaddTask to get consistent weights. This is NOT
512 # exposed as a config parameter, since this is only meant to be a
513 # fallback option that is not recommended for production.
514 statsCtrl.setNumIter(2)
515 statsCtrl.setAndMask(afwImage.Mask.getPlaneBitMask(self.config.bad_mask_planes))
516 statsCtrl.setNanSafe(True)
517 for plane, threshold in self.config.mask_propagation_thresholds.items():
518 bit = afwImage.Mask.getMaskPlane(plane)
519 statsCtrl.setMaskPropagationThreshold(bit, threshold)
520 return statsCtrl
522 def _construct_ap_corr_grid_container(self, skyInfo):
523 """Construct a grid of CoaddApCorrMapStacker instances.
525 Parameters
526 ----------
527 skyInfo : `~lsst.pipe.base.Struct`
528 A Struct object
530 Returns
531 -------
532 gc : `~lsst.cell_coadds.GridContainer`
533 A GridContainer object container one CoaddApCorrMapStacker per
534 cell.
535 """
536 grid = self._construct_grid(skyInfo)
538 # Initialize the grid container with CoaddApCorrMapStacker.
539 gc = GridContainer[CoaddApCorrMapStacker](grid.shape)
540 for cellInfo in skyInfo.patchInfo:
541 stacker = CoaddApCorrMapStacker(
542 evaluation_point=cellInfo.inner_bbox.getCenter(),
543 do_coadd_inverse_ap_corr=self.config.do_coadd_inverse_aperture_corrections,
544 )
545 gc[cellInfo.index] = stacker
547 return gc
549 def run(
550 self,
551 *,
552 inputs: dict[DataCoordinate, WarpInputs],
553 skyInfo,
554 visitSummaryList: list | None = None,
555 ):
556 for mask_plane in self.config.bad_mask_planes:
557 afwImage.Mask.addMaskPlane(mask_plane)
558 for mask_plane in self.config.mask_propagation_thresholds:
559 afwImage.Mask.addMaskPlane(mask_plane)
561 statsCtrl = self._construct_stats_control()
563 warp_stacker_gc = self._construct_grid_container(skyInfo, statsCtrl)
564 maskfrac_stacker_gc = self._construct_grid_container(skyInfo, statsCtrl)
565 noise_stacker_gc_list = [
566 self._construct_grid_container(skyInfo, statsCtrl)
567 for n in range(self.config.num_noise_realizations)
568 ]
569 psf_stacker_gc = GridContainer[AccumulatorMeanStack](warp_stacker_gc.shape)
570 psf_bbox_gc = GridContainer[geom.Box2I](warp_stacker_gc.shape)
571 ap_corr_stacker_gc = self._construct_ap_corr_grid_container(skyInfo)
573 # Make a container to hold the cell centers in sky coordinates now,
574 # so we don't have to recompute them for each warp
575 # (they share a common WCS). These are needed to find the various
576 # warp + detector combinations that contributed to each cell, and later
577 # get the corresponding PSFs as well.
578 cell_centers_sky = GridContainer[geom.SpherePoint](warp_stacker_gc.shape)
579 # Make a container to hold the observation identifiers for each cell.
580 observation_identifiers_gc = GridContainer[dict](warp_stacker_gc.shape)
582 if self.config.do_input_map:
583 # We need to know all the visit + detector pairs in the inputs.
584 warp_input_list = [warp_ref.warp.get(component="coaddInputs") for warp_ref in inputs.values()]
585 visit_detectors = []
586 for warp_input in warp_input_list:
587 for row in warp_input.ccds:
588 visit_detectors.append((int(row["visit"]), int(row["ccd"])))
590 self.input_mapper.initialize_cell_input_map(
591 skyInfo.patchInfo.getOuterBBox(),
592 skyInfo.patchInfo.wcs,
593 visit_detectors,
594 )
596 # Populate them.
597 for cellInfo in skyInfo.patchInfo:
598 # Make a list to hold the observation identifiers for each cell.
599 observation_identifiers_gc[cellInfo.index] = {}
600 cell_center_pixel = geom.Point2D(geom.Point2I(cellInfo.inner_bbox.getCenter()))
601 cell_centers_sky[cellInfo.index] = skyInfo.wcs.pixelToSky(cell_center_pixel)
602 psf_bbox_gc[cellInfo.index] = geom.Box2I.makeCenteredBox(
603 cell_center_pixel,
604 geom.Extent2I(self.config.psf_dimensions, self.config.psf_dimensions),
605 )
606 psf_stacker_gc[cellInfo.index] = AccumulatorMeanStack(
607 # The shape is for the numpy arrays, hence transposed.
608 shape=(self.config.psf_dimensions, self.config.psf_dimensions),
609 bit_mask_value=0,
610 calc_error_from_input_variance=self.config.calc_error_from_input_variance,
611 compute_n_image=False,
612 )
614 if self.config.do_input_map:
615 self.input_mapper.build_cell_input_map(cellInfo)
617 # visit_summary do not have (tract, patch, band, skymap) dimensions.
618 if not visitSummaryList:
619 visitSummaryList = []
620 visitSummaryRefDict = {
621 visitSummaryRef.dataId["visit"]: visitSummaryRef for visitSummaryRef in visitSummaryList
622 }
624 # Keep track of the polygons corresponding to each (visit, detector).
625 visit_polygons: dict[ObservationIdentifiers, afwGeom.Polygon] = {}
627 # Read in one warp at a time, and accumulate it in all the cells that
628 # it completely overlaps.
629 for _, warp_input in inputs.items():
630 # warps that have been excluded from CompareWarp via visit
631 # selection from SelectVisitsTasks will not have artifact masks.
632 # Exclude them from the cell coadds too.
633 if self.config.require_artifact_mask and warp_input.artifact_mask is None:
634 self.log.info(
635 "Excluding warp %s from cell coadds because it has no artifact mask",
636 warp_input.dataId["visit"],
637 )
638 continue
640 warp = warp_input.warp.get(parameters={"bbox": skyInfo.bbox})
641 masked_fraction_image = (
642 warp_input.masked_fraction.get(parameters={"bbox": skyInfo.bbox})
643 if warp_input.masked_fraction
644 else None
645 )
647 # Pre-process the warp before coadding.
648 # TODO: Can we get these mask names from artifactMask?
649 warp.mask.addMaskPlane("CLIPPED")
650 warp.mask.addMaskPlane("REJECTED")
651 warp.mask.addMaskPlane("SENSOR_EDGE")
652 warp.mask.addMaskPlane("INEXACT_PSF")
654 if artifact_mask_ref := warp_input.artifact_mask:
655 # Apply the artifact mask to the warp.
656 artifact_mask = artifact_mask_ref.get()
657 assert (
658 warp.mask.getMaskPlaneDict() == artifact_mask.getMaskPlaneDict()
659 ), "Mask dicts do not agree."
660 warp.mask.array = artifact_mask.array
661 del artifact_mask
663 if self.config.do_scale_zero_point:
664 # Each Warp that goes into a coadd will typically have an
665 # independent photometric zero-point. Therefore, we must scale
666 # each Warp to set it to a common photometric zeropoint.
667 imageScaler = self.scale_zero_point.run(exposure=warp, dataRef=warp_input.warp).imageScaler
668 zero_point_scale_factor = imageScaler.scale
669 self.log.debug(
670 "Scaled the warp %s by %f to match zero points",
671 warp_input.dataId,
672 zero_point_scale_factor,
673 )
674 else:
675 zero_point_scale_factor = 1.0
676 if "BUNIT" not in warp.metadata:
677 raise ValueError(f"Warp {warp_input.dataId} has no BUNIT metadata")
678 if warp.metadata["BUNIT"] != "nJy":
679 raise ValueError(
680 f"Warp {warp_input.dataId} has BUNIT {warp.metadata['BUNIT']}, expected nJy"
681 )
683 # Only try to remove maks planes that have been registered.
684 to_remove = []
685 for plane in self.config.remove_mask_planes:
686 if plane in warp.mask.getMaskPlaneDict():
687 to_remove.append(plane)
688 removeMaskPlanes(warp.mask, to_remove, self.log)
689 # Instead of using self.config.bad_mask_planes, we explicitly
690 # ask statsCtrl which pixels are going to be ignored/rejected.
691 rejected = afwImage.Mask.getPlaneBitMask(
692 ["CLIPPED", "REJECTED"] + afwImage.Mask.interpret(statsCtrl.getAndMask()).split(",")
693 )
695 # Compute the weight for each CCD in the warp from the visitSummary
696 # or from the warp itself, if not provided. Computing the weight
697 # from the warp is not recommended, and in that case we compute one
698 # weight per warp and not bother with per-detector weights.
699 full_ccd_table = warp.getInfo().getCoaddInputs().ccds
700 weights: dict[int, float] = dict.fromkeys(
701 full_ccd_table["ccd"].tolist(),
702 0.0,
703 ) # Mapping from detector to weight.
705 if visitSummaryRef := visitSummaryRefDict.get(warp_input.dataId["visit"]):
706 visitSummary = visitSummaryRef.get()
707 for detector in full_ccd_table["ccd"].tolist():
708 visitSummaryRow = visitSummary.find(detector)
709 mean_variance = visitSummaryRow["meanVar"]
710 mean_variance *= zero_point_scale_factor**2
711 if warp.metadata.get("BUNIT", None) == "nJy":
712 mean_variance *= visitSummaryRow.photoCalib.getCalibrationMean() ** 2
713 weights[detector] = 1.0 / mean_variance
714 del visitSummary
715 else:
716 self.log.debug("No visit summary found for %s; using warp-based weights", warp_input.dataId)
717 weight = self._compute_weight(warp, statsCtrl)
718 if not np.isfinite(weight):
719 self.log.warn("Non-finite weight for %s: skipping", warp_input.dataId)
720 continue
722 for detector in weights:
723 weights[detector] = weight
725 noise_warps = [ref.get(parameters={"bbox": skyInfo.bbox}) for ref in warp_input.noise_warps]
727 # Create an image where each pixel value corresponds to the
728 # detector ID that pixel comes from.
729 detector_map = afwImage.ImageI(bbox=warp.getBBox(), initialValue=-1)
730 for row in full_ccd_table:
731 transform = makeWcsPairTransform(row.wcs, warp.wcs)
732 if (src_polygon := row.validPolygon) is None:
733 src_polygon = afwGeom.Polygon(geom.Box2D(row.getBBox()))
734 try:
735 dest_polygon = src_polygon.transform(transform).intersectionSingle(
736 geom.Box2D(warp.getBBox())
737 )
738 except SinglePolygonException:
739 continue
741 observation_identifier = ObservationIdentifiers.from_data_id(
742 warp_input.dataId,
743 backup_detector=row["ccd"],
744 )
745 visit_polygons[observation_identifier] = dest_polygon
747 detector_map_slice = dest_polygon.createImage(detector_map.getBBox()).array > 0
748 if not (detector_map.array[detector_map_slice] < 0).all():
749 self.log.warning("Multiple detectors from visit %s are overlapping", warp_input.dataId)
750 detector_map.array[detector_map_slice] = row["ccd"]
752 if (detector_map.array < 0).all():
753 self.log.warning("Unable to split the warp %s into single-detector warps.", warp_input.dataId)
754 detector_map.array[:, :] = 0
756 for cellInfo, ccd_row in itertools.product(skyInfo.patchInfo, full_ccd_table):
757 bbox = cellInfo.outer_bbox
758 inner_bbox = cellInfo.inner_bbox
760 overlap_fraction = (detector_map[inner_bbox].array == ccd_row["ccd"]).mean()
761 assert -1e-4 < overlap_fraction < 1.0001, "Overlap fraction is not within [0, 1]."
762 if (overlap_fraction < self.config.min_overlap_fraction) or (overlap_fraction <= 0.0):
763 self.log.debug(
764 "Skipping %s in cell %s because it had only %.3f < %.3f fractional overlap.",
765 warp_input.dataId,
766 cellInfo.index,
767 overlap_fraction,
768 self.config.min_overlap_fraction,
769 )
770 continue
772 weight = weights[int(ccd_row["ccd"])]
773 if not np.isfinite(weight):
774 self.log.warn(
775 "Non-finite weight for %s in cell %s: skipping", warp_input.dataId, cellInfo.index
776 )
777 continue
779 if weight == 0:
780 self.log.info(
781 "Zero weight for %s in cell %s: skipping", warp_input.dataId, cellInfo.index
782 )
783 continue
785 # Decide if a deep copy is necessary to apply the single
786 # detector cuts since it involves modifying the image in-place.
787 # If within the inner cell, there are three or more different
788 # values that detector map takes, then there are definitely
789 # multiple detectors (one for chip gaps, two for two detectors)
790 deep_copy = len(set(detector_map[inner_bbox].array.ravel())) >= 3
791 if deep_copy:
792 single_detector_mask_array = detector_map[bbox].array != ccd_row["ccd"]
794 mi = afwImage.MaskedImageF(warp[bbox].maskedImage, deep=deep_copy)
795 if deep_copy:
796 mi.image.array[single_detector_mask_array] = 0.0
797 mi.variance.array[single_detector_mask_array] = np.inf
798 nodata_or_mask = (single_detector_mask_array) * afwImage.Mask.getPlaneBitMask("NO_DATA")
799 mi.mask[bbox].array |= nodata_or_mask
800 warp_stacker_gc[cellInfo.index].add_masked_image(mi, weight=weight)
802 if masked_fraction_image:
803 mi = afwImage.ImageF(masked_fraction_image[bbox], deep=deep_copy)
804 if deep_copy:
805 mi.array[single_detector_mask_array] = 0.0
806 maskfrac_stacker_gc[cellInfo.index].add_image(masked_fraction_image[bbox], weight=weight)
808 for n in range(self.config.num_noise_realizations):
809 mi = afwImage.MaskedImageF(noise_warps[n][bbox], deep=deep_copy)
810 if deep_copy:
811 mi.image.array[single_detector_mask_array] = 0.0
812 mi.variance.array[single_detector_mask_array] = np.inf
813 mi.mask[bbox].array |= nodata_or_mask
814 noise_stacker_gc_list[n][cellInfo.index].add_masked_image(mi, weight=weight)
816 # Set the defaults for PSF shape quantities.
817 psf_shape = afwGeom.Quadrupole()
818 psf_shape_flag = True
819 psf_eval_point = None
820 try:
821 # The `if` branch is buggy. `dest_polygon` is technically
822 # out of scope, but Python does not raise an error.
823 # TODO: Fix this properly in DM-53479, but sweep it under
824 # the rug for now.
825 if overlap_fraction < 0.5:
826 psf_eval_point = dest_polygon.intersectionSingle(
827 geom.Box2D(inner_bbox)
828 ).calculateCenter()
829 else:
830 psf_eval_point = geom.Point2D(geom.Point2I(inner_bbox.getCenter()))
831 psf_shape = warp.psf.computeShape(psf_eval_point)
832 psf_shape_flag = False
833 except SinglePolygonException:
834 self.log.info(
835 "Unable to find the overlapping polygon between %d detector in %s and cell %s",
836 ccd_row["ccd"],
837 warp_input.dataId,
838 cellInfo.index,
839 )
840 except InvalidPsfError:
841 self.log.info(
842 "Unable to compute PSF shape from %d detector in %s at %s",
843 ccd_row["ccd"],
844 warp_input.dataId,
845 psf_eval_point,
846 )
848 overlaps_center = detector_map[geom.Point2I(bbox.getCenter())] == ccd_row["ccd"]
850 observation_identifier = ObservationIdentifiers.from_data_id(
851 warp_input.dataId,
852 backup_detector=int(ccd_row["ccd"]),
853 )
854 observation_identifiers_gc[cellInfo.index][observation_identifier] = CoaddInputs(
855 overlaps_center=overlaps_center,
856 overlap_fraction=overlap_fraction,
857 weight=weight,
858 psf_shape=psf_shape,
859 psf_shape_flag=psf_shape_flag,
860 )
861 if overlaps_center is False:
862 self.log.debug(
863 "%s does not overlap with the center of the cell %s",
864 warp_input.dataId,
865 cellInfo.index,
866 )
867 continue
869 # Everything below this has to do with the center of the cell
870 calexp_point = ccd_row.getWcs().skyToPixel(cell_centers_sky[cellInfo.index])
871 undistorted_psf_im = ccd_row.getPsf().computeImage(calexp_point)
873 assert undistorted_psf_im.getBBox() == geom.Box2I.makeCenteredBox(
874 calexp_point,
875 undistorted_psf_im.getDimensions(),
876 ), "PSF image does not share the coordinates of the 'calexp'"
878 # Convert the PSF image from Image to MaskedImage and
879 # zero-pad the image.
880 undistorted_psf_bbox = undistorted_psf_im.getBBox()
881 undistorted_psf_maskedImage = afwImage.MaskedImageD(
882 undistorted_psf_bbox.dilatedBy(self.psf_padding)
883 )
884 undistorted_psf_maskedImage.image[undistorted_psf_bbox].array[:, :] = undistorted_psf_im.array
885 # TODO: In DM-43585, use the variance plane value from noise.
886 undistorted_psf_maskedImage.variance += 1.0 # Set variance to 1
888 warped_psf_maskedImage = self.psf_warper.warpImage(
889 destWcs=skyInfo.wcs,
890 srcImage=undistorted_psf_maskedImage,
891 srcWcs=ccd_row.getWcs(),
892 destBBox=psf_bbox_gc[cellInfo.index],
893 )
895 # There may be NaNs in the PSF image. Set them to 0.0
896 warped_psf_maskedImage.variance.array[np.isnan(warped_psf_maskedImage.image.array)] = 1.0
897 warped_psf_maskedImage.image.array[np.isnan(warped_psf_maskedImage.image.array)] = 0.0
899 psf_stacker = psf_stacker_gc[cellInfo.index]
900 psf_stacker.add_masked_image(warped_psf_maskedImage, weight=weight)
902 if not (0.995 < (psf_normalization := warped_psf_maskedImage.image.array.sum()) < 1.005):
903 self.log.warning(
904 "PSF image for %s in %s is not normalized to 1.0, but instead %f",
905 warp_input.dataId,
906 cellInfo.index,
907 psf_normalization,
908 )
910 if (ap_corr_map := warp.getInfo().getApCorrMap()) is not None:
911 ap_corr_stacker_gc[cellInfo.index].add(ap_corr_map, weight=weight)
913 if self.config.do_input_map:
914 self.input_mapper.add_warp_to_cell_input_map(
915 ccd_row,
916 weight,
917 cellInfo,
918 )
920 del warp
922 if self.config.do_input_map:
923 inputMap = self.input_mapper.cell_input_map
924 else:
925 inputMap = None
927 # Update common with the visit polygons.
928 self.common = dataclasses.replace(
929 self.common,
930 visit_polygons=visit_polygons,
931 )
933 cells: list[SingleCellCoadd] = []
934 for cellInfo in skyInfo.patchInfo:
935 if len(observation_identifiers_gc[cellInfo.index]) == 0:
936 self.log.debug("Skipping cell %s because it has no input warps", cellInfo.index)
937 continue
939 cell_masked_image = afwImage.MaskedImageF(cellInfo.outer_bbox)
940 cell_maskfrac_image = afwImage.ImageF(cellInfo.outer_bbox)
941 cell_noise_images = [
942 afwImage.MaskedImageF(cellInfo.outer_bbox) for n in range(self.config.num_noise_realizations)
943 ]
944 psf_masked_image = afwImage.MaskedImageF(psf_bbox_gc[cellInfo.index])
946 warp_stacker_gc[cellInfo.index].fill_stacked_masked_image(cell_masked_image)
947 maskfrac_stacker_gc[cellInfo.index].fill_stacked_image(cell_maskfrac_image)
948 for n in range(self.config.num_noise_realizations):
949 noise_stacker_gc_list[n][cellInfo.index].fill_stacked_masked_image(cell_noise_images[n])
950 psf_stacker_gc[cellInfo.index].fill_stacked_masked_image(psf_masked_image)
952 if ap_corr_stacker_gc[cellInfo.index].ap_corr_names:
953 ap_corr_map = ap_corr_stacker_gc[cellInfo.index].final_ap_corr_map
954 else:
955 ap_corr_map = None
957 # Post-process the coadd before converting to new data structures.
958 if np.isnan(cell_masked_image.image.array).all():
959 cell_masked_image.image.array[:, :] = 0.0
960 cell_masked_image.variance.array[:, :] = np.inf
961 elif self.config.do_interpolate_coadd:
962 self.interpolate_coadd.run(cell_masked_image, planeName="NO_DATA")
963 for noise_image in cell_noise_images:
964 self.interpolate_coadd.run(noise_image, planeName="NO_DATA")
965 # The variance must be positive; work around for DM-3201.
966 varArray = cell_masked_image.variance.array
967 with np.errstate(invalid="ignore"):
968 varArray[:] = np.where(varArray > 0, varArray, np.inf)
970 afwImage.Mask.addMaskPlane("INEXACT_PSF")
971 cell_masked_image.mask.array[
972 (cell_masked_image.mask.array & rejected) > 0
973 ] |= cell_masked_image.mask.getPlaneBitMask("INEXACT_PSF")
975 image_planes = OwnedImagePlanes.from_masked_image(
976 masked_image=cell_masked_image,
977 mask_fractions=cell_maskfrac_image,
978 noise_realizations=[noise_image.image for noise_image in cell_noise_images],
979 )
980 identifiers = CellIdentifiers(
981 cell=cellInfo.index,
982 skymap=self.common.identifiers.skymap,
983 tract=self.common.identifiers.tract,
984 patch=self.common.identifiers.patch,
985 band=self.common.identifiers.band,
986 )
988 singleCellCoadd = SingleCellCoadd(
989 outer=image_planes,
990 psf=psf_masked_image.image,
991 inner_bbox=cellInfo.inner_bbox,
992 inputs=observation_identifiers_gc[cellInfo.index],
993 common=self.common,
994 identifiers=identifiers,
995 aperture_correction_map=ap_corr_map,
996 )
997 # TODO: Attach transmission curve when they become available.
998 cells.append(singleCellCoadd)
1000 if not cells:
1001 raise NoWorkFound("No cells could be populated for the cell coadd.")
1003 grid = self._construct_grid(skyInfo)
1004 multipleCellCoadd = MultipleCellCoadd(
1005 cells,
1006 grid=grid,
1007 outer_cell_size=cellInfo.outer_bbox.getDimensions(),
1008 inner_bbox=None,
1009 common=self.common,
1010 psf_image_size=cells[0].psf_image.getDimensions(),
1011 )
1013 return Struct(
1014 multipleCellCoadd=multipleCellCoadd,
1015 inputMap=inputMap,
1016 )
1019class ConvertMultipleCellCoaddToExposureConnections(
1020 PipelineTaskConnections,
1021 dimensions=("tract", "patch", "band", "skymap"),
1022 defaultTemplates={"inputCoaddName": "deep", "inputCoaddSuffix": "Cell"},
1023):
1024 cellCoaddExposure = Input(
1025 doc="Output coadded exposure, produced by stacking input warps",
1026 name="{inputCoaddName}Coadd{inputCoaddSuffix}",
1027 storageClass="MultipleCellCoadd",
1028 dimensions=("tract", "patch", "skymap", "band"),
1029 )
1031 stitchedCoaddExposure = Output(
1032 doc="Output stitched coadded exposure, produced by stacking input warps",
1033 name="{inputCoaddName}Coadd{inputCoaddSuffix}_stitched",
1034 storageClass="ExposureF",
1035 dimensions=("tract", "patch", "skymap", "band"),
1036 )
1039class ConvertMultipleCellCoaddToExposureConfig(
1040 PipelineTaskConfig, pipelineConnections=ConvertMultipleCellCoaddToExposureConnections
1041):
1042 """A trivial PipelineTaskConfig class for
1043 ConvertMultipleCellCoaddToExposureTask.
1044 """
1047class ConvertMultipleCellCoaddToExposureTask(PipelineTask):
1048 """An after burner PipelineTask that converts a cell-based coadd from
1049 `MultipleCellCoadd` format to `ExposureF` format.
1051 The run method stitches the cell-based coadd into contiguous exposure and
1052 returns it in as an `Exposure` object. This is lossy as it preserves only
1053 the pixels in the inner bounding box of the cells and discards the values
1054 in the buffer region.
1056 Notes
1057 -----
1058 This task has no configurable parameters.
1059 """
1061 ConfigClass = ConvertMultipleCellCoaddToExposureConfig
1062 _DefaultName = "convertMultipleCellCoaddToExposure"
1064 def run(self, cellCoaddExposure):
1065 return Struct(
1066 stitchedCoaddExposure=cellCoaddExposure.stitch().asExposure(),
1067 )