Coverage for python/lsst/ip/diffim/getTemplate.py: 23%

178 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 14:03 +0000

1# This file is part of ip_diffim. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21import numpy as np 

22 

23import lsst.afw.image as afwImage 

24import lsst.geom as geom 

25import lsst.afw.geom as afwGeom 

26import lsst.afw.table as afwTable 

27import lsst.afw.math as afwMath 

28import lsst.pex.config as pexConfig 

29import lsst.pipe.base as pipeBase 

30from lsst.skymap import BaseSkyMap 

31from lsst.ip.diffim.dcrModel import DcrModel 

32from lsst.meas.algorithms import CoaddPsf, CoaddPsfConfig 

33from lsst.utils.timer import timeMethod 

34 

35__all__ = ["GetTemplateTask", "GetTemplateConfig", 

36 "GetDcrTemplateTask", "GetDcrTemplateConfig"] 

37 

38 

39class GetTemplateConnections(pipeBase.PipelineTaskConnections, 

40 dimensions=("instrument", "visit", "detector", "skymap"), 

41 defaultTemplates={"coaddName": "goodSeeing", 

42 "warpTypeSuffix": "", 

43 "fakesType": ""}): 

44 bbox = pipeBase.connectionTypes.Input( 

45 doc="BBoxes of calexp used determine geometry of output template", 

46 name="{fakesType}calexp.bbox", 

47 storageClass="Box2I", 

48 dimensions=("instrument", "visit", "detector"), 

49 ) 

50 wcs = pipeBase.connectionTypes.Input( 

51 doc="WCS of the calexp that we want to fetch the template for", 

52 name="{fakesType}calexp.wcs", 

53 storageClass="Wcs", 

54 dimensions=("instrument", "visit", "detector"), 

55 ) 

56 skyMap = pipeBase.connectionTypes.Input( 

57 doc="Input definition of geometry/bbox and projection/wcs for template exposures", 

58 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

59 dimensions=("skymap", ), 

60 storageClass="SkyMap", 

61 ) 

62 # TODO DM-31292: Add option to use global external wcs from jointcal 

63 # Needed for DRP HSC 

64 coaddExposures = pipeBase.connectionTypes.Input( 

65 doc="Input template to match and subtract from the exposure", 

66 dimensions=("tract", "patch", "skymap", "band"), 

67 storageClass="ExposureF", 

68 name="{fakesType}{coaddName}Coadd{warpTypeSuffix}", 

69 multiple=True, 

70 deferLoad=True 

71 ) 

72 template = pipeBase.connectionTypes.Output( 

73 doc="Warped template used to create `subtractedExposure`.", 

74 dimensions=("instrument", "visit", "detector"), 

75 storageClass="ExposureF", 

76 name="{fakesType}{coaddName}Diff_templateExp{warpTypeSuffix}", 

77 ) 

78 

79 

80class GetTemplateConfig(pipeBase.PipelineTaskConfig, 

81 pipelineConnections=GetTemplateConnections): 

82 templateBorderSize = pexConfig.Field( 

83 dtype=int, 

84 default=20, 

85 doc="Number of pixels to grow the requested template image to account for warping" 

86 ) 

87 warp = pexConfig.ConfigField( 

88 dtype=afwMath.Warper.ConfigClass, 

89 doc="warper configuration", 

90 ) 

91 coaddPsf = pexConfig.ConfigField( 

92 doc="Configuration for CoaddPsf", 

93 dtype=CoaddPsfConfig, 

94 ) 

95 

96 def setDefaults(self): 

97 self.warp.warpingKernelName = 'lanczos5' 

98 self.coaddPsf.warpingKernelName = 'lanczos5' 

99 

100 

101class GetTemplateTask(pipeBase.PipelineTask): 

102 ConfigClass = GetTemplateConfig 

103 _DefaultName = "getTemplate" 

104 

105 def __init__(self, *args, **kwargs): 

106 super().__init__(*args, **kwargs) 

107 self.warper = afwMath.Warper.fromConfig(self.config.warp) 

108 

109 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

110 # Read in all inputs. 

111 inputs = butlerQC.get(inputRefs) 

112 results = self.getOverlappingExposures(inputs) 

113 inputs["coaddExposures"] = results.coaddExposures 

114 inputs["dataIds"] = results.dataIds 

115 inputs["physical_filter"] = butlerQC.quantum.dataId["physical_filter"] 

116 outputs = self.run(**inputs) 

