Coverage for python/lsst/pipe/tasks/skyCorrection.py: 18%

172 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-10-27 02:40 -0700

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["SkyCorrectionTask", "SkyCorrectionConfig"] 

23 

24import numpy as np 

25 

26import lsst.afw.math as afwMath 

27import lsst.afw.image as afwImage 

28import lsst.pipe.base as pipeBase 

29 

30from lsst.afw.cameraGeom.utils import makeImageFromCamera 

31from lsst.daf.butler import DimensionGraph 

32from lsst.pex.config import Config, Field, ConfigurableField, ConfigField 

33import lsst.pipe.base.connectionTypes as cT 

34 

35from .background import (SkyMeasurementTask, FocalPlaneBackground, 

36 FocalPlaneBackgroundConfig, MaskObjectsTask) 

37 

38 

39def reorderAndPadList(inputList, inputKeys, outputKeys, padWith=None): 

40 """Match the order of one list to another, padding if necessary 

41 

42 Parameters 

43 ---------- 

44 inputList : list 

45 List to be reordered and padded. Elements can be any type. 

46 inputKeys : iterable 

47 Iterable of values to be compared with outputKeys. 

48 Length must match `inputList` 

49 outputKeys : iterable 

50 Iterable of values to be compared with inputKeys. 

51 padWith : 

52 Any value to be inserted where inputKey not in outputKeys 

53 

54 Returns 

55 ------- 

56 list 

57 Copy of inputList reordered per outputKeys and padded with `padWith` 

58 so that the length matches length of outputKeys. 

59 """ 

60 outputList = [] 

61 for d in outputKeys: 

62 if d in inputKeys: 

63 outputList.append(inputList[inputKeys.index(d)]) 

64 else: 

65 outputList.append(padWith) 

66 return outputList 

67 

68 

69def _makeCameraImage(camera, exposures, binning): 

70 """Make and write an image of an entire focal plane 

71 

72 Parameters 

73 ---------- 

74 camera : `lsst.afw.cameraGeom.Camera` 

75 Camera description. 

76 exposures : `dict` mapping detector ID to `lsst.afw.image.Exposure` 

77 CCD exposures, binned by `binning`. 

78 binning : `int` 

79 Binning size that has been applied to images. 

80 """ 

81 class ImageSource: 

82 """Source of images for makeImageFromCamera""" 

83 def __init__(self, exposures): 

84 """Constructor 

85 

86 Parameters 

87 ---------- 

88 exposures : `dict` mapping detector ID to `lsst.afw.image.Exposure` 

89 CCD exposures, already binned. 

90 """ 

91 self.isTrimmed = True 

92 self.exposures = exposures 

93 self.background = np.nan 

94 

95 def getCcdImage(self, detector, imageFactory, binSize): 

96 """Provide image of CCD to makeImageFromCamera""" 

97 detId = detector.getId() 

98 if detId not in self.exposures: 

99 dims = detector.getBBox().getDimensions()/binSize 

100 image = imageFactory(*[int(xx) for xx in dims]) 

101 image.set(self.background) 

102 else: 

103 image = self.exposures[detector.getId()] 

104 if hasattr(image, "getMaskedImage"): 

105 image = image.getMaskedImage() 

106 if hasattr(image, "getMask"): 

107 mask = image.getMask() 

108 isBad = mask.getArray() & mask.getPlaneBitMask("NO_DATA") > 0 

109 image = image.clone() 

110 image.getImage().getArray()[isBad] = self.background 

111 if hasattr(image, "getImage"): 

112 image = image.getImage() 

113 

114 image = afwMath.rotateImageBy90(image, detector.getOrientation().getNQuarter()) 

115 

116 return image, detector 

117 

118 image = makeImageFromCamera( 

119 camera, 

120 imageSource=ImageSource(exposures), 

121 imageFactory=afwImage.ImageF, 

122 binSize=binning 

123 ) 

124 return image 

125 

126 

127def makeCameraImage(camera, exposures, filename=None, binning=8): 

