Coverage for python / lsst / pipe / tasks / prettyPictureMaker / _task.py: 21%

421 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-22 09:17 +0000

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ( 

25 "ChannelRGBConfig", 

26 "PrettyPictureTask", 

27 "PrettyPictureConnections", 

28 "PrettyPictureConfig", 

29 "PrettyMosaicTask", 

30 "PrettyMosaicConnections", 

31 "PrettyMosaicConfig", 

32 "PrettyPictureBackgroundFixerConfig", 

33 "PrettyPictureBackgroundFixerTask", 

34 "PrettyPictureStarFixerConfig", 

35 "PrettyPictureStarFixerTask", 

36) 

37 

38import colour 

39import copy 

40from collections.abc import Iterable, Mapping 

41from lsst.afw.image import ExposureF 

42import numpy as np 

43from typing import TYPE_CHECKING, cast, Any 

44from lsst.skymap import BaseSkyMap 

45 

46from scipy.stats import halfnorm, mode 

47from scipy.ndimage import binary_dilation 

48from scipy.interpolate import RBFInterpolator 

49from skimage.restoration import inpaint_biharmonic 

50 

51from lsst.daf.butler import Butler, DeferredDatasetHandle 

52from lsst.daf.butler import DatasetRef 

53from lsst.pex.config import Field, Config, ConfigDictField, ListField, ChoiceField 

54from lsst.pex.config.configurableActions import ConfigurableActionField 

55from lsst.pipe.base import ( 

56 PipelineTask, 

57 PipelineTaskConfig, 

58 PipelineTaskConnections, 

59 Struct, 

60 InMemoryDatasetHandle, 

61 NoWorkFound, 

62) 

63from lsst.rubinoxide import rbf_interpolator 

64import cv2 

65 

66from lsst.pipe.base.connectionTypes import Input, Output 

67from lsst.geom import Box2I, Point2I, Extent2I 

68from lsst.afw.image import Exposure, Mask 

69 

70from ._plugins import plugins 

71from ._colorMapper import lsstRGB 

72from ._utils import FeatheredMosaicCreator 

73from ._functors import ( 

74 BoundsRemapper, 

75 ColorScaler, 

76 LumCompressor, 

77 ExposureBracketer, 

78 GamutFixer, 

79 LocalContrastEnhancer, 

80) 

81 

82import tempfile 

83 

84 

85if TYPE_CHECKING: 

86 from numpy.typing import NDArray 

87 from lsst.pipe.base import QuantumContext, InputQuantizedConnection, OutputQuantizedConnection 

88 from lsst.skymap import TractInfo, PatchInfo 

89 

90 

91class PrettyPictureConnections( 

92 PipelineTaskConnections, 

93 dimensions={"tract", "patch", "skymap"}, 

94 defaultTemplates={"coaddTypeName": "deep"}, 

95): 

96 inputCoadds = Input( 

97 doc=( 

98 "Model of the static sky, used to find temporal artifacts. Typically a PSF-Matched, " 

99 "sigma-clipped coadd. Written if and only if assembleStaticSkyModel.doWrite=True" 

100 ), 

101 name="pretty_coadd", 

102 storageClass="ExposureF", 

103 dimensions=("tract", "patch", "skymap", "band"), 

104 multiple=True, 

105 ) 

106 

107 outputRGB = Output( 

108 doc="A RGB image created from the input data stored as a 3d array", 

109 name="rgb_picture_array", 

110 storageClass="NumpyArray", 

111 dimensions=("tract", "patch", "skymap"), 

112 ) 

113 

114 outputRGBMask = Output( 

115 doc="A Mask corresponding to the fused masks of the input channels", 

116 name="rgb_picture_mask", 

117 storageClass="Mask", 

118 dimensions=("tract", "patch", "skymap"), 

119 ) 

120 

121 

122class ChannelRGBConfig(Config): 

123 """This describes the rgb values of a given input channel. 

124 

125 For instance if this channel is red the values would be self.r = 1, 

126 self.g = 0, self.b = 0. If the channel was cyan the values would be 

127 self.r = 0, self.g = 1, self.b = 1. 

128 """ 

129 

130 r = Field[float](doc="The amount of red contained in this channel") 

131 g = Field[float](doc="The amount of green contained in this channel") 

132 b = Field[float](doc="The amount of blue contained in this channel") 

133 

134 

135class PrettyPictureConfig(PipelineTaskConfig, pipelineConnections=PrettyPictureConnections): 

136 channelConfig = ConfigDictField( 

137 doc="A dictionary that maps band names to their rgb channel configurations", 

138 keytype=str, 

139 itemtype=ChannelRGBConfig, 

140 default={}, 

141 ) 

142 cieWhitePoint = ListField[float]( 

143 doc="The white point of the input arrays in ciexz coordinates", maxLength=2, default=[0.28, 0.28] 

144 ) 

145 arrayType = ChoiceField[str]( 

146 doc="The dataset type for the output image array", 

147 default="uint8", 

148 allowed={ 

149 "uint8": "Use 8 bit arrays, 255 max", 

150 "uint16": "Use 16 bit arrays, 65535 max", 

151 "half": "Use 16 bit float arrays, 1 max", 

152 "float": "Use 32 bit float arrays, 1 max", 

153 }, 

154 ) 

155 recenterNoise = Field[float]( 

156 doc="Recenter the noise away from zero. Supplied value is in units of sigma", 

157 optional=True, 

158 default=None, 

159 ) 

160 noiseSearchThreshold = Field[float]( 

161 doc=( 

162 "Flux threshold below which most flux will be considered noise, used to estimate noise properties" 

163 ), 

164 default=10, 

165 ) 

166 doPsfDeconvolve = Field[bool]( 

167 doc="Use the PSF in a Richardson-Lucy deconvolution on the luminance channel.", default=False 

168 ) 

169 doPSFDeconcovlve = Field[bool]( 

170 doc="Use the PSF in a Richardson-Lucy deconvolution on the luminance channel.", 

171 default=False, 

172 deprecated="This field will be removed in v32. Use doPsfDeconvolve instead.", 

173 optional=True, 

174 ) 

