Coverage for python/lsst/pipe/tasks/subtractBrightStars.py: 27%

106 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-03 03:39 -0700

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21# 

22"""Retrieve extended PSF model and subtract bright stars at calexp (ie 

23single visit) level. 

24""" 

25 

26__all__ = ["SubtractBrightStarsConnections", "SubtractBrightStarsConfig", "SubtractBrightStarsTask"] 

27 

28from functools import reduce 

29from operator import ior 

30 

31import numpy as np 

32from lsst.afw.image import Exposure, ExposureF, MaskedImageF 

33from lsst.afw.math import ( 

34 StatisticsControl, 

35 WarpingControl, 

36 makeStatistics, 

37 rotateImageBy90, 

38 stringToStatisticsProperty, 

39 warpImage, 

40) 

41from lsst.geom import Box2I, Point2D, Point2I 

42from lsst.pex.config import ChoiceField, Field, ListField 

43from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct 

44from lsst.pipe.base import connectionTypes as cT 

45 

46 

47class SubtractBrightStarsConnections( 

48 PipelineTaskConnections, 

49 dimensions=("instrument", "visit", "detector"), 

50 defaultTemplates={"outputExposureName": "brightStar_subtracted", "outputBackgroundName": "brightStars"}, 

51): 

52 inputExposure = cT.Input( 

53 doc="Input exposure from which to subtract bright star stamps.", 

54 name="calexp", 

55 storageClass="ExposureF", 

56 dimensions=( 

57 "visit", 

58 "detector", 

59 ), 

60 ) 

61 inputBrightStarStamps = cT.Input( 

62 doc="Set of preprocessed postage stamps, each centered on a single bright star.", 

63 name="brightStarStamps", 

64 storageClass="BrightStarStamps", 

65 dimensions=( 

66 "visit", 

67 "detector", 

68 ), 

69 ) 

70 inputExtendedPsf = cT.Input( 

71 doc="Extended PSF model.", 

72 name="extended_psf", 

73 storageClass="ExtendedPsf", 

74 dimensions=("band",), 

75 ) 

76 skyCorr = cT.Input( 

77 doc="Input Sky Correction to be subtracted from the calexp if ``doApplySkyCorr``=True.", 

78 name="skyCorr", 

79 storageClass="Background", 

80 dimensions=( 

81 "instrument", 

82 "visit", 

83 "detector", 

84 ), 

85 ) 

86 outputExposure = cT.Output( 

87 doc="Exposure with bright stars subtracted.", 

88 name="{outputExposureName}_calexp", 

89 storageClass="ExposureF", 

90 dimensions=( 

91 "visit", 

92 "detector", 

93 ), 

94 ) 

95 outputBackgroundExposure = cT.Output( 

96 doc="Exposure containing only the modelled bright stars.", 

97 name="{outputBackgroundName}_calexp_background", 

98 storageClass="ExposureF", 

99 dimensions=( 

100 "visit", 

101 "detector", 

102 ), 

103 ) 

104 

105 def __init__(self, *, config=None): 

106 super().__init__(config=config) 

107 if not config.doApplySkyCorr: 

108 self.inputs.remove("skyCorr") 

109 

110 

111class SubtractBrightStarsConfig(PipelineTaskConfig, pipelineConnections=SubtractBrightStarsConnections): 

112 """Configuration parameters for SubtractBrightStarsTask""" 

113 

114 doWriteSubtractor = Field[bool]( 

115 dtype=bool, 

116 doc="Should an exposure containing all bright star models be written to disk?", 

117 default=True, 

118 ) 

119 doWriteSubtractedExposure = Field[bool]( 

120 dtype=bool, 

121 doc="Should an exposure with bright stars subtracted be written to disk?", 

122 default=True, 

123 ) 

124 magLimit = Field[float]( 

125 dtype=float, 

126 doc="Magnitude limit, in Gaia G; all stars brighter than this value will be subtracted", 

127 default=18, 

128 ) 

129 warpingKernelName = ChoiceField[str]( 

130 dtype=str, 

131 doc="Warping kernel", 

132 default="lanczos5", 

133 allowed={ 

134 "bilinear": "bilinear interpolation", 

135 "lanczos3": "Lanczos kernel of order 3", 

136 "lanczos4": "Lanczos kernel of order 4", 

137 "lanczos5": "Lanczos kernel of order 5", 

138 "lanczos6": "Lanczos kernel of order 6", 

139 "lanczos7": "Lanczos kernel of order 7", 

140 }, 

141 ) 

142 scalingType = ChoiceField[str]( 

143 dtype=str, 

144 doc="How the model should be scaled to each bright star; implemented options are " 

145 "`annularFlux` to reuse the annular flux of each stamp, or `leastSquare` to perform " 

146 "least square fitting on each pixel with no bad mask plane set.", 

147 default="leastSquare", 

148 allowed={ 

149 "annularFlux": "reuse BrightStarStamp annular flux measurement", 

150 "leastSquare": "find least square scaling factor", 

151 }, 

152 ) 

