Coverage for python / lsst / pipe / tasks / prettyPictureMaker / _task.py: 21%
421 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 00:08 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 00:08 +0000
1# This file is part of pipe_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = (
25 "ChannelRGBConfig",
26 "PrettyPictureTask",
27 "PrettyPictureConnections",
28 "PrettyPictureConfig",
29 "PrettyMosaicTask",
30 "PrettyMosaicConnections",
31 "PrettyMosaicConfig",
32 "PrettyPictureBackgroundFixerConfig",
33 "PrettyPictureBackgroundFixerTask",
34 "PrettyPictureStarFixerConfig",
35 "PrettyPictureStarFixerTask",
36)
38import colour
39import copy
40from collections.abc import Iterable, Mapping
41from lsst.afw.image import ExposureF
42import numpy as np
43from typing import TYPE_CHECKING, cast, Any
44from lsst.skymap import BaseSkyMap
46from scipy.stats import halfnorm, mode
47from scipy.ndimage import binary_dilation
48from scipy.interpolate import RBFInterpolator
49from skimage.restoration import inpaint_biharmonic
51from lsst.daf.butler import Butler, DeferredDatasetHandle
52from lsst.daf.butler import DatasetRef
53from lsst.pex.config import Field, Config, ConfigDictField, ListField, ChoiceField
54from lsst.pex.config.configurableActions import ConfigurableActionField
55from lsst.pipe.base import (
56 PipelineTask,
57 PipelineTaskConfig,
58 PipelineTaskConnections,
59 Struct,
60 InMemoryDatasetHandle,
61 NoWorkFound,
62)
63from lsst.rubinoxide import rbf_interpolator
64import cv2
66from lsst.pipe.base.connectionTypes import Input, Output
67from lsst.geom import Box2I, Point2I, Extent2I
68from lsst.afw.image import Exposure, Mask
70from ._plugins import plugins
71from ._colorMapper import lsstRGB
72from ._utils import FeatheredMosaicCreator
73from ._functors import (
74 BoundsRemapper,
75 ColorScaler,
76 LumCompressor,
77 ExposureBracketer,
78 GamutFixer,
79 LocalContrastEnhancer,
80)
82import tempfile
85if TYPE_CHECKING:
86 from numpy.typing import NDArray
87 from lsst.pipe.base import QuantumContext, InputQuantizedConnection, OutputQuantizedConnection
88 from lsst.skymap import TractInfo, PatchInfo
91class PrettyPictureConnections(
92 PipelineTaskConnections,
93 dimensions={"tract", "patch", "skymap"},
94 defaultTemplates={"coaddTypeName": "deep"},
95):
96 inputCoadds = Input(
97 doc=(
98 "Model of the static sky, used to find temporal artifacts. Typically a PSF-Matched, "
99 "sigma-clipped coadd. Written if and only if assembleStaticSkyModel.doWrite=True"
100 ),
101 name="pretty_coadd",
102 storageClass="ExposureF",
103 dimensions=("tract", "patch", "skymap", "band"),
104 multiple=True,
105 )
107 outputRGB = Output(
108 doc="A RGB image created from the input data stored as a 3d array",
109 name="rgb_picture_array",
110 storageClass="NumpyArray",
111 dimensions=("tract", "patch", "skymap"),
112 )
114 outputRGBMask = Output(
115 doc="A Mask corresponding to the fused masks of the input channels",
116 name="rgb_picture_mask",
117 storageClass="Mask",
118 dimensions=("tract", "patch", "skymap"),
119 )
122class ChannelRGBConfig(Config):
123 """This describes the rgb values of a given input channel.
125 For instance if this channel is red the values would be self.r = 1,
126 self.g = 0, self.b = 0. If the channel was cyan the values would be
127 self.r = 0, self.g = 1, self.b = 1.
128 """
130 r = Field[float](doc="The amount of red contained in this channel")
131 g = Field[float](doc="The amount of green contained in this channel")
132 b = Field[float](doc="The amount of blue contained in this channel")
135class PrettyPictureConfig(PipelineTaskConfig, pipelineConnections=PrettyPictureConnections):
136 channelConfig = ConfigDictField(
137 doc="A dictionary that maps band names to their rgb channel configurations",
138 keytype=str,
139 itemtype=ChannelRGBConfig,
140 default={},
141 )
142 cieWhitePoint = ListField[float](
143 doc="The white point of the input arrays in ciexz coordinates", maxLength=2, default=[0.28, 0.28]
144 )
145 arrayType = ChoiceField[str](
146 doc="The dataset type for the output image array",
147 default="uint8",
148 allowed={
149 "uint8": "Use 8 bit arrays, 255 max",
150 "uint16": "Use 16 bit arrays, 65535 max",
151 "half": "Use 16 bit float arrays, 1 max",
152 "float": "Use 32 bit float arrays, 1 max",
153 },
154 )
155 recenterNoise = Field[float](
156 doc="Recenter the noise away from zero. Supplied value is in units of sigma",
157 optional=True,
158 default=None,
159 )
160 noiseSearchThreshold = Field[float](
161 doc=(
162 "Flux threshold below which most flux will be considered noise, used to estimate noise properties"
163 ),
164 default=10,
165 )
166 doPsfDeconvolve = Field[bool](
167 doc="Use the PSF in a Richardson-Lucy deconvolution on the luminance channel.", default=False
168 )
169 doPSFDeconcovlve = Field[bool](
170 doc="Use the PSF in a Richardson-Lucy deconvolution on the luminance channel.",
171 default=False,
172 deprecated="This field will be removed in v32. Use doPsfDeconvolve instead.",
173 optional=True,
174 )
175 doRemapGamut = Field[bool](
176 doc="Apply a color correction to unrepresentable colors; if False, clip them.", default=True
177 )
178 doExposureBrackets = Field[bool](
179 doc="Apply exposure bracketing to aid in dynamic range compression", default=True
180 )
181 doLocalContrast = Field[bool](doc="Apply local contrast optimizations to luminance.", default=True)
183 imageRemappingConfig = ConfigurableActionField[BoundsRemapper](
184 doc="Action controlling normalization process"
185 )
186 luminanceConfig = ConfigurableActionField[LumCompressor](
187 doc="Action controlling luminance scaling when making an RGB image"
188 )
189 localContrastConfig = ConfigurableActionField[LocalContrastEnhancer](
190 doc="Action controlling the local contrast correction in RGB image production"
191 )
192 colorConfig = ConfigurableActionField[ColorScaler](
193 doc="Action to control the color scaling process in RGB image production"
194 )
195 exposureBracketerConfig = ConfigurableActionField[ExposureBracketer](
196 doc=(
197 "Exposure scaling action used in creating multiple exposures with different scalings which will "
198 "then be fused into a final image"
199 ),
200 )
201 gamutMapperConfig = ConfigurableActionField[GamutFixer](
202 doc="Action to fix pixels which lay outside RGB color gamut"
203 )
205 exposureBrackets = ListField[float](
206 doc=(
207 "Exposure scaling factors used in creating multiple exposures with different scalings which will "
208 "then be fused into a final image"
209 ),
210 optional=True,
211 default=[1.25, 1, 0.75],
212 deprecated=(
213 "This field will stop working in v31 and be removed in v32, "
214 "please set exposureBracketerConfig.exposureBrackets"
215 ),
216 )
217 gamutMethod = ChoiceField[str](
218 doc="If doRemapGamut is True this determines the method",
219 default="inpaint",
220 allowed={
221 "mapping": "Use a mapping function",
222 "inpaint": "Use surrounding pixels to determine likely value",
223 },
224 deprecated="This field will stop working in v31 and be removed in v32, please set gamutMapperConfig",
225 )
227 def setDefaults(self):
228 self.channelConfig["i"] = ChannelRGBConfig(r=1, g=0, b=0)
229 self.channelConfig["r"] = ChannelRGBConfig(r=0, g=1, b=0)
230 self.channelConfig["g"] = ChannelRGBConfig(r=0, g=0, b=1)
231 return super().setDefaults()
233 def _handle_deprecated(self):
234 """Handle deprecated configuration migration.
236 This method migrates deprecated configuration fields to their new
237 locations in sub-configurations. It checks the configuration history
238 to determine if deprecated fields were explicitly set and updates
239 the new configuration locations accordingly.
241 Notes
242 -----
243 The following deprecated fields are migrated:
244 - ``gamutMethod`` -> ``gamutMapperConfig.gamutMethod``
245 - ``exposureBrackets`` -> ``exposureBracketerConfig.exposureBrackets``
246 - ``doLocalContrast`` -> ``localContrastConfig.doLocalContrast``
247 - ``doPSFDeconcovlve`` -> ``doPsfDeconvolve``
248 """
249 # check if gamutMethod is set
250 if len(self._history["gamutMethod"]) > 1:
251 # This has been set in config, update it in the new location
252 self.gamutMapperConfig.gamutMethod = self.gamutMethod
254 if len(self._history["exposureBrackets"]) > 1:
255 self.exposureBracketerConfig.exposureBrackets = self.exposureBrackets
256 if self.exposureBrackets is None:
257 self.doExposureBrackets = False
259 if len(self.localContrastConfig._history["doLocalContrast"]) > 1:
260 self.doLocalContrast = self.localContrastConfig.doLocalContrast
262 # Handle doPsfDeconcovlve typo fix
263 if len(self._history["doPSFDeconcovlve"]) > 1:
264 self.doPsfDeconvolve = self.doPSFDeconcovlve
266 def freeze(self):
267 # ensure this is not already frozen
268 if self._frozen is not True:
269 self._handle_deprecated()
270 super().freeze()
273class PrettyPictureTask(PipelineTask):
274 """Turns inputs into an RGB image."""
276 _DefaultName = "prettyPicture"
277 ConfigClass = PrettyPictureConfig
279 config: ConfigClass
281 def _find_normal_stats(self, array):
282 """Calculate standard deviation from negative values using half-normal distribution.
284 Raises
285 ------
286 ValueError
287 Array dimension validation fails.
289 Parameters
290 ----------
291 array : `numpy.array`
292 Input array of numerical values.
294 Returns
295 -------
296 mean : `float`
297 The central moment of the distribution
298 sigma : `float`
299 Estimated standard deviation from negative values. Returns np.inf if:
300 - No negative values exist in the array
301 - Half-normal fitting fails
302 """
303 # Extract negative values efficiently
304 values_noise = array[array < self.config.noiseSearchThreshold]
306 # find the mode
307 center = mode(np.round(values_noise, 2)).mode
309 # extract the negative values
310 values_neg = array[array < center]
312 # Return infinity if no negative values found
313 if values_neg.size == 0:
314 return 0, np.inf
316 try:
317 # Fit half-normal distribution to absolute negative values
318 mu, sigma = halfnorm.fit(np.abs(values_neg))
319 except (ValueError, RuntimeError):
320 # Handle fitting failures (e.g., constant data, optimization issues)
321 return 0, np.inf
323 return center, sigma
325 def _match_sigmas_and_recenter(self, *arrays, factor=1):
326 """Scale array values to match minimum standard deviation across arrays
327 and recenter noise.
329 Adjusts values below each array's sigma by scaling and shifting them to
330 align with the minimum sigma value across all input arrays. This operates
331 in-place for efficiency.
333 Parameters
334 ----------
335 *arrays : any number of `numpy.array`
336 Variable number of input arrays to process.
337 factor : float, optional
338 Scaling factor for adjustments (default: 1).
340 """
341 # Calculate standard deviations for all arrays
342 sigmas = []
343 mus = []
344 for arr in arrays:
345 m, s = self._find_normal_stats(arr)
346 mus.append(m)
347 sigmas.append(s)
348 mus = np.array(mus)
349 sigmas = np.array(sigmas)
351 # If no sigmas could be determined, return the original
352 # arrays.
353 if not np.any(np.isfinite(sigmas)):
354 return
356 min_sig = np.min(sigmas)
358 for mu, sigma, array in zip(mus, sigmas, arrays):
359 # Identify values below the array's sigma threshold
360 lower_pos = (array - mu) < sigma
362 # Skip processing if sigma is invalid
363 if not np.isfinite(sigma):
364 continue
366 # Calculate scaling ratio relative to minimum sigma
367 sigma_ratio = min_sig / sigma
369 # Apply adjustment to qualifying values
370 array[lower_pos] = (array[lower_pos] - mu) * sigma_ratio + min_sig * factor
372 def run(self, images: Mapping[str, Exposure]) -> Struct:
373 """Turns the input arguments in arguments into an RGB array.
375 Parameters
376 ----------
377 images : `Mapping` of `str` to `Exposure`
378 A mapping of input images and the band they correspond to.
380 Returns
381 -------
382 result : `Struct`
383 A struct with the corresponding RGB image, and mask used in
384 RGB image construction. The struct will have the attributes
385 outputRGBImage and outputRGBMask. Each of the outputs will
386 be a `NDarray` object.
388 Notes
389 -----
390 Construction of input images are made easier by use of the
391 makeInputsFrom* methods.
392 """
393 channels = {}
394 shape = (0, 0)
395 jointMask: None | NDArray = None
396 maskDict: Mapping[str, int] = {}
397 doJointMaskInit = False
398 if jointMask is None:
399 doJointMask = True
400 doJointMaskInit = True
401 for channel, imageExposure in images.items():
402 imageArray = imageExposure.image.array
403 # run all the plugins designed for array based interaction
404 for plug in plugins.channel():
405 imageArray = plug(
406 imageArray, imageExposure.mask.array, imageExposure.mask.getMaskPlaneDict(), self.config
407 ).astype(np.float32)
408 channels[channel] = imageArray
409 # These operations are trivial look-ups and don't matter if they
410 # happen in each loop.
411 shape = imageArray.shape
412 maskDict = imageExposure.mask.getMaskPlaneDict()
413 if doJointMaskInit:
414 jointMask = np.zeros(shape, dtype=imageExposure.mask.dtype)
415 doJointMaskInit = False
416 if doJointMask:
417 jointMask |= imageExposure.mask.array
419 # mix the images to RGB
420 imageRArray = np.zeros(shape, dtype=np.float32)
421 imageGArray = np.zeros(shape, dtype=np.float32)
422 imageBArray = np.zeros(shape, dtype=np.float32)
424 for band, image in channels.items():
425 if band not in self.config.channelConfig:
426 self.log.info(f"{band} image found but not requested in RGB image, skipping")
427 continue
428 mix = self.config.channelConfig[band]
429 if mix.r:
430 imageRArray += mix.r * image
431 if mix.g:
432 imageGArray += mix.g * image
433 if mix.b:
434 imageBArray += mix.b * image
436 exposure = next(iter(images.values()))
437 box: Box2I = exposure.getBBox()
438 boxCenter = box.getCenter()
439 try:
440 psf = exposure.psf.computeImage(boxCenter).array
441 except Exception:
442 psf = None
444 if self.config.recenterNoise:
445 self._match_sigmas_and_recenter(
446 imageRArray, imageGArray, imageBArray, factor=self.config.recenterNoise
447 )
449 # assert for typing reasons
450 assert jointMask is not None
451 # Run any image level correction plugins
452 colorImage = np.zeros((*imageRArray.shape, 3))
453 colorImage[:, :, 0] = imageRArray
454 colorImage[:, :, 1] = imageGArray
455 colorImage[:, :, 2] = imageBArray
456 for plug in plugins.partial():
457 colorImage = plug(colorImage, jointMask, maskDict, self.config)
459 # Filter the local contrast parameters for diffusion that are None
460 # This is so we only apply key word overrides that are specifically set.
461 local_contrast_config = self.config.localContrastConfig.toDict()
462 to_remove = []
463 for k, v in local_contrast_config["diffusionFunction"].items():
464 if v is None:
465 to_remove.append(k)
466 for item in to_remove:
467 local_contrast_config["diffusionControl"].pop(item)
469 colorImage = lsstRGB(
470 colorImage[:, :, 0],
471 colorImage[:, :, 1],
472 colorImage[:, :, 2],
473 local_contrast=self.config.localContrastConfig if self.config.doLocalContrast else None,
474 scale_lum=self.config.luminanceConfig,
475 scale_color=self.config.colorConfig,
476 remap_bounds=self.config.imageRemappingConfig,
477 bracketing_function=(
478 self.config.exposureBracketerConfig if self.config.doExposureBrackets else None
479 ),
480 gamut_remapping_function=self.config.gamutMapperConfig if self.config.doRemapGamut else None,
481 cieWhitePoint=tuple(self.config.cieWhitePoint), # type: ignore
482 psf=psf if self.config.doPsfDeconvolve else None,
483 )
485 # Find the dataset type and thus the maximum values as well
486 maxVal: int | float
487 match self.config.arrayType:
488 case "uint8":
489 dtype = np.uint8
490 maxVal = 255
491 case "uint16":
492 dtype = np.uint16
493 maxVal = 65535
494 case "half":
495 dtype = np.half
496 maxVal = 1.0
497 case "float":
498 dtype = np.float32
499 maxVal = 1.0
500 case _:
501 assert True, "This code path should be unreachable"
503 # lsstRGB returns an image in 0-1 scale it to the maximum value
504 colorImage *= maxVal # type: ignore
506 # pack the joint mask back into a mask object
507 lsstMask = Mask(width=jointMask.shape[1], height=jointMask.shape[0], planeDefs=maskDict)
508 lsstMask.array = jointMask # type: ignore
509 return Struct(outputRGB=colorImage.astype(dtype), outputRGBMask=lsstMask) # type: ignore
511 def runQuantum(
512 self,
513 butlerQC: QuantumContext,
514 inputRefs: InputQuantizedConnection,
515 outputRefs: OutputQuantizedConnection,
516 ) -> None:
517 imageRefs: list[DatasetRef] = inputRefs.inputCoadds
518 sortedImages = self.makeInputsFromRefs(imageRefs, butlerQC)
519 if not sortedImages:
520 requested = ", ".join(self.config.channelConfig.keys())
521 raise NoWorkFound(f"No input images of band(s) {requested}")
522 outputs = self.run(sortedImages)
523 butlerQC.put(outputs, outputRefs)
525 def makeInputsFromRefs(
526 self, refs: Iterable[DatasetRef], butler: Butler | QuantumContext
527 ) -> dict[str, Exposure]:
528 r"""Make valid inputs for the run method from butler references.
530 Parameters
531 ----------
532 refs : `Iterable` of `DatasetRef`
533 Some `Iterable` container of `Butler` `DatasetRef`\ s
534 butler : `Butler` or `QuantumContext`
535 This is the object that fetches the input data.
537 Returns
538 -------
539 sortedImages : `dict` of `str` to `Exposure`
540 A dictionary of `Exposure`\ s keyed by the band they
541 correspond to.
542 """
543 sortedImages: dict[str, Exposure] = {}
544 for ref in refs:
545 key: str = cast(str, ref.dataId["band"])
546 image = butler.get(ref)
547 sortedImages[key] = image
548 return sortedImages
550 def makeInputsFromArrays(self, **kwargs) -> dict[str, DeferredDatasetHandle]:
551 r"""Make valid inputs for the run method from numpy arrays.
553 Parameters
554 ----------
555 kwargs : `numpy.ndarray`
556 This is standard python kwargs where the left side of the equals
557 is the data band, and the right side is the corresponding `numpy.ndarray`
558 array.
560 Returns
561 -------
562 sortedImages : `dict` of `str` to \
563 `~lsst.daf.butler.DeferredDatasetHandle`
564 A dictionary of `~lsst.daf.butlger.DeferredDatasetHandle`\ s keyed
565 by the band they correspond to.
566 """
567 # ignore type because there aren't proper stubs for afw
568 temp = {}
569 for key, array in kwargs.items():
570 temp[key] = Exposure(Box2I(Point2I(0, 0), Extent2I(*array.shape)), dtype=array.dtype)
571 temp[key].image.array[:] = array
573 return self.makeInputsFromExposures(**temp)
575 def makeInputsFromExposures(self, **kwargs) -> dict[int, DeferredDatasetHandle]:
576 r"""Make valid inputs for the run method from `Exposure` objects.
578 Parameters
579 ----------
580 kwargs : `Exposure`
581 This is standard python kwargs where the left side of the equals
582 is the data band, and the right side is the corresponding
583 `Exposure`.
585 Returns
586 -------
587 sortedImages : `dict` of `int` to \
588 `~lsst.daf.butler.DeferredDatasetHandle`
589 A dictionary of `~lsst.daf.butler.DeferredDatasetHandle`\ s keyed
590 by the band they correspond to.
591 """
592 sortedImages = {}
593 for key, value in kwargs.items():
594 sortedImages[key] = value
595 return sortedImages
598class PrettyPictureBackgroundFixerConnections(
599 PipelineTaskConnections,
600 dimensions=("tract", "patch", "skymap", "band"),
601 defaultTemplates={"coaddTypeName": "deep"},
602):
603 inputCoadd = Input(
604 doc=("Input coadd for which the background is to be removed"),
605 name="{coaddTypeName}CoaddPsfMatched",
606 storageClass="ExposureF",
607 dimensions=("tract", "patch", "skymap", "band"),
608 )
609 outputCoadd = Output(
610 doc="The coadd with the background fixed and subtracted",
611 name="pretty_picture_coadd_bg_subtracted",
612 storageClass="ExposureF",
613 dimensions=("tract", "patch", "skymap", "band"),
614 )
617class PrettyPictureBackgroundFixerConfig(
618 PipelineTaskConfig, pipelineConnections=PrettyPictureBackgroundFixerConnections
619):
620 use_detection_mask = Field[bool](
621 doc="Use the detection mask to determine background instead of empirically finding it in this task",
622 default=False,
623 )
624 num_background_bins = Field[int](
625 doc="The number of bins along each axis when determining background", default=5
626 )
627 min_bin_fraction = Field[float](
628 doc="Bins with fewer pixels than this fraction of the total will be ignored", default=0.1
629 )
631 pos_sigma_multiplier = Field[float](
632 doc="How many sigma to consider as background in the positive direction", default=2
633 )
636class PrettyPictureBackgroundFixerTask(PipelineTask):
637 """Empirically flatten an images background.
639 Many astrophysical images have backgrounds with imperfections in them.
640 This Task attempts to determine control points which are considered
641 background values, and fits a radial basis function model to those
642 points. This model is then subtracted off the image.
644 """
646 _DefaultName = "prettyPictureBackgroundFixer"
647 ConfigClass = PrettyPictureBackgroundFixerConfig
649 config: ConfigClass
651 def _tile_slices(self, arr, R, C):
652 """Generate slices for tiling an array.
654 This function divides an array into a grid of tiles and returns a list of
655 slice objects representing each tile. It handles cases where the array
656 dimensions are not evenly divisible by the number of tiles in each
657 dimension, distributing the remainder among the tiles.
659 Parameters
660 ----------
661 arr : `numyp.ndarray`
662 The input array to be tiled. Used only to determine the array's shape.
663 R : `int`
664 The number of tiles in the row dimension.
665 C : `int`
666 The number of tiles in the column dimension.
668 Returns
669 -------
670 slices : `list` of `tuple`
671 A list of tuples, where each tuple contains two `slice` objects
672 representing the row and column slices for a single tile.
673 """
674 M = arr.shape[0]
675 N = arr.shape[1]
677 # Function to compute slices for a given dimension size and number of divisions
678 def get_slices(total_size: int, num_divisions: int) -> list[tuple[int, int]]:
679 """Generate slice ranges for dividing a size into equal parts.
681 Parameters
682 ----------
683 total_size : `int`
684 Total size to be divided into slices.
685 num_divisions : `int`
686 Number of divisions to create.
688 Returns
689 -------
690 `list` of `tuple` of `int`
691 List of (start, end) tuples representing each slice.
693 Notes
694 -----
695 This function divides the total_size into num_divisions equal parts.
696 If the division is not exact, the remainder is distributed by adding
697 1 to the first 'remainder' slices, ensuring balanced distribution.
698 """
699 base = total_size // num_divisions
700 remainder = total_size % num_divisions
701 slices = []
702 start = 0
703 for i in range(num_divisions):
704 end = start + base
705 if i < remainder:
706 end += 1
707 slices.append((start, end))
708 start = end
709 return slices
711 # Get row and column slices
712 row_slices = get_slices(M, R)
713 col_slices = get_slices(N, C)
715 # Generate all possible tile combinations of row and column slices
716 tiles = []
717 for rs in row_slices:
718 r_start, r_end = rs
719 for cs in col_slices:
720 c_start, c_end = cs
721 tile_slice = (slice(r_start, r_end), slice(c_start, c_end))
722 tiles.append(tile_slice)
724 return tiles
726 @staticmethod
727 def findBackgroundPixels(image, pos_sigma_mult=1):
728 """Find pixels that are likely to be background based on image statistics.
730 This method estimates background pixels by analyzing the distribution of
731 pixel values in the image. It uses the median as an estimate of the background
732 level and fits a half-normal distribution to values below the median to
733 determine the background sigma. Pixels below a threshold (mean + sigma) are
734 classified as background.
736 Parameters
737 ----------
738 image : `numpy.ndarray`
739 Input image array for which to find background pixels.
740 pos_sigma_mult : `float`
741 How many sigma to consider as background in the positive direction
743 Returns
744 -------
745 result : `numpy.ndarray`
746 Boolean mask array where True indicates background pixels.
748 Notes
749 -----
750 This method works best for images with relatively uniform background. It may
751 not perform well in fields with high density or diffuse flux, as noted in
752 the implementation comments.
753 """
754 # Find the median value in the image, which is likely to be
755 # close to average background. Note this doesn't work well
756 # in fields with high density or diffuse flux.
757 maxLikely = np.median(image, axis=None)
759 # find all the pixels that are fainter than this
760 # and find the std. This is just used as an initialization
761 # parameter and doesn't need to be accurate.
762 mask = image < maxLikely
763 initial_std = (image[mask] - maxLikely).std()
765 # Don't do anything if there are no pixels to check
766 if np.any(mask):
767 # use a minimizer to determine best mu and sigma for a Gaussian
768 # given only samples below the mean of the Gaussian.
769 mu_hat, sigma_hat = halfnorm.fit(np.abs(image[mask] - maxLikely))
770 # mu_hat = maxLikely
771 else:
772 mu_hat, sigma_hat = (maxLikely, 2 * initial_std)
774 # create a new masking threshold that is the determined
775 # mean plus std from the fit
776 threshhold = mu_hat + pos_sigma_mult * sigma_hat
777 image_mask = (image < threshhold) * (image > (mu_hat - 5 * sigma_hat))
778 return image_mask
780 def fixBackground(self, image, detection_mask=None):
781 """Estimate and subtract the background from an image.
783 This function estimates the background level in an image using a median-based
784 approach combined with Gaussian fitting and radial basis function interpolation.
785 It aims to provide a more accurate background estimation than a simple median
786 filter, especially in images with varying background levels.
788 Parameters
789 ----------
790 image : `numpy.ndarray`
791 The input image as a NumPy array.
793 Returns
794 -------
795 numpy.ndarray
796 An array representing the estimated background level across the image.
797 """
798 if detection_mask is None:
799 image_mask = self.findBackgroundPixels(image, self.config.pos_sigma_multiplier)
800 else:
801 image_mask = detection_mask
803 # create python slices that tile the image.
804 tiles = self._tile_slices(image, self.config.num_background_bins, self.config.num_background_bins)
806 yloc = []
807 xloc = []
808 values = []
810 # for each box find the middle position and the median background
811 # value in the window.
812 for xslice, yslice in tiles:
813 ypos = (yslice.stop - yslice.start) / 2 + yslice.start
814 xpos = (xslice.stop - xslice.start) / 2 + xslice.start
815 window = image[yslice, xslice][image_mask[yslice, xslice]]
816 # make sure each bin is at least 1% filled
817 min_fill = int((yslice.stop - yslice.start) ** 2 * self.config.min_bin_fraction)
818 if window.size > min_fill:
819 value = np.median(window)
820 else:
821 continue
822 values.append(value)
823 yloc.append(ypos)
824 xloc.append(xpos)
826 # At least 15 points are requred for TPS with 4th order polynomial
827 if len(yloc) < 15:
828 return np.zeros(image.shape)
830 # create an interpolant for the background and interpolate over the image.
831 inter = RBFInterpolator(
832 np.vstack((yloc, xloc)).T,
833 values,
834 kernel="thin_plate_spline",
835 degree=4,
836 smoothing=0.05,
837 neighbors=None,
838 )
840 backgrounds = rbf_interpolator.fast_rbf_interpolation_on_grid(inter, image.shape)
842 return backgrounds
844 def run(self, inputCoadd: Exposure):
845 """Estimate a background for an input Exposure and remove it.
847 Parameters
848 ----------
849 inputCoadd : `Exposure`
850 The exposure the background will be removed from.
852 Returns
853 -------
854 result : `Struct`
855 A `Struct` that contains the exposure with the background removed.
856 This `Struct` will have an attribute named ``outputCoadd``.
858 """
859 if self.config.use_detection_mask:
860 mask_plane_dict = inputCoadd.mask.getMaskPlaneDict()
861 detection_mask = ~(inputCoadd.mask.array & 2 ** mask_plane_dict["DETECTED"])
862 else:
863 detection_mask = None
864 background = self.fixBackground(inputCoadd.image.array, detection_mask=detection_mask)
865 # create a copy to mutate
866 output = ExposureF(inputCoadd, deep=True)
867 output.image.array -= background
868 return Struct(outputCoadd=output)
871class PrettyPictureStarFixerConnections(
872 PipelineTaskConnections,
873 dimensions=("tract", "patch", "skymap"),
874):
875 inputCoadd = Input(
876 doc=("Input coadd for which the background is to be removed"),
877 name="pretty_picture_coadd_bg_subtracted",
878 storageClass="ExposureF",
879 dimensions=("tract", "patch", "skymap", "band"),
880 multiple=True,
881 )
882 outputCoadd = Output(
883 doc="The coadd with the background fixed and subtracted",
884 name="pretty_picture_coadd_fixed_stars",
885 storageClass="ExposureF",
886 dimensions=("tract", "patch", "skymap", "band"),
887 multiple=True,
888 )
891class PrettyPictureStarFixerConfig(PipelineTaskConfig, pipelineConnections=PrettyPictureStarFixerConnections):
892 brightnessThresh = Field[float](
893 doc="The flux value below which pixels with SAT or NO_DATA bits will be ignored"
894 )
897class PrettyPictureStarFixerTask(PipelineTask):
898 """This class fixes up regions in an image where there is no, or bad data.
900 The fixes done by this task are overwhelmingly comprised of the cores of
901 bright stars for which there is no data.
902 """
904 _DefaultName = "prettyPictureStarFixer"
905 ConfigClass = PrettyPictureStarFixerConfig
907 config: ConfigClass
909 def run(self, inputs: Mapping[str, ExposureF]) -> Struct:
910 """Fix areas in an image where this is no data, most likely to be
911 the cores of bright stars.
913 Because we want to have consistent fixes accross bands, this method
914 relies on supplying all bands and fixing pixels that are marked
915 as having a defect in any band even if within one band there is
916 no issue.
918 Parameters
919 ----------
920 inputs : `Mapping` of `str` to `ExposureF`
921 This mapping has keys of band as a `str` and the corresponding
922 ExposureF as a value.
924 Returns
925 -------
926 results : `Struct` of `Mapping` of `str` to `ExposureF`
927 A `Struct` that has a mapping of band to `ExposureF`. The `Struct`
928 has an attribute named ``results``.
930 """
931 # make the joint mask of all the channels
932 doJointMaskInit = True
933 for imageExposure in inputs.values():
934 maskDict = imageExposure.mask.getMaskPlaneDict()
935 if doJointMaskInit:
936 jointMask = np.zeros(imageExposure.mask.array.shape, dtype=imageExposure.mask.array.dtype)
937 doJointMaskInit = False
938 jointMask |= imageExposure.mask.array
940 sat_bit = maskDict["SAT"]
941 no_data_bit = maskDict["NO_DATA"]
942 together = (jointMask & 2**sat_bit).astype(bool) | (jointMask & 2**no_data_bit).astype(bool)
944 # use the last imageExposure as it is likely close enough across all bands
945 bright_mask = imageExposure.image.array > self.config.brightnessThresh
947 # dilate the mask a bit, this helps get a bit fainter mask without starting
948 # to include pixels in an irregular shape, as only the star cores should be
949 # fixed.
950 both = together & bright_mask
951 struct = np.array(((0, 1, 0), (1, 1, 1), (0, 1, 0)), dtype=bool)
952 both = binary_dilation(both, struct, iterations=4).astype(bool)
954 # do the actual fixing of values
955 results = {}
956 for band, imageExposure in inputs.items():
957 if np.sum(both) > 0:
958 inpainted = inpaint_biharmonic(imageExposure.image.array, both, split_into_regions=True)
959 imageExposure.image.array[both] = inpainted[both]
960 results[band] = imageExposure
961 return Struct(results=results)
963 def runQuantum(
964 self,
965 butlerQC: QuantumContext,
966 inputRefs: InputQuantizedConnection,
967 outputRefs: OutputQuantizedConnection,
968 ) -> None:
969 refs = inputRefs.inputCoadd
970 sortedImages: dict[str, Exposure] = {}
971 for ref in refs:
972 key: str = cast(str, ref.dataId["band"])
973 image = butlerQC.get(ref)
974 sortedImages[key] = image
976 outputs = self.run(sortedImages).results
977 sortedOutputs = {}
978 for ref in outputRefs.outputCoadd:
979 sortedOutputs[ref.dataId["band"]] = ref
981 for band, data in outputs.items():
982 butlerQC.put(data, sortedOutputs[band])
985class PrettyMosaicConnections(PipelineTaskConnections, dimensions=("tract", "skymap")):
986 inputRGB = Input(
987 doc="Individual RGB images that are to go into the mosaic",
988 name="rgb_picture_array",
989 storageClass="NumpyArray",
990 dimensions=("tract", "patch", "skymap"),
991 multiple=True,
992 deferLoad=True,
993 )
995 skyMap = Input(
996 doc="The skymap which the data has been mapped onto",
997 storageClass="SkyMap",
998 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
999 dimensions=("skymap",),
1000 )
1002 inputRGBMask = Input(
1003 doc="Individual RGB images that are to go into the mosaic",
1004 name="rgb_picture_mask",
1005 storageClass="Mask",
1006 dimensions=("tract", "patch", "skymap"),
1007 multiple=True,
1008 deferLoad=True,
1009 )
1011 outputRGBMosaic = Output(
1012 doc="A RGB mosaic created from the input data stored as a 3d array",
1013 name="rgb_mosaic_array",
1014 storageClass="NumpyArray",
1015 dimensions=("tract", "skymap"),
1016 )
1019class PrettyMosaicConfig(PipelineTaskConfig, pipelineConnections=PrettyMosaicConnections):
1020 binFactor = Field[int](doc="The factor to bin by when producing the mosaic")
1021 doDCID65Convert = Field[bool]("Force the output to be converted from display p3 to DCI-D65 colorspace.")
1022 useLocalTemp = Field[bool](doc="Use the current directory when creating local temp files.", default=False)
1025class PrettyMosaicTask(PipelineTask):
1026 """Combines multiple RGB arrays into one mosaic."""
1028 _DefaultName = "prettyMosaic"
1029 ConfigClass = PrettyMosaicConfig
1031 config: ConfigClass
1033 def run(
1034 self,
1035 inputRGB: Iterable[DeferredDatasetHandle],
1036 skyMap: BaseSkyMap,
1037 inputRGBMask: Iterable[DeferredDatasetHandle],
1038 ) -> Struct:
1039 r"""Assemble individual `numpy.ndarrays` into a mosaic.
1041 Each input is a `~lsst.daf.butler.DeferredDatasetHandle` because
1042 they're loaded in one at a time to be placed into the mosaic to save
1043 memory.
1045 Parameters
1046 ----------
1047 inputRGB : `Iterable` of `~lsst.daf.butler.DeferredDatasetHandle`
1048 `~lsst.daf.butler.DeferredDatasetHandle`\ s pointing to RGB
1049 `numpy.ndarrays`.
1050 skyMap : `BaseSkyMap`
1051 The skymap that defines the relative position of each of the input
1052 images.
1053 inputRGBMask : `Iterable` of `~lsst.daf.butler.DeferredDatasetHandle`
1054 `~lsst.daf.butler.DeferredDatasetHandle`\ s pointing to masks for
1055 each of the corresponding images.
1057 Returns
1058 -------
1059 result : `Struct`
1060 The `Struct` containing the combined mosaic. The `Struct` has
1061 and attribute named ``outputRGBMosaic``.
1062 """
1063 # create the bounding region
1064 newBox = Box2I()
1065 # store the bounds as they are retrieved from the skymap
1066 boxes = []
1067 tractMaps = []
1068 for handle in inputRGB:
1069 dataId = handle.dataId
1070 tractInfo: TractInfo = skyMap[dataId["tract"]]
1071 patchInfo: PatchInfo = tractInfo[dataId["patch"]]
1072 bbox = patchInfo.getOuterBBox()
1073 boxes.append(bbox)
1074 newBox.include(bbox)
1075 tractMaps.append(tractInfo)
1076 # This will be overwritten in the loop, but that is ok, because
1077 # it is the same for each patch.
1078 patch_grow: int = patchInfo.getCellInnerDimensions().getX()
1080 # fixup the boxes to be smaller if needed, and put the origin at zero,
1081 # this must be done after constructing the complete outer box
1082 modifiedBoxes = []
1083 origin = newBox.getBegin()
1084 for iterBox in boxes:
1085 localOrigin = iterBox.getBegin() - origin
1086 localOrigin = Point2I(
1087 x=int(np.floor(localOrigin.x / self.config.binFactor)),
1088 y=int(np.floor(localOrigin.y / self.config.binFactor)),
1089 )
1090 localExtent = Extent2I(
1091 x=int(np.floor(iterBox.getWidth() / self.config.binFactor)),
1092 y=int(np.floor(iterBox.getHeight() / self.config.binFactor)),
1093 )
1094 tmpBox = Box2I(localOrigin, localExtent)
1095 modifiedBoxes.append(tmpBox)
1096 boxes = modifiedBoxes
1098 # scale the container box
1099 newBoxOrigin = Point2I(0, 0)
1100 newBoxExtent = Extent2I(
1101 x=int(np.floor(newBox.getWidth() / self.config.binFactor)),
1102 y=int(np.floor(newBox.getHeight() / self.config.binFactor)),
1103 )
1104 newBox = Box2I(newBoxOrigin, newBoxExtent)
1106 # Allocate storage for the mosaic
1107 self.imageHandle = tempfile.NamedTemporaryFile(dir="." if self.config.useLocalTemp else None)
1108 self.maskHandle = tempfile.NamedTemporaryFile(dir="." if self.config.useLocalTemp else None)
1109 consolidatedImage = None
1110 consolidatedMask = None
1112 # Setup color space conversion in case they are used.
1113 d65 = copy.deepcopy(colour.models.RGB_COLOURSPACE_DCI_P3)
1114 dp3 = copy.deepcopy(colour.models.RGB_COLOURSPACE_DISPLAY_P3)
1115 d65.whitepoint = dp3.whitepoint
1116 d65.whitepoint_name = dp3.whitepoint_name
1118 # Actually assemble the mosaic
1119 maskDict = {}
1120 mosaic_maker = FeatheredMosaicCreator(patch_grow, self.config.binFactor)
1121 for box, handle, handleMask, tractInfo in zip(boxes, inputRGB, inputRGBMask, tractMaps):
1122 rgb = handle.get()
1123 # convert to the dci-d65 colorspace
1124 if self.config.doDCID65Convert:
1125 rgb = colour.RGB_to_RGB(np.clip(rgb, 0, 1), dp3, d65)
1126 rgbMask = handleMask.get()
1127 maskDict = rgbMask.getMaskPlaneDict()
1128 # allocate the memory for the mosaic
1129 if consolidatedImage is None:
1130 consolidatedImage = np.memmap(
1131 self.imageHandle.name,
1132 mode="w+",
1133 shape=(newBox.getHeight(), newBox.getWidth(), 3),
1134 dtype=rgb.dtype,
1135 )
1136 if consolidatedMask is None:
1137 consolidatedMask = np.memmap(
1138 self.maskHandle.name,
1139 mode="w+",
1140 shape=(newBox.getHeight(), newBox.getWidth()),
1141 dtype=rgbMask.array.dtype,
1142 )
1144 if self.config.binFactor > 1:
1145 # opencv wants things in x, y dimensions
1146 shape = tuple(box.getDimensions())[::-1]
1147 rgb = cv2.resize(
1148 rgb,
1149 dst=None,
1150 dsize=shape,
1151 fx=shape[0] / self.config.binFactor,
1152 fy=shape[1] / self.config.binFactor,
1153 )
1154 mask_array = rgbMask.array[:: self.config.binFactor, :: self.config.binFactor]
1155 rgbMask = Mask(*(mask_array.shape[::-1]))
1156 mosaic_maker.add_to_image(consolidatedImage, rgb, newBox, box)
1158 consolidatedMask[*box.slices] = np.bitwise_or(consolidatedMask[*box.slices], rgbMask.array)
1160 for plugin in plugins.full():
1161 if consolidatedImage is not None and consolidatedMask is not None:
1162 consolidatedImage = plugin(consolidatedImage, consolidatedMask, maskDict)
1163 # If consolidated image still None, that means there was no work to do.
1164 # Return an empty image instead of letting this task fail.
1165 if consolidatedImage is None:
1166 consolidatedImage = np.zeros((0, 0, 0), dtype=np.uint8)
1168 return Struct(outputRGBMosaic=consolidatedImage)
1170 def runQuantum(
1171 self,
1172 butlerQC: QuantumContext,
1173 inputRefs: InputQuantizedConnection,
1174 outputRefs: OutputQuantizedConnection,
1175 ) -> None:
1176 inputs = butlerQC.get(inputRefs)
1177 outputs = self.run(**inputs)
1178 butlerQC.put(outputs, outputRefs)
1179 if hasattr(self, "imageHandle"):
1180 self.imageHandle.close()
1181 if hasattr(self, "maskHandle"):
1182 self.maskHandle.close()
1184 def makeInputsFromArrays(
1185 self, inputs: Iterable[tuple[Mapping[str, Any], NDArray]]
1186 ) -> Iterable[DeferredDatasetHandle]:
1187 r"""Make valid inputs for the run method from numpy arrays.
1189 Parameters
1190 ----------
1191 inputs : `Iterable` of `tuple` of `Mapping` and `numpy.ndarray`
1192 An iterable where each element is a tuple with the first
1193 element is a mapping that corresponds to an arrays dataId,
1194 and the second is an `numpy.ndarray`.
1196 Returns
1197 -------
1198 sortedImages : `Iterable` of `~lsst.daf.butler.DeferredDatasetHandle`
1199 An iterable of `~lsst.daf.butler.DeferredDatasetHandle`\ s
1200 containing the input data.
1201 """
1202 structuredInputs = []
1203 for dataId, array in inputs:
1204 structuredInputs.append(InMemoryDatasetHandle(inMemoryDataset=array, **dataId))
1206 return structuredInputs