Coverage for python/lsst/pipe/tasks/processBrightStars.py: 22%

179 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-05 04:19 -0700

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Extract bright star cutouts; normalize and warp to the same pixel grid.""" 

23 

24__all__ = ["ProcessBrightStarsTask"] 

25 

26import astropy.units as u 

27import numpy as np 

28from lsst.afw.cameraGeom import PIXELS, TAN_PIXELS 

29from lsst.afw.detection import FootprintSet, Threshold 

30from lsst.afw.geom.transformFactory import makeIdentityTransform, makeTransform 

31from lsst.afw.image import Exposure, ExposureF, MaskedImageF 

32from lsst.afw.math import ( 

33 StatisticsControl, 

34 WarpingControl, 

35 rotateImageBy90, 

36 stringToStatisticsProperty, 

37 warpImage, 

38) 

39from lsst.geom import AffineTransform, Box2I, Extent2I, Point2D, Point2I, SpherePoint, radians 

40from lsst.meas.algorithms import LoadReferenceObjectsConfig, ReferenceObjectLoader 

41from lsst.meas.algorithms.brightStarStamps import BrightStarStamp, BrightStarStamps 

42from lsst.pex.config import ChoiceField, ConfigField, Field, ListField 

43from lsst.pex.exceptions import InvalidParameterError 

44from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct 

45from lsst.pipe.base.connectionTypes import Input, Output, PrerequisiteInput 

46from lsst.utils.timer import timeMethod 

47 

48 

49class ProcessBrightStarsConnections(PipelineTaskConnections, dimensions=("instrument", "visit", "detector")): 

50 """Connections for ProcessBrightStarsTask.""" 

51 

52 inputExposure = Input( 

53 doc="Input exposure from which to extract bright star stamps", 

54 name="calexp", 

55 storageClass="ExposureF", 

56 dimensions=("visit", "detector"), 

57 ) 

58 skyCorr = Input( 

59 doc="Input Sky Correction to be subtracted from the calexp if doApplySkyCorr=True", 

60 name="skyCorr", 

61 storageClass="Background", 

62 dimensions=("instrument", "visit", "detector"), 

63 ) 

64 refCat = PrerequisiteInput( 

65 doc="Reference catalog that contains bright star positions", 

66 name="gaia_dr2_20200414", 

67 storageClass="SimpleCatalog", 

68 dimensions=("skypix",), 

69 multiple=True, 

70 deferLoad=True, 

71 ) 

72 brightStarStamps = Output( 

73 doc="Set of preprocessed postage stamps, each centered on a single bright star.", 

74 name="brightStarStamps", 

75 storageClass="BrightStarStamps", 

76 dimensions=("visit", "detector"), 

77 ) 

78 

79 def __init__(self, *, config=None): 

80 super().__init__(config=config) 

81 if not config.doApplySkyCorr: 

82 self.inputs.remove("skyCorr") 

83 

84 

85class ProcessBrightStarsConfig(PipelineTaskConfig, pipelineConnections=ProcessBrightStarsConnections): 

86 """Configuration parameters for ProcessBrightStarsTask.""" 

87 

88 magLimit = Field( 

89 dtype=float, 

90 doc="Magnitude limit, in Gaia G; all stars brighter than this value will be processed.", 

91 default=18, 

92 ) 

93 stampSize = ListField( 

94 dtype=int, 

95 doc="Size of the stamps to be extracted, in pixels.", 

96 default=(250, 250), 

97 ) 

98 modelStampBuffer = Field( 

99 dtype=float, 

100 doc=( 

101 "'Buffer' factor to be applied to determine the size of the stamp the processed stars will be " 

102 "saved in. This will also be the size of the extended PSF model." 

103 ), 

104 default=1.1, 

105 ) 

106 doRemoveDetected = Field( 

107 dtype=bool, 

108 doc="Whether DETECTION footprints, other than that for the central object, should be changed to BAD.", 

109 default=True, 

110 ) 

111 doApplyTransform = Field( 

112 dtype=bool, 

113 doc="Apply transform to bright star stamps to correct for optical distortions?", 

114 default=True, 

115 ) 

116 warpingKernelName = ChoiceField( 

117 dtype=str, 

118 doc="Warping kernel", 

119 default="lanczos5", 

120 allowed={ 

121 "bilinear": "bilinear interpolation", 

122 "lanczos3": "Lanczos kernel of order 3", 

123 "lanczos4": "Lanczos kernel of order 4", 

124 "lanczos5": "Lanczos kernel of order 5", 

125 }, 

126 ) 

127 annularFluxRadii = ListField( 

128 dtype=int, 

129 doc="Inner and outer radii of the annulus used to compute AnnularFlux for normalization, in pixels.", 

130 default=(40, 50), 

131 ) 

132 annularFluxStatistic = ChoiceField( 

133 dtype=str, 

134 doc="Type of statistic to use to compute annular flux.", 

135 default="MEANCLIP", 

136 allowed={ 

137 "MEAN": "mean", 

138 "MEDIAN": "median", 

139 "MEANCLIP": "clipped mean", 

140 }, 

141 ) 

142 numSigmaClip = Field( 

143 dtype=float, 

144 doc="Sigma for outlier rejection; ignored if annularFluxStatistic != 'MEANCLIP'.", 

145 default=4, 

146 ) 

147 numIter = Field( 

148 dtype=int, 

149 doc="Number of iterations of outlier rejection; ignored if annularFluxStatistic != 'MEANCLIP'.", 

150 default=3, 

151 ) 

152 badMaskPlanes = ListField( 

153 dtype=str, 

154 doc="Mask planes that identify pixels to not include in the computation of the annular flux.", 

155 default=("BAD", "CR", "CROSSTALK", "EDGE", "NO_DATA", "SAT", "SUSPECT", "UNMASKEDNAN"), 

156 ) 

157 minPixelsWithinFrame = Field( 

158 dtype=int, 

159 doc=( 

160 "Minimum number of pixels that must fall within the stamp boundary for the bright star to be " 

161 "saved when its center is beyond the exposure boundary." 

162 ), 

163 default=50, 

164 ) 

165 doApplySkyCorr = Field( 

166 dtype=bool, 

167 doc="Apply full focal plane sky correction before extracting stars?", 

168 default=True, 

169 ) 

170 discardNanFluxStars = Field( 

171 dtype=bool, 

172 doc="Should stars with NaN annular flux be discarded?", 

173 default=False, 

174 ) 

175 refObjLoader = ConfigField( 

176 dtype=LoadReferenceObjectsConfig, 

177 doc="Reference object loader for astrometric calibration.", 

178 ) 

179 

180 

181class ProcessBrightStarsTask(PipelineTask): 

182 """The description of the parameters for this Task are detailed in 