117 butlerQC.put(outputs, outputRefs) 

118 

119 def getOverlappingExposures(self, inputs): 

120 """Return lists of coadds and their corresponding dataIds that overlap 

121 the detector. 

122 

123 The spatial index in the registry has generous padding and often 

124 supplies patches near, but not directly overlapping the detector. 

125 Filters inputs so that we don't have to read in all input coadds. 

126 

127 Parameters 

128 ---------- 

129 inputs : `dict` of task Inputs, containing: 

130 - coaddExposureRefs : `list` \ 

131 [`lsst.daf.butler.DeferredDatasetHandle` of \ 

132 `lsst.afw.image.Exposure`] 

133 Data references to exposures that might overlap the detector. 

134 - bbox : `lsst.geom.Box2I` 

135 Template Bounding box of the detector geometry onto which to 

136 resample the coaddExposures. 

137 - skyMap : `lsst.skymap.SkyMap` 

138 Input definition of geometry/bbox and projection/wcs for 

139 template exposures. 

140 - wcs : `lsst.afw.geom.SkyWcs` 

141 Template WCS onto which to resample the coaddExposures. 

142 

143 Returns 

144 ------- 

145 result : `lsst.pipe.base.Struct` 

146 A struct with attributes: 

147 

148 ``coaddExposures`` 

149 List of Coadd exposures that overlap the detector (`list` 

150 [`lsst.afw.image.Exposure`]). 

151 ``dataIds`` 

152 List of data IDs of the coadd exposures that overlap the 

153 detector (`list` [`lsst.daf.butler.DataCoordinate`]). 

154 

155 Raises 

156 ------ 

157 NoWorkFound 

158 Raised if no patches overlap the input detector bbox. 

159 """ 

160 # Check that the patches actually overlap the detector 

161 # Exposure's validPolygon would be more accurate 

162 detectorPolygon = geom.Box2D(inputs['bbox']) 

163 overlappingArea = 0 

164 coaddExposureList = [] 

165 dataIds = [] 

166 for coaddRef in inputs['coaddExposures']: 

167 dataId = coaddRef.dataId 

168 patchWcs = inputs['skyMap'][dataId['tract']].getWcs() 

169 patchBBox = inputs['skyMap'][dataId['tract']][dataId['patch']].getOuterBBox() 

170 patchCorners = patchWcs.pixelToSky(geom.Box2D(patchBBox).getCorners()) 

171 inputsWcs = inputs['wcs'] 

172 if inputsWcs is not None: 

173 patchPolygon = afwGeom.Polygon(inputsWcs.skyToPixel(patchCorners)) 

174 if patchPolygon.intersection(detectorPolygon): 

175 overlappingArea += patchPolygon.intersectionSingle(detectorPolygon).calculateArea() 

176 self.log.info("Using template input tract=%s, patch=%s" % 

177 (dataId['tract'], dataId['patch'])) 

178 coaddExposureList.append(coaddRef.get()) 

179 dataIds.append(dataId) 

180 else: 

181 self.log.info("Exposure has no WCS, so cannot create associated template.") 

182 

183 if not overlappingArea: 

184 raise pipeBase.NoWorkFound('No patches overlap detector') 

185 

186 return pipeBase.Struct(coaddExposures=coaddExposureList, 

187 dataIds=dataIds) 

188 

189 @timeMethod 

190 def run(self, coaddExposures, bbox, wcs, dataIds, physical_filter=None, **kwargs): 

191 """Warp coadds from multiple tracts to form a template for image diff. 

192 

193 Where the tracts overlap, the resulting template image is averaged. 

194 The PSF on the template is created by combining the CoaddPsf on each 

195 template image into a meta-CoaddPsf. 

196 

197 Parameters 

198 ---------- 

199 coaddExposures : `list` [`lsst.afw.image.Exposure`] 

200 Coadds to be mosaicked. 

201 bbox : `lsst.geom.Box2I` 

202 Template Bounding box of the detector geometry onto which to 

203 resample the ``coaddExposures``. 

204 wcs : `lsst.afw.geom.SkyWcs` 

205 Template WCS onto which to resample the ``coaddExposures``. 

206 dataIds : `list` [`lsst.daf.butler.DataCoordinate`] 

207 Record of the tract and patch of each coaddExposure. 

208 physical_filter : `str`, optional 

209 The physical filter of the science image. 

210 **kwargs 

211 Any additional keyword parameters. 

212 

213 Returns 

214 ------- 

215 result : `lsst.pipe.base.Struct` 

216 A struct with attributes: 

217 

218 ``template`` 

219 A template coadd exposure assembled out of patches 

220 (`lsst.afw.image.ExposureF`). 

221 

222 Raises 

223 ------ 

224 NoWorkFound 

225 If no coadds are found with sufficient un-masked pixels. 

226 RuntimeError 

227 If the PSF of the template can't be calculated. 

228 """ 

