Coverage for python/lsst/pipe/tasks/processBrightStars.py: 22%

179 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-23 10:58 +0000

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Extract bright star cutouts; normalize and warp to the same pixel grid.""" 

23 

24__all__ = ["ProcessBrightStarsTask"] 

25 

26import astropy.units as u 

27import numpy as np 

28from lsst.afw.cameraGeom import PIXELS, TAN_PIXELS 

29from lsst.afw.detection import FootprintSet, Threshold 

30from lsst.afw.geom.transformFactory import makeIdentityTransform, makeTransform 

31from lsst.afw.image import Exposure, ExposureF, MaskedImageF 

32from lsst.afw.math import ( 

33 StatisticsControl, 

34 WarpingControl, 

35 rotateImageBy90, 

36 stringToStatisticsProperty, 

37 warpImage, 

38) 

39from lsst.geom import AffineTransform, Box2I, Extent2I, Point2D, Point2I, SpherePoint, radians 

40from lsst.meas.algorithms import LoadReferenceObjectsConfig, ReferenceObjectLoader 

41from lsst.meas.algorithms.brightStarStamps import BrightStarStamp, BrightStarStamps 

42from lsst.pex.config import ChoiceField, ConfigField, Field, ListField 

43from lsst.pex.exceptions import InvalidParameterError 

44from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct 

45from lsst.pipe.base.connectionTypes import Input, Output, PrerequisiteInput 

46from lsst.utils.timer import timeMethod 

47 

48 

49class ProcessBrightStarsConnections(PipelineTaskConnections, dimensions=("instrument", "visit", "detector")): 

50 """Connections for ProcessBrightStarsTask.""" 

51 

52 inputExposure = Input( 

53 doc="Input exposure from which to extract bright star stamps", 

54 name="calexp", 

55 storageClass="ExposureF", 

56 dimensions=("visit", "detector"), 

57 ) 

58 skyCorr = Input( 

59 doc="Input Sky Correction to be subtracted from the calexp if doApplySkyCorr=True", 

60 name="skyCorr", 

61 storageClass="Background", 

62 dimensions=("instrument", "visit", "detector"), 

63 ) 

64 refCat = PrerequisiteInput( 

65 doc="Reference catalog that contains bright star positions", 

66 name="gaia_dr2_20200414", 

67 storageClass="SimpleCatalog", 

68 dimensions=("skypix",), 

69 multiple=True, 

70 deferLoad=True, 

71 ) 

72 brightStarStamps = Output( 

73 doc="Set of preprocessed postage stamps, each centered on a single bright star.", 

74 name="brightStarStamps", 

75 storageClass="BrightStarStamps", 

76 dimensions=("visit", "detector"), 

77 ) 

78 

79 def __init__(self, *, config=None): 

80 super().__init__(config=config) 

81 if not config.doApplySkyCorr: 

82 self.inputs.remove("skyCorr") 

83 

84 

85class ProcessBrightStarsConfig(PipelineTaskConfig, pipelineConnections=ProcessBrightStarsConnections): 

86 """Configuration parameters for ProcessBrightStarsTask.""" 

87 

88 magLimit = Field( 

89 dtype=float, 

90 doc="Magnitude limit, in Gaia G; all stars brighter than this value will be processed.", 

91 default=18, 

92 ) 

93 stampSize = ListField( 

94 dtype=int, 

95 doc="Size of the stamps to be extracted, in pixels.", 

96 default=(250, 250), 

97 ) 

98 modelStampBuffer = Field( 

99 dtype=float, 

100 doc=( 

101 "'Buffer' factor to be applied to determine the size of the stamp the processed stars will be " 

102 "saved in. This will also be the size of the extended PSF model." 

103 ), 

104 default=1.1, 

105 ) 

106 doRemoveDetected = Field( 

107 dtype=bool, 

108 doc="Whether DETECTION footprints, other than that for the central object, should be changed to BAD.", 

109 default=True, 

110 ) 

111 doApplyTransform = Field( 

112 dtype=bool, 

113 doc="Apply transform to bright star stamps to correct for optical distortions?", 

114 default=True, 

115 ) 

116 warpingKernelName = ChoiceField( 

117 dtype=str, 

118 doc="Warping kernel", 

119 default="lanczos5", 

120 allowed={ 

121 "bilinear": "bilinear interpolation", 

122 "lanczos3": "Lanczos kernel of order 3", 

123 "lanczos4": "Lanczos kernel of order 4", 

124 "lanczos5": "Lanczos kernel of order 5", 

125 }, 

126 ) 

127 annularFluxRadii = ListField( 

128 dtype=int, 

129 doc="Inner and outer radii of the annulus used to compute AnnularFlux for normalization, in pixels.", 

130 default=(70, 80), 

131 ) 

132 annularFluxStatistic = ChoiceField( 

133 dtype=str, 

134 doc="Type of statistic to use to compute annular flux.", 

135 default="MEANCLIP", 

136 allowed={ 

137 "MEAN": "mean", 

138 "MEDIAN": "median", 

139 "MEANCLIP": "clipped mean", 

140 }, 

141 ) 

142 numSigmaClip = Field( 

143 dtype=float, 

144 doc="Sigma for outlier rejection; ignored if annularFluxStatistic != 'MEANCLIP'.", 

145 default=4, 

146 ) 

147 numIter = Field( 

148 dtype=int, 

149 doc="Number of iterations of outlier rejection; ignored if annularFluxStatistic != 'MEANCLIP'.", 

150 default=3, 

151 ) 

152 badMaskPlanes = ListField( 

153 dtype=str, 

154 doc="Mask planes that identify pixels to not include in the computation of the annular flux.", 

155 default=("BAD", "CR", "CROSSTALK", "EDGE", "NO_DATA", "SAT", "SUSPECT", "UNMASKEDNAN"), 

156 ) 

157 minValidAnnulusFraction = Field( 

158 dtype=float, 

159 doc="Minumum number of valid pixels that must fall within the annulus for the bright star to be " 

160 "saved for subsequent generation of a PSF.", 

161 default=0.0, 

162 ) 

163 doApplySkyCorr = Field( 

164 dtype=bool, 

165 doc="Apply full focal plane sky correction before extracting stars?", 

166 default=True, 

167 ) 

168 discardNanFluxStars = Field( 

169 dtype=bool, 

170 doc="Should stars with NaN annular flux be discarded?", 

171 default=False, 

172 ) 

173 refObjLoader = ConfigField( 

174 dtype=LoadReferenceObjectsConfig, 

175 doc="Reference object loader for astrometric calibration.", 

176 ) 

177 

178 

179class ProcessBrightStarsTask(PipelineTask): 

180 """The description of the parameters for this Task are detailed in 