175 doRemapGamut = Field[bool]( 

176 doc="Apply a color correction to unrepresentable colors; if False, clip them.", default=True 

177 ) 

178 doExposureBrackets = Field[bool]( 

179 doc="Apply exposure bracketing to aid in dynamic range compression", default=True 

180 ) 

181 doLocalContrast = Field[bool](doc="Apply local contrast optimizations to luminance.", default=True) 

182 

183 imageRemappingConfig = ConfigurableActionField[BoundsRemapper]( 

184 doc="Action controlling normalization process" 

185 ) 

186 luminanceConfig = ConfigurableActionField[LumCompressor]( 

187 doc="Action controlling luminance scaling when making an RGB image" 

188 ) 

189 localContrastConfig = ConfigurableActionField[LocalContrastEnhancer]( 

190 doc="Action controlling the local contrast correction in RGB image production" 

191 ) 

192 colorConfig = ConfigurableActionField[ColorScaler]( 

193 doc="Action to control the color scaling process in RGB image production" 

194 ) 

195 exposureBracketerConfig = ConfigurableActionField[ExposureBracketer]( 

196 doc=( 

197 "Exposure scaling action used in creating multiple exposures with different scalings which will " 

198 "then be fused into a final image" 

199 ), 

200 ) 

201 gamutMapperConfig = ConfigurableActionField[GamutFixer]( 

202 doc="Action to fix pixels which lay outside RGB color gamut" 

203 ) 

204 

205 exposureBrackets = ListField[float]( 

206 doc=( 

207 "Exposure scaling factors used in creating multiple exposures with different scalings which will " 

208 "then be fused into a final image" 

209 ), 

210 optional=True, 

211 default=[1.25, 1, 0.75], 

212 deprecated=( 

213 "This field will stop working in v31 and be removed in v32, " 

214 "please set exposureBracketerConfig.exposureBrackets" 

215 ), 

216 ) 

217 gamutMethod = ChoiceField[str]( 

218 doc="If doRemapGamut is True this determines the method", 

219 default="inpaint", 

220 allowed={ 

221 "mapping": "Use a mapping function", 

222 "inpaint": "Use surrounding pixels to determine likely value", 

223 }, 

224 deprecated="This field will stop working in v31 and be removed in v32, please set gamutMapperConfig", 

225 ) 

226 

227 def setDefaults(self): 

228 self.channelConfig["i"] = ChannelRGBConfig(r=1, g=0, b=0) 

229 self.channelConfig["r"] = ChannelRGBConfig(r=0, g=1, b=0) 

230 self.channelConfig["g"] = ChannelRGBConfig(r=0, g=0, b=1) 

231 return super().setDefaults() 

232 

233 def _handle_deprecated(self): 

234 """Handle deprecated configuration migration. 

235 

236 This method migrates deprecated configuration fields to their new 

237 locations in sub-configurations. It checks the configuration history 

238 to determine if deprecated fields were explicitly set and updates 

239 the new configuration locations accordingly. 

240 

241 Notes 

242 ----- 

243 The following deprecated fields are migrated: 

244 - ``gamutMethod`` -> ``gamutMapperConfig.gamutMethod`` 

245 - ``exposureBrackets`` -> ``exposureBracketerConfig.exposureBrackets`` 

246 - ``doLocalContrast`` -> ``localContrastConfig.doLocalContrast`` 

247 - ``doPSFDeconcovlve`` -> ``doPsfDeconvolve`` 

248 """ 

249 # check if gamutMethod is set 

250 if len(self._history["gamutMethod"]) > 1: 

251 # This has been set in config, update it in the new location 

252 self.gamutMapperConfig.gamutMethod = self.gamutMethod 

253 

254 if len(self._history["exposureBrackets"]) > 1: 

255 self.exposureBracketerConfig.exposureBrackets = self.exposureBrackets 

256 if self.exposureBrackets is None: 

257 self.doExposureBrackets = False 

258 

259 if len(self.localContrastConfig._history["doLocalContrast"]) > 1: 

260 self.doLocalContrast = self.localContrastConfig.doLocalContrast 

261 

262 # Handle doPsfDeconcovlve typo fix 

263 if len(self._history["doPSFDeconcovlve"]) > 1: 

264 self.doPsfDeconvolve = self.doPSFDeconcovlve 

265 

266 def freeze(self): 

267 # ensure this is not already frozen 

268 if self._frozen is not True: 

269 self._handle_deprecated() 

270 super().freeze() 

271 

272 

273class PrettyPictureTask(PipelineTask): 

274 """Turns inputs into an RGB image.""" 

275 

276 _DefaultName = "prettyPicture" 

277 ConfigClass = PrettyPictureConfig 

278 

279 config: ConfigClass 

280 

281 def _find_normal_stats(self, array): 

282 """Calculate standard deviation from negative values using half-normal distribution. 

283 

284 Raises 

285 ------ 

286 ValueError 

287 Array dimension validation fails. 

288 

289 Parameters 

290 ---------- 

291 array : `numpy.array` 

292 Input array of numerical values. 

293 

294 Returns 

295 ------- 

296 mean : `float` 

297 The central moment of the distribution 

298 sigma : `float` 

299 Estimated standard deviation from negative values. Returns np.inf if: 

300 - No negative values exist in the array 

301 - Half-normal fitting fails 

302 """ 

303 # Extract negative values efficiently 

304 values_noise = array[array < self.config.noiseSearchThreshold] 

305 

306 # find the mode 

307 center = mode(np.round(values_noise, 2)).mode 

308 

309 # extract the negative values 

310 values_neg = array[array < center] 

311 

312 # Return infinity if no negative values found 

313 if values_neg.size == 0: 

314 return 0, np.inf 

315 

316 try: 

317 # Fit half-normal distribution to absolute negative values 

318 mu, sigma = halfnorm.fit(np.abs(values_neg)) 

319 except (ValueError, RuntimeError): 

320 # Handle fitting failures (e.g., constant data, optimization issues) 

321 return 0, np.inf 

322 

323 return center, sigma 

324 