128 """Make and write an image of an entire focal plane 

129 

130 Parameters 

131 ---------- 

132 camera : `lsst.afw.cameraGeom.Camera` 

133 Camera description. 

134 exposures : `list` of `tuple` of `int` and `lsst.afw.image.Exposure` 

135 List of detector ID and CCD exposures (binned by `binning`). 

136 filename : `str`, optional 

137 Output filename. 

138 binning : `int` 

139 Binning size that has been applied to images. 

140 """ 

141 image = _makeCameraImage(camera, dict(exp for exp in exposures if exp is not None), binning) 

142 if filename is not None: 

143 image.writeFits(filename) 

144 return image 

145 

146 

147def _skyLookup(datasetType, registry, quantumDataId, collections): 

148 """Lookup function to identify sky frames 

149 

150 Parameters 

151 ---------- 

152 datasetType : `lsst.daf.butler.DatasetType` 

153 Dataset to lookup. 

154 registry : `lsst.daf.butler.Registry` 

155 Butler registry to query. 

156 quantumDataId : `lsst.daf.butler.DataCoordinate` 

157 Data id to transform to find sky frames. 

158 The ``detector`` entry will be stripped. 

159 collections : `lsst.daf.butler.CollectionSearch` 

160 Collections to search through. 

161 

162 Returns 

163 ------- 

164 results : `list` [`lsst.daf.butler.DatasetRef`] 

165 List of datasets that will be used as sky calibration frames 

166 """ 

167 newDataId = quantumDataId.subset(DimensionGraph(registry.dimensions, names=["instrument", "visit"])) 

168 skyFrames = [] 

169 for dataId in registry.queryDataIds(["visit", "detector"], dataId=newDataId).expanded(): 

170 skyFrame = registry.findDataset(datasetType, dataId, collections=collections, 

171 timespan=dataId.timespan) 

172 skyFrames.append(skyFrame) 

173 

174 return skyFrames 

175 

176 

177class SkyCorrectionConnections(pipeBase.PipelineTaskConnections, dimensions=("instrument", "visit")): 

178 rawLinker = cT.Input( 

179 doc="Raw data to provide exp-visit linkage to connect calExp inputs to camera/sky calibs.", 

180 name="raw", 

181 multiple=True, 

182 deferLoad=True, 

183 storageClass="Exposure", 

184 dimensions=["instrument", "exposure", "detector"], 

185 ) 

186 calExpArray = cT.Input( 

187 doc="Input exposures to process", 

188 name="calexp", 

189 multiple=True, 

190 storageClass="ExposureF", 

191 dimensions=["instrument", "visit", "detector"], 

192 ) 

193 calBkgArray = cT.Input( 

194 doc="Input background files to use", 

195 multiple=True, 

196 name="calexpBackground", 

197 storageClass="Background", 

198 dimensions=["instrument", "visit", "detector"], 

199 ) 

200 camera = cT.PrerequisiteInput( 

201 doc="Input camera to use.", 

202 name="camera", 

203 storageClass="Camera", 

204 dimensions=["instrument"], 

205 isCalibration=True, 

206 ) 

207 skyCalibs = cT.PrerequisiteInput( 

208 doc="Input sky calibrations to use.", 

209 name="sky", 

210 multiple=True, 

211 storageClass="ExposureF", 

212 dimensions=["instrument", "physical_filter", "detector"], 

213 isCalibration=True, 

214 lookupFunction=_skyLookup, 

215 ) 

216 calExpCamera = cT.Output( 

217 doc="Output camera image.", 

218 name='calexp_camera', 

219 storageClass="ImageF", 

220 dimensions=["instrument", "visit"], 

221 ) 

222 skyCorr = cT.Output( 

223 doc="Output sky corrected images.", 

224 name='skyCorr', 

225 multiple=True, 

226 storageClass="Background", 

227 dimensions=["instrument", "visit", "detector"], 

228 ) 

229 

230 

231class SkyCorrectionConfig(pipeBase.PipelineTaskConfig, pipelineConnections=SkyCorrectionConnections): 

232 """Configuration for SkyCorrectionTask""" 

233 bgModel = ConfigField(dtype=FocalPlaneBackgroundConfig, doc="Background model") 

234 bgModel2 = ConfigField(dtype=FocalPlaneBackgroundConfig, doc="2nd Background model") 