183 :lsst-task:`~lsst.pipe.base.PipelineTask`. 

184 

185 Parameters 

186 ---------- 

187 initInputs : `Unknown` 

188 *args 

189 Additional positional arguments. 

190 **kwargs 

191 Additional keyword arguments. 

192 

193 Notes 

194 ----- 

195 `ProcessBrightStarsTask` is used to extract, process, and store small 

196 image cut-outs (or "postage stamps") around bright stars. It relies on 

197 three methods, called in succession: 

198 

199 `extractStamps` 

200 Find bright stars within the exposure using a reference catalog and 

201 extract a stamp centered on each. 

202 `warpStamps` 

203 Shift and warp each stamp to remove optical distortions and sample all 

204 stars on the same pixel grid. 

205 `measureAndNormalize` 

206 Compute the flux of an object in an annulus and normalize it. This is 

207 required to normalize each bright star stamp as their central pixels 

208 are likely saturated and/or contain ghosts, and cannot be used. 

209 """ 

210 

211 ConfigClass = ProcessBrightStarsConfig 

212 _DefaultName = "processBrightStars" 

213 

214 def __init__(self, butler=None, initInputs=None, *args, **kwargs): 

215 super().__init__(*args, **kwargs) 

216 # Compute (model) stamp size depending on provided "buffer" value 

217 self.modelStampSize = [ 

218 int(self.config.stampSize[0] * self.config.modelStampBuffer), 

219 int(self.config.stampSize[1] * self.config.modelStampBuffer), 

220 ] 

221 # force it to be odd-sized so we have a central pixel 

222 if not self.modelStampSize[0] % 2: 

223 self.modelStampSize[0] += 1 

224 if not self.modelStampSize[1] % 2: 

225 self.modelStampSize[1] += 1 

226 # central pixel 

227 self.modelCenter = self.modelStampSize[0] // 2, self.modelStampSize[1] // 2 

228 # configure Gaia refcat 

229 if butler is not None: 

230 self.makeSubtask("refObjLoader", butler=butler) 

231 

232 def applySkyCorr(self, calexp, skyCorr): 

233 """Apply correction to the sky background level. 