325 def _match_sigmas_and_recenter(self, *arrays, factor=1): 

326 """Scale array values to match minimum standard deviation across arrays 

327 and recenter noise. 

328 

329 Adjusts values below each array's sigma by scaling and shifting them to 

330 align with the minimum sigma value across all input arrays. This operates 

331 in-place for efficiency. 

332 

333 Parameters 

334 ---------- 

335 *arrays : any number of `numpy.array` 

336 Variable number of input arrays to process. 

337 factor : float, optional 

338 Scaling factor for adjustments (default: 1). 

339 

340 """ 

341 # Calculate standard deviations for all arrays 

342 sigmas = [] 

343 mus = [] 

344 for arr in arrays: 

345 m, s = self._find_normal_stats(arr) 

346 mus.append(m) 

347 sigmas.append(s) 

348 mus = np.array(mus) 

349 sigmas = np.array(sigmas) 

350 

351 # If no sigmas could be determined, return the original 

352 # arrays. 

353 if not np.any(np.isfinite(sigmas)): 

354 return 

355 

356 min_sig = np.min(sigmas) 

357 

358 for mu, sigma, array in zip(mus, sigmas, arrays): 

359 # Identify values below the array's sigma threshold 

360 lower_pos = (array - mu) < sigma 

361 

362 # Skip processing if sigma is invalid 

363 if not np.isfinite(sigma): 

364 continue 

365 

366 # Calculate scaling ratio relative to minimum sigma 

367 sigma_ratio = min_sig / sigma 

368 

369 # Apply adjustment to qualifying values 

370 array[lower_pos] = (array[lower_pos] - mu) * sigma_ratio + min_sig * factor 

371 

372 def run(self, images: Mapping[str, Exposure]) -> Struct: 

373 """Turns the input arguments in arguments into an RGB array. 

374 

375 Parameters 

376 ---------- 

377 images : `Mapping` of `str` to `Exposure` 

378 A mapping of input images and the band they correspond to. 

379 

380 Returns 

381 ------- 

382 result : `Struct` 

383 A struct with the corresponding RGB image, and mask used in 

384 RGB image construction. The struct will have the attributes 

385 outputRGBImage and outputRGBMask. Each of the outputs will 

386 be a `NDarray` object. 

387 

388 Notes 

389 ----- 

390 Construction of input images are made easier by use of the 

391 makeInputsFrom* methods. 

392 """ 

393 channels = {} 

394 shape = (0, 0) 

395 jointMask: None | NDArray = None 

396 maskDict: Mapping[str, int] = {} 

397 doJointMaskInit = False 

398 if jointMask is None: 

399 doJointMask = True 

400 doJointMaskInit = True 

401 for channel, imageExposure in images.items(): 

402 imageArray = imageExposure.image.array 

403 # run all the plugins designed for array based interaction 

404 for plug in plugins.channel(): 

405 imageArray = plug( 

406 imageArray, imageExposure.mask.array, imageExposure.mask.getMaskPlaneDict(), self.config 

407 ).astype(np.float32) 

408 channels[channel] = imageArray 

409 # These operations are trivial look-ups and don't matter if they 

410 # happen in each loop. 

411 shape = imageArray.shape 

412 maskDict = imageExposure.mask.getMaskPlaneDict() 

413 if doJointMaskInit: 

414 jointMask = np.zeros(shape, dtype=imageExposure.mask.dtype) 

415 doJointMaskInit = False 

416 if doJointMask: 

417 jointMask |= imageExposure.mask.array 

418 

419 # mix the images to RGB 

420 imageRArray = np.zeros(shape, dtype=np.float32) 

421 imageGArray = np.zeros(shape, dtype=np.float32) 

422 imageBArray = np.zeros(shape, dtype=np.float32) 

423 

424 for band, image in channels.items(): 

425 if band not in self.config.channelConfig: 

426 self.log.info(f"{band} image found but not requested in RGB image, skipping") 

427 continue 

428 mix = self.config.channelConfig[band] 

429 if mix.r: 

430 imageRArray += mix.r * image 

431 if mix.g: 

432 imageGArray += mix.g * image 

433 if mix.b: 

434 imageBArray += mix.b * image 

435 

436 exposure = next(iter(images.values())) 

437 box: Box2I = exposure.getBBox() 

438 boxCenter = box.getCenter() 

439 try: 

440 psf = exposure.psf.computeImage(boxCenter).array 

441 except Exception: 

442 psf = None 

443 

444 if self.config.recenterNoise: 

445 self._match_sigmas_and_recenter( 

446 imageRArray, imageGArray, imageBArray, factor=self.config.recenterNoise 

447 ) 

448 

449 # assert for typing reasons 

450 assert jointMask is not None 

451 # Run any image level correction plugins 

452 colorImage = np.zeros((*imageRArray.shape, 3)) 

453 colorImage[:, :, 0] = imageRArray 

454 colorImage[:, :, 1] = imageGArray 

455 colorImage[:, :, 2] = imageBArray 

456 for plug in plugins.partial(): 

457 colorImage = plug(colorImage, jointMask, maskDict, self.config) 

458 

459 # Filter the local contrast parameters for diffusion that are None 

460 # This is so we only apply key word overrides that are specifically set. 

461 local_contrast_config = self.config.localContrastConfig.toDict() 

462 to_remove = [] 

463 for k, v in local_contrast_config["diffusionFunction"].items(): 

464 if v is None: 

465 to_remove.append(k) 

466 for item in to_remove: 

467 local_contrast_config["diffusionControl"].pop(item) 

468 

469 colorImage = lsstRGB( 

470 colorImage[:, :, 0], 

471 colorImage[:, :, 1], 

472 colorImage[:, :, 2], 

473 local_contrast=self.config.localContrastConfig if self.config.doLocalContrast else None, 

474 scale_lum=self.config.luminanceConfig, 

475 scale_color=self.config.colorConfig, 

476 remap_bounds=self.config.imageRemappingConfig, 

477 bracketing_function=( 

478 self.config.exposureBracketerConfig if self.config.doExposureBrackets else None 

479 ), 

480 gamut_remapping_function=self.config.gamutMapperConfig if self.config.doRemapGamut else None, 

481 cieWhitePoint=tuple(self.config.cieWhitePoint), # type: ignore 

482 psf=psf if self.config.doPsfDeconvolve else None, 

483 ) 

