Coverage for python / lsst / pipe / tasks / prettyPictureMaker / _task.py: 21%
427 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 08:40 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 08:40 +0000
1# This file is part of pipe_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = (
25 "ChannelRGBConfig",
26 "PrettyPictureTask",
27 "PrettyPictureConnections",
28 "PrettyPictureConfig",
29 "PrettyMosaicTask",
30 "PrettyMosaicConnections",
31 "PrettyMosaicConfig",
32 "PrettyPictureBackgroundFixerConfig",
33 "PrettyPictureBackgroundFixerTask",
34 "PrettyPictureStarFixerConfig",
35 "PrettyPictureStarFixerTask",
36)
38import colour
39import copy
40from collections.abc import Iterable, Mapping
41from lsst.afw.image import ExposureF
42import numpy as np
43from typing import TYPE_CHECKING, cast, Any
44from lsst.skymap import BaseSkyMap
46from scipy.stats import halfnorm, mode
47from scipy.ndimage import binary_dilation
48from scipy.interpolate import RBFInterpolator
49from skimage.restoration import inpaint_biharmonic
51from lsst.daf.butler import Butler, DeferredDatasetHandle
52from lsst.daf.butler import DatasetRef
53from lsst.images import ColorImage, Projection, Box, TractFrame
54from lsst.pex.config import Field, Config, ConfigDictField, ListField, ChoiceField
55from lsst.pex.config.configurableActions import ConfigurableActionField
56from lsst.pipe.base import (
57 PipelineTask,
58 PipelineTaskConfig,
59 PipelineTaskConnections,
60 Struct,
61 InMemoryDatasetHandle,
62 NoWorkFound,
63)
64from lsst.rubinoxide import rbf_interpolator
65import cv2
67from lsst.pipe.base.connectionTypes import Input, Output
68from lsst.geom import Box2I, Point2I, Extent2I
69from lsst.afw.image import Exposure, Mask
71from ._plugins import plugins
72from ._colorMapper import lsstRGB
73from ._utils import FeatheredMosaicCreator
74from ._functors import (
75 BoundsRemapper,
76 ColorScaler,
77 LumCompressor,
78 ExposureBracketer,
79 GamutFixer,
80 LocalContrastEnhancer,
81)
83import tempfile
86if TYPE_CHECKING:
87 from numpy.typing import NDArray
88 from lsst.pipe.base import QuantumContext, InputQuantizedConnection, OutputQuantizedConnection
89 from lsst.skymap import TractInfo, PatchInfo
92class PrettyPictureConnections(
93 PipelineTaskConnections,
94 dimensions={"tract", "patch", "skymap"},
95 defaultTemplates={"coaddTypeName": "deep"},
96):
97 inputCoadds = Input(
98 doc=(
99 "Model of the static sky, used to find temporal artifacts. Typically a PSF-Matched, "
100 "sigma-clipped coadd. Written if and only if assembleStaticSkyModel.doWrite=True"
101 ),
102 name="pretty_coadd",
103 storageClass="ExposureF",
104 dimensions=("tract", "patch", "skymap", "band"),
105 multiple=True,
106 )
108 skyMap = Input(
109 doc="The skymap which the data has been mapped onto",
110 storageClass="SkyMap",
111 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
112 dimensions=("skymap",),
113 )
115 outputRGB = Output(
116 doc="A RGB image created from the input data stored as a 3d array",
117 name="rgb_picture",
118 storageClass="ColorImage",
119 dimensions=("tract", "patch", "skymap"),
120 )
122 outputRGBMask = Output(
123 doc="A Mask corresponding to the fused masks of the input channels",
124 name="rgb_picture_mask",
125 storageClass="Mask",
126 dimensions=("tract", "patch", "skymap"),
127 )
130class ChannelRGBConfig(Config):
131 """This describes the rgb values of a given input channel.
133 For instance if this channel is red the values would be self.r = 1,
134 self.g = 0, self.b = 0. If the channel was cyan the values would be
135 self.r = 0, self.g = 1, self.b = 1.
136 """
138 r = Field[float](doc="The amount of red contained in this channel")
139 g = Field[float](doc="The amount of green contained in this channel")
140 b = Field[float](doc="The amount of blue contained in this channel")
143class PrettyPictureConfig(PipelineTaskConfig, pipelineConnections=PrettyPictureConnections):
144 channelConfig = ConfigDictField(
145 doc="A dictionary that maps band names to their rgb channel configurations",
146 keytype=str,
147 itemtype=ChannelRGBConfig,
148 default={},
149 )
150 cieWhitePoint = ListField[float](
151 doc="The white point of the input arrays in ciexz coordinates", maxLength=2, default=[0.28, 0.28]
152 )
153 arrayType = ChoiceField[str](
154 doc="The dataset type for the output image array",
155 default="uint8",
156 allowed={
157 "uint8": "Use 8 bit arrays, 255 max",
158 "uint16": "Use 16 bit arrays, 65535 max",
159 "half": "Use 16 bit float arrays, 1 max",
160 "float": "Use 32 bit float arrays, 1 max",
161 },
162 )
163 recenterNoise = Field[float](
164 doc="Recenter the noise away from zero. Supplied value is in units of sigma",
165 optional=True,
166 default=None,
167 )
168 noiseSearchThreshold = Field[float](
169 doc=(
170 "Flux threshold below which most flux will be considered noise, used to estimate noise properties"
171 ),
172 default=10,
173 )
174 doPsfDeconvolve = Field[bool](
175 doc="Use the PSF in a Richardson-Lucy deconvolution on the luminance channel.", default=False
176 )
177 doPSFDeconcovlve = Field[bool](
178 doc="Use the PSF in a Richardson-Lucy deconvolution on the luminance channel.",
179 default=False,
180 deprecated="This field will be removed in v32. Use doPsfDeconvolve instead.",
181 optional=True,
182 )
183 doRemapGamut = Field[bool](
184 doc="Apply a color correction to unrepresentable colors; if False, clip them.", default=True
185 )
186 doExposureBrackets = Field[bool](
187 doc="Apply exposure bracketing to aid in dynamic range compression", default=True
188 )
189 doLocalContrast = Field[bool](doc="Apply local contrast optimizations to luminance.", default=True)
191 imageRemappingConfig = ConfigurableActionField[BoundsRemapper](
192 doc="Action controlling normalization process"
193 )
194 luminanceConfig = ConfigurableActionField[LumCompressor](
195 doc="Action controlling luminance scaling when making an RGB image"
196 )
197 localContrastConfig = ConfigurableActionField[LocalContrastEnhancer](
198 doc="Action controlling the local contrast correction in RGB image production"
199 )
200 colorConfig = ConfigurableActionField[ColorScaler](
201 doc="Action to control the color scaling process in RGB image production"
202 )
203 exposureBracketerConfig = ConfigurableActionField[ExposureBracketer](
204 doc=(
205 "Exposure scaling action used in creating multiple exposures with different scalings which will "
206 "then be fused into a final image"
207 ),
208 )
209 gamutMapperConfig = ConfigurableActionField[GamutFixer](
210 doc="Action to fix pixels which lay outside RGB color gamut"
211 )
213 exposureBrackets = ListField[float](
214 doc=(
215 "Exposure scaling factors used in creating multiple exposures with different scalings which will "
216 "then be fused into a final image"
217 ),
218 optional=True,
219 default=[1.25, 1, 0.75],
220 deprecated=(
221 "This field will stop working in v31 and be removed in v32, "
222 "please set exposureBracketerConfig.exposureBrackets"
223 ),
224 )
225 gamutMethod = ChoiceField[str](
226 doc="If doRemapGamut is True this determines the method",
227 default="inpaint",
228 allowed={
229 "mapping": "Use a mapping function",
230 "inpaint": "Use surrounding pixels to determine likely value",
231 },
232 deprecated="This field will stop working in v31 and be removed in v32, please set gamutMapperConfig",
233 )
235 def setDefaults(self):
236 self.channelConfig["i"] = ChannelRGBConfig(r=1, g=0, b=0)
237 self.channelConfig["r"] = ChannelRGBConfig(r=0, g=1, b=0)
238 self.channelConfig["g"] = ChannelRGBConfig(r=0, g=0, b=1)
239 return super().setDefaults()
241 def _handle_deprecated(self):
242 """Handle deprecated configuration migration.
244 This method migrates deprecated configuration fields to their new
245 locations in sub-configurations. It checks the configuration history
246 to determine if deprecated fields were explicitly set and updates
247 the new configuration locations accordingly.
249 Notes
250 -----
251 The following deprecated fields are migrated:
252 - ``gamutMethod`` -> ``gamutMapperConfig.gamutMethod``
253 - ``exposureBrackets`` -> ``exposureBracketerConfig.exposureBrackets``
254 - ``doLocalContrast`` -> ``localContrastConfig.doLocalContrast``
255 - ``doPSFDeconcovlve`` -> ``doPsfDeconvolve``
256 """
257 # check if gamutMethod is set
258 if len(self._history["gamutMethod"]) > 1:
259 # This has been set in config, update it in the new location
260 self.gamutMapperConfig.gamutMethod = self.gamutMethod
262 if len(self._history["exposureBrackets"]) > 1:
263 self.exposureBracketerConfig.exposureBrackets = self.exposureBrackets
264 if self.exposureBrackets is None:
265 self.doExposureBrackets = False
267 if len(self.localContrastConfig._history["doLocalContrast"]) > 1:
268 self.doLocalContrast = self.localContrastConfig.doLocalContrast
270 # Handle doPsfDeconcovlve typo fix
271 if len(self._history["doPSFDeconcovlve"]) > 1:
272 self.doPsfDeconvolve = self.doPSFDeconcovlve
274 def freeze(self):
275 # ensure this is not already frozen
276 if self._frozen is not True:
277 self._handle_deprecated()
278 super().freeze()
281class PrettyPictureTask(PipelineTask):
282 """Turns inputs into an RGB image."""
284 _DefaultName = "prettyPicture"
285 ConfigClass = PrettyPictureConfig
287 config: ConfigClass
289 def _find_normal_stats(self, array):
290 """Calculate standard deviation from negative values using half-normal distribution.
292 Raises
293 ------
294 ValueError
295 Array dimension validation fails.
297 Parameters
298 ----------
299 array : `numpy.array`
300 Input array of numerical values.
302 Returns
303 -------
304 mean : `float`
305 The central moment of the distribution
306 sigma : `float`
307 Estimated standard deviation from negative values. Returns np.inf if:
308 - No negative values exist in the array
309 - Half-normal fitting fails
310 """
311 # Extract negative values efficiently
312 values_noise = array[array < self.config.noiseSearchThreshold]
314 # find the mode
315 center = mode(np.round(values_noise, 2)).mode
317 # extract the negative values
318 values_neg = array[array < center]
320 # Return infinity if no negative values found
321 if values_neg.size == 0:
322 return 0, np.inf
324 try:
325 # Fit half-normal distribution to absolute negative values
326 mu, sigma = halfnorm.fit(np.abs(values_neg))
327 except (ValueError, RuntimeError):
328 # Handle fitting failures (e.g., constant data, optimization issues)
329 return 0, np.inf
331 return center, sigma
333 def _match_sigmas_and_recenter(self, *arrays, factor=1):
334 """Scale array values to match minimum standard deviation across arrays
335 and recenter noise.
337 Adjusts values below each array's sigma by scaling and shifting them to
338 align with the minimum sigma value across all input arrays. This operates
339 in-place for efficiency.
341 Parameters
342 ----------
343 *arrays : any number of `numpy.array`
344 Variable number of input arrays to process.
345 factor : float, optional
346 Scaling factor for adjustments (default: 1).
348 """
349 # Calculate standard deviations for all arrays
350 sigmas = []
351 mus = []
352 for arr in arrays:
353 m, s = self._find_normal_stats(arr)
354 mus.append(m)
355 sigmas.append(s)
356 mus = np.array(mus)
357 sigmas = np.array(sigmas)
359 # If no sigmas could be determined, return the original
360 # arrays.
361 if not np.any(np.isfinite(sigmas)):
362 return
364 min_sig = np.min(sigmas)
366 for mu, sigma, array in zip(mus, sigmas, arrays):
367 # Identify values below the array's sigma threshold
368 lower_pos = (array - mu) < sigma
370 # Skip processing if sigma is invalid
371 if not np.isfinite(sigma):
372 continue
374 # Calculate scaling ratio relative to minimum sigma
375 sigma_ratio = min_sig / sigma
377 # Apply adjustment to qualifying values
378 array[lower_pos] = (array[lower_pos] - mu) * sigma_ratio + min_sig * factor
380 def run(
381 self,
382 images: Mapping[str, Exposure],
383 image_wcs: Projection[Any] | None = None,
384 image_box: Box | None = None,
385 ) -> Struct:
386 """Turns the input arguments in arguments into an RGB array.
388 Parameters
389 ----------
390 images : `Mapping` of `str` to `Exposure`
391 A mapping of input images and the band they correspond to.
392 image_wcs : `~lsst.images.Projection`, optional
393 A projection describing the sky coordinate of each pixel.
394 image_box : `~lsst.images.Box`, optional
395 A box that defines this image as part of a larger region.
397 Returns
398 -------
399 result : `Struct`
400 A struct with the corresponding RGB image, and mask used in
401 RGB image construction. The struct will have the attributes
402 outputRGB and outputRGBMask. Each of the outputs will
403 be a `~lsst.images.ColorImage` object.
405 Notes
406 -----
407 Construction of input images are made easier by use of the
408 makeInputsFrom* methods.
409 """
410 channels = {}
411 shape = (0, 0)
412 jointMask: None | NDArray = None
413 maskDict: Mapping[str, int] = {}
414 doJointMaskInit = False
415 if jointMask is None:
416 doJointMask = True
417 doJointMaskInit = True
418 for channel, imageExposure in images.items():
419 imageArray = imageExposure.image.array
420 # run all the plugins designed for array based interaction
421 for plug in plugins.channel():
422 imageArray = plug(
423 imageArray, imageExposure.mask.array, imageExposure.mask.getMaskPlaneDict(), self.config
424 ).astype(np.float32)
425 channels[channel] = imageArray
426 # These operations are trivial look-ups and don't matter if they
427 # happen in each loop.
428 shape = imageArray.shape
429 maskDict = imageExposure.mask.getMaskPlaneDict()
430 if doJointMaskInit:
431 jointMask = np.zeros(shape, dtype=imageExposure.mask.dtype)
432 doJointMaskInit = False
433 if doJointMask:
434 jointMask |= imageExposure.mask.array
436 # mix the images to RGB
437 imageRArray = np.zeros(shape, dtype=np.float32)
438 imageGArray = np.zeros(shape, dtype=np.float32)
439 imageBArray = np.zeros(shape, dtype=np.float32)
441 for band, image in channels.items():
442 if band not in self.config.channelConfig:
443 self.log.info(f"{band} image found but not requested in RGB image, skipping")
444 continue
445 mix = self.config.channelConfig[band]
446 if mix.r:
447 imageRArray += mix.r * image
448 if mix.g:
449 imageGArray += mix.g * image
450 if mix.b:
451 imageBArray += mix.b * image
453 exposure = next(iter(images.values()))
454 box: Box2I = exposure.getBBox()
455 boxCenter = box.getCenter()
456 try:
457 psf = exposure.psf.computeImage(boxCenter).array
458 except Exception:
459 psf = None
461 if self.config.recenterNoise:
462 self._match_sigmas_and_recenter(
463 imageRArray, imageGArray, imageBArray, factor=self.config.recenterNoise
464 )
466 # assert for typing reasons
467 assert jointMask is not None
468 # Run any image level correction plugins
469 colorImage = np.zeros((*imageRArray.shape, 3))
470 colorImage[:, :, 0] = imageRArray
471 colorImage[:, :, 1] = imageGArray
472 colorImage[:, :, 2] = imageBArray
473 for plug in plugins.partial():
474 colorImage = plug(colorImage, jointMask, maskDict, self.config)
476 # Filter the local contrast parameters for diffusion that are None
477 # This is so we only apply key word overrides that are specifically set.
478 local_contrast_config = self.config.localContrastConfig.toDict()
479 to_remove = []
480 for k, v in local_contrast_config["diffusionFunction"].items():
481 if v is None:
482 to_remove.append(k)
483 for item in to_remove:
484 local_contrast_config["diffusionControl"].pop(item)
486 colorImage = lsstRGB(
487 colorImage[:, :, 0],
488 colorImage[:, :, 1],
489 colorImage[:, :, 2],
490 local_contrast=self.config.localContrastConfig if self.config.doLocalContrast else None,
491 scale_lum=self.config.luminanceConfig,
492 scale_color=self.config.colorConfig,
493 remap_bounds=self.config.imageRemappingConfig,
494 bracketing_function=(
495 self.config.exposureBracketerConfig if self.config.doExposureBrackets else None
496 ),
497 gamut_remapping_function=self.config.gamutMapperConfig if self.config.doRemapGamut else None,
498 cieWhitePoint=tuple(self.config.cieWhitePoint), # type: ignore
499 psf=psf if self.config.doPsfDeconvolve else None,
500 )
502 # Find the dataset type and thus the maximum values as well
503 maxVal: int | float
504 match self.config.arrayType:
505 case "uint8":
506 dtype = np.uint8
507 maxVal = 255
508 case "uint16":
509 dtype = np.uint16
510 maxVal = 65535
511 case "half":
512 dtype = np.half
513 maxVal = 1.0
514 case "float":
515 dtype = np.float32
516 maxVal = 1.0
517 case _:
518 assert True, "This code path should be unreachable"
520 # lsstRGB returns an image in 0-1 scale it to the maximum value
521 colorImage *= maxVal # type: ignore
523 # pack the joint mask back into a mask object
524 lsstMask = Mask(width=jointMask.shape[1], height=jointMask.shape[0], planeDefs=maskDict)
525 lsstMask.array = jointMask # type: ignore
526 return Struct(
527 outputRGB=ColorImage(colorImage.astype(dtype), bbox=image_box, projection=image_wcs),
528 outputRGBMask=lsstMask,
529 ) # type: ignore
531 def runQuantum(
532 self,
533 butlerQC: QuantumContext,
534 inputRefs: InputQuantizedConnection,
535 outputRefs: OutputQuantizedConnection,
536 ) -> None:
537 imageRefs: list[DatasetRef] = inputRefs.inputCoadds
538 sortedImages = self.makeInputsFromRefs(imageRefs, butlerQC)
539 if not sortedImages:
540 requested = ", ".join(self.config.channelConfig.keys())
541 raise NoWorkFound(f"No input images of band(s) {requested}")
543 # get the patch tract bounding box and wcs
544 skymap = butlerQC.get(inputRefs.skyMap)
545 quantumDataId = butlerQC.quantum.dataId
546 tractInfo = skymap[quantumDataId["tract"]]
547 patchInfo = tractInfo[quantumDataId["patch"]]
548 outputs = self.run(
549 images=sortedImages,
550 image_wcs=Projection.from_legacy(
551 patchInfo.wcs,
552 TractFrame(
553 skymap=quantumDataId["skymap"],
554 tract=quantumDataId["tract"],
555 bbox=Box.from_legacy(tractInfo.bbox),
556 ),
557 ),
558 image_box=Box.from_legacy(patchInfo.getOuterBBox()),
559 )
560 butlerQC.put(outputs, outputRefs)
562 def makeInputsFromRefs(
563 self, refs: Iterable[DatasetRef], butler: Butler | QuantumContext
564 ) -> dict[str, Exposure]:
565 r"""Make valid inputs for the run method from butler references.
567 Parameters
568 ----------
569 refs : `Iterable` of `DatasetRef`
570 Some `Iterable` container of `Butler` `DatasetRef`\ s
571 butler : `Butler` or `QuantumContext`
572 This is the object that fetches the input data.
574 Returns
575 -------
576 sortedImages : `dict` of `str` to `Exposure`
577 A dictionary of `Exposure`\ s keyed by the band they
578 correspond to.
579 """
580 sortedImages: dict[str, Exposure] = {}
581 for ref in refs:
582 key: str = cast(str, ref.dataId["band"])
583 image = butler.get(ref)
584 sortedImages[key] = image
585 return sortedImages
587 def makeInputsFromArrays(self, **kwargs) -> dict[str, DeferredDatasetHandle]:
588 r"""Make valid inputs for the run method from numpy arrays.
590 Parameters
591 ----------
592 kwargs : `numpy.ndarray`
593 This is standard python kwargs where the left side of the equals
594 is the data band, and the right side is the corresponding `numpy.ndarray`
595 array.
597 Returns
598 -------
599 sortedImages : `dict` of `str` to \
600 `~lsst.daf.butler.DeferredDatasetHandle`
601 A dictionary of `~lsst.daf.butlger.DeferredDatasetHandle`\ s keyed
602 by the band they correspond to.
603 """
604 # ignore type because there aren't proper stubs for afw
605 temp = {}
606 for key, array in kwargs.items():
607 temp[key] = Exposure(Box2I(Point2I(0, 0), Extent2I(*array.shape)), dtype=array.dtype)
608 temp[key].image.array[:] = array
610 return self.makeInputsFromExposures(**temp)
612 def makeInputsFromExposures(self, **kwargs) -> dict[int, DeferredDatasetHandle]:
613 r"""Make valid inputs for the run method from `Exposure` objects.
615 Parameters
616 ----------
617 kwargs : `Exposure`
618 This is standard python kwargs where the left side of the equals
619 is the data band, and the right side is the corresponding
620 `Exposure`.
622 Returns
623 -------
624 sortedImages : `dict` of `int` to \
625 `~lsst.daf.butler.DeferredDatasetHandle`
626 A dictionary of `~lsst.daf.butler.DeferredDatasetHandle`\ s keyed
627 by the band they correspond to.
628 """
629 sortedImages = {}
630 for key, value in kwargs.items():
631 sortedImages[key] = value
632 return sortedImages
635class PrettyPictureBackgroundFixerConnections(
636 PipelineTaskConnections,
637 dimensions=("tract", "patch", "skymap", "band"),
638 defaultTemplates={"coaddTypeName": "deep"},
639):
640 inputCoadd = Input(
641 doc=("Input coadd for which the background is to be removed"),
642 name="{coaddTypeName}CoaddPsfMatched",
643 storageClass="ExposureF",
644 dimensions=("tract", "patch", "skymap", "band"),
645 )
646 outputCoadd = Output(
647 doc="The coadd with the background fixed and subtracted",
648 name="pretty_picture_coadd_bg_subtracted",
649 storageClass="ExposureF",
650 dimensions=("tract", "patch", "skymap", "band"),
651 )
654class PrettyPictureBackgroundFixerConfig(
655 PipelineTaskConfig, pipelineConnections=PrettyPictureBackgroundFixerConnections
656):
657 use_detection_mask = Field[bool](
658 doc="Use the detection mask to determine background instead of empirically finding it in this task",
659 default=False,
660 )
661 num_background_bins = Field[int](
662 doc="The number of bins along each axis when determining background", default=5
663 )
664 min_bin_fraction = Field[float](
665 doc="Bins with fewer pixels than this fraction of the total will be ignored", default=0.1
666 )
668 pos_sigma_multiplier = Field[float](
669 doc="How many sigma to consider as background in the positive direction", default=2
670 )
673class PrettyPictureBackgroundFixerTask(PipelineTask):
674 """Empirically flatten an images background.
676 Many astrophysical images have backgrounds with imperfections in them.
677 This Task attempts to determine control points which are considered
678 background values, and fits a radial basis function model to those
679 points. This model is then subtracted off the image.
681 """
683 _DefaultName = "prettyPictureBackgroundFixer"
684 ConfigClass = PrettyPictureBackgroundFixerConfig
686 config: ConfigClass
688 def _tile_slices(self, arr, R, C):
689 """Generate slices for tiling an array.
691 This function divides an array into a grid of tiles and returns a list of
692 slice objects representing each tile. It handles cases where the array
693 dimensions are not evenly divisible by the number of tiles in each
694 dimension, distributing the remainder among the tiles.
696 Parameters
697 ----------
698 arr : `numyp.ndarray`
699 The input array to be tiled. Used only to determine the array's shape.
700 R : `int`
701 The number of tiles in the row dimension.
702 C : `int`
703 The number of tiles in the column dimension.
705 Returns
706 -------
707 slices : `list` of `tuple`
708 A list of tuples, where each tuple contains two `slice` objects
709 representing the row and column slices for a single tile.
710 """
711 M = arr.shape[0]
712 N = arr.shape[1]
714 # Function to compute slices for a given dimension size and number of divisions
715 def get_slices(total_size: int, num_divisions: int) -> list[tuple[int, int]]:
716 """Generate slice ranges for dividing a size into equal parts.
718 Parameters
719 ----------
720 total_size : `int`
721 Total size to be divided into slices.
722 num_divisions : `int`
723 Number of divisions to create.
725 Returns
726 -------
727 `list` of `tuple` of `int`
728 List of (start, end) tuples representing each slice.
730 Notes
731 -----
732 This function divides the total_size into num_divisions equal parts.
733 If the division is not exact, the remainder is distributed by adding
734 1 to the first 'remainder' slices, ensuring balanced distribution.
735 """
736 base = total_size // num_divisions
737 remainder = total_size % num_divisions
738 slices = []
739 start = 0
740 for i in range(num_divisions):
741 end = start + base
742 if i < remainder:
743 end += 1
744 slices.append((start, end))
745 start = end
746 return slices
748 # Get row and column slices
749 row_slices = get_slices(M, R)
750 col_slices = get_slices(N, C)
752 # Generate all possible tile combinations of row and column slices
753 tiles = []
754 for rs in row_slices:
755 r_start, r_end = rs
756 for cs in col_slices:
757 c_start, c_end = cs
758 tile_slice = (slice(r_start, r_end), slice(c_start, c_end))
759 tiles.append(tile_slice)
761 return tiles
763 @staticmethod
764 def findBackgroundPixels(image, pos_sigma_mult=1):
765 """Find pixels that are likely to be background based on image statistics.
767 This method estimates background pixels by analyzing the distribution of
768 pixel values in the image. It uses the median as an estimate of the background
769 level and fits a half-normal distribution to values below the median to
770 determine the background sigma. Pixels below a threshold (mean + sigma) are
771 classified as background.
773 Parameters
774 ----------
775 image : `numpy.ndarray`
776 Input image array for which to find background pixels.
777 pos_sigma_mult : `float`
778 How many sigma to consider as background in the positive direction
780 Returns
781 -------
782 result : `numpy.ndarray`
783 Boolean mask array where True indicates background pixels.
785 Notes
786 -----
787 This method works best for images with relatively uniform background. It may
788 not perform well in fields with high density or diffuse flux, as noted in
789 the implementation comments.
790 """
791 # Find the median value in the image, which is likely to be
792 # close to average background. Note this doesn't work well
793 # in fields with high density or diffuse flux.
794 maxLikely = np.median(image, axis=None)
796 # find all the pixels that are fainter than this
797 # and find the std. This is just used as an initialization
798 # parameter and doesn't need to be accurate.
799 mask = image < maxLikely
800 initial_std = (image[mask] - maxLikely).std()
802 # Don't do anything if there are no pixels to check
803 if np.any(mask):
804 # use a minimizer to determine best mu and sigma for a Gaussian
805 # given only samples below the mean of the Gaussian.
806 mu_hat, sigma_hat = halfnorm.fit(np.abs(image[mask] - maxLikely))
807 # mu_hat = maxLikely
808 else:
809 mu_hat, sigma_hat = (maxLikely, 2 * initial_std)
811 # create a new masking threshold that is the determined
812 # mean plus std from the fit
813 threshhold = mu_hat + pos_sigma_mult * sigma_hat
814 image_mask = (image < threshhold) * (image > (mu_hat - 5 * sigma_hat))
815 return image_mask
817 def fixBackground(self, image, detection_mask=None):
818 """Estimate and subtract the background from an image.
820 This function estimates the background level in an image using a median-based
821 approach combined with Gaussian fitting and radial basis function interpolation.
822 It aims to provide a more accurate background estimation than a simple median
823 filter, especially in images with varying background levels.
825 Parameters
826 ----------
827 image : `numpy.ndarray`
828 The input image as a NumPy array.
830 Returns
831 -------
832 numpy.ndarray
833 An array representing the estimated background level across the image.
834 """
835 if detection_mask is None:
836 image_mask = self.findBackgroundPixels(image, self.config.pos_sigma_multiplier)
837 else:
838 image_mask = detection_mask
840 # create python slices that tile the image.
841 tiles = self._tile_slices(image, self.config.num_background_bins, self.config.num_background_bins)
843 yloc = []
844 xloc = []
845 values = []
847 # for each box find the middle position and the median background
848 # value in the window.
849 for xslice, yslice in tiles:
850 ypos = (yslice.stop - yslice.start) / 2 + yslice.start
851 xpos = (xslice.stop - xslice.start) / 2 + xslice.start
852 window = image[yslice, xslice][image_mask[yslice, xslice]]
853 # make sure each bin is at least 1% filled
854 min_fill = int((yslice.stop - yslice.start) ** 2 * self.config.min_bin_fraction)
855 if window.size > min_fill:
856 value = np.median(window)
857 else:
858 continue
859 values.append(value)
860 yloc.append(ypos)
861 xloc.append(xpos)
863 # At least 15 points are requred for TPS with 4th order polynomial
864 if len(yloc) < 15:
865 return np.zeros(image.shape)
867 # create an interpolant for the background and interpolate over the image.
868 inter = RBFInterpolator(
869 np.vstack((yloc, xloc)).T,
870 values,
871 kernel="thin_plate_spline",
872 degree=4,
873 smoothing=0.05,
874 neighbors=None,
875 )
877 backgrounds = rbf_interpolator.fast_rbf_interpolation_on_grid(inter, image.shape)
879 return backgrounds
881 def run(self, inputCoadd: Exposure):
882 """Estimate a background for an input Exposure and remove it.
884 Parameters
885 ----------
886 inputCoadd : `Exposure`
887 The exposure the background will be removed from.
889 Returns
890 -------
891 result : `Struct`
892 A `Struct` that contains the exposure with the background removed.
893 This `Struct` will have an attribute named ``outputCoadd``.
895 """
896 if self.config.use_detection_mask:
897 mask_plane_dict = inputCoadd.mask.getMaskPlaneDict()
898 detection_mask = ~(inputCoadd.mask.array & 2 ** mask_plane_dict["DETECTED"])
899 else:
900 detection_mask = None
901 background = self.fixBackground(inputCoadd.image.array, detection_mask=detection_mask)
902 # create a copy to mutate
903 output = ExposureF(inputCoadd, deep=True)
904 output.image.array -= background
905 return Struct(outputCoadd=output)
908class PrettyPictureStarFixerConnections(
909 PipelineTaskConnections,
910 dimensions=("tract", "patch", "skymap"),
911):
912 inputCoadd = Input(
913 doc=("Input coadd for which the background is to be removed"),
914 name="pretty_picture_coadd_bg_subtracted",
915 storageClass="ExposureF",
916 dimensions=("tract", "patch", "skymap", "band"),
917 multiple=True,
918 )
919 outputCoadd = Output(
920 doc="The coadd with the background fixed and subtracted",
921 name="pretty_picture_coadd_fixed_stars",
922 storageClass="ExposureF",
923 dimensions=("tract", "patch", "skymap", "band"),
924 multiple=True,
925 )
928class PrettyPictureStarFixerConfig(PipelineTaskConfig, pipelineConnections=PrettyPictureStarFixerConnections):
929 brightnessThresh = Field[float](
930 doc="The flux value below which pixels with SAT or NO_DATA bits will be ignored"
931 )
934class PrettyPictureStarFixerTask(PipelineTask):
935 """This class fixes up regions in an image where there is no, or bad data.
937 The fixes done by this task are overwhelmingly comprised of the cores of
938 bright stars for which there is no data.
939 """
941 _DefaultName = "prettyPictureStarFixer"
942 ConfigClass = PrettyPictureStarFixerConfig
944 config: ConfigClass
946 def run(self, inputs: Mapping[str, ExposureF]) -> Struct:
947 """Fix areas in an image where this is no data, most likely to be
948 the cores of bright stars.
950 Because we want to have consistent fixes accross bands, this method
951 relies on supplying all bands and fixing pixels that are marked
952 as having a defect in any band even if within one band there is
953 no issue.
955 Parameters
956 ----------
957 inputs : `Mapping` of `str` to `ExposureF`
958 This mapping has keys of band as a `str` and the corresponding
959 ExposureF as a value.
961 Returns
962 -------
963 results : `Struct` of `Mapping` of `str` to `ExposureF`
964 A `Struct` that has a mapping of band to `ExposureF`. The `Struct`
965 has an attribute named ``results``.
967 """
968 # make the joint mask of all the channels
969 doJointMaskInit = True
970 for imageExposure in inputs.values():
971 maskDict = imageExposure.mask.getMaskPlaneDict()
972 if doJointMaskInit:
973 jointMask = np.zeros(imageExposure.mask.array.shape, dtype=imageExposure.mask.array.dtype)
974 doJointMaskInit = False
975 jointMask |= imageExposure.mask.array
977 sat_bit = maskDict["SAT"]
978 no_data_bit = maskDict["NO_DATA"]
979 together = (jointMask & 2**sat_bit).astype(bool) | (jointMask & 2**no_data_bit).astype(bool)
981 # use the last imageExposure as it is likely close enough across all bands
982 bright_mask = imageExposure.image.array > self.config.brightnessThresh
984 # dilate the mask a bit, this helps get a bit fainter mask without starting
985 # to include pixels in an irregular shape, as only the star cores should be
986 # fixed.
987 both = together & bright_mask
988 struct = np.array(((0, 1, 0), (1, 1, 1), (0, 1, 0)), dtype=bool)
989 both = binary_dilation(both, struct, iterations=4).astype(bool)
991 # do the actual fixing of values
992 results = {}
993 for band, imageExposure in inputs.items():
994 if np.sum(both) > 0:
995 inpainted = inpaint_biharmonic(imageExposure.image.array, both, split_into_regions=True)
996 imageExposure.image.array[both] = inpainted[both]
997 results[band] = imageExposure
998 return Struct(results=results)
1000 def runQuantum(
1001 self,
1002 butlerQC: QuantumContext,
1003 inputRefs: InputQuantizedConnection,
1004 outputRefs: OutputQuantizedConnection,
1005 ) -> None:
1006 refs = inputRefs.inputCoadd
1007 sortedImages: dict[str, Exposure] = {}
1008 for ref in refs:
1009 key: str = cast(str, ref.dataId["band"])
1010 image = butlerQC.get(ref)
1011 sortedImages[key] = image
1013 outputs = self.run(sortedImages).results
1014 sortedOutputs = {}
1015 for ref in outputRefs.outputCoadd:
1016 sortedOutputs[ref.dataId["band"]] = ref
1018 for band, data in outputs.items():
1019 butlerQC.put(data, sortedOutputs[band])
1022class PrettyMosaicConnections(PipelineTaskConnections, dimensions=("tract", "skymap")):
1023 inputRGB = Input(
1024 doc="Individual RGB images that are to go into the mosaic",
1025 name="rgb_picture",
1026 storageClass="ColorImage",
1027 dimensions=("tract", "patch", "skymap"),
1028 multiple=True,
1029 deferLoad=True,
1030 )
1032 skyMap = Input(
1033 doc="The skymap which the data has been mapped onto",
1034 storageClass="SkyMap",
1035 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
1036 dimensions=("skymap",),
1037 )
1039 inputRGBMask = Input(
1040 doc="Individual RGB images that are to go into the mosaic",
1041 name="rgb_picture_mask",
1042 storageClass="Mask",
1043 dimensions=("tract", "patch", "skymap"),
1044 multiple=True,
1045 deferLoad=True,
1046 )
1048 outputRGBMosaic = Output(
1049 doc="A RGB mosaic created from the input data stored as a 3d array",
1050 name="rgb_mosaic",
1051 storageClass="ColorImage",
1052 dimensions=("tract", "skymap"),
1053 )
1056class PrettyMosaicConfig(PipelineTaskConfig, pipelineConnections=PrettyMosaicConnections):
1057 binFactor = Field[int](doc="The factor to bin by when producing the mosaic")
1058 doDCID65Convert = Field[bool]("Force the output to be converted from display p3 to DCI-D65 colorspace.")
1059 useLocalTemp = Field[bool](doc="Use the current directory when creating local temp files.", default=False)
1062class PrettyMosaicTask(PipelineTask):
1063 """Combines multiple RGB arrays into one mosaic."""
1065 _DefaultName = "prettyMosaic"
1066 ConfigClass = PrettyMosaicConfig
1068 config: ConfigClass
1070 def run(
1071 self,
1072 inputRGB: Iterable[DeferredDatasetHandle],
1073 skyMap: BaseSkyMap,
1074 inputRGBMask: Iterable[DeferredDatasetHandle],
1075 ) -> Struct:
1076 r"""Assemble individual `numpy.ndarrays` into a mosaic.
1078 Each input is a `~lsst.daf.butler.DeferredDatasetHandle` because
1079 they're loaded in one at a time to be placed into the mosaic to save
1080 memory.
1082 Parameters
1083 ----------
1084 inputRGB : `Iterable` of `~lsst.daf.butler.DeferredDatasetHandle`
1085 `~lsst.daf.butler.DeferredDatasetHandle`\ s pointing to RGB
1086 `numpy.ndarrays`.
1087 skyMap : `BaseSkyMap`
1088 The skymap that defines the relative position of each of the input
1089 images.
1090 inputRGBMask : `Iterable` of `~lsst.daf.butler.DeferredDatasetHandle`
1091 `~lsst.daf.butler.DeferredDatasetHandle`\ s pointing to masks for
1092 each of the corresponding images.
1094 Returns
1095 -------
1096 result : `Struct`
1097 The `Struct` containing the combined mosaic. The `Struct` has
1098 and attribute named ``outputRGBMosaic``.
1099 """
1100 # create the bounding region
1101 newBox = Box2I()
1102 # store the bounds as they are retrieved from the skymap
1103 boxes = []
1104 tractMaps = []
1105 for handle in inputRGB:
1106 dataId = handle.dataId
1107 tractInfo: TractInfo = skyMap[dataId["tract"]]
1108 patchInfo: PatchInfo = tractInfo[dataId["patch"]]
1109 bbox = patchInfo.getOuterBBox()
1110 boxes.append(bbox)
1111 newBox.include(bbox)
1112 tractMaps.append(tractInfo)
1113 # This will be overwritten in the loop, but that is ok, because
1114 # it is the same for each patch.
1115 patch_grow: int = patchInfo.getCellInnerDimensions().getX()
1117 # fixup the boxes to be smaller if needed, and put the origin at zero,
1118 # this must be done after constructing the complete outer box
1119 modifiedBoxes = []
1120 origin = newBox.getBegin()
1121 for iterBox in boxes:
1122 localOrigin = iterBox.getBegin() - origin
1123 localOrigin = Point2I(
1124 x=int(np.floor(localOrigin.x / self.config.binFactor)),
1125 y=int(np.floor(localOrigin.y / self.config.binFactor)),
1126 )
1127 localExtent = Extent2I(
1128 x=int(np.floor(iterBox.getWidth() / self.config.binFactor)),
1129 y=int(np.floor(iterBox.getHeight() / self.config.binFactor)),
1130 )
1131 tmpBox = Box2I(localOrigin, localExtent)
1132 modifiedBoxes.append(tmpBox)
1133 boxes = modifiedBoxes
1135 # scale the container box
1136 newBoxOrigin = Point2I(0, 0)
1137 newBoxExtent = Extent2I(
1138 x=int(np.floor(newBox.getWidth() / self.config.binFactor)),
1139 y=int(np.floor(newBox.getHeight() / self.config.binFactor)),
1140 )
1141 newBox = Box2I(newBoxOrigin, newBoxExtent)
1143 # Allocate storage for the mosaic
1144 self.imageHandle = tempfile.NamedTemporaryFile(dir="." if self.config.useLocalTemp else None)
1145 self.maskHandle = tempfile.NamedTemporaryFile(dir="." if self.config.useLocalTemp else None)
1146 consolidatedImage = None
1147 consolidatedMask = None
1149 # Setup color space conversion in case they are used.
1150 d65 = copy.deepcopy(colour.models.RGB_COLOURSPACE_DCI_P3)
1151 dp3 = copy.deepcopy(colour.models.RGB_COLOURSPACE_DISPLAY_P3)
1152 d65.whitepoint = dp3.whitepoint
1153 d65.whitepoint_name = dp3.whitepoint_name
1155 # Actually assemble the mosaic
1156 maskDict = {}
1157 mosaic_maker = FeatheredMosaicCreator(patch_grow, self.config.binFactor)
1158 for box, handle, handleMask, tractInfo in zip(boxes, inputRGB, inputRGBMask, tractMaps):
1159 rgb = handle.get().array
1160 # convert to the dci-d65 colorspace
1161 if self.config.doDCID65Convert:
1162 rgb = colour.RGB_to_RGB(np.clip(rgb, 0, 1), dp3, d65)
1163 rgbMask = handleMask.get()
1164 maskDict = rgbMask.getMaskPlaneDict()
1165 # allocate the memory for the mosaic
1166 if consolidatedImage is None:
1167 consolidatedImage = np.memmap(
1168 self.imageHandle.name,
1169 mode="w+",
1170 shape=(newBox.getHeight(), newBox.getWidth(), 3),
1171 dtype=rgb.dtype,
1172 )
1173 if consolidatedMask is None:
1174 consolidatedMask = np.memmap(
1175 self.maskHandle.name,
1176 mode="w+",
1177 shape=(newBox.getHeight(), newBox.getWidth()),
1178 dtype=rgbMask.array.dtype,
1179 )
1181 if self.config.binFactor > 1:
1182 # opencv wants things in x, y dimensions
1183 shape = tuple(box.getDimensions())[::-1]
1184 rgb = cv2.resize(
1185 rgb,
1186 dst=None,
1187 dsize=shape,
1188 fx=shape[0] / self.config.binFactor,
1189 fy=shape[1] / self.config.binFactor,
1190 )
1191 mask_array = rgbMask.array[:: self.config.binFactor, :: self.config.binFactor]
1192 rgbMask = Mask(*(mask_array.shape[::-1]))
1193 mosaic_maker.add_to_image(consolidatedImage, rgb, newBox, box)
1195 consolidatedMask[*box.slices] = np.bitwise_or(consolidatedMask[*box.slices], rgbMask.array)
1197 for plugin in plugins.full():
1198 if consolidatedImage is not None and consolidatedMask is not None:
1199 consolidatedImage = plugin(consolidatedImage, consolidatedMask, maskDict)
1200 # If consolidated image still None, that means there was no work to do.
1201 # Return an empty image instead of letting this task fail.
1202 if consolidatedImage is None:
1203 consolidatedImage = np.zeros((0, 0, 0), dtype=np.uint8)
1205 return Struct(outputRGBMosaic=ColorImage(consolidatedImage))
1207 def runQuantum(
1208 self,
1209 butlerQC: QuantumContext,
1210 inputRefs: InputQuantizedConnection,
1211 outputRefs: OutputQuantizedConnection,
1212 ) -> None:
1213 inputs = butlerQC.get(inputRefs)
1214 outputs = self.run(**inputs)
1215 butlerQC.put(outputs, outputRefs)
1216 if hasattr(self, "imageHandle"):
1217 self.imageHandle.close()
1218 if hasattr(self, "maskHandle"):
1219 self.maskHandle.close()
1221 def makeInputsFromArrays(
1222 self, inputs: Iterable[tuple[Mapping[str, Any], NDArray]]
1223 ) -> Iterable[DeferredDatasetHandle]:
1224 r"""Make valid inputs for the run method from numpy arrays.
1226 Parameters
1227 ----------
1228 inputs : `Iterable` of `tuple` of `Mapping` and `numpy.ndarray`
1229 An iterable where each element is a tuple with the first
1230 element is a mapping that corresponds to an arrays dataId,
1231 and the second is an `numpy.ndarray`.
1233 Returns
1234 -------
1235 sortedImages : `Iterable` of `~lsst.daf.butler.DeferredDatasetHandle`
1236 An iterable of `~lsst.daf.butler.DeferredDatasetHandle`\ s
1237 containing the input data.
1238 """
1239 structuredInputs = []
1240 for dataId, array in inputs:
1241 structuredInputs.append(InMemoryDatasetHandle(inMemoryDataset=array, **dataId))
1243 return structuredInputs