181 :lsst-task:`~lsst.pipe.base.PipelineTask`. 

182 

183 Parameters 

184 ---------- 

185 initInputs : `Unknown` 

186 *args 

187 Additional positional arguments. 

188 **kwargs 

189 Additional keyword arguments. 

190 

191 Notes 

192 ----- 

193 `ProcessBrightStarsTask` is used to extract, process, and store small 

194 image cut-outs (or "postage stamps") around bright stars. It relies on 

195 three methods, called in succession: 

196 

197 `extractStamps` 

198 Find bright stars within the exposure using a reference catalog and 

199 extract a stamp centered on each. 

200 `warpStamps` 

201 Shift and warp each stamp to remove optical distortions and sample all 

202 stars on the same pixel grid. 

203 `measureAndNormalize` 

204 Compute the flux of an object in an annulus and normalize it. This is 

205 required to normalize each bright star stamp as their central pixels 

206 are likely saturated and/or contain ghosts, and cannot be used. 

207 """ 

208 

209 ConfigClass = ProcessBrightStarsConfig 

210 _DefaultName = "processBrightStars" 

211 

212 def __init__(self, butler=None, initInputs=None, *args, **kwargs): 

213 super().__init__(*args, **kwargs) 

214 # Compute (model) stamp size depending on provided "buffer" value 

215 self.modelStampSize = [ 

216 int(self.config.stampSize[0] * self.config.modelStampBuffer), 

217 int(self.config.stampSize[1] * self.config.modelStampBuffer), 

218 ] 

219 # force it to be odd-sized so we have a central pixel 

220 if not self.modelStampSize[0] % 2: 

221 self.modelStampSize[0] += 1 

222 if not self.modelStampSize[1] % 2: 

223 self.modelStampSize[1] += 1 

224 # central pixel 

225 self.modelCenter = self.modelStampSize[0] // 2, self.modelStampSize[1] // 2 

226 # configure Gaia refcat 

227 if butler is not None: 

228 self.makeSubtask("refObjLoader", butler=butler) 

229 

230 def applySkyCorr(self, calexp, skyCorr): 

231 """Apply correction to the sky background level. 