234 

235 Sky corrections can be generated using the ``SkyCorrectionTask``. 

236 As the sky model generated there extends over the full focal plane, 

237 this should produce a more optimal sky subtraction solution. 

238 

239 Parameters 

240 ---------- 

241 calexp : `~lsst.afw.image.Exposure` or `~lsst.afw.image.MaskedImage` 

242 Calibrated exposure. 

243 skyCorr : `~lsst.afw.math.backgroundList.BackgroundList`, optional 

244 Full focal plane sky correction from ``SkyCorrectionTask``. 

245 

246 Notes 

247 ----- 

248 This method modifies the input ``calexp`` in-place. 

249 """ 

250 if isinstance(calexp, Exposure): 

251 calexp = calexp.getMaskedImage() 

252 calexp -= skyCorr.getImage() 

253 

254 def extractStamps(self, inputExposure, refObjLoader=None): 

255 """Read the position of bright stars within an input exposure using a 

256 refCat and extract them. 

257 

258 Parameters 

259 ---------- 

260 inputExposure : `~lsst.afw.image.ExposureF` 

261 The image from which bright star stamps should be extracted. 

262 refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional 

263 Loader to find objects within a reference catalog. 

264 

265 Returns 

266 ------- 

267 result : `~lsst.pipe.base.Struct` 

268 Results as a struct with attributes: 

269 

270 ``starIms`` 

271 Postage stamps (`list`). 

272 ``pixCenters`` 

273 Corresponding coords to each star's center, in pixels (`list`). 

274 ``GMags`` 

275 Corresponding (Gaia) G magnitudes (`list`). 

276 ``gaiaIds`` 

277 Corresponding unique Gaia identifiers (`np.ndarray`). 