229 # Table for CoaddPSF 

230 tractsSchema = afwTable.ExposureTable.makeMinimalSchema() 

231 tractKey = tractsSchema.addField('tract', type=np.int32, doc='Which tract') 

232 patchKey = tractsSchema.addField('patch', type=np.int32, doc='Which patch') 

233 weightKey = tractsSchema.addField('weight', type=float, doc='Weight for each tract, should be 1') 

234 tractsCatalog = afwTable.ExposureCatalog(tractsSchema) 

235 

236 finalWcs = wcs 

237 bbox.grow(self.config.templateBorderSize) 

238 finalBBox = bbox 

239 

240 nPatchesFound = 0 

241 maskedImageList = [] 

242 weightList = [] 

243 

244 for coaddExposure, dataId in zip(coaddExposures, dataIds): 

245 

246 # warp to detector WCS 

247 warped = self.warper.warpExposure(finalWcs, coaddExposure, maxBBox=finalBBox) 

248 

249 # Check if warped image is viable 

250 if not np.any(np.isfinite(warped.image.array)): 

251 self.log.info("No overlap for warped %s. Skipping" % dataId) 

252 continue 

253 

254 exp = afwImage.ExposureF(finalBBox, finalWcs) 

255 exp.maskedImage.set(np.nan, afwImage.Mask.getPlaneBitMask("NO_DATA"), np.nan) 

256 exp.maskedImage.assign(warped.maskedImage, warped.getBBox()) 

257 

258 maskedImageList.append(exp.maskedImage) 

259 weightList.append(1) 

260 record = tractsCatalog.addNew() 

261 record.setPsf(coaddExposure.getPsf()) 

262 record.setWcs(coaddExposure.getWcs()) 

263 record.setPhotoCalib(coaddExposure.getPhotoCalib()) 

264 record.setBBox(coaddExposure.getBBox()) 

265 record.setValidPolygon(afwGeom.Polygon(geom.Box2D(coaddExposure.getBBox()).getCorners())) 

266 record.set(tractKey, dataId['tract']) 

267 record.set(patchKey, dataId['patch']) 

268 record.set(weightKey, 1.) 

269 nPatchesFound += 1 

270 

271 if nPatchesFound == 0: 

272 raise pipeBase.NoWorkFound("No patches found to overlap detector") 

273 

274 # Combine images from individual patches together 

275 statsFlags = afwMath.stringToStatisticsProperty('MEAN') 

276 statsCtrl = afwMath.StatisticsControl() 

277 statsCtrl.setNanSafe(True) 

278 statsCtrl.setWeighted(True) 

279 statsCtrl.setCalcErrorMosaicMode(True) 

280 

281 templateExposure = afwImage.ExposureF(finalBBox, finalWcs) 

282 templateExposure.maskedImage.set(np.nan, afwImage.Mask.getPlaneBitMask("NO_DATA"), np.nan) 

283 xy0 = templateExposure.getXY0() 

284 # Do not mask any values 

285 templateExposure.maskedImage = afwMath.statisticsStack(maskedImageList, statsFlags, statsCtrl, 

286 weightList, clipped=0, maskMap=[]) 

287 templateExposure.maskedImage.setXY0(xy0) 

288 

289 # CoaddPsf centroid not only must overlap image, but must overlap the 

290 # part of image with data. Use centroid of region with data. 

291 boolmask = templateExposure.mask.array & templateExposure.mask.getPlaneBitMask('NO_DATA') == 0 

292 maskx = afwImage.makeMaskFromArray(boolmask.astype(afwImage.MaskPixel)) 

293 centerCoord = afwGeom.SpanSet.fromMask(maskx, 1).computeCentroid() 

294 

295 ctrl = self.config.coaddPsf.makeControl() 

296 coaddPsf = CoaddPsf(tractsCatalog, finalWcs, centerCoord, ctrl.warpingKernelName, ctrl.cacheSize) 

297 if coaddPsf is None: 