232 

233 Sky corrections can be generated using the ``SkyCorrectionTask``. 

234 As the sky model generated there extends over the full focal plane, 

235 this should produce a more optimal sky subtraction solution. 

236 

237 Parameters 

238 ---------- 

239 calexp : `~lsst.afw.image.Exposure` or `~lsst.afw.image.MaskedImage` 

240 Calibrated exposure. 

241 skyCorr : `~lsst.afw.math.backgroundList.BackgroundList`, optional 

242 Full focal plane sky correction from ``SkyCorrectionTask``. 

243 

244 Notes 

245 ----- 

246 This method modifies the input ``calexp`` in-place. 

247 """ 

248 if isinstance(calexp, Exposure): 

249 calexp = calexp.getMaskedImage() 

250 calexp -= skyCorr.getImage() 

251 

252 def extractStamps(self, inputExposure, refObjLoader=None): 

253 """Read the position of bright stars within an input exposure using a 

254 refCat and extract them. 

255 

256 Parameters 

257 ---------- 

258 inputExposure : `~lsst.afw.image.ExposureF` 

259 The image from which bright star stamps should be extracted. 

260 refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional 

261 Loader to find objects within a reference catalog. 

262 

263 Returns 

264 ------- 

265 result : `~lsst.pipe.base.Struct` 

266 Results as a struct with attributes: 

267 

268 ``starIms`` 

269 Postage stamps (`list`). 

270 ``pixCenters`` 

271 Corresponding coords to each star's center, in pixels (`list`). 

272 ``GMags`` 

273 Corresponding (Gaia) G magnitudes (`list`). 

274 ``gaiaIds`` 

275 Corresponding unique Gaia identifiers (`np.ndarray`). 

