Coverage for python/lsst/ip/diffim/getTemplate.py: 21%

176 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-01 03:35 -0700

1# This file is part of ip_diffim. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21import numpy as np 

22 

23import lsst.afw.image as afwImage 

24import lsst.geom as geom 

25import lsst.afw.geom as afwGeom 

26import lsst.afw.table as afwTable 

27import lsst.afw.math as afwMath 

28import lsst.pex.config as pexConfig 

29import lsst.pipe.base as pipeBase 

30from lsst.skymap import BaseSkyMap 

31from lsst.ip.diffim.dcrModel import DcrModel 

32from lsst.meas.algorithms import CoaddPsf, CoaddPsfConfig 

33 

34__all__ = ["GetTemplateTask", "GetTemplateConfig", 

35 "GetDcrTemplateTask", "GetDcrTemplateConfig"] 

36 

37 

38class GetTemplateConnections(pipeBase.PipelineTaskConnections, 

39 dimensions=("instrument", "visit", "detector", "skymap"), 

40 defaultTemplates={"coaddName": "goodSeeing", 

41 "warpTypeSuffix": "", 

42 "fakesType": ""}): 

43 bbox = pipeBase.connectionTypes.Input( 

44 doc="BBoxes of calexp used determine geometry of output template", 

45 name="{fakesType}calexp.bbox", 

46 storageClass="Box2I", 

47 dimensions=("instrument", "visit", "detector"), 

48 ) 

49 wcs = pipeBase.connectionTypes.Input( 

50 doc="WCS of the calexp that we want to fetch the template for", 

51 name="{fakesType}calexp.wcs", 

52 storageClass="Wcs", 

53 dimensions=("instrument", "visit", "detector"), 

54 ) 

55 skyMap = pipeBase.connectionTypes.Input( 

56 doc="Input definition of geometry/bbox and projection/wcs for template exposures", 

57 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

58 dimensions=("skymap", ), 

59 storageClass="SkyMap", 

60 ) 

61 # TODO DM-31292: Add option to use global external wcs from jointcal 

62 # Needed for DRP HSC 

63 coaddExposures = pipeBase.connectionTypes.Input( 

64 doc="Input template to match and subtract from the exposure", 

65 dimensions=("tract", "patch", "skymap", "band"), 

66 storageClass="ExposureF", 

67 name="{fakesType}{coaddName}Coadd{warpTypeSuffix}", 

68 multiple=True, 

69 deferLoad=True 

70 ) 

71 template = pipeBase.connectionTypes.Output( 

72 doc="Warped template used to create `subtractedExposure`.", 

73 dimensions=("instrument", "visit", "detector"), 

74 storageClass="ExposureF", 

75 name="{fakesType}{coaddName}Diff_templateExp{warpTypeSuffix}", 

76 ) 

77 

78 

79class GetTemplateConfig(pipeBase.PipelineTaskConfig, 

80 pipelineConnections=GetTemplateConnections): 

81 templateBorderSize = pexConfig.Field( 

82 dtype=int, 

83 default=20, 

84 doc="Number of pixels to grow the requested template image to account for warping" 

85 ) 

86 warp = pexConfig.ConfigField( 

87 dtype=afwMath.Warper.ConfigClass, 

88 doc="warper configuration", 

89 ) 

90 coaddPsf = pexConfig.ConfigField( 

91 doc="Configuration for CoaddPsf", 

92 dtype=CoaddPsfConfig, 

93 ) 

94 

95 def setDefaults(self): 

96 self.warp.warpingKernelName = 'lanczos5' 

97 self.coaddPsf.warpingKernelName = 'lanczos5' 

98 

99 

100class GetTemplateTask(pipeBase.PipelineTask): 

101 ConfigClass = GetTemplateConfig 

102 _DefaultName = "getTemplate" 

103 

104 def __init__(self, *args, **kwargs): 

105 super().__init__(*args, **kwargs) 

106 self.warper = afwMath.Warper.fromConfig(self.config.warp) 

107 

108 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

109 # Read in all inputs. 

110 inputs = butlerQC.get(inputRefs) 

111 results = self.getOverlappingExposures(inputs) 

112 inputs["coaddExposures"] = results.coaddExposures 

113 inputs["dataIds"] = results.dataIds 

114 inputs["physical_filter"] = butlerQC.quantum.dataId["physical_filter"] 

115 outputs = self.run(**inputs) 