278 """ 

279 if refObjLoader is None: 

280 refObjLoader = self.refObjLoader 

281 starIms = [] 

282 pixCenters = [] 

283 GMags = [] 

284 ids = [] 

285 wcs = inputExposure.getWcs() 

286 # select stars within, or close enough to input exposure from refcat 

287 inputIm = inputExposure.maskedImage 

288 inputExpBBox = inputExposure.getBBox() 

289 dilatationExtent = Extent2I(np.array(self.config.stampSize) - self.config.minPixelsWithinFrame) 

290 # TODO (DM-25894): handle catalog with stars missing from Gaia 

291 withinCalexp = refObjLoader.loadPixelBox( 

292 inputExpBBox.dilatedBy(dilatationExtent), wcs, filterName="phot_g_mean" 

293 ) 

294 refCat = withinCalexp.refCat 

295 # keep bright objects 

296 fluxLimit = ((self.config.magLimit * u.ABmag).to(u.nJy)).to_value() 

297 GFluxes = np.array(refCat["phot_g_mean_flux"]) 

298 bright = GFluxes > fluxLimit 

299 # convert to AB magnitudes 

300 allGMags = [((gFlux * u.nJy).to(u.ABmag)).to_value() for gFlux in GFluxes[bright]] 

301 allIds = refCat.columns.extract("id", where=bright)["id"] 

302 selectedColumns = refCat.columns.extract("coord_ra", "coord_dec", where=bright) 

303 for j, (ra, dec) in enumerate(zip(selectedColumns["coord_ra"], selectedColumns["coord_dec"])): 

304 sp = SpherePoint(ra, dec, radians) 

305 cpix = wcs.skyToPixel(sp) 

306 try: 

307 starIm = inputExposure.getCutout(sp, Extent2I(self.config.stampSize)) 

308 except InvalidParameterError: 

309 # star is beyond boundary 

310 bboxCorner = np.array(cpix) - np.array(self.config.stampSize) / 2 

311 # compute bbox as it would be otherwise 

312 idealBBox = Box2I(Point2I(bboxCorner), Extent2I(self.config.stampSize)) 

313 clippedStarBBox = Box2I(idealBBox) 

314 clippedStarBBox.clip(inputExpBBox) 

315 if clippedStarBBox.getArea() > 0: 

316 # create full-sized stamp with all pixels 

317 # flagged as NO_DATA 

318 starIm = ExposureF(bbox=idealBBox) 

319 starIm.image[:] = np.nan 

320 starIm.mask.set(inputExposure.mask.getPlaneBitMask("NO_DATA")) 

321 # recover pixels from intersection with the exposure 

322 clippedIm = inputIm.Factory(inputIm, clippedStarBBox) 

323 starIm.maskedImage[clippedStarBBox] = clippedIm 

324 # set detector and wcs, used in warpStars 

325 starIm.setDetector(inputExposure.getDetector()) 

326 starIm.setWcs(inputExposure.getWcs()) 

327 else: 

328 continue 

329 if self.config.doRemoveDetected: 

330 # give detection footprint of other objects the BAD flag 

331 detThreshold = Threshold(starIm.mask.getPlaneBitMask("DETECTED"), Threshold.BITMASK) 

332 omask = FootprintSet(starIm.mask, detThreshold) 

333 allFootprints = omask.getFootprints() 

334 otherFootprints = [] 

335 for fs in allFootprints: 

336 if not fs.contains(Point2I(cpix)): 

337 otherFootprints.append(fs) 

338 nbMatchingFootprints = len(allFootprints) - len(otherFootprints) 

339 if not nbMatchingFootprints == 1: 

340 self.log.warning( 

341 "Failed to uniquely identify central DETECTION footprint for star " 

342 "%s; found %d footprints instead.", 

343 allIds[j], 

344 nbMatchingFootprints, 

345 ) 

346 omask.setFootprints(otherFootprints) 

347 omask.setMask(starIm.mask, "BAD") 

348 starIms.append(starIm) 

349 pixCenters.append(cpix) 

350 GMags.append(allGMags[j]) 

351 ids.append(allIds[j]) 

352 return Struct(starIms=starIms, pixCenters=pixCenters, GMags=GMags, gaiaIds=ids) 

353 

354 def warpStamps(self, stamps, pixCenters): 

355 """Warps and shifts all given stamps so they are sampled on the same 

356 pixel grid and centered on the central pixel. This includes rotating 

357 the stamp depending on detector orientation. 

358 

359 Parameters 

360 ---------- 

361 stamps : `Sequence` [`~lsst.afw.image.ExposureF`] 

362 Image cutouts centered on a single object. 

363 pixCenters : `Sequence` [`~lsst.geom.Point2D`] 

364 Positions of each object's center (from the refCat) in pixels. 

365 

366 Returns 

367 ------- 

368 result : `~lsst.pipe.base.Struct` 

369 Results as a struct with attributes: 

370 

371 ``warpedStars`` 

372 Stamps of warped stars. 

373 (`list` [`~lsst.afw.image.MaskedImage`]) 

374 ``warpTransforms`` 

375 The corresponding Transform from the initial star stamp 

376 to the common model grid. 

