Coverage for python/lsst/pipe/tasks/subtractBrightStars.py: 27%

106 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-12 02:46 -0700

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Retrieve extended PSF model and subtract bright stars at visit level.""" 

23 

24__all__ = ["SubtractBrightStarsConnections", "SubtractBrightStarsConfig", "SubtractBrightStarsTask"] 

25 

26from functools import reduce 

27from operator import ior 

28 

29import numpy as np 

30from lsst.afw.image import Exposure, ExposureF, MaskedImageF 

31from lsst.afw.math import ( 

32 StatisticsControl, 

33 WarpingControl, 

34 makeStatistics, 

35 rotateImageBy90, 

36 stringToStatisticsProperty, 

37 warpImage, 

38) 

39from lsst.geom import Box2I, Point2D, Point2I 

40from lsst.pex.config import ChoiceField, Field, ListField 

41from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct 

42from lsst.pipe.base.connectionTypes import Input, Output 

43 

44 

45class SubtractBrightStarsConnections( 

46 PipelineTaskConnections, 

47 dimensions=("instrument", "visit", "detector"), 

48 defaultTemplates={"outputExposureName": "brightStar_subtracted", "outputBackgroundName": "brightStars"}, 

49): 

50 inputExposure = Input( 

51 doc="Input exposure from which to subtract bright star stamps.", 

52 name="calexp", 

53 storageClass="ExposureF", 

54 dimensions=( 

55 "visit", 

56 "detector", 

57 ), 

58 ) 

59 inputBrightStarStamps = Input( 

60 doc="Set of preprocessed postage stamps, each centered on a single bright star.", 

61 name="brightStarStamps", 

62 storageClass="BrightStarStamps", 

63 dimensions=( 

64 "visit", 

65 "detector", 

66 ), 

67 ) 

68 inputExtendedPsf = Input( 

69 doc="Extended PSF model.", 

70 name="extended_psf", 

71 storageClass="ExtendedPsf", 

72 dimensions=("band",), 

73 ) 

74 skyCorr = Input( 

75 doc="Input Sky Correction to be subtracted from the calexp if ``doApplySkyCorr``=True.", 

76 name="skyCorr", 

77 storageClass="Background", 

78 dimensions=( 

79 "instrument", 

80 "visit", 

81 "detector", 

82 ), 

83 ) 

84 outputExposure = Output( 

85 doc="Exposure with bright stars subtracted.", 

86 name="{outputExposureName}_calexp", 

87 storageClass="ExposureF", 

88 dimensions=( 

89 "visit", 

90 "detector", 

91 ), 

92 ) 

93 outputBackgroundExposure = Output( 

94 doc="Exposure containing only the modelled bright stars.", 

95 name="{outputBackgroundName}_calexp_background", 

96 storageClass="ExposureF", 

97 dimensions=( 

98 "visit", 

99 "detector", 

100 ), 

101 ) 

102 

103 def __init__(self, *, config=None): 

104 super().__init__(config=config) 

105 if not config.doApplySkyCorr: 

106 self.inputs.remove("skyCorr") 

107 

108 

109class SubtractBrightStarsConfig(PipelineTaskConfig, pipelineConnections=SubtractBrightStarsConnections): 

110 """Configuration parameters for SubtractBrightStarsTask""" 

111 

112 doWriteSubtractor = Field[bool]( 

113 dtype=bool, 

114 doc="Should an exposure containing all bright star models be written to disk?", 

115 default=True, 

116 ) 

117 doWriteSubtractedExposure = Field[bool]( 

118 dtype=bool, 

119 doc="Should an exposure with bright stars subtracted be written to disk?", 

120 default=True, 

121 ) 

122 magLimit = Field[float]( 

123 dtype=float, 

124 doc="Magnitude limit, in Gaia G; all stars brighter than this value will be subtracted", 

125 default=18, 

126 ) 

127 warpingKernelName = ChoiceField[str]( 

128 dtype=str, 

129 doc="Warping kernel", 

130 default="lanczos5", 

131 allowed={ 

132 "bilinear": "bilinear interpolation", 

133 "lanczos3": "Lanczos kernel of order 3", 

134 "lanczos4": "Lanczos kernel of order 4", 

135 "lanczos5": "Lanczos kernel of order 5", 

136 "lanczos6": "Lanczos kernel of order 6", 

137 "lanczos7": "Lanczos kernel of order 7", 

138 }, 

139 ) 

140 scalingType = ChoiceField[str]( 

141 dtype=str, 

142 doc="How the model should be scaled to each bright star; implemented options are " 

143 "`annularFlux` to reuse the annular flux of each stamp, or `leastSquare` to perform " 

144 "least square fitting on each pixel with no bad mask plane set.", 

145 default="leastSquare", 

146 allowed={ 

147 "annularFlux": "reuse BrightStarStamp annular flux measurement", 

148 "leastSquare": "find least square scaling factor", 

149 }, 

150 ) 

151 badMaskPlanes = ListField[str]( 

152 dtype=str, 

153 doc="Mask planes that, if set, lead to associated pixels not being included in the computation of " 

154 "the scaling factor (`BAD` should always be included). Ignored if scalingType is `annularFlux`, " 

155 "as the stamps are expected to already be normalized.", 

156 # Note that `BAD` should always be included, as secondary detected 

157 # sources (i.e., detected sources other than the primary source of 

158 # interest) also get set to `BAD`. 

159 default=("BAD", "CR", "CROSSTALK", "EDGE", "NO_DATA", "SAT", "SUSPECT", "UNMASKEDNAN"), 

160 ) 

161 doApplySkyCorr = Field[bool]( 

162 dtype=bool, 

163 doc="Apply full focal plane sky correction before extracting stars?", 

164 default=True, 

165 ) 

166 

167 

168class SubtractBrightStarsTask(PipelineTask): 

169 """Use an extended PSF model to subtract bright stars from a calibrated 