235 sky = ConfigurableField(target=SkyMeasurementTask, doc="Sky measurement") 

236 maskObjects = ConfigurableField(target=MaskObjectsTask, doc="Mask Objects") 

237 doMaskObjects = Field(dtype=bool, default=True, doc="Mask objects to find good sky?") 

238 doBgModel = Field(dtype=bool, default=True, doc="Do background model subtraction?") 

239 doBgModel2 = Field(dtype=bool, default=True, doc="Do cleanup background model subtraction?") 

240 doSky = Field(dtype=bool, default=True, doc="Do sky frame subtraction?") 

241 binning = Field(dtype=int, default=8, doc="Binning factor for constructing focal-plane images") 

242 calexpType = Field(dtype=str, default="calexp", 

243 doc="Should be set to fakes_calexp if you want to process calexps with fakes in.") 

244 

245 def setDefaults(self): 

246 Config.setDefaults(self) 

247 self.bgModel2.doSmooth = True 

248 self.bgModel2.minFrac = 0.5 

249 self.bgModel2.xSize = 256 

250 self.bgModel2.ySize = 256 

251 self.bgModel2.smoothScale = 1.0 

252 

253 

254class SkyCorrectionTask(pipeBase.PipelineTask): 

255 """Correct sky over entire focal plane""" 

256 ConfigClass = SkyCorrectionConfig 

257 _DefaultName = "skyCorr" 

258 

259 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

260 

261 # Reorder the skyCalibs, calBkgArray, and calExpArray inputRefs and the 

262 # skyCorr outputRef sorted by detector id to ensure reproducibility. 

263 detectorOrder = [ref.dataId['detector'] for ref in inputRefs.calExpArray] 

264 detectorOrder.sort() 

265 inputRefs.calExpArray = reorderAndPadList(inputRefs.calExpArray, 

266 [ref.dataId['detector'] for ref in inputRefs.calExpArray], 

267 detectorOrder) 

268 inputRefs.skyCalibs = reorderAndPadList(inputRefs.skyCalibs, 

269 [ref.dataId['detector'] for ref in inputRefs.skyCalibs], 

270 detectorOrder) 

271 inputRefs.calBkgArray = reorderAndPadList(inputRefs.calBkgArray, 

272 [ref.dataId['detector'] for ref in inputRefs.calBkgArray], 

273 detectorOrder) 

274 outputRefs.skyCorr = reorderAndPadList(outputRefs.skyCorr, 

275 [ref.dataId['detector'] for ref in outputRefs.skyCorr], 

276 detectorOrder) 

277 inputs = butlerQC.get(inputRefs) 

278 inputs.pop("rawLinker", None) 

279 outputs = self.run(**inputs) 

280 butlerQC.put(outputs, outputRefs) 

281 

282 def __init__(self, *args, **kwargs): 

283 super().__init__(**kwargs) 

284 

285 self.makeSubtask("sky") 

286 self.makeSubtask("maskObjects") 

287 

288 def focalPlaneBackgroundRun(self, camera, cacheExposures, idList, config): 

289 """Perform full focal-plane background subtraction 

290 

291 This method runs on the master node. 

292 

293 Parameters 

294 ---------- 

295 camera : `lsst.afw.cameraGeom.Camera` 

296 Camera description. 

297 cacheExposures : `list` of `lsst.afw.image.Exposures` 

298 List of loaded and processed input calExp. 

299 idList : `list` of `int` 

300 List of detector ids to iterate over. 

301 config : `lsst.pipe.drivers.background.FocalPlaneBackgroundConfig` 

302 Configuration to use for background subtraction. 

303 

304 Returns 

305 ------- 

306 exposures : `list` of `lsst.afw.image.Image` 

307 List of binned images, for creating focal plane image. 

308 newCacheBgList : `list` of `lsst.afwMath.backgroundList` 

309 Background lists generated. 

310 cacheBgModel : `FocalPlaneBackground` 

311 Full focal plane background model. 

312 """ 

313 bgModel = FocalPlaneBackground.fromCamera(config, camera) 

314 data = [pipeBase.Struct(id=id, bgModel=bgModel.clone()) for id in idList] 