484 

485 # Find the dataset type and thus the maximum values as well 

486 maxVal: int | float 

487 match self.config.arrayType: 

488 case "uint8": 

489 dtype = np.uint8 

490 maxVal = 255 

491 case "uint16": 

492 dtype = np.uint16 

493 maxVal = 65535 

494 case "half": 

495 dtype = np.half 

496 maxVal = 1.0 

497 case "float": 

498 dtype = np.float32 

499 maxVal = 1.0 

500 case _: 

501 assert True, "This code path should be unreachable" 

502 

503 # lsstRGB returns an image in 0-1 scale it to the maximum value 

504 colorImage *= maxVal # type: ignore 

505 

506 # pack the joint mask back into a mask object 

507 lsstMask = Mask(width=jointMask.shape[1], height=jointMask.shape[0], planeDefs=maskDict) 

508 lsstMask.array = jointMask # type: ignore 

509 return Struct(outputRGB=colorImage.astype(dtype), outputRGBMask=lsstMask) # type: ignore 

510 

511 def runQuantum( 

512 self, 

513 butlerQC: QuantumContext, 

514 inputRefs: InputQuantizedConnection, 

515 outputRefs: OutputQuantizedConnection, 

516 ) -> None: 

517 imageRefs: list[DatasetRef] = inputRefs.inputCoadds 

518 sortedImages = self.makeInputsFromRefs(imageRefs, butlerQC) 

519 if not sortedImages: 

520 requested = ", ".join(self.config.channelConfig.keys()) 

521 raise NoWorkFound(f"No input images of band(s) {requested}") 

522 outputs = self.run(sortedImages) 

523 butlerQC.put(outputs, outputRefs) 

524 

525 def makeInputsFromRefs( 

526 self, refs: Iterable[DatasetRef], butler: Butler | QuantumContext 

527 ) -> dict[str, Exposure]: 

528 r"""Make valid inputs for the run method from butler references. 

529 

530 Parameters 

531 ---------- 

532 refs : `Iterable` of `DatasetRef` 

533 Some `Iterable` container of `Butler` `DatasetRef`\ s 

534 butler : `Butler` or `QuantumContext` 

535 This is the object that fetches the input data. 

536 

537 Returns 

538 ------- 

539 sortedImages : `dict` of `str` to `Exposure` 

540 A dictionary of `Exposure`\ s keyed by the band they 

541 correspond to. 

542 """ 

543 sortedImages: dict[str, Exposure] = {} 

544 for ref in refs: 

545 key: str = cast(str, ref.dataId["band"]) 

546 image = butler.get(ref) 

547 sortedImages[key] = image 

548 return sortedImages 

549 

550 def makeInputsFromArrays(self, **kwargs) -> dict[str, DeferredDatasetHandle]: 

551 r"""Make valid inputs for the run method from numpy arrays. 

552 

553 Parameters 

554 ---------- 

555 kwargs : `numpy.ndarray` 

556 This is standard python kwargs where the left side of the equals 

557 is the data band, and the right side is the corresponding `numpy.ndarray` 

558 array. 

559 

560 Returns 

561 ------- 

562 sortedImages : `dict` of `str` to \ 

563 `~lsst.daf.butler.DeferredDatasetHandle` 

564 A dictionary of `~lsst.daf.butlger.DeferredDatasetHandle`\ s keyed 

565 by the band they correspond to. 

566 """ 

567 # ignore type because there aren't proper stubs for afw 

568 temp = {} 

569 for key, array in kwargs.items(): 

570 temp[key] = Exposure(Box2I(Point2I(0, 0), Extent2I(*array.shape)), dtype=array.dtype) 

571 temp[key].image.array[:] = array 

572 

573 return self.makeInputsFromExposures(**temp) 

574 

575 def makeInputsFromExposures(self, **kwargs) -> dict[int, DeferredDatasetHandle]: 

576 r"""Make valid inputs for the run method from `Exposure` objects. 

577 

578 Parameters 

579 ---------- 

580 kwargs : `Exposure` 

581 This is standard python kwargs where the left side of the equals 

582 is the data band, and the right side is the corresponding 

583 `Exposure`. 

584 

585 Returns 

586 ------- 

587 sortedImages : `dict` of `int` to \ 

588 `~lsst.daf.butler.DeferredDatasetHandle` 

589 A dictionary of `~lsst.daf.butler.DeferredDatasetHandle`\ s keyed 

590 by the band they correspond to. 

591 """ 

592 sortedImages = {} 

593 for key, value in kwargs.items(): 

594 sortedImages[key] = value 

595 return sortedImages 

596 

597 

598class PrettyPictureBackgroundFixerConnections( 

599 PipelineTaskConnections, 

600 dimensions=("tract", "patch", "skymap", "band"), 

601 defaultTemplates={"coaddTypeName": "deep"}, 

602): 

603 inputCoadd = Input( 

604 doc=("Input coadd for which the background is to be removed"), 

605 name="{coaddTypeName}CoaddPsfMatched", 

606 storageClass="ExposureF", 

607 dimensions=("tract", "patch", "skymap", "band"), 

608 ) 

609 outputCoadd = Output( 

610 doc="The coadd with the background fixed and subtracted", 

611 name="pretty_picture_coadd_bg_subtracted", 

612 storageClass="ExposureF", 

613 dimensions=("tract", "patch", "skymap", "band"), 

614 ) 

615 

616 

617class PrettyPictureBackgroundFixerConfig( 

618 PipelineTaskConfig, pipelineConnections=PrettyPictureBackgroundFixerConnections 

619): 

620 use_detection_mask = Field[bool]( 

621 doc="Use the detection mask to determine background instead of empirically finding it in this task", 

622 default=False, 

623 ) 

624 num_background_bins = Field[int]( 

625 doc="The number of bins along each axis when determining background", default=5 

626 ) 