116 butlerQC.put(outputs, outputRefs) 

117 

118 def getOverlappingExposures(self, inputs): 

119 """Return lists of coadds and their corresponding dataIds that overlap 

120 the detector. 

121 

122 The spatial index in the registry has generous padding and often 

123 supplies patches near, but not directly overlapping the detector. 

124 Filters inputs so that we don't have to read in all input coadds. 

125 

126 Parameters 

127 ---------- 

128 inputs : `dict` of task Inputs, containing: 

129 - coaddExposureRefs : `list` 

130 [`lsst.daf.butler.DeferredDatasetHandle` of 

131 `lsst.afw.image.Exposure`] 

132 Data references to exposures that might overlap the detector. 

133 - bbox : `lsst.geom.Box2I` 

134 Template Bounding box of the detector geometry onto which to 

135 resample the coaddExposures. 

136 - skyMap : `lsst.skymap.SkyMap` 

137 Input definition of geometry/bbox and projection/wcs for 

138 template exposures. 

139 - wcs : `lsst.afw.geom.SkyWcs` 

140 Template WCS onto which to resample the coaddExposures. 

141 

142 Returns 

143 ------- 

144 result : `lsst.pipe.base.Struct` 

145 A struct with attributes: 

146 

147 ``coaddExposures`` 

148 List of Coadd exposures that overlap the detector (`list` 

149 [`lsst.afw.image.Exposure`]). 

150 ``dataIds`` 

151 List of data IDs of the coadd exposures that overlap the 

152 detector (`list` [`lsst.daf.butler.DataCoordinate`]). 

153 

154 Raises 

155 ------ 

156 NoWorkFound 

157 Raised if no patches overlap the input detector bbox. 

158 """ 

159 # Check that the patches actually overlap the detector 

160 # Exposure's validPolygon would be more accurate 

161 detectorPolygon = geom.Box2D(inputs['bbox']) 

162 overlappingArea = 0 

163 coaddExposureList = [] 

164 dataIds = [] 

165 for coaddRef in inputs['coaddExposures']: 

166 dataId = coaddRef.dataId 

167 patchWcs = inputs['skyMap'][dataId['tract']].getWcs() 

168 patchBBox = inputs['skyMap'][dataId['tract']][dataId['patch']].getOuterBBox() 

169 patchCorners = patchWcs.pixelToSky(geom.Box2D(patchBBox).getCorners()) 

170 inputsWcs = inputs['wcs'] 

171 if inputsWcs is not None: 

172 patchPolygon = afwGeom.Polygon(inputsWcs.skyToPixel(patchCorners)) 

173 if patchPolygon.intersection(detectorPolygon): 

174 overlappingArea += patchPolygon.intersectionSingle(detectorPolygon).calculateArea() 

175 self.log.info("Using template input tract=%s, patch=%s" % 

176 (dataId['tract'], dataId['patch'])) 

177 coaddExposureList.append(coaddRef.get()) 

178 dataIds.append(dataId) 

179 else: 

180 self.log.info("Exposure has no WCS, so cannot create associated template.") 

181 

182 if not overlappingArea: 

183 raise pipeBase.NoWorkFound('No patches overlap detector') 

184 

185 return pipeBase.Struct(coaddExposures=coaddExposureList, 

186 dataIds=dataIds) 

187 

188 def run(self, coaddExposures, bbox, wcs, dataIds, physical_filter=None, **kwargs): 

189 """Warp coadds from multiple tracts to form a template for image diff. 

190 

191 Where the tracts overlap, the resulting template image is averaged. 

192 The PSF on the template is created by combining the CoaddPsf on each 

193 template image into a meta-CoaddPsf. 

194 

195 Parameters 

196 ---------- 

197 coaddExposures : `list` [`lsst.afw.image.Exposure`] 

198 Coadds to be mosaicked. 

199 bbox : `lsst.geom.Box2I` 

200 Template Bounding box of the detector geometry onto which to 

201 resample the ``coaddExposures``. 

202 wcs : `lsst.afw.geom.SkyWcs` 

203 Template WCS onto which to resample the ``coaddExposures``. 

204 dataIds : `list` [`lsst.daf.butler.DataCoordinate`] 

205 Record of the tract and patch of each coaddExposure. 

206 physical_filter : `str`, optional 

207 The physical filter of the science image. 

208 **kwargs 

209 Any additional keyword parameters. 

210 

211 Returns 

212 ------- 

213 result : `lsst.pipe.base.Struct` 

214 A struct with attributes: 

215 

216 ``template`` 

217 A template coadd exposure assembled out of patches 

218 (`lsst.afw.image.ExposureF`). 

219 

220 Raises 

221 ------ 

222 NoWorkFound 

223 If no coadds are found with sufficient un-masked pixels. 

224 RuntimeError 

225 If the PSF of the template can't be calculated. 

226 """ 