276 """ 

277 if refObjLoader is None: 

278 refObjLoader = self.refObjLoader 

279 starIms = [] 

280 pixCenters = [] 

281 GMags = [] 

282 ids = [] 

283 wcs = inputExposure.getWcs() 

284 # select stars within, or close enough to input exposure from refcat 

285 inputIm = inputExposure.maskedImage 

286 inputExpBBox = inputExposure.getBBox() 

287 # Attempt to include stars that are outside of the exposure but their 

288 # stamps overlap with the exposure. 

289 dilatationExtent = Extent2I(np.array(self.config.stampSize) // 2) 

290 # TODO (DM-25894): handle catalog with stars missing from Gaia 

291 withinCalexp = refObjLoader.loadPixelBox( 

292 inputExpBBox.dilatedBy(dilatationExtent), 

293 wcs, 

294 filterName="phot_g_mean", 

295 ) 

296 refCat = withinCalexp.refCat 

297 # keep bright objects 

298 fluxLimit = ((self.config.magLimit * u.ABmag).to(u.nJy)).to_value() 

299 GFluxes = np.array(refCat["phot_g_mean_flux"]) 

300 bright = GFluxes > fluxLimit 

301 # convert to AB magnitudes 

302 allGMags = [((gFlux * u.nJy).to(u.ABmag)).to_value() for gFlux in GFluxes[bright]] 

303 allIds = refCat.columns.extract("id", where=bright)["id"] 

304 selectedColumns = refCat.columns.extract("coord_ra", "coord_dec", where=bright) 

305 for j, (ra, dec) in enumerate(zip(selectedColumns["coord_ra"], selectedColumns["coord_dec"])): 

306 sp = SpherePoint(ra, dec, radians) 

307 cpix = wcs.skyToPixel(sp) 

308 try: 

309 starIm = inputExposure.getCutout(sp, Extent2I(self.config.stampSize)) 

310 except InvalidParameterError: 

311 # star is beyond boundary 

312 bboxCorner = np.array(cpix) - np.array(self.config.stampSize) / 2 

313 # compute bbox as it would be otherwise 

314 idealBBox = Box2I(Point2I(bboxCorner), Extent2I(self.config.stampSize)) 

315 clippedStarBBox = Box2I(idealBBox) 

316 clippedStarBBox.clip(inputExpBBox) 

317 if clippedStarBBox.getArea() > 0: 

318 # create full-sized stamp with all pixels 

319 # flagged as NO_DATA 

320 starIm = ExposureF(bbox=idealBBox) 

321 starIm.image[:] = np.nan 

322 starIm.mask.set(inputExposure.mask.getPlaneBitMask("NO_DATA")) 

323 # recover pixels from intersection with the exposure 

324 clippedIm = inputIm.Factory(inputIm, clippedStarBBox) 

325 starIm.maskedImage[clippedStarBBox] = clippedIm 

326 # set detector and wcs, used in warpStars 

327 starIm.setDetector(inputExposure.getDetector()) 

328 starIm.setWcs(inputExposure.getWcs()) 

329 else: 

330 continue 

331 if self.config.doRemoveDetected: 

332 # give detection footprint of other objects the BAD flag 

333 detThreshold = Threshold(starIm.mask.getPlaneBitMask("DETECTED"), Threshold.BITMASK) 

334 omask = FootprintSet(starIm.mask, detThreshold) 

335 allFootprints = omask.getFootprints() 

336 otherFootprints = [] 

337 for fs in allFootprints: 

338 if not fs.contains(Point2I(cpix)): 

339 otherFootprints.append(fs) 

340 nbMatchingFootprints = len(allFootprints) - len(otherFootprints) 

341 if not nbMatchingFootprints == 1: 

342 self.log.warning( 

343 "Failed to uniquely identify central DETECTION footprint for star " 

344 "%s; found %d footprints instead.", 

345 allIds[j], 

346 nbMatchingFootprints, 

347 ) 

348 omask.setFootprints(otherFootprints) 

349 omask.setMask(starIm.mask, "BAD") 

350 starIms.append(starIm) 

351 pixCenters.append(cpix) 

352 GMags.append(allGMags[j]) 

353 ids.append(allIds[j]) 

354 return Struct(starIms=starIms, pixCenters=pixCenters, GMags=GMags, gaiaIds=ids) 

355 

356 def warpStamps(self, stamps, pixCenters): 

357 """Warps and shifts all given stamps so they are sampled on the same 

358 pixel grid and centered on the central pixel. This includes rotating 

359 the stamp depending on detector orientation. 

360 

361 Parameters 

362 ---------- 

363 stamps : `Sequence` [`~lsst.afw.image.ExposureF`] 

364 Image cutouts centered on a single object. 

365 pixCenters : `Sequence` [`~lsst.geom.Point2D`] 

366 Positions of each object's center (from the refCat) in pixels. 

367 

368 Returns 

369 ------- 

370 result : `~lsst.pipe.base.Struct` 

371 Results as a struct with attributes: 

372 

373 ``warpedStars`` 

374 Stamps of warped stars. 

375 (`list` [`~lsst.afw.image.MaskedImage`]) 

376 ``warpTransforms`` 

377 The corresponding Transform from the initial star stamp 

378 to the common model grid. 

379 (`list` [`~lsst.afw.geom.TransformPoint2ToPoint2`]) 