627 min_bin_fraction = Field[float]( 

628 doc="Bins with fewer pixels than this fraction of the total will be ignored", default=0.1 

629 ) 

630 

631 pos_sigma_multiplier = Field[float]( 

632 doc="How many sigma to consider as background in the positive direction", default=2 

633 ) 

634 

635 

636class PrettyPictureBackgroundFixerTask(PipelineTask): 

637 """Empirically flatten an images background. 

638 

639 Many astrophysical images have backgrounds with imperfections in them. 

640 This Task attempts to determine control points which are considered 

641 background values, and fits a radial basis function model to those 

642 points. This model is then subtracted off the image. 

643 

644 """ 

645 

646 _DefaultName = "prettyPictureBackgroundFixer" 

647 ConfigClass = PrettyPictureBackgroundFixerConfig 

648 

649 config: ConfigClass 

650 

651 def _tile_slices(self, arr, R, C): 

652 """Generate slices for tiling an array. 

653 

654 This function divides an array into a grid of tiles and returns a list of 

655 slice objects representing each tile. It handles cases where the array 

656 dimensions are not evenly divisible by the number of tiles in each 

657 dimension, distributing the remainder among the tiles. 

658 

659 Parameters 

660 ---------- 

661 arr : `numyp.ndarray` 

662 The input array to be tiled. Used only to determine the array's shape. 

663 R : `int` 

664 The number of tiles in the row dimension. 

665 C : `int` 

666 The number of tiles in the column dimension. 

667 

668 Returns 

669 ------- 

670 slices : `list` of `tuple` 

671 A list of tuples, where each tuple contains two `slice` objects 

672 representing the row and column slices for a single tile. 

673 """ 

674 M = arr.shape[0] 

675 N = arr.shape[1] 

676 

677 # Function to compute slices for a given dimension size and number of divisions 

678 def get_slices(total_size: int, num_divisions: int) -> list[tuple[int, int]]: 

679 """Generate slice ranges for dividing a size into equal parts. 

680 

681 Parameters 

682 ---------- 

683 total_size : `int` 

684 Total size to be divided into slices. 

685 num_divisions : `int` 

686 Number of divisions to create. 

687 

688 Returns 

689 ------- 

690 `list` of `tuple` of `int` 

691 List of (start, end) tuples representing each slice. 

692 

693 Notes 

694 ----- 

695 This function divides the total_size into num_divisions equal parts. 

696 If the division is not exact, the remainder is distributed by adding 

697 1 to the first 'remainder' slices, ensuring balanced distribution. 

698 """ 

699 base = total_size // num_divisions 

700 remainder = total_size % num_divisions 

701 slices = [] 

702 start = 0 

703 for i in range(num_divisions): 

704 end = start + base 

705 if i < remainder: 

706 end += 1 

707 slices.append((start, end)) 

708 start = end 

709 return slices 

710 

711 # Get row and column slices 

712 row_slices = get_slices(M, R) 

713 col_slices = get_slices(N, C) 

714 

715 # Generate all possible tile combinations of row and column slices 

716 tiles = [] 

717 for rs in row_slices: 

718 r_start, r_end = rs 

719 for cs in col_slices: 

720 c_start, c_end = cs 

721 tile_slice = (slice(r_start, r_end), slice(c_start, c_end)) 

722 tiles.append(tile_slice) 

723 

724 return tiles 

725 

726 @staticmethod 

727 def findBackgroundPixels(image, pos_sigma_mult=1): 

728 """Find pixels that are likely to be background based on image statistics. 

729 

730 This method estimates background pixels by analyzing the distribution of 

731 pixel values in the image. It uses the median as an estimate of the background 

732 level and fits a half-normal distribution to values below the median to 

733 determine the background sigma. Pixels below a threshold (mean + sigma) are 

734 classified as background. 

735 

736 Parameters 

737 ---------- 

738 image : `numpy.ndarray` 

739 Input image array for which to find background pixels. 

740 pos_sigma_mult : `float` 

741 How many sigma to consider as background in the positive direction 

742 

743 Returns 

744 ------- 

745 result : `numpy.ndarray` 

746 Boolean mask array where True indicates background pixels. 

747 

748 Notes 

749 ----- 

750 This method works best for images with relatively uniform background. It may 

751 not perform well in fields with high density or diffuse flux, as noted in 

752 the implementation comments. 

753 """ 

754 # Find the median value in the image, which is likely to be 

755 # close to average background. Note this doesn't work well 

756 # in fields with high density or diffuse flux. 

757 maxLikely = np.median(image, axis=None) 

758 

759 # find all the pixels that are fainter than this 

760 # and find the std. This is just used as an initialization 

761 # parameter and doesn't need to be accurate. 

762 mask = image < maxLikely 

763 initial_std = (image[mask] - maxLikely).std() 

764 

765 # Don't do anything if there are no pixels to check 

766 if np.any(mask): 

767 # use a minimizer to determine best mu and sigma for a Gaussian 

768 # given only samples below the mean of the Gaussian. 

769 mu_hat, sigma_hat = halfnorm.fit(np.abs(image[mask] - maxLikely)) 

770 # mu_hat = maxLikely 

771 else: 

772 mu_hat, sigma_hat = (maxLikely, 2 * initial_std) 

773 

774 # create a new masking threshold that is the determined 

775 # mean plus std from the fit 

776 threshhold = mu_hat + pos_sigma_mult * sigma_hat 

777 image_mask = (image < threshhold) * (image > (mu_hat - 5 * sigma_hat)) 

778 return image_mask 

779 

780 def fixBackground(self, image, detection_mask=None): 

781 """Estimate and subtract the background from an image. 

782 

783 This function estimates the background level in an image using a median-based 

784 approach combined with Gaussian fitting and radial basis function interpolation. 

785 It aims to provide a more accurate background estimation than a simple median 

786 filter, especially in images with varying background levels. 

787 

788 Parameters 

789 ---------- 

790 image : `numpy.ndarray` 

791 The input image as a NumPy array. 

792 

793 Returns 

794 ------- 

795 numpy.ndarray 

796 An array representing the estimated background level across the image. 

797 """ 