227 # Table for CoaddPSF 

228 tractsSchema = afwTable.ExposureTable.makeMinimalSchema() 

229 tractKey = tractsSchema.addField('tract', type=np.int32, doc='Which tract') 

230 patchKey = tractsSchema.addField('patch', type=np.int32, doc='Which patch') 

231 weightKey = tractsSchema.addField('weight', type=float, doc='Weight for each tract, should be 1') 

232 tractsCatalog = afwTable.ExposureCatalog(tractsSchema) 

233 

234 finalWcs = wcs 

235 bbox.grow(self.config.templateBorderSize) 

236 finalBBox = bbox 

237 

238 nPatchesFound = 0 

239 maskedImageList = [] 

240 weightList = [] 

241 

242 for coaddExposure, dataId in zip(coaddExposures, dataIds): 

243 

244 # warp to detector WCS 

245 warped = self.warper.warpExposure(finalWcs, coaddExposure, maxBBox=finalBBox) 

246 

247 # Check if warped image is viable 

248 if not np.any(np.isfinite(warped.image.array)): 

249 self.log.info("No overlap for warped %s. Skipping" % dataId) 

250 continue 

251 

252 exp = afwImage.ExposureF(finalBBox, finalWcs) 

253 exp.maskedImage.set(np.nan, afwImage.Mask.getPlaneBitMask("NO_DATA"), np.nan) 

254 exp.maskedImage.assign(warped.maskedImage, warped.getBBox()) 

255 

256 maskedImageList.append(exp.maskedImage) 

257 weightList.append(1) 

258 record = tractsCatalog.addNew() 

259 record.setPsf(coaddExposure.getPsf()) 

260 record.setWcs(coaddExposure.getWcs()) 

261 record.setPhotoCalib(coaddExposure.getPhotoCalib()) 

262 record.setBBox(coaddExposure.getBBox()) 

263 record.setValidPolygon(afwGeom.Polygon(geom.Box2D(coaddExposure.getBBox()).getCorners())) 

264 record.set(tractKey, dataId['tract']) 

265 record.set(patchKey, dataId['patch']) 

266 record.set(weightKey, 1.) 

267 nPatchesFound += 1 

268 

269 if nPatchesFound == 0: 

270 raise pipeBase.NoWorkFound("No patches found to overlap detector") 

271 

272 # Combine images from individual patches together 

273 statsFlags = afwMath.stringToStatisticsProperty('MEAN') 

274 statsCtrl = afwMath.StatisticsControl() 

275 statsCtrl.setNanSafe(True) 

276 statsCtrl.setWeighted(True) 

277 statsCtrl.setCalcErrorMosaicMode(True) 

278 

279 templateExposure = afwImage.ExposureF(finalBBox, finalWcs) 

280 templateExposure.maskedImage.set(np.nan, afwImage.Mask.getPlaneBitMask("NO_DATA"), np.nan) 

281 xy0 = templateExposure.getXY0() 

282 # Do not mask any values 

283 templateExposure.maskedImage = afwMath.statisticsStack(maskedImageList, statsFlags, statsCtrl, 

284 weightList, clipped=0, maskMap=[]) 

285 templateExposure.maskedImage.setXY0(xy0) 

286 

287 # CoaddPsf centroid not only must overlap image, but must overlap the 

288 # part of image with data. Use centroid of region with data. 

289 boolmask = templateExposure.mask.array & templateExposure.mask.getPlaneBitMask('NO_DATA') == 0 

290 maskx = afwImage.makeMaskFromArray(boolmask.astype(afwImage.MaskPixel)) 

291 centerCoord = afwGeom.SpanSet.fromMask(maskx, 1).computeCentroid() 

292 

293 ctrl = self.config.coaddPsf.makeControl() 

294 coaddPsf = CoaddPsf(tractsCatalog, finalWcs, centerCoord, ctrl.warpingKernelName, ctrl.cacheSize) 

295 if coaddPsf is None: 