380 ``xy0s`` 

381 Coordinates of the bottom-left pixels of each stamp, 

382 before rotation. 

383 (`list` [`~lsst.geom.Point2I`]) 

384 ``nb90Rots`` 

385 The number of 90 degrees rotations required to compensate for 

386 detector orientation. 

387 (`int`) 

388 """ 

389 # warping control; only contains shiftingALg provided in config 

390 warpCont = WarpingControl(self.config.warpingKernelName) 

391 # Compare model to star stamp sizes 

392 bufferPix = ( 

393 self.modelStampSize[0] - self.config.stampSize[0], 

394 self.modelStampSize[1] - self.config.stampSize[1], 

395 ) 

396 # Initialize detector instance (note all stars were extracted from an 

397 # exposure from the same detector) 

398 det = stamps[0].getDetector() 

399 # Define correction for optical distortions 

400 if self.config.doApplyTransform: 

401 pixToTan = det.getTransform(PIXELS, TAN_PIXELS) 

402 else: 

403 pixToTan = makeIdentityTransform() 

404 # Array of all possible rotations for detector orientation: 

405 possibleRots = np.array([k * np.pi / 2 for k in range(4)]) 

406 # determine how many, if any, rotations are required 

407 yaw = det.getOrientation().getYaw() 

408 nb90Rots = np.argmin(np.abs(possibleRots - float(yaw))) 

409 

410 # apply transformation to each star 

411 warpedStars, warpTransforms, xy0s = [], [], [] 

412 for star, cent in zip(stamps, pixCenters): 

413 # (re)create empty destination image 

414 destImage = MaskedImageF(*self.modelStampSize) 

415 bottomLeft = Point2D(star.image.getXY0()) 

416 newBottomLeft = pixToTan.applyForward(bottomLeft) 

417 newBottomLeft.setX(newBottomLeft.getX() - bufferPix[0] / 2) 

418 newBottomLeft.setY(newBottomLeft.getY() - bufferPix[1] / 2) 

419 # Convert to int 

420 newBottomLeft = Point2I(newBottomLeft) 

421 # Set origin and save it 

422 destImage.setXY0(newBottomLeft) 

423 xy0s.append(newBottomLeft) 

424 

425 # Define linear shifting to recenter stamps 

426 newCenter = pixToTan.applyForward(cent) # center of warped star 

427 shift = ( 

428 self.modelCenter[0] + newBottomLeft[0] - newCenter[0], 

429 self.modelCenter[1] + newBottomLeft[1] - newCenter[1], 

430 ) 

431 affineShift = AffineTransform(shift) 

432 shiftTransform = makeTransform(affineShift) 

433 

434 # Define full transform (warp and shift) 

435 starWarper = pixToTan.then(shiftTransform) 

436 

437 # Apply it 

438 goodPix = warpImage(destImage, star.getMaskedImage(), starWarper, warpCont) 

439 if not goodPix: 

440 self.log.debug("Warping of a star failed: no good pixel in output") 

441 

442 # Arbitrarily set origin of shifted star to 0 

443 destImage.setXY0(0, 0) 

444 

445 # Apply rotation if appropriate 

446 if nb90Rots: 

447 destImage = rotateImageBy90(destImage, nb90Rots) 

448 warpedStars.append(destImage.clone()) 

449 warpTransforms.append(starWarper) 

450 return Struct(warpedStars=warpedStars, warpTransforms=warpTransforms, xy0s=xy0s, nb90Rots=nb90Rots) 

451 

452 @timeMethod 

453 def run(self, inputExposure, refObjLoader=None, dataId=None, skyCorr=None): 

454 """Identify bright stars within an exposure using a reference catalog, 

455 extract stamps around each, then preprocess them. The preprocessing 

456 steps are: shifting, warping and potentially rotating them to the same 

457 pixel grid; computing their annular flux and normalizing them. 

458 

459 Parameters 

460 ---------- 

461 inputExposure : `~lsst.afw.image.ExposureF` 