153 badMaskPlanes = ListField[str]( 

154 dtype=str, 

155 doc="Mask planes that, if set, lead to associated pixels not being included in the computation of " 

156 "the scaling factor (`BAD` should always be included). Ignored if scalingType is `annularFlux`, " 

157 "as the stamps are expected to already be normalized.", 

158 # Note that `BAD` should always be included, as secondary detected 

159 # sources (i.e., detected sources other than the primary source of 

160 # interest) also get set to `BAD`. 

161 default=("BAD", "CR", "CROSSTALK", "EDGE", "NO_DATA", "SAT", "SUSPECT", "UNMASKEDNAN"), 

162 ) 

163 doApplySkyCorr = Field[bool]( 

164 dtype=bool, 

165 doc="Apply full focal plane sky correction before extracting stars?", 

166 default=True, 

167 ) 

168 

169 

170class SubtractBrightStarsTask(PipelineTask): 

171 """Use an extended PSF model to subtract bright stars from a calibrated 

172 exposure (i.e. at single-visit level). 

173 

174 This task uses both a set of bright star stamps produced by 

175 `~lsst.pipe.tasks.processBrightStars.ProcessBrightStarsTask` 

176 and an extended PSF model produced by 

177 `~lsst.pipe.tasks.extended_psf.MeasureExtendedPsfTask`. 

178 """ 

179 

180 ConfigClass = SubtractBrightStarsConfig 

181 _DefaultName = "subtractBrightStars" 

182 

183 def __init__(self, *args, **kwargs): 

184 super().__init__(*args, **kwargs) 

185 # Placeholders to set up Statistics if scalingType is leastSquare. 

186 self.statsControl, self.statsFlag = None, None 

187 

188 def _setUpStatistics(self, exampleMask): 

189 """Configure statistics control and flag, for use if ``scalingType`` is 

190 `leastSquare`. 

191 """ 

192 if self.config.scalingType == "leastSquare": 

193 self.statsControl = StatisticsControl() 

194 # Set the mask planes which will be ignored. 

195 andMask = reduce(ior, (exampleMask.getPlaneBitMask(bm) for bm in self.config.badMaskPlanes)) 

196 self.statsControl.setAndMask(andMask) 

197 self.statsFlag = stringToStatisticsProperty("SUM") 

198 

199 def applySkyCorr(self, calexp, skyCorr): 

200 """Apply correction to the sky background level. 

201 Sky corrections can be generated via the SkyCorrectionTask within the 

202 pipe_tools module. Because the sky model used by that code extends over 

203 the entire focal plane, this can produce better sky subtraction. 

204 The calexp is updated in-place. 

205 

206 Parameters 

207 ---------- 

208 calexp : `~lsst.afw.image.Exposure` or `~lsst.afw.image.MaskedImage` 

209 Calibrated exposure. 

210 skyCorr : `~lsst.afw.math.backgroundList.BackgroundList` 

211 Full focal plane sky correction, obtained by running 

212 `~lsst.pipe.drivers.skyCorrection.SkyCorrectionTask`. 

213 """ 

214 if isinstance(calexp, Exposure): 

215 calexp = calexp.getMaskedImage() 

216 calexp -= skyCorr.getImage() 

217 

218 def scaleModel(self, model, star, inPlace=True, nb90Rots=0): 

219 """Compute scaling factor to be applied to the extended PSF so that its 

220 amplitude matches that of an individual star. 

221 

222 Parameters 

223 ---------- 

224 model : `~lsst.afw.image.MaskedImageF` 

225 The extended PSF model, shifted (and potentially warped) to match 

226 the bright star's positioning. 

227 star : `~lsst.meas.algorithms.brightStarStamps.BrightStarStamp` 

228 A stamp centered on the bright star to be subtracted. 

229 inPlace : `bool` 

230 Whether the model should be scaled in place. Default is `True`. 

231 nb90Rots : `int` 

232 The number of 90-degrees rotations to apply to the star stamp. 

233 

234 Returns 

235 ------- 

236 scalingFactor : `float` 

237 The factor by which the model image should be multiplied for it 

238 to be scaled to the input bright star. 

239 """ 

240 if self.config.scalingType == "annularFlux": 

241 scalingFactor = star.annularFlux 

242 elif self.config.scalingType == "leastSquare": 

243 if self.statsControl is None: 

244 self._setUpStatistics(star.stamp_im.mask) 

245 starIm = star.stamp_im.clone() 

246 # Rotate the star postage stamp. 

247 starIm = rotateImageBy90(starIm, nb90Rots) 

248 # Reverse the prior star flux normalization ("unnormalize"). 

249 starIm *= star.annularFlux 

250 # The estimator of the scalingFactor (f) that minimizes (Y-fX)^2 

251 # is E[XY]/E[XX]. 

252 xy = starIm.clone() 

253 xy.image.array *= model.image.array 

254 xx = starIm.clone() 

255 xx.image.array = model.image.array**2 

256 # Compute the least squares scaling factor. 

257 xySum = makeStatistics(xy, self.statsFlag, self.statsControl).getValue() 

258 xxSum = makeStatistics(xx, self.statsFlag, self.statsControl).getValue() 