296 raise RuntimeError("CoaddPsf could not be constructed") 

297 

298 templateExposure.setPsf(coaddPsf) 

299 # Coadds do not have a physical filter, so fetch it from the butler to prevent downstream warnings. 

300 if physical_filter is None: 

301 filterLabel = coaddExposure.getFilter() 

302 else: 

303 filterLabel = afwImage.FilterLabel(dataId['band'], physical_filter) 

304 templateExposure.setFilter(filterLabel) 

305 templateExposure.setPhotoCalib(coaddExposure.getPhotoCalib()) 

306 return pipeBase.Struct(template=templateExposure) 

307 

308 

309class GetDcrTemplateConnections(GetTemplateConnections, 

310 dimensions=("instrument", "visit", "detector", "skymap"), 

311 defaultTemplates={"coaddName": "dcr", 

312 "warpTypeSuffix": "", 

313 "fakesType": ""}): 

314 visitInfo = pipeBase.connectionTypes.Input( 

315 doc="VisitInfo of calexp used to determine observing conditions.", 

316 name="{fakesType}calexp.visitInfo", 

317 storageClass="VisitInfo", 

318 dimensions=("instrument", "visit", "detector"), 

319 ) 

320 dcrCoadds = pipeBase.connectionTypes.Input( 

321 doc="Input DCR template to match and subtract from the exposure", 

322 name="{fakesType}dcrCoadd{warpTypeSuffix}", 

323 storageClass="ExposureF", 

324 dimensions=("tract", "patch", "skymap", "band", "subfilter"), 

325 multiple=True, 

326 deferLoad=True 

327 ) 

328 

329 def __init__(self, *, config=None): 

330 super().__init__(config=config) 

331 self.inputs.remove("coaddExposures") 

332 

333 

334class GetDcrTemplateConfig(GetTemplateConfig, 

335 pipelineConnections=GetDcrTemplateConnections): 

336 numSubfilters = pexConfig.Field( 

337 doc="Number of subfilters in the DcrCoadd.", 

338 dtype=int, 

339 default=3, 

340 ) 

341 effectiveWavelength = pexConfig.Field( 

342 doc="Effective wavelength of the filter.", 

343 optional=False, 

344 dtype=float, 

345 ) 

346 bandwidth = pexConfig.Field( 

347 doc="Bandwidth of the physical filter.", 

348 optional=False, 

349 dtype=float, 

350 ) 

351 

352 def validate(self): 

353 if self.effectiveWavelength is None or self.bandwidth is None: 

354 raise ValueError("The effective wavelength and bandwidth of the physical filter " 

355 "must be set in the getTemplate config for DCR coadds. " 

356 "Required until transmission curves are used in DM-13668.") 

357 

358 

359class GetDcrTemplateTask(GetTemplateTask): 

360 ConfigClass = GetDcrTemplateConfig 

361 _DefaultName = "getDcrTemplate" 

362 

363 def getOverlappingExposures(self, inputs): 

364 """Return lists of coadds and their corresponding dataIds that overlap 

365 the detector. 

366 

367 The spatial index in the registry has generous padding and often 

368 supplies patches near, but not directly overlapping the detector. 

369 Filters inputs so that we don't have to read in all input coadds. 

370 

371 Parameters 

372 ---------- 

373 inputs : `dict` of task Inputs, containing: 

374 - coaddExposureRefs : `list` 

375 [`lsst.daf.butler.DeferredDatasetHandle` of 

376 `lsst.afw.image.Exposure`] 

377 Data references to exposures that might overlap the detector. 

378 - bbox : `lsst.geom.Box2I` 

379 Template Bounding box of the detector geometry onto which to 

380 resample the coaddExposures. 

381 - skyMap : `lsst.skymap.SkyMap` 

382 Input definition of geometry/bbox and projection/wcs for 

383 template exposures. 

384 - wcs : `lsst.afw.geom.SkyWcs` 

385 Template WCS onto which to resample the coaddExposures. 

386 - visitInfo : `lsst.afw.image.VisitInfo` 

387 Metadata for the science image. 

388 

389 Returns 

390 ------- 

391 result : `lsst.pipe.base.Struct` 

392 A struct with attibutes: 

393 

394 ``coaddExposures`` 

395 Coadd exposures that overlap the detector (`list` 

396 [`lsst.afw.image.Exposure`]). 

397 ``dataIds`` 

398 Data IDs of the coadd exposures that overlap the detector 

399 (`list` [`lsst.daf.butler.DataCoordinate`]). 

400 

401 Raises 

402 ------ 

403 NoWorkFound 

404 Raised if no patches overlatp the input detector bbox. 

405 """ 