315 

316 bgModelList = [] 

317 for nodeData, cacheExp in zip(data, cacheExposures): 

318 nodeData.bgModel.addCcd(cacheExp) 

319 bgModelList.append(nodeData.bgModel) 

320 

321 for ii, bg in enumerate(bgModelList): 

322 self.log.info("Background %d: %d pixels", ii, bg._numbers.getArray().sum()) 

323 bgModel.merge(bg) 

324 

325 exposures = [] 

326 newCacheBgList = [] 

327 cacheBgModel = [] 

328 for cacheExp in cacheExposures: 

329 nodeExp, nodeBgModel, nodeBgList = self.subtractModelRun(cacheExp, bgModel) 

330 exposures.append(afwMath.binImage(nodeExp.getMaskedImage(), self.config.binning)) 

331 cacheBgModel.append(nodeBgModel) 

332 newCacheBgList.append(nodeBgList) 

333 

334 return exposures, newCacheBgList, cacheBgModel 

335 

336 def run(self, calExpArray, calBkgArray, skyCalibs, camera): 

337 """Performa sky correction on an exposure. 

338 

339 Parameters 

340 ---------- 

341 calExpArray : `list` of `lsst.afw.image.Exposure` 

342 Array of detector input calExp images for the exposure to 

343 process. 

344 calBkgArray : `list` of `lsst.afw.math.BackgroundList` 

345 Array of detector input background lists matching the 

346 calExps to process. 

347 skyCalibs : `list` of `lsst.afw.image.Exposure` 

348 Array of SKY calibrations for the input detectors to be 

349 processed. 

350 camera : `lsst.afw.cameraGeom.Camera` 

351 Camera matching the input data to process. 

352 

353 Returns 

354 ------- 

355 results : `pipeBase.Struct` containing 

356 calExpCamera : `lsst.afw.image.Exposure` 

357 Full camera image of the sky-corrected data. 

358 skyCorr : `list` of `lsst.afw.math.BackgroundList` 

359 Detector-level sky-corrected background lists. 

360 """ 

361 # To allow SkyCorrectionTask to run in the Gen3 butler 

362 # environment, a new run() method was added that performs the 

363 # same operations in a serial environment (pipetask processing 

364 # does not support MPI processing as of 2019-05-03). Methods 

365 # used in runDataRef() are used as appropriate in run(), but 

366 # some have been rewritten in serial form. Please ensure that 

367 # any updates to runDataRef() or the methods it calls with 

368 # pool.mapToPrevious() are duplicated in run() and its 

369 # methods. 

370 # 

371 # Variable names here should match those in runDataRef() as 

372 # closely as possible. Variables matching data stored in the 

373 # pool cache have a prefix indicating this. Variables that 

374 # would be local to an MPI processing client have a prefix 

375 # "node". 

376 idList = [exp.getDetector().getId() for exp in calExpArray] 

377 

378 # Construct arrays that match the cache in self.runDataRef() after 

379 # self.loadImage() is map/reduced. 

380 cacheExposures = [] 

381 cacheBgList = [] 

382 exposures = [] 

383 for calExp, calBgModel in zip(calExpArray, calBkgArray): 

384 nodeExp, nodeBgList = self.loadImageRun(calExp, calBgModel) 

385 cacheExposures.append(nodeExp) 

386 cacheBgList.append(nodeBgList) 

387 exposures.append(afwMath.binImage(nodeExp.getMaskedImage(), self.config.binning)) 

388 

389 if self.config.doBgModel: 

390 # Generate focal plane background, updating backgrounds in the "cache". 

391 exposures, newCacheBgList, cacheBgModel = self.focalPlaneBackgroundRun( 

392 camera, cacheExposures, idList, self.config.bgModel 

393 ) 

394 for cacheBg, newBg in zip(cacheBgList, newCacheBgList): 

395 cacheBg.append(newBg) 

396 

397 if self.config.doSky: 

398 # Measure the sky frame scale on all inputs. Results in 

399 # values equal to self.measureSkyFrame() and 

400 # self.sky.solveScales() in runDataRef(). 

401 cacheSky = [] 