259 scalingFactor = xySum / xxSum if xxSum else 1 

260 if inPlace: 

261 model.image *= scalingFactor 

262 return scalingFactor 

263 

264 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

265 # Docstring inherited. 

266 inputs = butlerQC.get(inputRefs) 

267 dataId = butlerQC.quantum.dataId 

268 subtractor, _ = self.run(**inputs, dataId=dataId) 

269 if self.config.doWriteSubtractedExposure: 

270 outputExposure = inputs["inputExposure"].clone() 

271 outputExposure.image -= subtractor.image 

272 else: 

273 outputExposure = None 

274 outputBackgroundExposure = subtractor if self.config.doWriteSubtractor else None 

275 output = Struct(outputExposure=outputExposure, outputBackgroundExposure=outputBackgroundExposure) 

276 butlerQC.put(output, outputRefs) 

277 

278 def run(self, inputExposure, inputBrightStarStamps, inputExtendedPsf, dataId, skyCorr=None): 

279 """Iterate over all bright stars in an exposure to scale the extended 

280 PSF model before subtracting bright stars. 

281 

282 Parameters 

283 ---------- 

284 inputExposure : `~lsst.afw.image.exposure.exposure.ExposureF` 

285 The image from which bright stars should be subtracted. 

286 inputBrightStarStamps : 

287 `~lsst.meas.algorithms.brightStarStamps.BrightStarStamps` 

288 Set of stamps centered on each bright star to be subtracted, 

289 produced by running 

290 `~lsst.pipe.tasks.processBrightStars.ProcessBrightStarsTask`. 

291 inputExtendedPsf : `~lsst.pipe.tasks.extended_psf.ExtendedPsf` 

292 Extended PSF model, produced by 

293 `~lsst.pipe.tasks.extended_psf.MeasureExtendedPsfTask`. 

294 dataId : `dict` or `~lsst.daf.butler.DataCoordinate` 

295 The dataId of the exposure (and detector) bright stars should be 

296 subtracted from. 

297 skyCorr : `~lsst.afw.math.backgroundList.BackgroundList`, optional 

298 Full focal plane sky correction, obtained by running 

299 `~lsst.pipe.drivers.skyCorrection.SkyCorrectionTask`. If 

300 `doApplySkyCorr` is set to `True`, `skyCorr` cannot be `None`. 

301 

302 Returns 

303 ------- 

304 subtractorExp : `~lsst.afw.image.ExposureF` 

305 An Exposure containing a scaled bright star model fit to every 

306 bright star profile; its image can then be subtracted from the 

307 input exposure. 

308 invImages : `list` [`~lsst.afw.image.MaskedImageF`] 

309 A list of small images ("stamps") containing the model, each scaled 

310 to its corresponding input bright star. 

311 """ 

312 inputExpBBox = inputExposure.getBBox() 

313 if self.config.doApplySkyCorr and (skyCorr is not None): 

314 self.log.info( 

315 "Applying sky correction to exposure %s (exposure will be modified in-place).", dataId 

316 ) 

317 self.applySkyCorr(inputExposure, skyCorr) 

318 # Create an empty image the size of the exposure. 

319 # TODO: DM-31085 (set mask planes). 

320 subtractorExp = ExposureF(bbox=inputExposure.getBBox()) 

321 subtractor = subtractorExp.maskedImage 

322 # Make a copy of the input model. 

323 model = inputExtendedPsf(dataId["detector"]).clone() 

324 modelStampSize = model.getDimensions() 

325 inv90Rots = 4 - inputBrightStarStamps.nb90Rots % 4 

326 model = rotateImageBy90(model, inv90Rots) 

327 warpCont = WarpingControl(self.config.warpingKernelName) 

328 invImages = [] 

329 # Loop over bright stars, computing the inverse transformed and scaled 

330 # postage stamp for each. 

331 for star in inputBrightStarStamps: 

332 if star.gaiaGMag < self.config.magLimit: 

333 # Set the origin. 

334 model.setXY0(star.position) 

335 # Create an empty destination image. 

336 invTransform = star.archive_element.inverted() 

337 invOrigin = Point2I(invTransform.applyForward(Point2D(star.position))) 

338 bbox = Box2I(corner=invOrigin, dimensions=modelStampSize) 

339 invImage = MaskedImageF(bbox) 

340 # Apply inverse transform. 

341 goodPix = warpImage(invImage, model, invTransform, warpCont) 

342 if not goodPix: 

343 self.log.debug( 

344 f"Warping of a model failed for star {star.gaiaId}: " "no good pixel in output" 

345 ) 

346 # Scale the model. 

347 self.scaleModel(invImage, star, inPlace=True, nb90Rots=inv90Rots) 

348 # Replace NaNs before subtraction (note all NaN pixels have 

349 # the NO_DATA flag). 

350 invImage.image.array[np.isnan(invImage.image.array)] = 0 

351 bbox.clip(inputExpBBox) 

352 if bbox.getArea() > 0: 

353 subtractor[bbox] += invImage[bbox] 

354 invImages.append(invImage) 

355 return subtractorExp, invImages