798 if detection_mask is None: 

799 image_mask = self.findBackgroundPixels(image, self.config.pos_sigma_multiplier) 

800 else: 

801 image_mask = detection_mask 

802 

803 # create python slices that tile the image. 

804 tiles = self._tile_slices(image, self.config.num_background_bins, self.config.num_background_bins) 

805 

806 yloc = [] 

807 xloc = [] 

808 values = [] 

809 

810 # for each box find the middle position and the median background 

811 # value in the window. 

812 for xslice, yslice in tiles: 

813 ypos = (yslice.stop - yslice.start) / 2 + yslice.start 

814 xpos = (xslice.stop - xslice.start) / 2 + xslice.start 

815 window = image[yslice, xslice][image_mask[yslice, xslice]] 

816 # make sure each bin is at least 1% filled 

817 min_fill = int((yslice.stop - yslice.start) ** 2 * self.config.min_bin_fraction) 

818 if window.size > min_fill: 

819 value = np.median(window) 

820 else: 

821 continue 

822 values.append(value) 

823 yloc.append(ypos) 

824 xloc.append(xpos) 

825 

826 # At least 15 points are requred for TPS with 4th order polynomial 

827 if len(yloc) < 15: 

828 return np.zeros(image.shape) 

829 

830 # create an interpolant for the background and interpolate over the image. 

831 inter = RBFInterpolator( 

832 np.vstack((yloc, xloc)).T, 

833 values, 

834 kernel="thin_plate_spline", 

835 degree=4, 

836 smoothing=0.05, 

837 neighbors=None, 

838 ) 

839 

840 backgrounds = rbf_interpolator.fast_rbf_interpolation_on_grid(inter, image.shape) 

841 

842 return backgrounds 

843 

844 def run(self, inputCoadd: Exposure): 

845 """Estimate a background for an input Exposure and remove it. 

846 

847 Parameters 

848 ---------- 

849 inputCoadd : `Exposure` 

850 The exposure the background will be removed from. 

851 

852 Returns 

853 ------- 

854 result : `Struct` 

855 A `Struct` that contains the exposure with the background removed. 

856 This `Struct` will have an attribute named ``outputCoadd``. 

857 

858 """ 

859 if self.config.use_detection_mask: 

860 mask_plane_dict = inputCoadd.mask.getMaskPlaneDict() 

861 detection_mask = ~(inputCoadd.mask.array & 2 ** mask_plane_dict["DETECTED"]) 

862 else: 

863 detection_mask = None 

864 background = self.fixBackground(inputCoadd.image.array, detection_mask=detection_mask) 

865 # create a copy to mutate 

866 output = ExposureF(inputCoadd, deep=True) 

867 output.image.array -= background 

868 return Struct(outputCoadd=output) 

869 

870 

871class PrettyPictureStarFixerConnections( 

872 PipelineTaskConnections, 

873 dimensions=("tract", "patch", "skymap"), 

874): 

875 inputCoadd = Input( 

876 doc=("Input coadd for which the background is to be removed"), 

877 name="pretty_picture_coadd_bg_subtracted", 

878 storageClass="ExposureF", 

879 dimensions=("tract", "patch", "skymap", "band"), 

880 multiple=True, 

881 ) 

882 outputCoadd = Output( 

883 doc="The coadd with the background fixed and subtracted", 

884 name="pretty_picture_coadd_fixed_stars", 

885 storageClass="ExposureF", 

886 dimensions=("tract", "patch", "skymap", "band"), 

887 multiple=True, 

888 ) 

889 

890 

891class PrettyPictureStarFixerConfig(PipelineTaskConfig, pipelineConnections=PrettyPictureStarFixerConnections): 

892 brightnessThresh = Field[float]( 

893 doc="The flux value below which pixels with SAT or NO_DATA bits will be ignored" 

894 ) 

895 

896 

897class PrettyPictureStarFixerTask(PipelineTask): 

898 """This class fixes up regions in an image where there is no, or bad data. 

899 

900 The fixes done by this task are overwhelmingly comprised of the cores of 

901 bright stars for which there is no data. 

902 """ 

903 

904 _DefaultName = "prettyPictureStarFixer" 

905 ConfigClass = PrettyPictureStarFixerConfig 

906 

907 config: ConfigClass 

908 

909 def run(self, inputs: Mapping[str, ExposureF]) -> Struct: 

910 """Fix areas in an image where this is no data, most likely to be 

911 the cores of bright stars. 

912 

913 Because we want to have consistent fixes accross bands, this method 

914 relies on supplying all bands and fixing pixels that are marked 

915 as having a defect in any band even if within one band there is 

916 no issue. 

917 

918 Parameters 

919 ---------- 

920 inputs : `Mapping` of `str` to `ExposureF` 

921 This mapping has keys of band as a `str` and the corresponding 

922 ExposureF as a value. 

923 

924 Returns 

925 ------- 

926 results : `Struct` of `Mapping` of `str` to `ExposureF` 

927 A `Struct` that has a mapping of band to `ExposureF`. The `Struct` 

928 has an attribute named ``results``. 

929 

930 """ 

931 # make the joint mask of all the channels 

932 doJointMaskInit = True 

933 for imageExposure in inputs.values(): 

934 maskDict = imageExposure.mask.getMaskPlaneDict() 

935 if doJointMaskInit: 

936 jointMask = np.zeros(imageExposure.mask.array.shape, dtype=imageExposure.mask.array.dtype) 

937 doJointMaskInit = False 

938 jointMask |= imageExposure.mask.array 

939 

940 sat_bit = maskDict["SAT"] 

941 no_data_bit = maskDict["NO_DATA"] 

942 together = (jointMask & 2**sat_bit).astype(bool) | (jointMask & 2**no_data_bit).astype(bool) 

943 

944 # use the last imageExposure as it is likely close enough across all bands 

945 bright_mask = imageExposure.image.array > self.config.brightnessThresh 

946 

947 # dilate the mask a bit, this helps get a bit fainter mask without starting 