298 raise RuntimeError("CoaddPsf could not be constructed") 

299 

300 templateExposure.setPsf(coaddPsf) 

301 # Coadds do not have a physical filter, so fetch it from the butler to prevent downstream warnings. 

302 if physical_filter is None: 

303 filterLabel = coaddExposure.getFilter() 

304 else: 

305 filterLabel = afwImage.FilterLabel(dataId['band'], physical_filter) 

306 templateExposure.setFilter(filterLabel) 

307 templateExposure.setPhotoCalib(coaddExposure.getPhotoCalib()) 

308 return pipeBase.Struct(template=templateExposure) 

309 

310 

311class GetDcrTemplateConnections(GetTemplateConnections, 

312 dimensions=("instrument", "visit", "detector", "skymap"), 

313 defaultTemplates={"coaddName": "dcr", 

314 "warpTypeSuffix": "", 

315 "fakesType": ""}): 

316 visitInfo = pipeBase.connectionTypes.Input( 

317 doc="VisitInfo of calexp used to determine observing conditions.", 

318 name="{fakesType}calexp.visitInfo", 

319 storageClass="VisitInfo", 

320 dimensions=("instrument", "visit", "detector"), 

321 ) 

322 dcrCoadds = pipeBase.connectionTypes.Input( 

323 doc="Input DCR template to match and subtract from the exposure", 

324 name="{fakesType}dcrCoadd{warpTypeSuffix}", 

325 storageClass="ExposureF", 

326 dimensions=("tract", "patch", "skymap", "band", "subfilter"), 

327 multiple=True, 

328 deferLoad=True 

329 ) 

330 

331 def __init__(self, *, config=None): 

332 super().__init__(config=config) 

333 self.inputs.remove("coaddExposures") 

334 

335 

336class GetDcrTemplateConfig(GetTemplateConfig, 

337 pipelineConnections=GetDcrTemplateConnections): 

338 numSubfilters = pexConfig.Field( 

339 doc="Number of subfilters in the DcrCoadd.", 

340 dtype=int, 

341 default=3, 

342 ) 

343 effectiveWavelength = pexConfig.Field( 

344 doc="Effective wavelength of the filter.", 

345 optional=False, 

346 dtype=float, 

347 ) 

348 bandwidth = pexConfig.Field( 

349 doc="Bandwidth of the physical filter.", 

350 optional=False, 

351 dtype=float, 

352 ) 

353 

354 def validate(self): 

355 if self.effectiveWavelength is None or self.bandwidth is None: 

356 raise ValueError("The effective wavelength and bandwidth of the physical filter " 

357 "must be set in the getTemplate config for DCR coadds. " 

358 "Required until transmission curves are used in DM-13668.") 

359 

360 

361class GetDcrTemplateTask(GetTemplateTask): 

362 ConfigClass = GetDcrTemplateConfig 

363 _DefaultName = "getDcrTemplate" 

364 

365 def getOverlappingExposures(self, inputs): 

366 """Return lists of coadds and their corresponding dataIds that overlap 

367 the detector. 

368 

369 The spatial index in the registry has generous padding and often 

370 supplies patches near, but not directly overlapping the detector. 

371 Filters inputs so that we don't have to read in all input coadds. 

372 

373 Parameters 

374 ---------- 

375 inputs : `dict` of task Inputs, containing: 

376 - coaddExposureRefs : `list` \ 

377 [`lsst.daf.butler.DeferredDatasetHandle` of \ 

378 `lsst.afw.image.Exposure`] 

379 Data references to exposures that might overlap the detector. 

380 - bbox : `lsst.geom.Box2I` 

381 Template Bounding box of the detector geometry onto which to 

382 resample the coaddExposures. 

383 - skyMap : `lsst.skymap.SkyMap` 

384 Input definition of geometry/bbox and projection/wcs for 

385 template exposures. 

386 - wcs : `lsst.afw.geom.SkyWcs` 

387 Template WCS onto which to resample the coaddExposures. 

388 - visitInfo : `lsst.afw.image.VisitInfo` 

389 Metadata for the science image. 

390 

391 Returns 

392 ------- 

393 result : `lsst.pipe.base.Struct` 

394 A struct with attibutes: 

395 

396 ``coaddExposures`` 

397 Coadd exposures that overlap the detector (`list` 

398 [`lsst.afw.image.Exposure`]). 

399 ``dataIds`` 

400 Data IDs of the coadd exposures that overlap the detector 

401 (`list` [`lsst.daf.butler.DataCoordinate`]). 

402 

403 Raises 

404 ------ 

405 NoWorkFound 

406 Raised if no patches overlatp the input detector bbox. 

407 """ 