170 exposure (i.e. at single-visit level). 

171 

172 This task uses both a set of bright star stamps produced by 

173 `~lsst.pipe.tasks.processBrightStars.ProcessBrightStarsTask` 

174 and an extended PSF model produced by 

175 `~lsst.pipe.tasks.extended_psf.MeasureExtendedPsfTask`. 

176 """ 

177 

178 ConfigClass = SubtractBrightStarsConfig 

179 _DefaultName = "subtractBrightStars" 

180 

181 def __init__(self, *args, **kwargs): 

182 super().__init__(*args, **kwargs) 

183 # Placeholders to set up Statistics if scalingType is leastSquare. 

184 self.statsControl, self.statsFlag = None, None 

185 

186 def _setUpStatistics(self, exampleMask): 

187 """Configure statistics control and flag, for use if ``scalingType`` is 

188 `leastSquare`. 

189 """ 

190 if self.config.scalingType == "leastSquare": 

191 self.statsControl = StatisticsControl() 

192 # Set the mask planes which will be ignored. 

193 andMask = reduce(ior, (exampleMask.getPlaneBitMask(bm) for bm in self.config.badMaskPlanes)) 

194 self.statsControl.setAndMask(andMask) 

195 self.statsFlag = stringToStatisticsProperty("SUM") 

196 

197 def applySkyCorr(self, calexp, skyCorr): 

198 """Apply correction to the sky background level. 

199 Sky corrections can be generated via the SkyCorrectionTask within the 

200 pipe_tools module. Because the sky model used by that code extends over 

201 the entire focal plane, this can produce better sky subtraction. 

202 The calexp is updated in-place. 

203 

204 Parameters 

205 ---------- 

206 calexp : `~lsst.afw.image.Exposure` or `~lsst.afw.image.MaskedImage` 

207 Calibrated exposure. 

208 skyCorr : `~lsst.afw.math.backgroundList.BackgroundList` 

209 Full focal plane sky correction, obtained by running 

210 `~lsst.pipe.tasks.skyCorrection.SkyCorrectionTask`. 

211 """ 

212 if isinstance(calexp, Exposure): 

213 calexp = calexp.getMaskedImage() 

214 calexp -= skyCorr.getImage() 

215 

216 def scaleModel(self, model, star, inPlace=True, nb90Rots=0): 

217 """Compute scaling factor to be applied to the extended PSF so that its 

218 amplitude matches that of an individual star. 

219 

220 Parameters 

221 ---------- 

222 model : `~lsst.afw.image.MaskedImageF` 

223 The extended PSF model, shifted (and potentially warped) to match 

224 the bright star's positioning. 

225 star : `~lsst.meas.algorithms.brightStarStamps.BrightStarStamp` 

226 A stamp centered on the bright star to be subtracted. 

227 inPlace : `bool` 

228 Whether the model should be scaled in place. Default is `True`. 

229 nb90Rots : `int` 

230 The number of 90-degrees rotations to apply to the star stamp. 

231 

232 Returns 

233 ------- 

234 scalingFactor : `float` 

235 The factor by which the model image should be multiplied for it 

236 to be scaled to the input bright star. 

237 """ 

238 if self.config.scalingType == "annularFlux": 

239 scalingFactor = star.annularFlux 

240 elif self.config.scalingType == "leastSquare": 

241 if self.statsControl is None: 

242 self._setUpStatistics(star.stamp_im.mask) 

243 starIm = star.stamp_im.clone() 

244 # Rotate the star postage stamp. 

245 starIm = rotateImageBy90(starIm, nb90Rots) 

246 # Reverse the prior star flux normalization ("unnormalize"). 

247 starIm *= star.annularFlux 

248 # The estimator of the scalingFactor (f) that minimizes (Y-fX)^2 

249 # is E[XY]/E[XX]. 

250 xy = starIm.clone() 

251 xy.image.array *= model.image.array 

252 xx = starIm.clone() 

253 xx.image.array = model.image.array**2 

254 # Compute the least squares scaling factor. 

255 xySum = makeStatistics(xy, self.statsFlag, self.statsControl).getValue() 

256 xxSum = makeStatistics(xx, self.statsFlag, self.statsControl).getValue() 

257 scalingFactor = xySum / xxSum if xxSum else 1 

258 if inPlace: 

259 model.image *= scalingFactor 

260 return scalingFactor 

261 

262 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

263 # Docstring inherited. 

264 inputs = butlerQC.get(inputRefs) 

265 dataId = butlerQC.quantum.dataId 

266 subtractor, _ = self.run(**inputs, dataId=dataId) 

267 if self.config.doWriteSubtractedExposure: 

268 outputExposure = inputs["inputExposure"].clone() 

269 outputExposure.image -= subtractor.image 

270 else: 

271 outputExposure = None 

272 outputBackgroundExposure = subtractor if self.config.doWriteSubtractor else None 

273 output = Struct(outputExposure=outputExposure, outputBackgroundExposure=outputBackgroundExposure) 

274 butlerQC.put(output, outputRefs) 

275 

276 def run(self, inputExposure, inputBrightStarStamps, inputExtendedPsf, dataId, skyCorr=None): 

277 """Iterate over all bright stars in an exposure to scale the extended 