402 measScales = [] 

403 for cacheExp, skyCalib in zip(cacheExposures, skyCalibs): 

404 skyExp = self.sky.exposureToBackground(skyCalib) 

405 cacheSky.append(skyExp) 

406 scale = self.sky.measureScale(cacheExp.getMaskedImage(), skyExp) 

407 measScales.append(scale) 

408 

409 scale = self.sky.solveScales(measScales) 

410 self.log.info("Sky frame scale: %s" % (scale, )) 

411 

412 # Subtract sky frame, as in self.subtractSkyFrame(), with 

413 # appropriate scale from the "cache". 

414 exposures = [] 

415 newBgList = [] 

416 for cacheExp, nodeSky, nodeBgList in zip(cacheExposures, cacheSky, cacheBgList): 

417 self.sky.subtractSkyFrame(cacheExp.getMaskedImage(), nodeSky, scale, nodeBgList) 

418 exposures.append(afwMath.binImage(cacheExp.getMaskedImage(), self.config.binning)) 

419 

420 if self.config.doBgModel2: 

421 # As above, generate a focal plane background model and 

422 # update the cache models. 

423 exposures, newBgList, cacheBgModel = self.focalPlaneBackgroundRun( 

424 camera, cacheExposures, idList, self.config.bgModel2 

425 ) 

426 for cacheBg, newBg in zip(cacheBgList, newBgList): 

427 cacheBg.append(newBg) 

428 

429 # Generate camera-level image of calexp and return it along 

430 # with the list of sky corrected background models. 

431 image = makeCameraImage(camera, zip(idList, exposures)) 

432 

433 return pipeBase.Struct( 

434 calExpCamera=image, 

435 skyCorr=cacheBgList, 

436 ) 

437 

438 def loadImageRun(self, calExp, calExpBkg): 

439 """Serial implementation of self.loadImage() for Gen3. 

440 

441 Load and restore background to calExp and calExpBkg. 

442 

443 Parameters 

444 ---------- 

445 calExp : `lsst.afw.image.Exposure` 

446 Detector level calExp image to process. 

447 calExpBkg : `lsst.afw.math.BackgroundList` 

448 Detector level background list associated with the calExp. 

449 

450 Returns 

451 ------- 

452 calExp : `lsst.afw.image.Exposure` 

453 Background restored calExp. 

454 bgList : `lsst.afw.math.BackgroundList` 

455 New background list containing the restoration background. 

456 """ 

457 image = calExp.getMaskedImage() 

458 

459 for bgOld in calExpBkg: 

460 statsImage = bgOld[0].getStatsImage() 

461 statsImage *= -1 

462 

463 image -= calExpBkg.getImage() 

464 bgList = afwMath.BackgroundList() 

465 for bgData in calExpBkg: 

466 bgList.append(bgData) 

467 

468 if self.config.doMaskObjects: 

469 self.maskObjects.findObjects(calExp) 

470 

471 return (calExp, bgList) 

472 

473 def subtractModelRun(self, exposure, bgModel): 

474 """Serial implementation of self.subtractModel() for Gen3. 

475 

476 Load and restore background to calExp and calExpBkg. 

477 

478 Parameters 

479 ---------- 

480 exposure : `lsst.afw.image.Exposure` 

481 Exposure to subtract the background model from. 

482 bgModel : `lsst.pipe.drivers.background.FocalPlaneBackground` 

483 Full camera level background model. 

484 

485 Returns 

486 ------- 

487 exposure : `lsst.afw.image.Exposure` 

488 Background subtracted input exposure. 

489 bgModelCcd : `lsst.afw.math.BackgroundList` 

490 Detector level realization of the full background model. 

491 bgModelMaskedImage : `lsst.afw.image.MaskedImage` 

492 Background model from the bgModelCcd realization. 

493 """ 

494 image = exposure.getMaskedImage() 

495 detector = exposure.getDetector() 

496 bbox = image.getBBox() 

497 bgModelCcd = bgModel.toCcdBackground(detector, bbox) 

498 image -= bgModelCcd.getImage() 

499 

500 return (exposure, bgModelCcd, bgModelCcd[0])