408 # Check that the patches actually overlap the detector 

409 # Exposure's validPolygon would be more accurate 

410 detectorPolygon = geom.Box2D(inputs["bbox"]) 

411 overlappingArea = 0 

412 coaddExposureRefList = [] 

413 dataIds = [] 

414 patchList = dict() 

415 for coaddRef in inputs["dcrCoadds"]: 

416 dataId = coaddRef.dataId 

417 patchWcs = inputs["skyMap"][dataId['tract']].getWcs() 

418 patchBBox = inputs["skyMap"][dataId['tract']][dataId['patch']].getOuterBBox() 

419 patchCorners = patchWcs.pixelToSky(geom.Box2D(patchBBox).getCorners()) 

420 patchPolygon = afwGeom.Polygon(inputs["wcs"].skyToPixel(patchCorners)) 

421 if patchPolygon.intersection(detectorPolygon): 

422 overlappingArea += patchPolygon.intersectionSingle(detectorPolygon).calculateArea() 

423 self.log.info("Using template input tract=%s, patch=%s, subfilter=%s" % 

424 (dataId['tract'], dataId['patch'], dataId["subfilter"])) 

425 coaddExposureRefList.append(coaddRef) 

426 if dataId['tract'] in patchList: 

427 patchList[dataId['tract']].append(dataId['patch']) 

428 else: 

429 patchList[dataId['tract']] = [dataId['patch'], ] 

430 dataIds.append(dataId) 

431 

432 if not overlappingArea: 

433 raise pipeBase.NoWorkFound('No patches overlap detector') 

434 

435 self.checkPatchList(patchList) 

436 

437 coaddExposures = self.getDcrModel(patchList, inputs['dcrCoadds'], inputs['visitInfo']) 

438 return pipeBase.Struct(coaddExposures=coaddExposures, 

439 dataIds=dataIds) 

440 

441 def checkPatchList(self, patchList): 

442 """Check that all of the DcrModel subfilters are present for each 

443 patch. 

444 

445 Parameters 

446 ---------- 

447 patchList : `dict` 

448 Dict of the patches containing valid data for each tract. 

449 

450 Raises 

451 ------ 

452 RuntimeError 

453 If the number of exposures found for a patch does not match the 

454 number of subfilters. 

455 """ 

456 for tract in patchList: 

457 for patch in set(patchList[tract]): 

458 if patchList[tract].count(patch) != self.config.numSubfilters: 

459 raise RuntimeError("Invalid number of DcrModel subfilters found: %d vs %d expected", 

460 patchList[tract].count(patch), self.config.numSubfilters) 

461 

462 def getDcrModel(self, patchList, coaddRefs, visitInfo): 

463 """Build DCR-matched coadds from a list of exposure references. 

464 

465 Parameters 

466 ---------- 

467 patchList : `dict` 

468 Dict of the patches containing valid data for each tract. 

469 coaddRefs : `list` [`lsst.daf.butler.DeferredDatasetHandle`] 

470 Data references to `~lsst.afw.image.Exposure` representing 

471 DcrModels that overlap the detector. 

472 visitInfo : `lsst.afw.image.VisitInfo` 

473 Metadata for the science image. 

474 

475 Returns 

476 ------- 

477 coaddExposureList : `list` [`lsst.afw.image.Exposure`] 

478 Coadd exposures that overlap the detector. 

479 """ 

480 coaddExposureList = [] 

481 for tract in patchList: 

482 for patch in set(patchList[tract]): 

483 coaddRefList = [coaddRef for coaddRef in coaddRefs 

484 if _selectDataRef(coaddRef, tract, patch)] 

485 

486 dcrModel = DcrModel.fromQuantum(coaddRefList, 

487 self.config.effectiveWavelength, 

488 self.config.bandwidth, 

489 self.config.numSubfilters) 

490 coaddExposureList.append(dcrModel.buildMatchedExposure(visitInfo=visitInfo)) 

491 return coaddExposureList 

492 

493 

494def _selectDataRef(coaddRef, tract, patch): 

495 condition = (coaddRef.dataId['tract'] == tract) & (coaddRef.dataId['patch'] == patch) 

496 return condition