948 # to include pixels in an irregular shape, as only the star cores should be 

949 # fixed. 

950 both = together & bright_mask 

951 struct = np.array(((0, 1, 0), (1, 1, 1), (0, 1, 0)), dtype=bool) 

952 both = binary_dilation(both, struct, iterations=4).astype(bool) 

953 

954 # do the actual fixing of values 

955 results = {} 

956 for band, imageExposure in inputs.items(): 

957 if np.sum(both) > 0: 

958 inpainted = inpaint_biharmonic(imageExposure.image.array, both, split_into_regions=True) 

959 imageExposure.image.array[both] = inpainted[both] 

960 results[band] = imageExposure 

961 return Struct(results=results) 

962 

963 def runQuantum( 

964 self, 

965 butlerQC: QuantumContext, 

966 inputRefs: InputQuantizedConnection, 

967 outputRefs: OutputQuantizedConnection, 

968 ) -> None: 

969 refs = inputRefs.inputCoadd 

970 sortedImages: dict[str, Exposure] = {} 

971 for ref in refs: 

972 key: str = cast(str, ref.dataId["band"]) 

973 image = butlerQC.get(ref) 

974 sortedImages[key] = image 

975 

976 outputs = self.run(sortedImages).results 

977 sortedOutputs = {} 

978 for ref in outputRefs.outputCoadd: 

979 sortedOutputs[ref.dataId["band"]] = ref 

980 

981 for band, data in outputs.items(): 

982 butlerQC.put(data, sortedOutputs[band]) 

983 

984 

985class PrettyMosaicConnections(PipelineTaskConnections, dimensions=("tract", "skymap")): 

986 inputRGB = Input( 

987 doc="Individual RGB images that are to go into the mosaic", 

988 name="rgb_picture_array", 

989 storageClass="NumpyArray", 

990 dimensions=("tract", "patch", "skymap"), 

991 multiple=True, 

992 deferLoad=True, 

993 ) 

994 

995 skyMap = Input( 

996 doc="The skymap which the data has been mapped onto", 

997 storageClass="SkyMap", 

998 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

999 dimensions=("skymap",), 

1000 ) 

1001 

1002 inputRGBMask = Input( 

1003 doc="Individual RGB images that are to go into the mosaic", 

1004 name="rgb_picture_mask", 

1005 storageClass="Mask", 

1006 dimensions=("tract", "patch", "skymap"), 

1007 multiple=True, 

1008 deferLoad=True, 

1009 ) 

1010 

1011 outputRGBMosaic = Output( 

1012 doc="A RGB mosaic created from the input data stored as a 3d array", 

1013 name="rgb_mosaic_array", 

1014 storageClass="NumpyArray", 

1015 dimensions=("tract", "skymap"), 

1016 ) 

1017 

1018 

1019class PrettyMosaicConfig(PipelineTaskConfig, pipelineConnections=PrettyMosaicConnections): 

1020 binFactor = Field[int](doc="The factor to bin by when producing the mosaic") 

1021 doDCID65Convert = Field[bool]("Force the output to be converted from display p3 to DCI-D65 colorspace.") 

1022 useLocalTemp = Field[bool](doc="Use the current directory when creating local temp files.", default=False) 

1023 

1024 

1025class PrettyMosaicTask(PipelineTask): 

1026 """Combines multiple RGB arrays into one mosaic.""" 

1027 

1028 _DefaultName = "prettyMosaic" 

1029 ConfigClass = PrettyMosaicConfig 

1030 

1031 config: ConfigClass 

1032 

1033 def run( 

1034 self, 

1035 inputRGB: Iterable[DeferredDatasetHandle], 

1036 skyMap: BaseSkyMap, 

1037 inputRGBMask: Iterable[DeferredDatasetHandle], 

1038 ) -> Struct: 

1039 r"""Assemble individual `numpy.ndarrays` into a mosaic. 

1040 

1041 Each input is a `~lsst.daf.butler.DeferredDatasetHandle` because 

1042 they're loaded in one at a time to be placed into the mosaic to save 

1043 memory. 

1044 

1045 Parameters 

1046 ---------- 

1047 inputRGB : `Iterable` of `~lsst.daf.butler.DeferredDatasetHandle` 

1048 `~lsst.daf.butler.DeferredDatasetHandle`\ s pointing to RGB 

1049 `numpy.ndarrays`. 

1050 skyMap : `BaseSkyMap` 

1051 The skymap that defines the relative position of each of the input 

1052 images. 

1053 inputRGBMask : `Iterable` of `~lsst.daf.butler.DeferredDatasetHandle` 

1054 `~lsst.daf.butler.DeferredDatasetHandle`\ s pointing to masks for 

1055 each of the corresponding images. 

1056 

1057 Returns 

1058 ------- 

1059 result : `Struct` 

1060 The `Struct` containing the combined mosaic. The `Struct` has 

1061 and attribute named ``outputRGBMosaic``. 

1062 """ 

1063 # create the bounding region 

1064 newBox = Box2I() 

1065 # store the bounds as they are retrieved from the skymap 

1066 boxes = [] 

1067 tractMaps = [] 

1068 for handle in inputRGB: 

1069 dataId = handle.dataId 

1070 tractInfo: TractInfo = skyMap[dataId["tract"]] 

1071 patchInfo: PatchInfo = tractInfo[dataId["patch"]] 

1072 bbox = patchInfo.getOuterBBox() 

1073 boxes.append(bbox) 

1074 newBox.include(bbox) 

1075 tractMaps.append(tractInfo) 

1076 # This will be overwritten in the loop, but that is ok, because 

1077 # it is the same for each patch. 

1078 patch_grow: int = patchInfo.getCellInnerDimensions().getX() 

1079 

1080 # fixup the boxes to be smaller if needed, and put the origin at zero, 

1081 # this must be done after constructing the complete outer box 

1082 modifiedBoxes = [] 

1083 origin = newBox.getBegin() 

1084 for iterBox in boxes: 

1085 localOrigin = iterBox.getBegin() - origin 