377 (`list` [`~lsst.afw.geom.TransformPoint2ToPoint2`]) 

378 ``xy0s`` 

379 Coordinates of the bottom-left pixels of each stamp, 

380 before rotation. 

381 (`list` [`~lsst.geom.Point2I`]) 

382 ``nb90Rots`` 

383 The number of 90 degrees rotations required to compensate for 

384 detector orientation. 

385 (`int`) 

386 """ 

387 # warping control; only contains shiftingALg provided in config 

388 warpCont = WarpingControl(self.config.warpingKernelName) 

389 # Compare model to star stamp sizes 

390 bufferPix = ( 

391 self.modelStampSize[0] - self.config.stampSize[0], 

392 self.modelStampSize[1] - self.config.stampSize[1], 

393 ) 

394 # Initialize detector instance (note all stars were extracted from an 

395 # exposure from the same detector) 

396 det = stamps[0].getDetector() 

397 # Define correction for optical distortions 

398 if self.config.doApplyTransform: 

399 pixToTan = det.getTransform(PIXELS, TAN_PIXELS) 

400 else: 

401 pixToTan = makeIdentityTransform() 

402 # Array of all possible rotations for detector orientation: 

403 possibleRots = np.array([k * np.pi / 2 for k in range(4)]) 

404 # determine how many, if any, rotations are required 

405 yaw = det.getOrientation().getYaw() 

406 nb90Rots = np.argmin(np.abs(possibleRots - float(yaw))) 

407 

408 # apply transformation to each star 

409 warpedStars, warpTransforms, xy0s = [], [], [] 

410 for star, cent in zip(stamps, pixCenters): 

411 # (re)create empty destination image 

412 destImage = MaskedImageF(*self.modelStampSize) 

413 bottomLeft = Point2D(star.image.getXY0()) 

414 newBottomLeft = pixToTan.applyForward(bottomLeft) 

415 newBottomLeft.setX(newBottomLeft.getX() - bufferPix[0] / 2) 

416 newBottomLeft.setY(newBottomLeft.getY() - bufferPix[1] / 2) 

417 # Convert to int 

418 newBottomLeft = Point2I(newBottomLeft) 

419 # Set origin and save it 

420 destImage.setXY0(newBottomLeft) 

421 xy0s.append(newBottomLeft) 

422 

423 # Define linear shifting to recenter stamps 

424 newCenter = pixToTan.applyForward(cent) # center of warped star 

425 shift = ( 

426 self.modelCenter[0] + newBottomLeft[0] - newCenter[0], 

427 self.modelCenter[1] + newBottomLeft[1] - newCenter[1], 

428 ) 

429 affineShift = AffineTransform(shift) 

430 shiftTransform = makeTransform(affineShift) 

431 

432 # Define full transform (warp and shift) 

433 starWarper = pixToTan.then(shiftTransform) 

434 

435 # Apply it 

436 goodPix = warpImage(destImage, star.getMaskedImage(), starWarper, warpCont) 

437 if not goodPix: 

438 self.log.debug("Warping of a star failed: no good pixel in output") 

439 

440 # Arbitrarily set origin of shifted star to 0 

441 destImage.setXY0(0, 0) 

442 

443 # Apply rotation if appropriate 

444 if nb90Rots: 

445 destImage = rotateImageBy90(destImage, nb90Rots) 

446 warpedStars.append(destImage.clone()) 

447 warpTransforms.append(starWarper) 

448 return Struct(warpedStars=warpedStars, warpTransforms=warpTransforms, xy0s=xy0s, nb90Rots=nb90Rots) 

449 

450 @timeMethod 

451 def run(self, inputExposure, refObjLoader=None, dataId=None, skyCorr=None): 

452 """Identify bright stars within an exposure using a reference catalog, 

453 extract stamps around each, then preprocess them. The preprocessing 

454 steps are: shifting, warping and potentially rotating them to the same 

455 pixel grid; computing their annular flux and normalizing them. 

456 

457 Parameters 

458 ---------- 

