lsst.pipe.tasks g540474b770+e939cf0e26
Loading...
Searching...
No Matches
_task.py
Go to the documentation of this file.
1# This file is part of pipe_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21
22from __future__ import annotations
23
24__all__ = (
25 "ChannelRGBConfig",
26 "PrettyPictureTask",
27 "PrettyPictureConnections",
28 "PrettyPictureConfig",
29 "PrettyMosaicTask",
30 "PrettyMosaicConnections",
31 "PrettyMosaicConfig",
32 "PrettyPictureBackgroundFixerConfig",
33 "PrettyPictureBackgroundFixerTask",
34 "PrettyPictureStarFixerConfig",
35 "PrettyPictureStarFixerTask",
36)
37
38import colour
39import copy
40from collections.abc import Iterable, Mapping
41from lsst.afw.image import ExposureF
42import numpy as np
43from typing import TYPE_CHECKING, cast, Any
44from lsst.skymap import BaseSkyMap
45
46from scipy.stats import halfnorm, mode
47from scipy.ndimage import binary_dilation
48from scipy.interpolate import RBFInterpolator
49from skimage.restoration import inpaint_biharmonic
50
51from lsst.daf.butler import Butler, DeferredDatasetHandle
52from lsst.daf.butler import DatasetRef
53from lsst.images import ColorImage, Projection, Box, TractFrame
54from lsst.pex.config import Field, Config, ConfigDictField, ListField, ChoiceField
55from lsst.pex.config.configurableActions import ConfigurableActionField
56from lsst.pipe.base import (
57 PipelineTask,
58 PipelineTaskConfig,
59 PipelineTaskConnections,
60 Struct,
61 InMemoryDatasetHandle,
62 NoWorkFound,
63)
64from lsst.rubinoxide import rbf_interpolator
65import cv2
66
67from lsst.pipe.base.connectionTypes import Input, Output
68from lsst.geom import Box2I, Point2I, Extent2I
69from lsst.afw.image import Exposure, Mask
70
71from ._plugins import plugins
72from ._colorMapper import lsstRGB
73from ._utils import FeatheredMosaicCreator
74from ._functors import (
75 BoundsRemapper,
76 ColorScaler,
77 LumCompressor,
78 ExposureBracketer,
79 GamutFixer,
80 LocalContrastEnhancer,
81)
82
83import tempfile
84
85
86if TYPE_CHECKING:
87 from numpy.typing import NDArray
88 from lsst.pipe.base import QuantumContext, InputQuantizedConnection, OutputQuantizedConnection
89 from lsst.skymap import TractInfo, PatchInfo
90
91
93 PipelineTaskConnections,
94 dimensions={"tract", "patch", "skymap"},
95 defaultTemplates={"coaddTypeName": "deep"},
96):
97 inputCoadds = Input(
98 doc=(
99 "Model of the static sky, used to find temporal artifacts. Typically a PSF-Matched, "
100 "sigma-clipped coadd. Written if and only if assembleStaticSkyModel.doWrite=True"
101 ),
102 name="pretty_coadd",
103 storageClass="ExposureF",
104 dimensions=("tract", "patch", "skymap", "band"),
105 multiple=True,
106 )
107
108 skyMap = Input(
109 doc="The skymap which the data has been mapped onto",
110 storageClass="SkyMap",
111 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
112 dimensions=("skymap",),
113 )
114
115 outputRGB = Output(
116 doc="A RGB image created from the input data stored as a 3d array",
117 name="rgb_picture",
118 storageClass="ColorImage",
119 dimensions=("tract", "patch", "skymap"),
120 )
121
122 outputRGBMask = Output(
123 doc="A Mask corresponding to the fused masks of the input channels",
124 name="rgb_picture_mask",
125 storageClass="Mask",
126 dimensions=("tract", "patch", "skymap"),
127 )
128
129
130class ChannelRGBConfig(Config):
131 """This describes the rgb values of a given input channel.
132
133 For instance if this channel is red the values would be self.r = 1,
134 self.g = 0, self.b = 0. If the channel was cyan the values would be
135 self.r = 0, self.g = 1, self.b = 1.
136 """
137
138 r = Field[float](doc="The amount of red contained in this channel")
139 g = Field[float](doc="The amount of green contained in this channel")
140 b = Field[float](doc="The amount of blue contained in this channel")
141
142
143class PrettyPictureConfig(PipelineTaskConfig, pipelineConnections=PrettyPictureConnections):
144 channelConfig = ConfigDictField(
145 doc="A dictionary that maps band names to their rgb channel configurations",
146 keytype=str,
147 itemtype=ChannelRGBConfig,
148 default={},
149 )
150 cieWhitePoint = ListField[float](
151 doc="The white point of the input arrays in ciexz coordinates", maxLength=2, default=[0.28, 0.28]
152 )
153 arrayType = ChoiceField[str](
154 doc="The dataset type for the output image array",
155 default="uint8",
156 allowed={
157 "uint8": "Use 8 bit arrays, 255 max",
158 "uint16": "Use 16 bit arrays, 65535 max",
159 "half": "Use 16 bit float arrays, 1 max",
160 "float": "Use 32 bit float arrays, 1 max",
161 },
162 )
163 recenterNoise = Field[float](
164 doc="Recenter the noise away from zero. Supplied value is in units of sigma",
165 optional=True,
166 default=None,
167 )
168 noiseSearchThreshold = Field[float](
169 doc=(
170 "Flux threshold below which most flux will be considered noise, used to estimate noise properties"
171 ),
172 default=10,
173 )
174 doPsfDeconvolve = Field[bool](
175 doc="Use the PSF in a Richardson-Lucy deconvolution on the luminance channel.", default=False
176 )
177 doPSFDeconcovlve = Field[bool](
178 doc="Use the PSF in a Richardson-Lucy deconvolution on the luminance channel.",
179 default=False,
180 deprecated="This field will be removed in v32. Use doPsfDeconvolve instead.",
181 optional=True,
182 )
183 doRemapGamut = Field[bool](
184 doc="Apply a color correction to unrepresentable colors; if False, clip them.", default=True
185 )
186 doExposureBrackets = Field[bool](
187 doc="Apply exposure bracketing to aid in dynamic range compression", default=True
188 )
189 doLocalContrast = Field[bool](doc="Apply local contrast optimizations to luminance.", default=True)
190
191 imageRemappingConfig = ConfigurableActionField[BoundsRemapper](
192 doc="Action controlling normalization process"
193 )
194 luminanceConfig = ConfigurableActionField[LumCompressor](
195 doc="Action controlling luminance scaling when making an RGB image"
196 )
197 localContrastConfig = ConfigurableActionField[LocalContrastEnhancer](
198 doc="Action controlling the local contrast correction in RGB image production"
199 )
200 colorConfig = ConfigurableActionField[ColorScaler](
201 doc="Action to control the color scaling process in RGB image production"
202 )
203 exposureBracketerConfig = ConfigurableActionField[ExposureBracketer](
204 doc=(
205 "Exposure scaling action used in creating multiple exposures with different scalings which will "
206 "then be fused into a final image"
207 ),
208 )
209 gamutMapperConfig = ConfigurableActionField[GamutFixer](
210 doc="Action to fix pixels which lay outside RGB color gamut"
211 )
212
213 exposureBrackets = ListField[float](
214 doc=(
215 "Exposure scaling factors used in creating multiple exposures with different scalings which will "
216 "then be fused into a final image"
217 ),
218 optional=True,
219 default=[1.25, 1, 0.75],
220 deprecated=(
221 "This field will stop working in v31 and be removed in v32, "
222 "please set exposureBracketerConfig.exposureBrackets"
223 ),
224 )
225 gamutMethod = ChoiceField[str](
226 doc="If doRemapGamut is True this determines the method",
227 default="inpaint",
228 allowed={
229 "mapping": "Use a mapping function",
230 "inpaint": "Use surrounding pixels to determine likely value",
231 },
232 deprecated="This field will stop working in v31 and be removed in v32, please set gamutMapperConfig",
233 )
234
235 def setDefaults(self):
236 self.channelConfig["i"] = ChannelRGBConfig(r=1, g=0, b=0)
237 self.channelConfig["r"] = ChannelRGBConfig(r=0, g=1, b=0)
238 self.channelConfig["g"] = ChannelRGBConfig(r=0, g=0, b=1)
239 return super().setDefaults()
240
241 def _handle_deprecated(self):
242 """Handle deprecated configuration migration.
243
244 This method migrates deprecated configuration fields to their new
245 locations in sub-configurations. It checks the configuration history
246 to determine if deprecated fields were explicitly set and updates
247 the new configuration locations accordingly.
248
249 Notes
250 -----
251 The following deprecated fields are migrated:
252 - ``gamutMethod`` -> ``gamutMapperConfig.gamutMethod``
253 - ``exposureBrackets`` -> ``exposureBracketerConfig.exposureBrackets``
254 - ``doLocalContrast`` -> ``localContrastConfig.doLocalContrast``
255 - ``doPSFDeconcovlve`` -> ``doPsfDeconvolve``
256 """
257 # check if gamutMethod is set
258 if len(self._history["gamutMethod"]) > 1:
259 # This has been set in config, update it in the new location
260 self.gamutMapperConfig.gamutMethod = self.gamutMethod
261
262 if len(self._history["exposureBrackets"]) > 1:
263 self.exposureBracketerConfig.exposureBrackets = self.exposureBrackets
264 if self.exposureBrackets is None:
265 self.doExposureBrackets = False
266
267 if len(self.localContrastConfig._history["doLocalContrast"]) > 1:
268 self.doLocalContrast = self.localContrastConfig.doLocalContrast
269
270 # Handle doPsfDeconcovlve typo fix
271 if len(self._history["doPSFDeconcovlve"]) > 1:
272 self.doPsfDeconvolve = self.doPSFDeconcovlve
273
274 def freeze(self):
275 # ensure this is not already frozen
276 if self._frozen is not True:
277 self._handle_deprecated()
278 super().freeze()
279
280
281class PrettyPictureTask(PipelineTask):
282 """Turns inputs into an RGB image."""
283
284 _DefaultName = "prettyPicture"
285 ConfigClass = PrettyPictureConfig
286
287 config: ConfigClass
288
289 def _find_normal_stats(self, array):
290 """Calculate standard deviation from negative values using half-normal distribution.
291
292 Raises
293 ------
294 ValueError
295 Array dimension validation fails.
296
297 Parameters
298 ----------
299 array : `numpy.array`
300 Input array of numerical values.
301
302 Returns
303 -------
304 mean : `float`
305 The central moment of the distribution
306 sigma : `float`
307 Estimated standard deviation from negative values. Returns np.inf if:
308 - No negative values exist in the array
309 - Half-normal fitting fails
310 """
311 # Extract negative values efficiently
312 values_noise = array[array < self.config.noiseSearchThreshold]
313
314 # find the mode
315 center = mode(np.round(values_noise, 2)).mode
316
317 # extract the negative values
318 values_neg = array[array < center]
319
320 # Return infinity if no negative values found
321 if values_neg.size == 0:
322 return 0, np.inf
323
324 try:
325 # Fit half-normal distribution to absolute negative values
326 mu, sigma = halfnorm.fit(np.abs(values_neg))
327 except (ValueError, RuntimeError):
328 # Handle fitting failures (e.g., constant data, optimization issues)
329 return 0, np.inf
330
331 return center, sigma
332
333 def _match_sigmas_and_recenter(self, *arrays, factor=1):
334 """Scale array values to match minimum standard deviation across arrays
335 and recenter noise.
336
337 Adjusts values below each array's sigma by scaling and shifting them to
338 align with the minimum sigma value across all input arrays. This operates
339 in-place for efficiency.
340
341 Parameters
342 ----------
343 *arrays : any number of `numpy.array`
344 Variable number of input arrays to process.
345 factor : float, optional
346 Scaling factor for adjustments (default: 1).
347
348 """
349 # Calculate standard deviations for all arrays
350 sigmas = []
351 mus = []
352 for arr in arrays:
353 m, s = self._find_normal_stats(arr)
354 mus.append(m)
355 sigmas.append(s)
356 mus = np.array(mus)
357 sigmas = np.array(sigmas)
358
359 # If no sigmas could be determined, return the original
360 # arrays.
361 if not np.any(np.isfinite(sigmas)):
362 return
363
364 min_sig = np.min(sigmas)
365
366 for mu, sigma, array in zip(mus, sigmas, arrays):
367 # Identify values below the array's sigma threshold
368 lower_pos = (array - mu) < sigma
369
370 # Skip processing if sigma is invalid
371 if not np.isfinite(sigma):
372 continue
373
374 # Calculate scaling ratio relative to minimum sigma
375 sigma_ratio = min_sig / sigma
376
377 # Apply adjustment to qualifying values
378 array[lower_pos] = (array[lower_pos] - mu) * sigma_ratio + min_sig * factor
379
380 def run(
381 self,
382 images: Mapping[str, Exposure],
383 image_wcs: Projection[Any] | None = None,
384 image_box: Box | None = None,
385 ) -> Struct:
386 """Turns the input arguments in arguments into an RGB array.
387
388 Parameters
389 ----------
390 images : `Mapping` of `str` to `Exposure`
391 A mapping of input images and the band they correspond to.
392 image_wcs : `~lsst.images.Projection`, optional
393 A projection describing the sky coordinate of each pixel.
394 image_box : `~lsst.images.Box`, optional
395 A box that defines this image as part of a larger region.
396
397 Returns
398 -------
399 result : `Struct`
400 A struct with the corresponding RGB image, and mask used in
401 RGB image construction. The struct will have the attributes
402 outputRGB and outputRGBMask. Each of the outputs will
403 be a `~lsst.images.ColorImage` object.
404
405 Notes
406 -----
407 Construction of input images are made easier by use of the
408 makeInputsFrom* methods.
409 """
410 channels = {}
411 shape = (0, 0)
412 jointMask: None | NDArray = None
413 maskDict: Mapping[str, int] = {}
414 doJointMaskInit = False
415 if jointMask is None:
416 doJointMask = True
417 doJointMaskInit = True
418 for channel, imageExposure in images.items():
419 imageArray = imageExposure.image.array
420 # run all the plugins designed for array based interaction
421 for plug in plugins.channel():
422 imageArray = plug(
423 imageArray, imageExposure.mask.array, imageExposure.mask.getMaskPlaneDict(), self.config
424 ).astype(np.float32)
425 channels[channel] = imageArray
426 # These operations are trivial look-ups and don't matter if they
427 # happen in each loop.
428 shape = imageArray.shape
429 maskDict = imageExposure.mask.getMaskPlaneDict()
430 if doJointMaskInit:
431 jointMask = np.zeros(shape, dtype=imageExposure.mask.dtype)
432 doJointMaskInit = False
433 if doJointMask:
434 jointMask |= imageExposure.mask.array
435
436 # mix the images to RGB
437 imageRArray = np.zeros(shape, dtype=np.float32)
438 imageGArray = np.zeros(shape, dtype=np.float32)
439 imageBArray = np.zeros(shape, dtype=np.float32)
440
441 for band, image in channels.items():
442 if band not in self.config.channelConfig:
443 self.log.info(f"{band} image found but not requested in RGB image, skipping")
444 continue
445 mix = self.config.channelConfig[band]
446 if mix.r:
447 imageRArray += mix.r * image
448 if mix.g:
449 imageGArray += mix.g * image
450 if mix.b:
451 imageBArray += mix.b * image
452
453 exposure = next(iter(images.values()))
454 box: Box2I = exposure.getBBox()
455 boxCenter = box.getCenter()
456 try:
457 psf = exposure.psf.computeImage(boxCenter).array
458 except Exception:
459 psf = None
460
461 if self.config.recenterNoise:
462 self._match_sigmas_and_recenter(
463 imageRArray, imageGArray, imageBArray, factor=self.config.recenterNoise
464 )
465
466 # assert for typing reasons
467 assert jointMask is not None
468 # Run any image level correction plugins
469 colorImage = np.zeros((*imageRArray.shape, 3))
470 colorImage[:, :, 0] = imageRArray
471 colorImage[:, :, 1] = imageGArray
472 colorImage[:, :, 2] = imageBArray
473 for plug in plugins.partial():
474 colorImage = plug(colorImage, jointMask, maskDict, self.config)
475
476 # Filter the local contrast parameters for diffusion that are None
477 # This is so we only apply key word overrides that are specifically set.
478 local_contrast_config = self.config.localContrastConfig.toDict()
479 to_remove = []
480 for k, v in local_contrast_config["diffusionFunction"].items():
481 if v is None:
482 to_remove.append(k)
483 for item in to_remove:
484 local_contrast_config["diffusionControl"].pop(item)
485
486 colorImage = lsstRGB(
487 colorImage[:, :, 0],
488 colorImage[:, :, 1],
489 colorImage[:, :, 2],
490 local_contrast=self.config.localContrastConfig if self.config.doLocalContrast else None,
491 scale_lum=self.config.luminanceConfig,
492 scale_color=self.config.colorConfig,
493 remap_bounds=self.config.imageRemappingConfig,
494 bracketing_function=(
495 self.config.exposureBracketerConfig if self.config.doExposureBrackets else None
496 ),
497 gamut_remapping_function=self.config.gamutMapperConfig if self.config.doRemapGamut else None,
498 cieWhitePoint=tuple(self.config.cieWhitePoint), # type: ignore
499 psf=psf if self.config.doPsfDeconvolve else None,
500 )
501
502 # Find the dataset type and thus the maximum values as well
503 maxVal: int | float
504 match self.config.arrayType:
505 case "uint8":
506 dtype = np.uint8
507 maxVal = 255
508 case "uint16":
509 dtype = np.uint16
510 maxVal = 65535
511 case "half":
512 dtype = np.half
513 maxVal = 1.0
514 case "float":
515 dtype = np.float32
516 maxVal = 1.0
517 case _:
518 assert True, "This code path should be unreachable"
519
520 # lsstRGB returns an image in 0-1 scale it to the maximum value
521 colorImage *= maxVal # type: ignore
522
523 # pack the joint mask back into a mask object
524 lsstMask = Mask(width=jointMask.shape[1], height=jointMask.shape[0], planeDefs=maskDict)
525 lsstMask.array = jointMask # type: ignore
526 return Struct(
527 outputRGB=ColorImage(colorImage.astype(dtype), bbox=image_box, projection=image_wcs),
528 outputRGBMask=lsstMask,
529 ) # type: ignore
530
532 self,
533 butlerQC: QuantumContext,
534 inputRefs: InputQuantizedConnection,
535 outputRefs: OutputQuantizedConnection,
536 ) -> None:
537 imageRefs: list[DatasetRef] = inputRefs.inputCoadds
538 sortedImages = self.makeInputsFromRefs(imageRefs, butlerQC)
539 if not sortedImages:
540 requested = ", ".join(self.config.channelConfig.keys())
541 raise NoWorkFound(f"No input images of band(s) {requested}")
542
543 # get the patch tract bounding box and wcs
544 skymap = butlerQC.get(inputRefs.skyMap)
545 quantumDataId = butlerQC.quantum.dataId
546 tractInfo = skymap[quantumDataId["tract"]]
547 patchInfo = tractInfo[quantumDataId["patch"]]
548 outputs = self.run(
549 images=sortedImages,
550 image_wcs=Projection.from_legacy(
551 patchInfo.wcs,
552 TractFrame(
553 skymap=quantumDataId["skymap"],
554 tract=quantumDataId["tract"],
555 bbox=Box.from_legacy(tractInfo.bbox),
556 ),
557 ),
558 image_box=Box.from_legacy(patchInfo.getOuterBBox()),
559 )
560 butlerQC.put(outputs, outputRefs)
561
563 self, refs: Iterable[DatasetRef], butler: Butler | QuantumContext
564 ) -> dict[str, Exposure]:
565 r"""Make valid inputs for the run method from butler references.
566
567 Parameters
568 ----------
569 refs : `Iterable` of `DatasetRef`
570 Some `Iterable` container of `Butler` `DatasetRef`\ s
571 butler : `Butler` or `QuantumContext`
572 This is the object that fetches the input data.
573
574 Returns
575 -------
576 sortedImages : `dict` of `str` to `Exposure`
577 A dictionary of `Exposure`\ s keyed by the band they
578 correspond to.
579 """
580 sortedImages: dict[str, Exposure] = {}
581 for ref in refs:
582 key: str = cast(str, ref.dataId["band"])
583 image = butler.get(ref)
584 sortedImages[key] = image
585 return sortedImages
586
587 def makeInputsFromArrays(self, **kwargs) -> dict[str, DeferredDatasetHandle]:
588 r"""Make valid inputs for the run method from numpy arrays.
589
590 Parameters
591 ----------
592 kwargs : `numpy.ndarray`
593 This is standard python kwargs where the left side of the equals
594 is the data band, and the right side is the corresponding `numpy.ndarray`
595 array.
596
597 Returns
598 -------
599 sortedImages : `dict` of `str` to \
600 `~lsst.daf.butler.DeferredDatasetHandle`
601 A dictionary of `~lsst.daf.butlger.DeferredDatasetHandle`\ s keyed
602 by the band they correspond to.
603 """
604 # ignore type because there aren't proper stubs for afw
605 temp = {}
606 for key, array in kwargs.items():
607 temp[key] = Exposure(Box2I(Point2I(0, 0), Extent2I(*array.shape)), dtype=array.dtype)
608 temp[key].image.array[:] = array
609
610 return self.makeInputsFromExposures(**temp)
611
612 def makeInputsFromExposures(self, **kwargs) -> dict[int, DeferredDatasetHandle]:
613 r"""Make valid inputs for the run method from `Exposure` objects.
614
615 Parameters
616 ----------
617 kwargs : `Exposure`
618 This is standard python kwargs where the left side of the equals
619 is the data band, and the right side is the corresponding
620 `Exposure`.
621
622 Returns
623 -------
624 sortedImages : `dict` of `int` to \
625 `~lsst.daf.butler.DeferredDatasetHandle`
626 A dictionary of `~lsst.daf.butler.DeferredDatasetHandle`\ s keyed
627 by the band they correspond to.
628 """
629 sortedImages = {}
630 for key, value in kwargs.items():
631 sortedImages[key] = value
632 return sortedImages
633
634
636 PipelineTaskConnections,
637 dimensions=("tract", "patch", "skymap", "band"),
638 defaultTemplates={"coaddTypeName": "deep"},
639):
640 inputCoadd = Input(
641 doc=("Input coadd for which the background is to be removed"),
642 name="{coaddTypeName}CoaddPsfMatched",
643 storageClass="ExposureF",
644 dimensions=("tract", "patch", "skymap", "band"),
645 )
646 outputCoadd = Output(
647 doc="The coadd with the background fixed and subtracted",
648 name="pretty_picture_coadd_bg_subtracted",
649 storageClass="ExposureF",
650 dimensions=("tract", "patch", "skymap", "band"),
651 )
652
653
654class PrettyPictureBackgroundFixerConfig(
655 PipelineTaskConfig, pipelineConnections=PrettyPictureBackgroundFixerConnections
656):
657 use_detection_mask = Field[bool](
658 doc="Use the detection mask to determine background instead of empirically finding it in this task",
659 default=False,
660 )
661 num_background_bins = Field[int](
662 doc="The number of bins along each axis when determining background", default=5
663 )
664 min_bin_fraction = Field[float](
665 doc="Bins with fewer pixels than this fraction of the total will be ignored", default=0.1
666 )
667
668 pos_sigma_multiplier = Field[float](
669 doc="How many sigma to consider as background in the positive direction", default=2
670 )
671
672
673class PrettyPictureBackgroundFixerTask(PipelineTask):
674 """Empirically flatten an images background.
675
676 Many astrophysical images have backgrounds with imperfections in them.
677 This Task attempts to determine control points which are considered
678 background values, and fits a radial basis function model to those
679 points. This model is then subtracted off the image.
680
681 """
682
683 _DefaultName = "prettyPictureBackgroundFixer"
684 ConfigClass = PrettyPictureBackgroundFixerConfig
685
686 config: ConfigClass
687
688 def _tile_slices(self, arr, R, C):
689 """Generate slices for tiling an array.
690
691 This function divides an array into a grid of tiles and returns a list of
692 slice objects representing each tile. It handles cases where the array
693 dimensions are not evenly divisible by the number of tiles in each
694 dimension, distributing the remainder among the tiles.
695
696 Parameters
697 ----------
698 arr : `numyp.ndarray`
699 The input array to be tiled. Used only to determine the array's shape.
700 R : `int`
701 The number of tiles in the row dimension.
702 C : `int`
703 The number of tiles in the column dimension.
704
705 Returns
706 -------
707 slices : `list` of `tuple`
708 A list of tuples, where each tuple contains two `slice` objects
709 representing the row and column slices for a single tile.
710 """
711 M = arr.shape[0]
712 N = arr.shape[1]
713
714 # Function to compute slices for a given dimension size and number of divisions
715 def get_slices(total_size: int, num_divisions: int) -> list[tuple[int, int]]:
716 """Generate slice ranges for dividing a size into equal parts.
717
718 Parameters
719 ----------
720 total_size : `int`
721 Total size to be divided into slices.
722 num_divisions : `int`
723 Number of divisions to create.
724
725 Returns
726 -------
727 `list` of `tuple` of `int`
728 List of (start, end) tuples representing each slice.
729
730 Notes
731 -----
732 This function divides the total_size into num_divisions equal parts.
733 If the division is not exact, the remainder is distributed by adding
734 1 to the first 'remainder' slices, ensuring balanced distribution.
735 """
736 base = total_size // num_divisions
737 remainder = total_size % num_divisions
738 slices = []
739 start = 0
740 for i in range(num_divisions):
741 end = start + base
742 if i < remainder:
743 end += 1
744 slices.append((start, end))
745 start = end
746 return slices
747
748 # Get row and column slices
749 row_slices = get_slices(M, R)
750 col_slices = get_slices(N, C)
751
752 # Generate all possible tile combinations of row and column slices
753 tiles = []
754 for rs in row_slices:
755 r_start, r_end = rs
756 for cs in col_slices:
757 c_start, c_end = cs
758 tile_slice = (slice(r_start, r_end), slice(c_start, c_end))
759 tiles.append(tile_slice)
760
761 return tiles
762
763 @staticmethod
764 def findBackgroundPixels(image, pos_sigma_mult=1):
765 """Find pixels that are likely to be background based on image statistics.
766
767 This method estimates background pixels by analyzing the distribution of
768 pixel values in the image. It uses the median as an estimate of the background
769 level and fits a half-normal distribution to values below the median to
770 determine the background sigma. Pixels below a threshold (mean + sigma) are
771 classified as background.
772
773 Parameters
774 ----------
775 image : `numpy.ndarray`
776 Input image array for which to find background pixels.
777 pos_sigma_mult : `float`
778 How many sigma to consider as background in the positive direction
779
780 Returns
781 -------
782 result : `numpy.ndarray`
783 Boolean mask array where True indicates background pixels.
784
785 Notes
786 -----
787 This method works best for images with relatively uniform background. It may
788 not perform well in fields with high density or diffuse flux, as noted in
789 the implementation comments.
790 """
791 # Find the median value in the image, which is likely to be
792 # close to average background. Note this doesn't work well
793 # in fields with high density or diffuse flux.
794 maxLikely = np.median(image, axis=None)
795
796 # find all the pixels that are fainter than this
797 # and find the std. This is just used as an initialization
798 # parameter and doesn't need to be accurate.
799 mask = image < maxLikely
800 initial_std = (image[mask] - maxLikely).std()
801
802 # Don't do anything if there are no pixels to check
803 if np.any(mask):
804 # use a minimizer to determine best mu and sigma for a Gaussian
805 # given only samples below the mean of the Gaussian.
806 mu_hat, sigma_hat = halfnorm.fit(np.abs(image[mask] - maxLikely))
807 # mu_hat = maxLikely
808 else:
809 mu_hat, sigma_hat = (maxLikely, 2 * initial_std)
810
811 # create a new masking threshold that is the determined
812 # mean plus std from the fit
813 threshhold = mu_hat + pos_sigma_mult * sigma_hat
814 image_mask = (image < threshhold) * (image > (mu_hat - 5 * sigma_hat))
815 return image_mask
816
817 def fixBackground(self, image, detection_mask=None):
818 """Estimate and subtract the background from an image.
819
820 This function estimates the background level in an image using a median-based
821 approach combined with Gaussian fitting and radial basis function interpolation.
822 It aims to provide a more accurate background estimation than a simple median
823 filter, especially in images with varying background levels.
824
825 Parameters
826 ----------
827 image : `numpy.ndarray`
828 The input image as a NumPy array.
829
830 Returns
831 -------
832 numpy.ndarray
833 An array representing the estimated background level across the image.
834 """
835 if detection_mask is None:
836 image_mask = self.findBackgroundPixels(image, self.config.pos_sigma_multiplier)
837 else:
838 image_mask = detection_mask
839
840 # create python slices that tile the image.
841 tiles = self._tile_slices(image, self.config.num_background_bins, self.config.num_background_bins)
842
843 yloc = []
844 xloc = []
845 values = []
846
847 # for each box find the middle position and the median background
848 # value in the window.
849 for xslice, yslice in tiles:
850 ypos = (yslice.stop - yslice.start) / 2 + yslice.start
851 xpos = (xslice.stop - xslice.start) / 2 + xslice.start
852 window = image[yslice, xslice][image_mask[yslice, xslice]]
853 # make sure each bin is at least 1% filled
854 min_fill = int((yslice.stop - yslice.start) ** 2 * self.config.min_bin_fraction)
855 if window.size > min_fill:
856 value = np.median(window)
857 else:
858 continue
859 values.append(value)
860 yloc.append(ypos)
861 xloc.append(xpos)
862
863 # At least 15 points are requred for TPS with 4th order polynomial
864 if len(yloc) < 15:
865 return np.zeros(image.shape)
866
867 # create an interpolant for the background and interpolate over the image.
868 inter = RBFInterpolator(
869 np.vstack((yloc, xloc)).T,
870 values,
871 kernel="thin_plate_spline",
872 degree=4,
873 smoothing=0.05,
874 neighbors=None,
875 )
876
877 backgrounds = rbf_interpolator.fast_rbf_interpolation_on_grid(inter, image.shape)
878
879 return backgrounds
880
881 def run(self, inputCoadd: Exposure):
882 """Estimate a background for an input Exposure and remove it.
883
884 Parameters
885 ----------
886 inputCoadd : `Exposure`
887 The exposure the background will be removed from.
888
889 Returns
890 -------
891 result : `Struct`
892 A `Struct` that contains the exposure with the background removed.
893 This `Struct` will have an attribute named ``outputCoadd``.
894
895 """
896 if self.config.use_detection_mask:
897 mask_plane_dict = inputCoadd.mask.getMaskPlaneDict()
898 detection_mask = ~(inputCoadd.mask.array & 2 ** mask_plane_dict["DETECTED"])
899 else:
900 detection_mask = None
901 background = self.fixBackground(inputCoadd.image.array, detection_mask=detection_mask)
902 # create a copy to mutate
903 output = ExposureF(inputCoadd, deep=True)
904 output.image.array -= background
905 return Struct(outputCoadd=output)
906
907
909 PipelineTaskConnections,
910 dimensions=("tract", "patch", "skymap"),
911):
912 inputCoadd = Input(
913 doc=("Input coadd for which the background is to be removed"),
914 name="pretty_picture_coadd_bg_subtracted",
915 storageClass="ExposureF",
916 dimensions=("tract", "patch", "skymap", "band"),
917 multiple=True,
918 )
919 outputCoadd = Output(
920 doc="The coadd with the background fixed and subtracted",
921 name="pretty_picture_coadd_fixed_stars",
922 storageClass="ExposureF",
923 dimensions=("tract", "patch", "skymap", "band"),
924 multiple=True,
925 )
926
927
928class PrettyPictureStarFixerConfig(PipelineTaskConfig, pipelineConnections=PrettyPictureStarFixerConnections):
929 brightnessThresh = Field[float](
930 doc="The flux value below which pixels with SAT or NO_DATA bits will be ignored"
931 )
932
933
934class PrettyPictureStarFixerTask(PipelineTask):
935 """This class fixes up regions in an image where there is no, or bad data.
936
937 The fixes done by this task are overwhelmingly comprised of the cores of
938 bright stars for which there is no data.
939 """
940
941 _DefaultName = "prettyPictureStarFixer"
942 ConfigClass = PrettyPictureStarFixerConfig
943
944 config: ConfigClass
945
946 def run(self, inputs: Mapping[str, ExposureF]) -> Struct:
947 """Fix areas in an image where this is no data, most likely to be
948 the cores of bright stars.
949
950 Because we want to have consistent fixes accross bands, this method
951 relies on supplying all bands and fixing pixels that are marked
952 as having a defect in any band even if within one band there is
953 no issue.
954
955 Parameters
956 ----------
957 inputs : `Mapping` of `str` to `ExposureF`
958 This mapping has keys of band as a `str` and the corresponding
959 ExposureF as a value.
960
961 Returns
962 -------
963 results : `Struct` of `Mapping` of `str` to `ExposureF`
964 A `Struct` that has a mapping of band to `ExposureF`. The `Struct`
965 has an attribute named ``results``.
966
967 """
968 # make the joint mask of all the channels
969 doJointMaskInit = True
970 for imageExposure in inputs.values():
971 maskDict = imageExposure.mask.getMaskPlaneDict()
972 if doJointMaskInit:
973 jointMask = np.zeros(imageExposure.mask.array.shape, dtype=imageExposure.mask.array.dtype)
974 doJointMaskInit = False
975 jointMask |= imageExposure.mask.array
976
977 sat_bit = maskDict["SAT"]
978 no_data_bit = maskDict["NO_DATA"]
979 together = (jointMask & 2**sat_bit).astype(bool) | (jointMask & 2**no_data_bit).astype(bool)
980
981 # use the last imageExposure as it is likely close enough across all bands
982 bright_mask = imageExposure.image.array > self.config.brightnessThresh
983
984 # dilate the mask a bit, this helps get a bit fainter mask without starting
985 # to include pixels in an irregular shape, as only the star cores should be
986 # fixed.
987 both = together & bright_mask
988 struct = np.array(((0, 1, 0), (1, 1, 1), (0, 1, 0)), dtype=bool)
989 both = binary_dilation(both, struct, iterations=4).astype(bool)
990
991 # do the actual fixing of values
992 results = {}
993 for band, imageExposure in inputs.items():
994 if np.sum(both) > 0:
995 inpainted = inpaint_biharmonic(imageExposure.image.array, both, split_into_regions=True)
996 imageExposure.image.array[both] = inpainted[both]
997 results[band] = imageExposure
998 return Struct(results=results)
999
1001 self,
1002 butlerQC: QuantumContext,
1003 inputRefs: InputQuantizedConnection,
1004 outputRefs: OutputQuantizedConnection,
1005 ) -> None:
1006 refs = inputRefs.inputCoadd
1007 sortedImages: dict[str, Exposure] = {}
1008 for ref in refs:
1009 key: str = cast(str, ref.dataId["band"])
1010 image = butlerQC.get(ref)
1011 sortedImages[key] = image
1012
1013 outputs = self.run(sortedImages).results
1014 sortedOutputs = {}
1015 for ref in outputRefs.outputCoadd:
1016 sortedOutputs[ref.dataId["band"]] = ref
1017
1018 for band, data in outputs.items():
1019 butlerQC.put(data, sortedOutputs[band])
1020
1021
1022class PrettyMosaicConnections(PipelineTaskConnections, dimensions=("tract", "skymap")):
1023 inputRGB = Input(
1024 doc="Individual RGB images that are to go into the mosaic",
1025 name="rgb_picture",
1026 storageClass="ColorImage",
1027 dimensions=("tract", "patch", "skymap"),
1028 multiple=True,
1029 deferLoad=True,
1030 )
1031
1032 skyMap = Input(
1033 doc="The skymap which the data has been mapped onto",
1034 storageClass="SkyMap",
1035 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
1036 dimensions=("skymap",),
1037 )
1038
1039 inputRGBMask = Input(
1040 doc="Individual RGB images that are to go into the mosaic",
1041 name="rgb_picture_mask",
1042 storageClass="Mask",
1043 dimensions=("tract", "patch", "skymap"),
1044 multiple=True,
1045 deferLoad=True,
1046 )
1047
1048 outputRGBMosaic = Output(
1049 doc="A RGB mosaic created from the input data stored as a 3d array",
1050 name="rgb_mosaic",
1051 storageClass="ColorImage",
1052 dimensions=("tract", "skymap"),
1053 )
1054
1055
1056class PrettyMosaicConfig(PipelineTaskConfig, pipelineConnections=PrettyMosaicConnections):
1057 binFactor = Field[int](doc="The factor to bin by when producing the mosaic")
1058 doDCID65Convert = Field[bool]("Force the output to be converted from display p3 to DCI-D65 colorspace.")
1059 useLocalTemp = Field[bool](doc="Use the current directory when creating local temp files.", default=False)
1060
1061
1062class PrettyMosaicTask(PipelineTask):
1063 """Combines multiple RGB arrays into one mosaic."""
1064
1065 _DefaultName = "prettyMosaic"
1066 ConfigClass = PrettyMosaicConfig
1067
1068 config: ConfigClass
1069
1070 def run(
1071 self,
1072 inputRGB: Iterable[DeferredDatasetHandle],
1073 skyMap: BaseSkyMap,
1074 inputRGBMask: Iterable[DeferredDatasetHandle],
1075 ) -> Struct:
1076 r"""Assemble individual `numpy.ndarrays` into a mosaic.
1077
1078 Each input is a `~lsst.daf.butler.DeferredDatasetHandle` because
1079 they're loaded in one at a time to be placed into the mosaic to save
1080 memory.
1081
1082 Parameters
1083 ----------
1084 inputRGB : `Iterable` of `~lsst.daf.butler.DeferredDatasetHandle`
1085 `~lsst.daf.butler.DeferredDatasetHandle`\ s pointing to RGB
1086 `numpy.ndarrays`.
1087 skyMap : `BaseSkyMap`
1088 The skymap that defines the relative position of each of the input
1089 images.
1090 inputRGBMask : `Iterable` of `~lsst.daf.butler.DeferredDatasetHandle`
1091 `~lsst.daf.butler.DeferredDatasetHandle`\ s pointing to masks for
1092 each of the corresponding images.
1093
1094 Returns
1095 -------
1096 result : `Struct`
1097 The `Struct` containing the combined mosaic. The `Struct` has
1098 and attribute named ``outputRGBMosaic``.
1099 """
1100 # create the bounding region
1101 newBox = Box2I()
1102 # store the bounds as they are retrieved from the skymap
1103 boxes = []
1104 tractMaps = []
1105 for handle in inputRGB:
1106 dataId = handle.dataId
1107 tractInfo: TractInfo = skyMap[dataId["tract"]]
1108 patchInfo: PatchInfo = tractInfo[dataId["patch"]]
1109 bbox = patchInfo.getOuterBBox()
1110 boxes.append(bbox)
1111 newBox.include(bbox)
1112 tractMaps.append(tractInfo)
1113 # This will be overwritten in the loop, but that is ok, because
1114 # it is the same for each patch.
1115 patch_grow: int = patchInfo.getCellInnerDimensions().getX()
1116
1117 # fixup the boxes to be smaller if needed, and put the origin at zero,
1118 # this must be done after constructing the complete outer box
1119 modifiedBoxes = []
1120 origin = newBox.getBegin()
1121 for iterBox in boxes:
1122 localOrigin = iterBox.getBegin() - origin
1123 localOrigin = Point2I(
1124 x=int(np.floor(localOrigin.x / self.config.binFactor)),
1125 y=int(np.floor(localOrigin.y / self.config.binFactor)),
1126 )
1127 localExtent = Extent2I(
1128 x=int(np.floor(iterBox.getWidth() / self.config.binFactor)),
1129 y=int(np.floor(iterBox.getHeight() / self.config.binFactor)),
1130 )
1131 tmpBox = Box2I(localOrigin, localExtent)
1132 modifiedBoxes.append(tmpBox)
1133 boxes = modifiedBoxes
1134
1135 # scale the container box
1136 newBoxOrigin = Point2I(0, 0)
1137 newBoxExtent = Extent2I(
1138 x=int(np.floor(newBox.getWidth() / self.config.binFactor)),
1139 y=int(np.floor(newBox.getHeight() / self.config.binFactor)),
1140 )
1141 newBox = Box2I(newBoxOrigin, newBoxExtent)
1142
1143 # Allocate storage for the mosaic
1144 self.imageHandle = tempfile.NamedTemporaryFile(dir="." if self.config.useLocalTemp else None)
1145 self.maskHandle = tempfile.NamedTemporaryFile(dir="." if self.config.useLocalTemp else None)
1146 consolidatedImage = None
1147 consolidatedMask = None
1148
1149 # Setup color space conversion in case they are used.
1150 d65 = copy.deepcopy(colour.models.RGB_COLOURSPACE_DCI_P3)
1151 dp3 = copy.deepcopy(colour.models.RGB_COLOURSPACE_DISPLAY_P3)
1152 d65.whitepoint = dp3.whitepoint
1153 d65.whitepoint_name = dp3.whitepoint_name
1154
1155 # Actually assemble the mosaic
1156 maskDict = {}
1157 mosaic_maker = FeatheredMosaicCreator(patch_grow, self.config.binFactor)
1158 for box, handle, handleMask, tractInfo in zip(boxes, inputRGB, inputRGBMask, tractMaps):
1159 rgb = handle.get().array
1160 # convert to the dci-d65 colorspace
1161 if self.config.doDCID65Convert:
1162 rgb = colour.RGB_to_RGB(np.clip(rgb, 0, 1), dp3, d65)
1163 rgbMask = handleMask.get()
1164 maskDict = rgbMask.getMaskPlaneDict()
1165 # allocate the memory for the mosaic
1166 if consolidatedImage is None:
1167 consolidatedImage = np.memmap(
1168 self.imageHandle.name,
1169 mode="w+",
1170 shape=(newBox.getHeight(), newBox.getWidth(), 3),
1171 dtype=rgb.dtype,
1172 )
1173 if consolidatedMask is None:
1174 consolidatedMask = np.memmap(
1175 self.maskHandle.name,
1176 mode="w+",
1177 shape=(newBox.getHeight(), newBox.getWidth()),
1178 dtype=rgbMask.array.dtype,
1179 )
1180
1181 if self.config.binFactor > 1:
1182 # opencv wants things in x, y dimensions
1183 shape = tuple(box.getDimensions())[::-1]
1184 rgb = cv2.resize(
1185 rgb,
1186 dst=None,
1187 dsize=shape,
1188 fx=shape[0] / self.config.binFactor,
1189 fy=shape[1] / self.config.binFactor,
1190 )
1191 mask_array = rgbMask.array[:: self.config.binFactor, :: self.config.binFactor]
1192 rgbMask = Mask(*(mask_array.shape[::-1]))
1193 mosaic_maker.add_to_image(consolidatedImage, rgb, newBox, box)
1194
1195 consolidatedMask[*box.slices] = np.bitwise_or(consolidatedMask[*box.slices], rgbMask.array)
1196
1197 for plugin in plugins.full():
1198 if consolidatedImage is not None and consolidatedMask is not None:
1199 consolidatedImage = plugin(consolidatedImage, consolidatedMask, maskDict)
1200 # If consolidated image still None, that means there was no work to do.
1201 # Return an empty image instead of letting this task fail.
1202 if consolidatedImage is None:
1203 consolidatedImage = np.zeros((0, 0, 0), dtype=np.uint8)
1204
1205 return Struct(outputRGBMosaic=ColorImage(consolidatedImage))
1206
1208 self,
1209 butlerQC: QuantumContext,
1210 inputRefs: InputQuantizedConnection,
1211 outputRefs: OutputQuantizedConnection,
1212 ) -> None:
1213 inputs = butlerQC.get(inputRefs)
1214 outputs = self.run(**inputs)
1215 butlerQC.put(outputs, outputRefs)
1216 if hasattr(self, "imageHandle"):
1217 self.imageHandle.close()
1218 if hasattr(self, "maskHandle"):
1219 self.maskHandle.close()
1220
1222 self, inputs: Iterable[tuple[Mapping[str, Any], NDArray]]
1223 ) -> Iterable[DeferredDatasetHandle]:
1224 r"""Make valid inputs for the run method from numpy arrays.
1225
1226 Parameters
1227 ----------
1228 inputs : `Iterable` of `tuple` of `Mapping` and `numpy.ndarray`
1229 An iterable where each element is a tuple with the first
1230 element is a mapping that corresponds to an arrays dataId,
1231 and the second is an `numpy.ndarray`.
1232
1233 Returns
1234 -------
1235 sortedImages : `Iterable` of `~lsst.daf.butler.DeferredDatasetHandle`
1236 An iterable of `~lsst.daf.butler.DeferredDatasetHandle`\ s
1237 containing the input data.
1238 """
1239 structuredInputs = []
1240 for dataId, array in inputs:
1241 structuredInputs.append(InMemoryDatasetHandle(inMemoryDataset=array, **dataId))
1242
1243 return structuredInputs
Iterable[DeferredDatasetHandle] makeInputsFromArrays(self, Iterable[tuple[Mapping[str, Any], NDArray]] inputs)
Definition _task.py:1223
Struct run(self, Iterable[DeferredDatasetHandle] inputRGB, BaseSkyMap skyMap, Iterable[DeferredDatasetHandle] inputRGBMask)
Definition _task.py:1075
None runQuantum(self, QuantumContext butlerQC, InputQuantizedConnection inputRefs, OutputQuantizedConnection outputRefs)
Definition _task.py:1212
None runQuantum(self, QuantumContext butlerQC, InputQuantizedConnection inputRefs, OutputQuantizedConnection outputRefs)
Definition _task.py:1005
Struct run(self, Mapping[str, ExposureF] inputs)
Definition _task.py:946
RGBImage lsstRGB(FloatImagePlane rArray, FloatImagePlane gArray, FloatImagePlane bArray, LocalContrastFunction|None|_SentinalDefault local_contrast=DEFAULT_FUNCTION, ScaleLumFunction|None|_SentinalDefault scale_lum=DEFAULT_FUNCTION, ScaleColorFunction|None|_SentinalDefault scale_color=DEFAULT_FUNCTION, RemapBoundsFunction|None|_SentinalDefault remap_bounds=DEFAULT_FUNCTION, BracketingFunction|None|_SentinalDefault bracketing_function=DEFAULT_FUNCTION, GamutRemappingFunction|None|_SentinalDefault gamut_remapping_function=DEFAULT_FUNCTION, FloatImagePlane|None psf=None, tuple[float, float] cieWhitePoint=(0.28, 0.28))
dict[str, Exposure] makeInputsFromRefs(self, Iterable[DatasetRef] refs, Butler|QuantumContext butler)
Definition _task.py:564
Struct run(self, Mapping[str, Exposure] images, Projection[Any]|None image_wcs=None, Box|None image_box=None)
Definition _task.py:385
dict[int, DeferredDatasetHandle] makeInputsFromExposures(self, **kwargs)
Definition _task.py:612
fixBackground(self, image, detection_mask=None)
Definition _task.py:817
dict[str, DeferredDatasetHandle] makeInputsFromArrays(self, **kwargs)
Definition _task.py:587
None runQuantum(self, QuantumContext butlerQC, InputQuantizedConnection inputRefs, OutputQuantizedConnection outputRefs)
Definition _task.py:536
STL namespace.