406 # Check that the patches actually overlap the detector 

407 # Exposure's validPolygon would be more accurate 

408 detectorPolygon = geom.Box2D(inputs["bbox"]) 

409 overlappingArea = 0 

410 coaddExposureRefList = [] 

411 dataIds = [] 

412 patchList = dict() 

413 for coaddRef in inputs["dcrCoadds"]: 

414 dataId = coaddRef.dataId 

415 patchWcs = inputs["skyMap"][dataId['tract']].getWcs() 

416 patchBBox = inputs["skyMap"][dataId['tract']][dataId['patch']].getOuterBBox() 

417 patchCorners = patchWcs.pixelToSky(geom.Box2D(patchBBox).getCorners()) 

418 patchPolygon = afwGeom.Polygon(inputs["wcs"].skyToPixel(patchCorners)) 

419 if patchPolygon.intersection(detectorPolygon): 

420 overlappingArea += patchPolygon.intersectionSingle(detectorPolygon).calculateArea() 

421 self.log.info("Using template input tract=%s, patch=%s, subfilter=%s" % 

422 (dataId['tract'], dataId['patch'], dataId["subfilter"])) 

423 coaddExposureRefList.append(coaddRef) 

424 if dataId['tract'] in patchList: 

425 patchList[dataId['tract']].append(dataId['patch']) 

426 else: 

427 patchList[dataId['tract']] = [dataId['patch'], ] 

428 dataIds.append(dataId) 

429 

430 if not overlappingArea: 

431 raise pipeBase.NoWorkFound('No patches overlap detector') 

432 

433 self.checkPatchList(patchList) 

434 

435 coaddExposures = self.getDcrModel(patchList, inputs['dcrCoadds'], inputs['visitInfo']) 

436 return pipeBase.Struct(coaddExposures=coaddExposures, 

437 dataIds=dataIds) 

438 

439 def checkPatchList(self, patchList): 

440 """Check that all of the DcrModel subfilters are present for each 

441 patch. 

442 

443 Parameters 

444 ---------- 

445 patchList : `dict` 

446 Dict of the patches containing valid data for each tract. 

447 

448 Raises 

449 ------ 

450 RuntimeError 

451 If the number of exposures found for a patch does not match the 

452 number of subfilters. 

453 """ 

454 for tract in patchList: 

455 for patch in set(patchList[tract]): 

456 if patchList[tract].count(patch) != self.config.numSubfilters: 

457 raise RuntimeError("Invalid number of DcrModel subfilters found: %d vs %d expected", 

458 patchList[tract].count(patch), self.config.numSubfilters) 

459 

460 def getDcrModel(self, patchList, coaddRefs, visitInfo): 

461 """Build DCR-matched coadds from a list of exposure references. 

462 

463 Parameters 

464 ---------- 

465 patchList : `dict` 

466 Dict of the patches containing valid data for each tract. 

467 coaddRefs : `list` [`lsst.daf.butler.DeferredDatasetHandle`] 

468 Data references to `~lsst.afw.image.Exposure` representing 

469 DcrModels that overlap the detector. 

470 visitInfo : `lsst.afw.image.VisitInfo` 

471 Metadata for the science image. 

472 

473 Returns 

474 ------- 

475 coaddExposureList : `list` [`lsst.afw.image.Exposure`] 

476 Coadd exposures that overlap the detector. 

477 """ 

478 coaddExposureList = [] 

479 for tract in patchList: 

480 for patch in set(patchList[tract]): 

481 coaddRefList = [coaddRef for coaddRef in coaddRefs 

482 if _selectDataRef(coaddRef, tract, patch)] 

483 

484 dcrModel = DcrModel.fromQuantum(coaddRefList, 

485 self.config.effectiveWavelength, 

486 self.config.bandwidth, 

487 self.config.numSubfilters) 

488 coaddExposureList.append(dcrModel.buildMatchedExposure(visitInfo=visitInfo)) 

489 return coaddExposureList 

490 

491 

492def _selectDataRef(coaddRef, tract, patch): 

493 condition = (coaddRef.dataId['tract'] == tract) & (coaddRef.dataId['patch'] == patch) 

494 return condition