462 The image from which bright star stamps should be extracted. 

463 refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional 

464 Loader to find objects within a reference catalog. 

465 dataId : `dict` or `~lsst.daf.butler.DataCoordinate` 

466 The dataId of the exposure (and detector) bright stars should be 

467 extracted from. 

468 skyCorr : `~lsst.afw.math.backgroundList.BackgroundList`, optional 

469 Full focal plane sky correction obtained by `SkyCorrectionTask`. 

470 

471 Returns 

472 ------- 

473 result : `~lsst.pipe.base.Struct` 

474 Results as a struct with attributes: 

475 

476 ``brightStarStamps`` 

477 (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`) 

478 """ 

479 if self.config.doApplySkyCorr: 

480 self.log.info( 

481 "Applying sky correction to exposure %s (exposure will be modified in-place).", dataId 

482 ) 

483 self.applySkyCorr(inputExposure, skyCorr) 

484 self.log.info("Extracting bright stars from exposure %s", dataId) 

485 # Extract stamps around bright stars 

486 extractedStamps = self.extractStamps(inputExposure, refObjLoader=refObjLoader) 

487 if not extractedStamps.starIms: 

488 self.log.info("No suitable bright star found.") 

489 return None 

490 # Warp (and shift, and potentially rotate) them 

491 self.log.info( 

492 "Applying warp and/or shift to %i star stamps from exposure %s.", 

493 len(extractedStamps.starIms), 

494 dataId, 

495 ) 

496 warpOutputs = self.warpStamps(extractedStamps.starIms, extractedStamps.pixCenters) 

497 warpedStars = warpOutputs.warpedStars 

498 xy0s = warpOutputs.xy0s 

499 brightStarList = [ 

500 BrightStarStamp( 

501 stamp_im=warp, 

502 archive_element=transform, 

503 position=xy0s[j], 

504 gaiaGMag=extractedStamps.GMags[j], 

505 gaiaId=extractedStamps.gaiaIds[j], 

506 minValidAnnulusFraction=self.config.minValidAnnulusFraction, 

507 ) 

508 for j, (warp, transform) in enumerate(zip(warpedStars, warpOutputs.warpTransforms)) 

509 ] 

510 # Compute annularFlux and normalize 

511 self.log.info( 

512 "Computing annular flux and normalizing %i bright stars from exposure %s.", 

513 len(warpedStars), 

514 dataId, 

515 ) 

516 # annularFlux statistic set-up, excluding mask planes 

517 statsControl = StatisticsControl() 

518 statsControl.setNumSigmaClip(self.config.numSigmaClip) 

519 statsControl.setNumIter(self.config.numIter) 

520 innerRadius, outerRadius = self.config.annularFluxRadii 

521 statsFlag = stringToStatisticsProperty(self.config.annularFluxStatistic) 

522 brightStarStamps = BrightStarStamps.initAndNormalize( 

523 brightStarList, 

524 innerRadius=innerRadius, 

525 outerRadius=outerRadius, 

526 nb90Rots=warpOutputs.nb90Rots, 

527 imCenter=self.modelCenter, 

528 use_archive=True, 

529 statsControl=statsControl, 

530 statsFlag=statsFlag, 

531 badMaskPlanes=self.config.badMaskPlanes, 

532 discardNanFluxObjects=(self.config.discardNanFluxStars), 

533 ) 

534 return Struct(brightStarStamps=brightStarStamps) 

535 

536 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

537 inputs = butlerQC.get(inputRefs) 

538 inputs["dataId"] = str(butlerQC.quantum.dataId) 

539 refObjLoader = ReferenceObjectLoader( 

540 dataIds=[ref.datasetRef.dataId for ref in inputRefs.refCat], 

541 refCats=inputs.pop("refCat"), 

542 name=self.config.connections.refCat, 

543 config=self.config.refObjLoader, 

544 ) 

545 output = self.run(**inputs, refObjLoader=refObjLoader) 

546 if output: 

547 butlerQC.put(output, outputRefs)