278 PSF model before subtracting bright stars. 

279 

280 Parameters 

281 ---------- 

282 inputExposure : `~lsst.afw.image.ExposureF` 

283 The image from which bright stars should be subtracted. 

284 inputBrightStarStamps : 

285 `~lsst.meas.algorithms.brightStarStamps.BrightStarStamps` 

286 Set of stamps centered on each bright star to be subtracted, 

287 produced by running 

288 `~lsst.pipe.tasks.processBrightStars.ProcessBrightStarsTask`. 

289 inputExtendedPsf : `~lsst.pipe.tasks.extended_psf.ExtendedPsf` 

290 Extended PSF model, produced by 

291 `~lsst.pipe.tasks.extended_psf.MeasureExtendedPsfTask`. 

292 dataId : `dict` or `~lsst.daf.butler.DataCoordinate` 

293 The dataId of the exposure (and detector) bright stars should be 

294 subtracted from. 

295 skyCorr : `~lsst.afw.math.backgroundList.BackgroundList`, optional 

296 Full focal plane sky correction, obtained by running 

297 `~lsst.pipe.tasks.skyCorrection.SkyCorrectionTask`. If 

298 `doApplySkyCorr` is set to `True`, `skyCorr` cannot be `None`. 

299 

300 Returns 

301 ------- 

302 subtractorExp : `~lsst.afw.image.ExposureF` 

303 An Exposure containing a scaled bright star model fit to every 

304 bright star profile; its image can then be subtracted from the 

305 input exposure. 

306 invImages : `list` [`~lsst.afw.image.MaskedImageF`] 

307 A list of small images ("stamps") containing the model, each scaled 

308 to its corresponding input bright star. 

309 """ 

310 inputExpBBox = inputExposure.getBBox() 

311 if self.config.doApplySkyCorr and (skyCorr is not None): 

312 self.log.info( 

313 "Applying sky correction to exposure %s (exposure will be modified in-place).", dataId 

314 ) 

315 self.applySkyCorr(inputExposure, skyCorr) 

316 # Create an empty image the size of the exposure. 

317 # TODO: DM-31085 (set mask planes). 

318 subtractorExp = ExposureF(bbox=inputExposure.getBBox()) 

319 subtractor = subtractorExp.maskedImage 

320 # Make a copy of the input model. 

321 model = inputExtendedPsf(dataId["detector"]).clone() 

322 modelStampSize = model.getDimensions() 

323 inv90Rots = 4 - inputBrightStarStamps.nb90Rots % 4 

324 model = rotateImageBy90(model, inv90Rots) 

325 warpCont = WarpingControl(self.config.warpingKernelName) 

326 invImages = [] 

327 # Loop over bright stars, computing the inverse transformed and scaled 

328 # postage stamp for each. 

329 for star in inputBrightStarStamps: 

330 if star.gaiaGMag < self.config.magLimit: 

331 # Set the origin. 

332 model.setXY0(star.position) 

333 # Create an empty destination image. 

334 invTransform = star.archive_element.inverted() 

335 invOrigin = Point2I(invTransform.applyForward(Point2D(star.position))) 

336 bbox = Box2I(corner=invOrigin, dimensions=modelStampSize) 

337 invImage = MaskedImageF(bbox) 

338 # Apply inverse transform. 

339 goodPix = warpImage(invImage, model, invTransform, warpCont) 

340 if not goodPix: 

341 self.log.debug( 

342 f"Warping of a model failed for star {star.gaiaId}: " "no good pixel in output" 

343 ) 

344 # Scale the model. 

345 self.scaleModel(invImage, star, inPlace=True, nb90Rots=inv90Rots) 

346 # Replace NaNs before subtraction (note all NaN pixels have 

347 # the NO_DATA flag). 

348 invImage.image.array[np.isnan(invImage.image.array)] = 0 

349 bbox.clip(inputExpBBox) 

350 if bbox.getArea() > 0: 

351 subtractor[bbox] += invImage[bbox] 

352 invImages.append(invImage) 

353 return subtractorExp, invImages