459 inputExposure : `~lsst.afw.image.ExposureF` 

460 The image from which bright star stamps should be extracted. 

461 refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional 

462 Loader to find objects within a reference catalog. 

463 dataId : `dict` or `~lsst.daf.butler.DataCoordinate` 

464 The dataId of the exposure (and detector) bright stars should be 

465 extracted from. 

466 skyCorr : `~lsst.afw.math.backgroundList.BackgroundList`, optional 

467 Full focal plane sky correction obtained by `SkyCorrectionTask`. 

468 

469 Returns 

470 ------- 

471 result : `~lsst.pipe.base.Struct` 

472 Results as a struct with attributes: 

473 

474 ``brightStarStamps`` 

475 (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`) 

476 """ 

477 if self.config.doApplySkyCorr: 

478 self.log.info( 

479 "Applying sky correction to exposure %s (exposure will be modified in-place).", dataId 

480 ) 

481 self.applySkyCorr(inputExposure, skyCorr) 

482 self.log.info("Extracting bright stars from exposure %s", dataId) 

483 # Extract stamps around bright stars 

484 extractedStamps = self.extractStamps(inputExposure, refObjLoader=refObjLoader) 

485 if not extractedStamps.starIms: 

486 self.log.info("No suitable bright star found.") 

487 return None 

488 # Warp (and shift, and potentially rotate) them 

489 self.log.info( 

490 "Applying warp and/or shift to %i star stamps from exposure %s.", 

491 len(extractedStamps.starIms), 

492 dataId, 

493 ) 

494 warpOutputs = self.warpStamps(extractedStamps.starIms, extractedStamps.pixCenters) 

495 warpedStars = warpOutputs.warpedStars 

496 xy0s = warpOutputs.xy0s 

497 brightStarList = [ 

498 BrightStarStamp( 

499 stamp_im=warp, 

500 archive_element=transform, 

501 position=xy0s[j], 

502 gaiaGMag=extractedStamps.GMags[j], 

503 gaiaId=extractedStamps.gaiaIds[j], 

504 ) 

505 for j, (warp, transform) in enumerate(zip(warpedStars, warpOutputs.warpTransforms)) 

506 ] 

507 # Compute annularFlux and normalize 

508 self.log.info( 

509 "Computing annular flux and normalizing %i bright stars from exposure %s.", 

510 len(warpedStars), 

511 dataId, 

512 ) 

513 # annularFlux statistic set-up, excluding mask planes 

514 statsControl = StatisticsControl() 

515 statsControl.setNumSigmaClip(self.config.numSigmaClip) 

516 statsControl.setNumIter(self.config.numIter) 

517 innerRadius, outerRadius = self.config.annularFluxRadii 

518 statsFlag = stringToStatisticsProperty(self.config.annularFluxStatistic) 

519 brightStarStamps = BrightStarStamps.initAndNormalize( 

520 brightStarList, 

521 innerRadius=innerRadius, 

522 outerRadius=outerRadius, 

523 nb90Rots=warpOutputs.nb90Rots, 

524 imCenter=self.modelCenter, 

525 use_archive=True, 

526 statsControl=statsControl, 

527 statsFlag=statsFlag, 

528 badMaskPlanes=self.config.badMaskPlanes, 

529 discardNanFluxObjects=(self.config.discardNanFluxStars), 

530 ) 

531 return Struct(brightStarStamps=brightStarStamps) 

532 

533 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

534 inputs = butlerQC.get(inputRefs) 

535 inputs["dataId"] = str(butlerQC.quantum.dataId) 

536 refObjLoader = ReferenceObjectLoader( 

537 dataIds=[ref.datasetRef.dataId for ref in inputRefs.refCat], 

538 refCats=inputs.pop("refCat"), 

539 name=self.config.connections.refCat, 

540 config=self.config.refObjLoader, 

541 ) 

542 output = self.run(**inputs, refObjLoader=refObjLoader) 

543 if output: 

544 butlerQC.put(output, outputRefs)