1086 localOrigin = Point2I( 

1087 x=int(np.floor(localOrigin.x / self.config.binFactor)), 

1088 y=int(np.floor(localOrigin.y / self.config.binFactor)), 

1089 ) 

1090 localExtent = Extent2I( 

1091 x=int(np.floor(iterBox.getWidth() / self.config.binFactor)), 

1092 y=int(np.floor(iterBox.getHeight() / self.config.binFactor)), 

1093 ) 

1094 tmpBox = Box2I(localOrigin, localExtent) 

1095 modifiedBoxes.append(tmpBox) 

1096 boxes = modifiedBoxes 

1097 

1098 # scale the container box 

1099 newBoxOrigin = Point2I(0, 0) 

1100 newBoxExtent = Extent2I( 

1101 x=int(np.floor(newBox.getWidth() / self.config.binFactor)), 

1102 y=int(np.floor(newBox.getHeight() / self.config.binFactor)), 

1103 ) 

1104 newBox = Box2I(newBoxOrigin, newBoxExtent) 

1105 

1106 # Allocate storage for the mosaic 

1107 self.imageHandle = tempfile.NamedTemporaryFile(dir="." if self.config.useLocalTemp else None) 

1108 self.maskHandle = tempfile.NamedTemporaryFile(dir="." if self.config.useLocalTemp else None) 

1109 consolidatedImage = None 

1110 consolidatedMask = None 

1111 

1112 # Setup color space conversion in case they are used. 

1113 d65 = copy.deepcopy(colour.models.RGB_COLOURSPACE_DCI_P3) 

1114 dp3 = copy.deepcopy(colour.models.RGB_COLOURSPACE_DISPLAY_P3) 

1115 d65.whitepoint = dp3.whitepoint 

1116 d65.whitepoint_name = dp3.whitepoint_name 

1117 

1118 # Actually assemble the mosaic 

1119 maskDict = {} 

1120 mosaic_maker = FeatheredMosaicCreator(patch_grow, self.config.binFactor) 

1121 for box, handle, handleMask, tractInfo in zip(boxes, inputRGB, inputRGBMask, tractMaps): 

1122 rgb = handle.get() 

1123 # convert to the dci-d65 colorspace 

1124 if self.config.doDCID65Convert: 

1125 rgb = colour.RGB_to_RGB(np.clip(rgb, 0, 1), dp3, d65) 

1126 rgbMask = handleMask.get() 

1127 maskDict = rgbMask.getMaskPlaneDict() 

1128 # allocate the memory for the mosaic 

1129 if consolidatedImage is None: 

1130 consolidatedImage = np.memmap( 

1131 self.imageHandle.name, 

1132 mode="w+", 

1133 shape=(newBox.getHeight(), newBox.getWidth(), 3), 

1134 dtype=rgb.dtype, 

1135 ) 

1136 if consolidatedMask is None: 

1137 consolidatedMask = np.memmap( 

1138 self.maskHandle.name, 

1139 mode="w+", 

1140 shape=(newBox.getHeight(), newBox.getWidth()), 

1141 dtype=rgbMask.array.dtype, 

1142 ) 

1143 

1144 if self.config.binFactor > 1: 

1145 # opencv wants things in x, y dimensions 

1146 shape = tuple(box.getDimensions())[::-1] 

1147 rgb = cv2.resize( 

1148 rgb, 

1149 dst=None, 

1150 dsize=shape, 

1151 fx=shape[0] / self.config.binFactor, 

1152 fy=shape[1] / self.config.binFactor, 

1153 ) 

1154 mask_array = rgbMask.array[:: self.config.binFactor, :: self.config.binFactor] 

1155 rgbMask = Mask(*(mask_array.shape[::-1])) 

1156 mosaic_maker.add_to_image(consolidatedImage, rgb, newBox, box) 

1157 

1158 consolidatedMask[*box.slices] = np.bitwise_or(consolidatedMask[*box.slices], rgbMask.array) 

1159 

1160 for plugin in plugins.full(): 

1161 if consolidatedImage is not None and consolidatedMask is not None: 

1162 consolidatedImage = plugin(consolidatedImage, consolidatedMask, maskDict) 

1163 # If consolidated image still None, that means there was no work to do. 

1164 # Return an empty image instead of letting this task fail. 

1165 if consolidatedImage is None: 

1166 consolidatedImage = np.zeros((0, 0, 0), dtype=np.uint8) 

1167 

1168 return Struct(outputRGBMosaic=consolidatedImage) 

1169 

1170 def runQuantum( 

1171 self, 

1172 butlerQC: QuantumContext, 

1173 inputRefs: InputQuantizedConnection, 

1174 outputRefs: OutputQuantizedConnection, 

1175 ) -> None: 

1176 inputs = butlerQC.get(inputRefs) 

1177 outputs = self.run(**inputs) 

1178 butlerQC.put(outputs, outputRefs) 

1179 if hasattr(self, "imageHandle"): 

1180 self.imageHandle.close() 

1181 if hasattr(self, "maskHandle"): 

1182 self.maskHandle.close() 

1183 

1184 def makeInputsFromArrays( 

1185 self, inputs: Iterable[tuple[Mapping[str, Any], NDArray]] 

1186 ) -> Iterable[DeferredDatasetHandle]: 

1187 r"""Make valid inputs for the run method from numpy arrays. 

1188 

1189 Parameters 

1190 ---------- 

1191 inputs : `Iterable` of `tuple` of `Mapping` and `numpy.ndarray` 

1192 An iterable where each element is a tuple with the first 

1193 element is a mapping that corresponds to an arrays dataId, 

1194 and the second is an `numpy.ndarray`. 

1195 

1196 Returns 

1197 ------- 

1198 sortedImages : `Iterable` of `~lsst.daf.butler.DeferredDatasetHandle` 

1199 An iterable of `~lsst.daf.butler.DeferredDatasetHandle`\ s 

1200 containing the input data. 

1201 """ 

1202 structuredInputs = [] 

1203 for dataId, array in inputs: 

1204 structuredInputs.append(InMemoryDatasetHandle(inMemoryDataset=array, **dataId)) 

1205 

1206 return structuredInputs