lsst.pipe.tasks g253578fa50+e0a50b457a
Loading...
Searching...
No Matches
_task.py
Go to the documentation of this file.
1# This file is part of pipe_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21
22from __future__ import annotations
23
24__all__ = (
25 "ChannelRGBConfig",
26 "PrettyPictureTask",
27 "PrettyPictureConnections",
28 "PrettyPictureConfig",
29 "PrettyMosaicTask",
30 "PrettyMosaicConnections",
31 "PrettyMosaicConfig",
32 "PrettyPictureBackgroundFixerConfig",
33 "PrettyPictureBackgroundFixerTask",
34 "PrettyPictureStarFixerConfig",
35 "PrettyPictureStarFixerTask",
36)
37
38import colour
39import copy
40from collections.abc import Iterable, Mapping
41from lsst.afw.image import ExposureF
42import numpy as np
43from typing import TYPE_CHECKING, cast, Any
44from lsst.skymap import BaseSkyMap
45
46from scipy.stats import halfnorm, mode
47from scipy.ndimage import binary_dilation
48from scipy.interpolate import RBFInterpolator
49from skimage.restoration import inpaint_biharmonic
50
51from lsst.daf.butler import Butler, DeferredDatasetHandle
52from lsst.daf.butler import DatasetRef
53from lsst.pex.config import Field, Config, ConfigDictField, ListField, ChoiceField
54from lsst.pex.config.configurableActions import ConfigurableActionField
55from lsst.pipe.base import (
56 PipelineTask,
57 PipelineTaskConfig,
58 PipelineTaskConnections,
59 Struct,
60 InMemoryDatasetHandle,
61 NoWorkFound,
62)
63from lsst.rubinoxide import rbf_interpolator
64import cv2
65
66from lsst.pipe.base.connectionTypes import Input, Output
67from lsst.geom import Box2I, Point2I, Extent2I
68from lsst.afw.image import Exposure, Mask
69
70from ._plugins import plugins
71from ._colorMapper import lsstRGB
72from ._utils import FeatheredMosaicCreator
73from ._functors import (
74 BoundsRemapper,
75 ColorScaler,
76 LumCompressor,
77 ExposureBracketer,
78 GamutFixer,
79 LocalContrastEnhancer,
80)
81
82import tempfile
83
84
85if TYPE_CHECKING:
86 from numpy.typing import NDArray
87 from lsst.pipe.base import QuantumContext, InputQuantizedConnection, OutputQuantizedConnection
88 from lsst.skymap import TractInfo, PatchInfo
89
90
92 PipelineTaskConnections,
93 dimensions={"tract", "patch", "skymap"},
94 defaultTemplates={"coaddTypeName": "deep"},
95):
96 inputCoadds = Input(
97 doc=(
98 "Model of the static sky, used to find temporal artifacts. Typically a PSF-Matched, "
99 "sigma-clipped coadd. Written if and only if assembleStaticSkyModel.doWrite=True"
100 ),
101 name="pretty_coadd",
102 storageClass="ExposureF",
103 dimensions=("tract", "patch", "skymap", "band"),
104 multiple=True,
105 )
106
107 outputRGB = Output(
108 doc="A RGB image created from the input data stored as a 3d array",
109 name="rgb_picture_array",
110 storageClass="NumpyArray",
111 dimensions=("tract", "patch", "skymap"),
112 )
113
114 outputRGBMask = Output(
115 doc="A Mask corresponding to the fused masks of the input channels",
116 name="rgb_picture_mask",
117 storageClass="Mask",
118 dimensions=("tract", "patch", "skymap"),
119 )
120
121
122class ChannelRGBConfig(Config):
123 """This describes the rgb values of a given input channel.
124
125 For instance if this channel is red the values would be self.r = 1,
126 self.g = 0, self.b = 0. If the channel was cyan the values would be
127 self.r = 0, self.g = 1, self.b = 1.
128 """
129
130 r = Field[float](doc="The amount of red contained in this channel")
131 g = Field[float](doc="The amount of green contained in this channel")
132 b = Field[float](doc="The amount of blue contained in this channel")
133
134
135class PrettyPictureConfig(PipelineTaskConfig, pipelineConnections=PrettyPictureConnections):
136 channelConfig = ConfigDictField(
137 doc="A dictionary that maps band names to their rgb channel configurations",
138 keytype=str,
139 itemtype=ChannelRGBConfig,
140 default={},
141 )
142 cieWhitePoint = ListField[float](
143 doc="The white point of the input arrays in ciexz coordinates", maxLength=2, default=[0.28, 0.28]
144 )
145 arrayType = ChoiceField[str](
146 doc="The dataset type for the output image array",
147 default="uint8",
148 allowed={
149 "uint8": "Use 8 bit arrays, 255 max",
150 "uint16": "Use 16 bit arrays, 65535 max",
151 "half": "Use 16 bit float arrays, 1 max",
152 "float": "Use 32 bit float arrays, 1 max",
153 },
154 )
155 recenterNoise = Field[float](
156 doc="Recenter the noise away from zero. Supplied value is in units of sigma",
157 optional=True,
158 default=None,
159 )
160 noiseSearchThreshold = Field[float](
161 doc=(
162 "Flux threshold below which most flux will be considered noise, used to estimate noise properties"
163 ),
164 default=10,
165 )
166 doPsfDeconvolve = Field[bool](
167 doc="Use the PSF in a Richardson-Lucy deconvolution on the luminance channel.", default=False
168 )
169 doPSFDeconcovlve = Field[bool](
170 doc="Use the PSF in a Richardson-Lucy deconvolution on the luminance channel.",
171 default=False,
172 deprecated="This field will be removed in v32. Use doPsfDeconvolve instead.",
173 optional=True,
174 )
175 doRemapGamut = Field[bool](
176 doc="Apply a color correction to unrepresentable colors; if False, clip them.", default=True
177 )
178 doExposureBrackets = Field[bool](
179 doc="Apply exposure bracketing to aid in dynamic range compression", default=True
180 )
181 doLocalContrast = Field[bool](doc="Apply local contrast optimizations to luminance.", default=True)
182
183 imageRemappingConfig = ConfigurableActionField[BoundsRemapper](
184 doc="Action controlling normalization process"
185 )
186 luminanceConfig = ConfigurableActionField[LumCompressor](
187 doc="Action controlling luminance scaling when making an RGB image"
188 )
189 localContrastConfig = ConfigurableActionField[LocalContrastEnhancer](
190 doc="Action controlling the local contrast correction in RGB image production"
191 )
192 colorConfig = ConfigurableActionField[ColorScaler](
193 doc="Action to control the color scaling process in RGB image production"
194 )
195 exposureBracketerConfig = ConfigurableActionField[ExposureBracketer](
196 doc=(
197 "Exposure scaling action used in creating multiple exposures with different scalings which will "
198 "then be fused into a final image"
199 ),
200 )
201 gamutMapperConfig = ConfigurableActionField[GamutFixer](
202 doc="Action to fix pixels which lay outside RGB color gamut"
203 )
204
205 exposureBrackets = ListField[float](
206 doc=(
207 "Exposure scaling factors used in creating multiple exposures with different scalings which will "
208 "then be fused into a final image"
209 ),
210 optional=True,
211 default=[1.25, 1, 0.75],
212 deprecated=(
213 "This field will stop working in v31 and be removed in v32, "
214 "please set exposureBracketerConfig.exposureBrackets"
215 ),
216 )
217 gamutMethod = ChoiceField[str](
218 doc="If doRemapGamut is True this determines the method",
219 default="inpaint",
220 allowed={
221 "mapping": "Use a mapping function",
222 "inpaint": "Use surrounding pixels to determine likely value",
223 },
224 deprecated="This field will stop working in v31 and be removed in v32, please set gamutMapperConfig",
225 )
226
227 def setDefaults(self):
228 self.channelConfig["i"] = ChannelRGBConfig(r=1, g=0, b=0)
229 self.channelConfig["r"] = ChannelRGBConfig(r=0, g=1, b=0)
230 self.channelConfig["g"] = ChannelRGBConfig(r=0, g=0, b=1)
231 return super().setDefaults()
232
233 def _handle_deprecated(self):
234 """Handle deprecated configuration migration.
235
236 This method migrates deprecated configuration fields to their new
237 locations in sub-configurations. It checks the configuration history
238 to determine if deprecated fields were explicitly set and updates
239 the new configuration locations accordingly.
240
241 Notes
242 -----
243 The following deprecated fields are migrated:
244 - ``gamutMethod`` -> ``gamutMapperConfig.gamutMethod``
245 - ``exposureBrackets`` -> ``exposureBracketerConfig.exposureBrackets``
246 - ``doLocalContrast`` -> ``localContrastConfig.doLocalContrast``
247 - ``doPSFDeconcovlve`` -> ``doPsfDeconvolve``
248 """
249 # check if gamutMethod is set
250 if len(self._history["gamutMethod"]) > 1:
251 # This has been set in config, update it in the new location
252 self.gamutMapperConfig.gamutMethod = self.gamutMethod
253
254 if len(self._history["exposureBrackets"]) > 1:
255 self.exposureBracketerConfig.exposureBrackets = self.exposureBrackets
256 if self.exposureBrackets is None:
257 self.doExposureBrackets = False
258
259 if len(self.localContrastConfig._history["doLocalContrast"]) > 1:
260 self.doLocalContrast = self.localContrastConfig.doLocalContrast
261
262 # Handle doPsfDeconcovlve typo fix
263 if len(self._history["doPSFDeconcovlve"]) > 1:
264 self.doPsfDeconvolve = self.doPSFDeconcovlve
265
266 def freeze(self):
267 # ensure this is not already frozen
268 if self._frozen is not True:
269 self._handle_deprecated()
270 super().freeze()
271
272
273class PrettyPictureTask(PipelineTask):
274 """Turns inputs into an RGB image."""
275
276 _DefaultName = "prettyPicture"
277 ConfigClass = PrettyPictureConfig
278
279 config: ConfigClass
280
281 def _find_normal_stats(self, array):
282 """Calculate standard deviation from negative values using half-normal distribution.
283
284 Raises
285 ------
286 ValueError
287 Array dimension validation fails.
288
289 Parameters
290 ----------
291 array : `numpy.array`
292 Input array of numerical values.
293
294 Returns
295 -------
296 mean : `float`
297 The central moment of the distribution
298 sigma : `float`
299 Estimated standard deviation from negative values. Returns np.inf if:
300 - No negative values exist in the array
301 - Half-normal fitting fails
302 """
303 # Extract negative values efficiently
304 values_noise = array[array < self.config.noiseSearchThreshold]
305
306 # find the mode
307 center = mode(np.round(values_noise, 2)).mode
308
309 # extract the negative values
310 values_neg = array[array < center]
311
312 # Return infinity if no negative values found
313 if values_neg.size == 0:
314 return 0, np.inf
315
316 try:
317 # Fit half-normal distribution to absolute negative values
318 mu, sigma = halfnorm.fit(np.abs(values_neg))
319 except (ValueError, RuntimeError):
320 # Handle fitting failures (e.g., constant data, optimization issues)
321 return 0, np.inf
322
323 return center, sigma
324
325 def _match_sigmas_and_recenter(self, *arrays, factor=1):
326 """Scale array values to match minimum standard deviation across arrays
327 and recenter noise.
328
329 Adjusts values below each array's sigma by scaling and shifting them to
330 align with the minimum sigma value across all input arrays. This operates
331 in-place for efficiency.
332
333 Parameters
334 ----------
335 *arrays : any number of `numpy.array`
336 Variable number of input arrays to process.
337 factor : float, optional
338 Scaling factor for adjustments (default: 1).
339
340 """
341 # Calculate standard deviations for all arrays
342 sigmas = []
343 mus = []
344 for arr in arrays:
345 m, s = self._find_normal_stats(arr)
346 mus.append(m)
347 sigmas.append(s)
348 mus = np.array(mus)
349 sigmas = np.array(sigmas)
350
351 # If no sigmas could be determined, return the original
352 # arrays.
353 if not np.any(np.isfinite(sigmas)):
354 return
355
356 min_sig = np.min(sigmas)
357
358 for mu, sigma, array in zip(mus, sigmas, arrays):
359 # Identify values below the array's sigma threshold
360 lower_pos = (array - mu) < sigma
361
362 # Skip processing if sigma is invalid
363 if not np.isfinite(sigma):
364 continue
365
366 # Calculate scaling ratio relative to minimum sigma
367 sigma_ratio = min_sig / sigma
368
369 # Apply adjustment to qualifying values
370 array[lower_pos] = (array[lower_pos] - mu) * sigma_ratio + min_sig * factor
371
372 def run(self, images: Mapping[str, Exposure]) -> Struct:
373 """Turns the input arguments in arguments into an RGB array.
374
375 Parameters
376 ----------
377 images : `Mapping` of `str` to `Exposure`
378 A mapping of input images and the band they correspond to.
379
380 Returns
381 -------
382 result : `Struct`
383 A struct with the corresponding RGB image, and mask used in
384 RGB image construction. The struct will have the attributes
385 outputRGBImage and outputRGBMask. Each of the outputs will
386 be a `NDarray` object.
387
388 Notes
389 -----
390 Construction of input images are made easier by use of the
391 makeInputsFrom* methods.
392 """
393 channels = {}
394 shape = (0, 0)
395 jointMask: None | NDArray = None
396 maskDict: Mapping[str, int] = {}
397 doJointMaskInit = False
398 if jointMask is None:
399 doJointMask = True
400 doJointMaskInit = True
401 for channel, imageExposure in images.items():
402 imageArray = imageExposure.image.array
403 # run all the plugins designed for array based interaction
404 for plug in plugins.channel():
405 imageArray = plug(
406 imageArray, imageExposure.mask.array, imageExposure.mask.getMaskPlaneDict(), self.config
407 ).astype(np.float32)
408 channels[channel] = imageArray
409 # These operations are trivial look-ups and don't matter if they
410 # happen in each loop.
411 shape = imageArray.shape
412 maskDict = imageExposure.mask.getMaskPlaneDict()
413 if doJointMaskInit:
414 jointMask = np.zeros(shape, dtype=imageExposure.mask.dtype)
415 doJointMaskInit = False
416 if doJointMask:
417 jointMask |= imageExposure.mask.array
418
419 # mix the images to RGB
420 imageRArray = np.zeros(shape, dtype=np.float32)
421 imageGArray = np.zeros(shape, dtype=np.float32)
422 imageBArray = np.zeros(shape, dtype=np.float32)
423
424 for band, image in channels.items():
425 if band not in self.config.channelConfig:
426 self.log.info(f"{band} image found but not requested in RGB image, skipping")
427 continue
428 mix = self.config.channelConfig[band]
429 if mix.r:
430 imageRArray += mix.r * image
431 if mix.g:
432 imageGArray += mix.g * image
433 if mix.b:
434 imageBArray += mix.b * image
435
436 exposure = next(iter(images.values()))
437 box: Box2I = exposure.getBBox()
438 boxCenter = box.getCenter()
439 try:
440 psf = exposure.psf.computeImage(boxCenter).array
441 except Exception:
442 psf = None
443
444 if self.config.recenterNoise:
445 self._match_sigmas_and_recenter(
446 imageRArray, imageGArray, imageBArray, factor=self.config.recenterNoise
447 )
448
449 # assert for typing reasons
450 assert jointMask is not None
451 # Run any image level correction plugins
452 colorImage = np.zeros((*imageRArray.shape, 3))
453 colorImage[:, :, 0] = imageRArray
454 colorImage[:, :, 1] = imageGArray
455 colorImage[:, :, 2] = imageBArray
456 for plug in plugins.partial():
457 colorImage = plug(colorImage, jointMask, maskDict, self.config)
458
459 # Filter the local contrast parameters for diffusion that are None
460 # This is so we only apply key word overrides that are specifically set.
461 local_contrast_config = self.config.localContrastConfig.toDict()
462 to_remove = []
463 for k, v in local_contrast_config["diffusionFunction"].items():
464 if v is None:
465 to_remove.append(k)
466 for item in to_remove:
467 local_contrast_config["diffusionControl"].pop(item)
468
469 colorImage = lsstRGB(
470 colorImage[:, :, 0],
471 colorImage[:, :, 1],
472 colorImage[:, :, 2],
473 local_contrast=self.config.localContrastConfig if self.config.doLocalContrast else None,
474 scale_lum=self.config.luminanceConfig,
475 scale_color=self.config.colorConfig,
476 remap_bounds=self.config.imageRemappingConfig,
477 bracketing_function=(
478 self.config.exposureBracketerConfig if self.config.doExposureBrackets else None
479 ),
480 gamut_remapping_function=self.config.gamutMapperConfig if self.config.doRemapGamut else None,
481 cieWhitePoint=tuple(self.config.cieWhitePoint), # type: ignore
482 psf=psf if self.config.doPsfDeconvolve else None,
483 )
484
485 # Find the dataset type and thus the maximum values as well
486 maxVal: int | float
487 match self.config.arrayType:
488 case "uint8":
489 dtype = np.uint8
490 maxVal = 255
491 case "uint16":
492 dtype = np.uint16
493 maxVal = 65535
494 case "half":
495 dtype = np.half
496 maxVal = 1.0
497 case "float":
498 dtype = np.float32
499 maxVal = 1.0
500 case _:
501 assert True, "This code path should be unreachable"
502
503 # lsstRGB returns an image in 0-1 scale it to the maximum value
504 colorImage *= maxVal # type: ignore
505
506 # pack the joint mask back into a mask object
507 lsstMask = Mask(width=jointMask.shape[1], height=jointMask.shape[0], planeDefs=maskDict)
508 lsstMask.array = jointMask # type: ignore
509 return Struct(outputRGB=colorImage.astype(dtype), outputRGBMask=lsstMask) # type: ignore
510
512 self,
513 butlerQC: QuantumContext,
514 inputRefs: InputQuantizedConnection,
515 outputRefs: OutputQuantizedConnection,
516 ) -> None:
517 imageRefs: list[DatasetRef] = inputRefs.inputCoadds
518 sortedImages = self.makeInputsFromRefs(imageRefs, butlerQC)
519 if not sortedImages:
520 requested = ", ".join(self.config.channelConfig.keys())
521 raise NoWorkFound(f"No input images of band(s) {requested}")
522 outputs = self.run(sortedImages)
523 butlerQC.put(outputs, outputRefs)
524
526 self, refs: Iterable[DatasetRef], butler: Butler | QuantumContext
527 ) -> dict[str, Exposure]:
528 r"""Make valid inputs for the run method from butler references.
529
530 Parameters
531 ----------
532 refs : `Iterable` of `DatasetRef`
533 Some `Iterable` container of `Butler` `DatasetRef`\ s
534 butler : `Butler` or `QuantumContext`
535 This is the object that fetches the input data.
536
537 Returns
538 -------
539 sortedImages : `dict` of `str` to `Exposure`
540 A dictionary of `Exposure`\ s keyed by the band they
541 correspond to.
542 """
543 sortedImages: dict[str, Exposure] = {}
544 for ref in refs:
545 key: str = cast(str, ref.dataId["band"])
546 image = butler.get(ref)
547 sortedImages[key] = image
548 return sortedImages
549
550 def makeInputsFromArrays(self, **kwargs) -> dict[str, DeferredDatasetHandle]:
551 r"""Make valid inputs for the run method from numpy arrays.
552
553 Parameters
554 ----------
555 kwargs : `numpy.ndarray`
556 This is standard python kwargs where the left side of the equals
557 is the data band, and the right side is the corresponding `numpy.ndarray`
558 array.
559
560 Returns
561 -------
562 sortedImages : `dict` of `str` to \
563 `~lsst.daf.butler.DeferredDatasetHandle`
564 A dictionary of `~lsst.daf.butlger.DeferredDatasetHandle`\ s keyed
565 by the band they correspond to.
566 """
567 # ignore type because there aren't proper stubs for afw
568 temp = {}
569 for key, array in kwargs.items():
570 temp[key] = Exposure(Box2I(Point2I(0, 0), Extent2I(*array.shape)), dtype=array.dtype)
571 temp[key].image.array[:] = array
572
573 return self.makeInputsFromExposures(**temp)
574
575 def makeInputsFromExposures(self, **kwargs) -> dict[int, DeferredDatasetHandle]:
576 r"""Make valid inputs for the run method from `Exposure` objects.
577
578 Parameters
579 ----------
580 kwargs : `Exposure`
581 This is standard python kwargs where the left side of the equals
582 is the data band, and the right side is the corresponding
583 `Exposure`.
584
585 Returns
586 -------
587 sortedImages : `dict` of `int` to \
588 `~lsst.daf.butler.DeferredDatasetHandle`
589 A dictionary of `~lsst.daf.butler.DeferredDatasetHandle`\ s keyed
590 by the band they correspond to.
591 """
592 sortedImages = {}
593 for key, value in kwargs.items():
594 sortedImages[key] = value
595 return sortedImages
596
597
599 PipelineTaskConnections,
600 dimensions=("tract", "patch", "skymap", "band"),
601 defaultTemplates={"coaddTypeName": "deep"},
602):
603 inputCoadd = Input(
604 doc=("Input coadd for which the background is to be removed"),
605 name="{coaddTypeName}CoaddPsfMatched",
606 storageClass="ExposureF",
607 dimensions=("tract", "patch", "skymap", "band"),
608 )
609 outputCoadd = Output(
610 doc="The coadd with the background fixed and subtracted",
611 name="pretty_picture_coadd_bg_subtracted",
612 storageClass="ExposureF",
613 dimensions=("tract", "patch", "skymap", "band"),
614 )
615
616
617class PrettyPictureBackgroundFixerConfig(
618 PipelineTaskConfig, pipelineConnections=PrettyPictureBackgroundFixerConnections
619):
620 use_detection_mask = Field[bool](
621 doc="Use the detection mask to determine background instead of empirically finding it in this task",
622 default=False,
623 )
624 num_background_bins = Field[int](
625 doc="The number of bins along each axis when determining background", default=5
626 )
627 min_bin_fraction = Field[float](
628 doc="Bins with fewer pixels than this fraction of the total will be ignored", default=0.1
629 )
630
631 pos_sigma_multiplier = Field[float](
632 doc="How many sigma to consider as background in the positive direction", default=2
633 )
634
635
636class PrettyPictureBackgroundFixerTask(PipelineTask):
637 """Empirically flatten an images background.
638
639 Many astrophysical images have backgrounds with imperfections in them.
640 This Task attempts to determine control points which are considered
641 background values, and fits a radial basis function model to those
642 points. This model is then subtracted off the image.
643
644 """
645
646 _DefaultName = "prettyPictureBackgroundFixer"
647 ConfigClass = PrettyPictureBackgroundFixerConfig
648
649 config: ConfigClass
650
651 def _tile_slices(self, arr, R, C):
652 """Generate slices for tiling an array.
653
654 This function divides an array into a grid of tiles and returns a list of
655 slice objects representing each tile. It handles cases where the array
656 dimensions are not evenly divisible by the number of tiles in each
657 dimension, distributing the remainder among the tiles.
658
659 Parameters
660 ----------
661 arr : `numyp.ndarray`
662 The input array to be tiled. Used only to determine the array's shape.
663 R : `int`
664 The number of tiles in the row dimension.
665 C : `int`
666 The number of tiles in the column dimension.
667
668 Returns
669 -------
670 slices : `list` of `tuple`
671 A list of tuples, where each tuple contains two `slice` objects
672 representing the row and column slices for a single tile.
673 """
674 M = arr.shape[0]
675 N = arr.shape[1]
676
677 # Function to compute slices for a given dimension size and number of divisions
678 def get_slices(total_size: int, num_divisions: int) -> list[tuple[int, int]]:
679 """Generate slice ranges for dividing a size into equal parts.
680
681 Parameters
682 ----------
683 total_size : `int`
684 Total size to be divided into slices.
685 num_divisions : `int`
686 Number of divisions to create.
687
688 Returns
689 -------
690 `list` of `tuple` of `int`
691 List of (start, end) tuples representing each slice.
692
693 Notes
694 -----
695 This function divides the total_size into num_divisions equal parts.
696 If the division is not exact, the remainder is distributed by adding
697 1 to the first 'remainder' slices, ensuring balanced distribution.
698 """
699 base = total_size // num_divisions
700 remainder = total_size % num_divisions
701 slices = []
702 start = 0
703 for i in range(num_divisions):
704 end = start + base
705 if i < remainder:
706 end += 1
707 slices.append((start, end))
708 start = end
709 return slices
710
711 # Get row and column slices
712 row_slices = get_slices(M, R)
713 col_slices = get_slices(N, C)
714
715 # Generate all possible tile combinations of row and column slices
716 tiles = []
717 for rs in row_slices:
718 r_start, r_end = rs
719 for cs in col_slices:
720 c_start, c_end = cs
721 tile_slice = (slice(r_start, r_end), slice(c_start, c_end))
722 tiles.append(tile_slice)
723
724 return tiles
725
726 @staticmethod
727 def findBackgroundPixels(image, pos_sigma_mult=1):
728 """Find pixels that are likely to be background based on image statistics.
729
730 This method estimates background pixels by analyzing the distribution of
731 pixel values in the image. It uses the median as an estimate of the background
732 level and fits a half-normal distribution to values below the median to
733 determine the background sigma. Pixels below a threshold (mean + sigma) are
734 classified as background.
735
736 Parameters
737 ----------
738 image : `numpy.ndarray`
739 Input image array for which to find background pixels.
740 pos_sigma_mult : `float`
741 How many sigma to consider as background in the positive direction
742
743 Returns
744 -------
745 result : `numpy.ndarray`
746 Boolean mask array where True indicates background pixels.
747
748 Notes
749 -----
750 This method works best for images with relatively uniform background. It may
751 not perform well in fields with high density or diffuse flux, as noted in
752 the implementation comments.
753 """
754 # Find the median value in the image, which is likely to be
755 # close to average background. Note this doesn't work well
756 # in fields with high density or diffuse flux.
757 maxLikely = np.median(image, axis=None)
758
759 # find all the pixels that are fainter than this
760 # and find the std. This is just used as an initialization
761 # parameter and doesn't need to be accurate.
762 mask = image < maxLikely
763 initial_std = (image[mask] - maxLikely).std()
764
765 # Don't do anything if there are no pixels to check
766 if np.any(mask):
767 # use a minimizer to determine best mu and sigma for a Gaussian
768 # given only samples below the mean of the Gaussian.
769 mu_hat, sigma_hat = halfnorm.fit(np.abs(image[mask] - maxLikely))
770 # mu_hat = maxLikely
771 else:
772 mu_hat, sigma_hat = (maxLikely, 2 * initial_std)
773
774 # create a new masking threshold that is the determined
775 # mean plus std from the fit
776 threshhold = mu_hat + pos_sigma_mult * sigma_hat
777 image_mask = (image < threshhold) * (image > (mu_hat - 5 * sigma_hat))
778 return image_mask
779
780 def fixBackground(self, image, detection_mask=None):
781 """Estimate and subtract the background from an image.
782
783 This function estimates the background level in an image using a median-based
784 approach combined with Gaussian fitting and radial basis function interpolation.
785 It aims to provide a more accurate background estimation than a simple median
786 filter, especially in images with varying background levels.
787
788 Parameters
789 ----------
790 image : `numpy.ndarray`
791 The input image as a NumPy array.
792
793 Returns
794 -------
795 numpy.ndarray
796 An array representing the estimated background level across the image.
797 """
798 if detection_mask is None:
799 image_mask = self.findBackgroundPixels(image, self.config.pos_sigma_multiplier)
800 else:
801 image_mask = detection_mask
802
803 # create python slices that tile the image.
804 tiles = self._tile_slices(image, self.config.num_background_bins, self.config.num_background_bins)
805
806 yloc = []
807 xloc = []
808 values = []
809
810 # for each box find the middle position and the median background
811 # value in the window.
812 for xslice, yslice in tiles:
813 ypos = (yslice.stop - yslice.start) / 2 + yslice.start
814 xpos = (xslice.stop - xslice.start) / 2 + xslice.start
815 window = image[yslice, xslice][image_mask[yslice, xslice]]
816 # make sure each bin is at least 1% filled
817 min_fill = int((yslice.stop - yslice.start) ** 2 * self.config.min_bin_fraction)
818 if window.size > min_fill:
819 value = np.median(window)
820 else:
821 continue
822 values.append(value)
823 yloc.append(ypos)
824 xloc.append(xpos)
825
826 # At least 15 points are requred for TPS with 4th order polynomial
827 if len(yloc) < 15:
828 return np.zeros(image.shape)
829
830 # create an interpolant for the background and interpolate over the image.
831 inter = RBFInterpolator(
832 np.vstack((yloc, xloc)).T,
833 values,
834 kernel="thin_plate_spline",
835 degree=4,
836 smoothing=0.05,
837 neighbors=None,
838 )
839
840 backgrounds = rbf_interpolator.fast_rbf_interpolation_on_grid(inter, image.shape)
841
842 return backgrounds
843
844 def run(self, inputCoadd: Exposure):
845 """Estimate a background for an input Exposure and remove it.
846
847 Parameters
848 ----------
849 inputCoadd : `Exposure`
850 The exposure the background will be removed from.
851
852 Returns
853 -------
854 result : `Struct`
855 A `Struct` that contains the exposure with the background removed.
856 This `Struct` will have an attribute named ``outputCoadd``.
857
858 """
859 if self.config.use_detection_mask:
860 mask_plane_dict = inputCoadd.mask.getMaskPlaneDict()
861 detection_mask = ~(inputCoadd.mask.array & 2 ** mask_plane_dict["DETECTED"])
862 else:
863 detection_mask = None
864 background = self.fixBackground(inputCoadd.image.array, detection_mask=detection_mask)
865 # create a copy to mutate
866 output = ExposureF(inputCoadd, deep=True)
867 output.image.array -= background
868 return Struct(outputCoadd=output)
869
870
872 PipelineTaskConnections,
873 dimensions=("tract", "patch", "skymap"),
874):
875 inputCoadd = Input(
876 doc=("Input coadd for which the background is to be removed"),
877 name="pretty_picture_coadd_bg_subtracted",
878 storageClass="ExposureF",
879 dimensions=("tract", "patch", "skymap", "band"),
880 multiple=True,
881 )
882 outputCoadd = Output(
883 doc="The coadd with the background fixed and subtracted",
884 name="pretty_picture_coadd_fixed_stars",
885 storageClass="ExposureF",
886 dimensions=("tract", "patch", "skymap", "band"),
887 multiple=True,
888 )
889
890
891class PrettyPictureStarFixerConfig(PipelineTaskConfig, pipelineConnections=PrettyPictureStarFixerConnections):
892 brightnessThresh = Field[float](
893 doc="The flux value below which pixels with SAT or NO_DATA bits will be ignored"
894 )
895
896
897class PrettyPictureStarFixerTask(PipelineTask):
898 """This class fixes up regions in an image where there is no, or bad data.
899
900 The fixes done by this task are overwhelmingly comprised of the cores of
901 bright stars for which there is no data.
902 """
903
904 _DefaultName = "prettyPictureStarFixer"
905 ConfigClass = PrettyPictureStarFixerConfig
906
907 config: ConfigClass
908
909 def run(self, inputs: Mapping[str, ExposureF]) -> Struct:
910 """Fix areas in an image where this is no data, most likely to be
911 the cores of bright stars.
912
913 Because we want to have consistent fixes accross bands, this method
914 relies on supplying all bands and fixing pixels that are marked
915 as having a defect in any band even if within one band there is
916 no issue.
917
918 Parameters
919 ----------
920 inputs : `Mapping` of `str` to `ExposureF`
921 This mapping has keys of band as a `str` and the corresponding
922 ExposureF as a value.
923
924 Returns
925 -------
926 results : `Struct` of `Mapping` of `str` to `ExposureF`
927 A `Struct` that has a mapping of band to `ExposureF`. The `Struct`
928 has an attribute named ``results``.
929
930 """
931 # make the joint mask of all the channels
932 doJointMaskInit = True
933 for imageExposure in inputs.values():
934 maskDict = imageExposure.mask.getMaskPlaneDict()
935 if doJointMaskInit:
936 jointMask = np.zeros(imageExposure.mask.array.shape, dtype=imageExposure.mask.array.dtype)
937 doJointMaskInit = False
938 jointMask |= imageExposure.mask.array
939
940 sat_bit = maskDict["SAT"]
941 no_data_bit = maskDict["NO_DATA"]
942 together = (jointMask & 2**sat_bit).astype(bool) | (jointMask & 2**no_data_bit).astype(bool)
943
944 # use the last imageExposure as it is likely close enough across all bands
945 bright_mask = imageExposure.image.array > self.config.brightnessThresh
946
947 # dilate the mask a bit, this helps get a bit fainter mask without starting
948 # to include pixels in an irregular shape, as only the star cores should be
949 # fixed.
950 both = together & bright_mask
951 struct = np.array(((0, 1, 0), (1, 1, 1), (0, 1, 0)), dtype=bool)
952 both = binary_dilation(both, struct, iterations=4).astype(bool)
953
954 # do the actual fixing of values
955 results = {}
956 for band, imageExposure in inputs.items():
957 if np.sum(both) > 0:
958 inpainted = inpaint_biharmonic(imageExposure.image.array, both, split_into_regions=True)
959 imageExposure.image.array[both] = inpainted[both]
960 results[band] = imageExposure
961 return Struct(results=results)
962
964 self,
965 butlerQC: QuantumContext,
966 inputRefs: InputQuantizedConnection,
967 outputRefs: OutputQuantizedConnection,
968 ) -> None:
969 refs = inputRefs.inputCoadd
970 sortedImages: dict[str, Exposure] = {}
971 for ref in refs:
972 key: str = cast(str, ref.dataId["band"])
973 image = butlerQC.get(ref)
974 sortedImages[key] = image
975
976 outputs = self.run(sortedImages).results
977 sortedOutputs = {}
978 for ref in outputRefs.outputCoadd:
979 sortedOutputs[ref.dataId["band"]] = ref
980
981 for band, data in outputs.items():
982 butlerQC.put(data, sortedOutputs[band])
983
984
985class PrettyMosaicConnections(PipelineTaskConnections, dimensions=("tract", "skymap")):
986 inputRGB = Input(
987 doc="Individual RGB images that are to go into the mosaic",
988 name="rgb_picture_array",
989 storageClass="NumpyArray",
990 dimensions=("tract", "patch", "skymap"),
991 multiple=True,
992 deferLoad=True,
993 )
994
995 skyMap = Input(
996 doc="The skymap which the data has been mapped onto",
997 storageClass="SkyMap",
998 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
999 dimensions=("skymap",),
1000 )
1001
1002 inputRGBMask = Input(
1003 doc="Individual RGB images that are to go into the mosaic",
1004 name="rgb_picture_mask",
1005 storageClass="Mask",
1006 dimensions=("tract", "patch", "skymap"),
1007 multiple=True,
1008 deferLoad=True,
1009 )
1010
1011 outputRGBMosaic = Output(
1012 doc="A RGB mosaic created from the input data stored as a 3d array",
1013 name="rgb_mosaic_array",
1014 storageClass="NumpyArray",
1015 dimensions=("tract", "skymap"),
1016 )
1017
1018
1019class PrettyMosaicConfig(PipelineTaskConfig, pipelineConnections=PrettyMosaicConnections):
1020 binFactor = Field[int](doc="The factor to bin by when producing the mosaic")
1021 doDCID65Convert = Field[bool]("Force the output to be converted from display p3 to DCI-D65 colorspace.")
1022 useLocalTemp = Field[bool](doc="Use the current directory when creating local temp files.", default=False)
1023
1024
1025class PrettyMosaicTask(PipelineTask):
1026 """Combines multiple RGB arrays into one mosaic."""
1027
1028 _DefaultName = "prettyMosaic"
1029 ConfigClass = PrettyMosaicConfig
1030
1031 config: ConfigClass
1032
1033 def run(
1034 self,
1035 inputRGB: Iterable[DeferredDatasetHandle],
1036 skyMap: BaseSkyMap,
1037 inputRGBMask: Iterable[DeferredDatasetHandle],
1038 ) -> Struct:
1039 r"""Assemble individual `numpy.ndarrays` into a mosaic.
1040
1041 Each input is a `~lsst.daf.butler.DeferredDatasetHandle` because
1042 they're loaded in one at a time to be placed into the mosaic to save
1043 memory.
1044
1045 Parameters
1046 ----------
1047 inputRGB : `Iterable` of `~lsst.daf.butler.DeferredDatasetHandle`
1048 `~lsst.daf.butler.DeferredDatasetHandle`\ s pointing to RGB
1049 `numpy.ndarrays`.
1050 skyMap : `BaseSkyMap`
1051 The skymap that defines the relative position of each of the input
1052 images.
1053 inputRGBMask : `Iterable` of `~lsst.daf.butler.DeferredDatasetHandle`
1054 `~lsst.daf.butler.DeferredDatasetHandle`\ s pointing to masks for
1055 each of the corresponding images.
1056
1057 Returns
1058 -------
1059 result : `Struct`
1060 The `Struct` containing the combined mosaic. The `Struct` has
1061 and attribute named ``outputRGBMosaic``.
1062 """
1063 # create the bounding region
1064 newBox = Box2I()
1065 # store the bounds as they are retrieved from the skymap
1066 boxes = []
1067 tractMaps = []
1068 for handle in inputRGB:
1069 dataId = handle.dataId
1070 tractInfo: TractInfo = skyMap[dataId["tract"]]
1071 patchInfo: PatchInfo = tractInfo[dataId["patch"]]
1072 bbox = patchInfo.getOuterBBox()
1073 boxes.append(bbox)
1074 newBox.include(bbox)
1075 tractMaps.append(tractInfo)
1076 # This will be overwritten in the loop, but that is ok, because
1077 # it is the same for each patch.
1078 patch_grow: int = patchInfo.getCellInnerDimensions().getX()
1079
1080 # fixup the boxes to be smaller if needed, and put the origin at zero,
1081 # this must be done after constructing the complete outer box
1082 modifiedBoxes = []
1083 origin = newBox.getBegin()
1084 for iterBox in boxes:
1085 localOrigin = iterBox.getBegin() - origin
1086 localOrigin = Point2I(
1087 x=int(np.floor(localOrigin.x / self.config.binFactor)),
1088 y=int(np.floor(localOrigin.y / self.config.binFactor)),
1089 )
1090 localExtent = Extent2I(
1091 x=int(np.floor(iterBox.getWidth() / self.config.binFactor)),
1092 y=int(np.floor(iterBox.getHeight() / self.config.binFactor)),
1093 )
1094 tmpBox = Box2I(localOrigin, localExtent)
1095 modifiedBoxes.append(tmpBox)
1096 boxes = modifiedBoxes
1097
1098 # scale the container box
1099 newBoxOrigin = Point2I(0, 0)
1100 newBoxExtent = Extent2I(
1101 x=int(np.floor(newBox.getWidth() / self.config.binFactor)),
1102 y=int(np.floor(newBox.getHeight() / self.config.binFactor)),
1103 )
1104 newBox = Box2I(newBoxOrigin, newBoxExtent)
1105
1106 # Allocate storage for the mosaic
1107 self.imageHandle = tempfile.NamedTemporaryFile(dir="." if self.config.useLocalTemp else None)
1108 self.maskHandle = tempfile.NamedTemporaryFile(dir="." if self.config.useLocalTemp else None)
1109 consolidatedImage = None
1110 consolidatedMask = None
1111
1112 # Setup color space conversion in case they are used.
1113 d65 = copy.deepcopy(colour.models.RGB_COLOURSPACE_DCI_P3)
1114 dp3 = copy.deepcopy(colour.models.RGB_COLOURSPACE_DISPLAY_P3)
1115 d65.whitepoint = dp3.whitepoint
1116 d65.whitepoint_name = dp3.whitepoint_name
1117
1118 # Actually assemble the mosaic
1119 maskDict = {}
1120 mosaic_maker = FeatheredMosaicCreator(patch_grow, self.config.binFactor)
1121 for box, handle, handleMask, tractInfo in zip(boxes, inputRGB, inputRGBMask, tractMaps):
1122 rgb = handle.get()
1123 # convert to the dci-d65 colorspace
1124 if self.config.doDCID65Convert:
1125 rgb = colour.RGB_to_RGB(np.clip(rgb, 0, 1), dp3, d65)
1126 rgbMask = handleMask.get()
1127 maskDict = rgbMask.getMaskPlaneDict()
1128 # allocate the memory for the mosaic
1129 if consolidatedImage is None:
1130 consolidatedImage = np.memmap(
1131 self.imageHandle.name,
1132 mode="w+",
1133 shape=(newBox.getHeight(), newBox.getWidth(), 3),
1134 dtype=rgb.dtype,
1135 )
1136 if consolidatedMask is None:
1137 consolidatedMask = np.memmap(
1138 self.maskHandle.name,
1139 mode="w+",
1140 shape=(newBox.getHeight(), newBox.getWidth()),
1141 dtype=rgbMask.array.dtype,
1142 )
1143
1144 if self.config.binFactor > 1:
1145 # opencv wants things in x, y dimensions
1146 shape = tuple(box.getDimensions())[::-1]
1147 rgb = cv2.resize(
1148 rgb,
1149 dst=None,
1150 dsize=shape,
1151 fx=shape[0] / self.config.binFactor,
1152 fy=shape[1] / self.config.binFactor,
1153 )
1154 mask_array = rgbMask.array[:: self.config.binFactor, :: self.config.binFactor]
1155 rgbMask = Mask(*(mask_array.shape[::-1]))
1156 mosaic_maker.add_to_image(consolidatedImage, rgb, newBox, box)
1157
1158 consolidatedMask[*box.slices] = np.bitwise_or(consolidatedMask[*box.slices], rgbMask.array)
1159
1160 for plugin in plugins.full():
1161 if consolidatedImage is not None and consolidatedMask is not None:
1162 consolidatedImage = plugin(consolidatedImage, consolidatedMask, maskDict)
1163 # If consolidated image still None, that means there was no work to do.
1164 # Return an empty image instead of letting this task fail.
1165 if consolidatedImage is None:
1166 consolidatedImage = np.zeros((0, 0, 0), dtype=np.uint8)
1167
1168 return Struct(outputRGBMosaic=consolidatedImage)
1169
1171 self,
1172 butlerQC: QuantumContext,
1173 inputRefs: InputQuantizedConnection,
1174 outputRefs: OutputQuantizedConnection,
1175 ) -> None:
1176 inputs = butlerQC.get(inputRefs)
1177 outputs = self.run(**inputs)
1178 butlerQC.put(outputs, outputRefs)
1179 if hasattr(self, "imageHandle"):
1180 self.imageHandle.close()
1181 if hasattr(self, "maskHandle"):
1182 self.maskHandle.close()
1183
1185 self, inputs: Iterable[tuple[Mapping[str, Any], NDArray]]
1186 ) -> Iterable[DeferredDatasetHandle]:
1187 r"""Make valid inputs for the run method from numpy arrays.
1188
1189 Parameters
1190 ----------
1191 inputs : `Iterable` of `tuple` of `Mapping` and `numpy.ndarray`
1192 An iterable where each element is a tuple with the first
1193 element is a mapping that corresponds to an arrays dataId,
1194 and the second is an `numpy.ndarray`.
1195
1196 Returns
1197 -------
1198 sortedImages : `Iterable` of `~lsst.daf.butler.DeferredDatasetHandle`
1199 An iterable of `~lsst.daf.butler.DeferredDatasetHandle`\ s
1200 containing the input data.
1201 """
1202 structuredInputs = []
1203 for dataId, array in inputs:
1204 structuredInputs.append(InMemoryDatasetHandle(inMemoryDataset=array, **dataId))
1205
1206 return structuredInputs
Iterable[DeferredDatasetHandle] makeInputsFromArrays(self, Iterable[tuple[Mapping[str, Any], NDArray]] inputs)
Definition _task.py:1186
Struct run(self, Iterable[DeferredDatasetHandle] inputRGB, BaseSkyMap skyMap, Iterable[DeferredDatasetHandle] inputRGBMask)
Definition _task.py:1038
None runQuantum(self, QuantumContext butlerQC, InputQuantizedConnection inputRefs, OutputQuantizedConnection outputRefs)
Definition _task.py:1175
None runQuantum(self, QuantumContext butlerQC, InputQuantizedConnection inputRefs, OutputQuantizedConnection outputRefs)
Definition _task.py:968
Struct run(self, Mapping[str, ExposureF] inputs)
Definition _task.py:909
RGBImage lsstRGB(FloatImagePlane rArray, FloatImagePlane gArray, FloatImagePlane bArray, LocalContrastFunction|None|_SentinalDefault local_contrast=DEFAULT_FUNCTION, ScaleLumFunction|None|_SentinalDefault scale_lum=DEFAULT_FUNCTION, ScaleColorFunction|None|_SentinalDefault scale_color=DEFAULT_FUNCTION, RemapBoundsFunction|None|_SentinalDefault remap_bounds=DEFAULT_FUNCTION, BracketingFunction|None|_SentinalDefault bracketing_function=DEFAULT_FUNCTION, GamutRemappingFunction|None|_SentinalDefault gamut_remapping_function=DEFAULT_FUNCTION, FloatImagePlane|None psf=None, tuple[float, float] cieWhitePoint=(0.28, 0.28))
dict[str, Exposure] makeInputsFromRefs(self, Iterable[DatasetRef] refs, Butler|QuantumContext butler)
Definition _task.py:527
Struct run(self, Mapping[str, Exposure] images)
Definition _task.py:372
dict[int, DeferredDatasetHandle] makeInputsFromExposures(self, **kwargs)
Definition _task.py:575
fixBackground(self, image, detection_mask=None)
Definition _task.py:780
dict[str, DeferredDatasetHandle] makeInputsFromArrays(self, **kwargs)
Definition _task.py:550
None runQuantum(self, QuantumContext butlerQC, InputQuantizedConnection inputRefs, OutputQuantizedConnection outputRefs)
Definition _task.py:516
STL namespace.