Coverage for python / lsst / ip / diffim / dipoleMeasurement.py: 19%

171 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-23 08:42 +0000

1# This file is part of ip_diffim. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import numpy as np 

23 

24import lsst.afw.image as afwImage 

25import lsst.geom as geom 

26import lsst.pex.config as pexConfig 

27import lsst.meas.deblender.baseline as deblendBaseline 

28from lsst.meas.base.pluginRegistry import register 

29from lsst.meas.base import SingleFrameMeasurementTask, SingleFrameMeasurementConfig, \ 

30 SingleFramePluginConfig, SingleFramePlugin 

31import lsst.afw.display as afwDisplay 

32from lsst.utils.logging import getLogger 

33 

34__all__ = ("DipoleMeasurementConfig", "DipoleMeasurementTask", "DipoleAnalysis", "DipoleDeblender", 

35 "SourceFlagChecker", "ClassificationDipoleConfig", "ClassificationDipolePlugin") 

36 

37 

38class ClassificationDipoleConfig(SingleFramePluginConfig): 

39 """Configuration for classification of detected diaSources as dipole or not""" 

40 minSn = pexConfig.Field( 

41 doc="Minimum quadrature sum of positive+negative lobe S/N to be considered a dipole", 

42 dtype=float, default=np.sqrt(2) * 5.0, 

43 ) 

44 maxFluxRatio = pexConfig.Field( 

45 doc="Maximum flux ratio in either lobe to be considered a dipole", 

46 dtype=float, default=0.65 

47 ) 

48 

49 

50@register("ip_diffim_ClassificationDipole") 

51class ClassificationDipolePlugin(SingleFramePlugin): 

52 """A plugin to classify whether a diaSource is a dipole. 

53 """ 

54 

55 ConfigClass = ClassificationDipoleConfig 

56 

57 @classmethod 

58 def getExecutionOrder(cls): 

59 """ 

60 Returns 

61 ------- 

62 result : `callable` 

63 """ 

64 return cls.APCORR_ORDER 

65 

66 def __init__(self, config, name, schema, metadata): 

67 SingleFramePlugin.__init__(self, config, name, schema, metadata) 

68 self.dipoleAnalysis = DipoleAnalysis() 

69 self.keyProbability = schema.addField(name + "_value", type="D", 

70 doc="Set to 1 for dipoles, else 0.") 

71 self.keyFlag = schema.addField(name + "_flag", type="Flag", doc="Set to 1 for any fatal failure.") 

72 

73 def measure(self, measRecord, exposure): 

74 passesSn = self.dipoleAnalysis.getSn(measRecord) > self.config.minSn 

75 negFlux = np.abs(measRecord.get("ip_diffim_PsfDipoleFlux_neg_instFlux")) 

76 negFluxFlag = measRecord.get("ip_diffim_PsfDipoleFlux_neg_flag") 

77 posFlux = np.abs(measRecord.get("ip_diffim_PsfDipoleFlux_pos_instFlux")) 

78 posFluxFlag = measRecord.get("ip_diffim_PsfDipoleFlux_pos_flag") 

79 

80 if negFluxFlag or posFluxFlag: 

81 self.fail(measRecord) 

82 # continue on to classify 

83 

84 totalFlux = negFlux + posFlux 

85 

86 # If negFlux or posFlux are NaN, these evaluate to False 

87 passesFluxNeg = (negFlux / totalFlux) < self.config.maxFluxRatio 

88 passesFluxPos = (posFlux / totalFlux) < self.config.maxFluxRatio 

89 if (passesSn and passesFluxPos and passesFluxNeg): 

90 val = 1.0 

91 else: 

92 val = 0.0 

93 

94 measRecord.set(self.keyProbability, val) 

95 

96 def fail(self, measRecord, error=None): 

97 measRecord.set(self.keyFlag, True) 

98 

99 

100class DipoleMeasurementConfig(SingleFrameMeasurementConfig): 

101 """Measurement of detected diaSources as dipoles""" 

102 

103 def setDefaults(self): 

104 SingleFrameMeasurementConfig.setDefaults(self) 

105 self.plugins = ["base_CircularApertureFlux", 

106 "base_PixelFlags", 

107 "base_SkyCoord", 

108 "base_PsfFlux", 

109 "ip_diffim_PsfDipoleFlux", 

110 "ip_diffim_ClassificationDipole", 

111 ] 

112 

113 self.slots.calibFlux = None 

114 self.slots.modelFlux = None 

115 self.slots.gaussianFlux = None 

116 self.slots.shape = None 

117 self.slots.centroid = "ip_diffim_PsfDipoleFlux" 

118 self.doReplaceWithNoise = False 

119 

120 

121class DipoleMeasurementTask(SingleFrameMeasurementTask): 

122 """Measurement of Sources, specifically ones from difference images, for characterization as dipoles 

123 

124 Parameters 

125 ---------- 

126 sources : 'lsst.afw.table.SourceCatalog' 

127 Sources that will be measured 

128 badFlags : `list` of `dict` 

129 A list of flags that will be used to determine if there was a measurement problem 

130 

131 """ 

132 ConfigClass = DipoleMeasurementConfig 

133 _DefaultName = "dipoleMeasurement" 

134 

135 

136######### 

137# Other Support classs 

138######### 

139 

140class SourceFlagChecker(object): 

141 """Functor class to check whether a diaSource has flags set that should cause it to be labeled bad.""" 

142 

143 def __init__(self, sources, badFlags=None): 

144 self.badFlags = [ 

145 "base_PixelFlags_flag_edge", 

146 "base_PixelFlags_flag_nodata", 

147 "base_PixelFlags_flag_interpolatedCenter", 

148 "base_PixelFlags_flag_saturatedCenter", 

149 ] 

150 if badFlags is not None: 

151 for flag in badFlags: 

152 self.badFlags.append(flag) 

153 self.keys = [sources.getSchema().find(name).key for name in self.badFlags] 

154 self.keys.append(sources.table.getCentroidFlagKey()) 

155 

156 def __call__(self, source): 

157 """Call the source flag checker on a single Source 

158 

159 Parameters 

160 ---------- 

161 source : 

162 Source that will be examined 

163 """ 

164 for k in self.keys: 

165 if source.get(k): 

166 return False 

167 return True 

168 

169 

170class DipoleAnalysis(object): 

171 """Functor class that provides (S/N, position, orientation) of measured dipoles""" 

172 

173 def __init__(self): 

174 pass 

175 

176 def __call__(self, source): 

177 """Parse information returned from dipole measurement 

178 

179 Parameters 

180 ---------- 

181 source : `lsst.afw.table.SourceRecord` 

182 The source that will be examined""" 

183 return self.getSn(source), self.getCentroid(source), self.getOrientation(source) 

184 

185 def getSn(self, source): 

186 """Get the total signal-to-noise of the dipole; total S/N is from positive and negative lobe 

187 

188 Parameters 

189 ---------- 

190 source : `lsst.afw.table.SourceRecord` 

191 The source that will be examined""" 

192 

193 posflux = source.get("ip_diffim_PsfDipoleFlux_pos_instFlux") 

194 posfluxErr = source.get("ip_diffim_PsfDipoleFlux_pos_instFluxErr") 

195 negflux = source.get("ip_diffim_PsfDipoleFlux_neg_instFlux") 

196 negfluxErr = source.get("ip_diffim_PsfDipoleFlux_neg_instFluxErr") 

197 

198 # Not a dipole! 

199 if (posflux < 0) is (negflux < 0): 

200 return 0 

201 

202 return np.sqrt((posflux/posfluxErr)**2 + (negflux/negfluxErr)**2) 

203 

204 def getCentroid(self, source): 

205 """Get the centroid of the dipole; average of positive and negative lobe 

206 

207 Parameters 

208 ---------- 

209 source : `lsst.afw.table.SourceRecord` 

210 The source that will be examined""" 

211 

212 negCenX = source.get("ip_diffim_PsfDipoleFlux_neg_centroid_x") 

213 negCenY = source.get("ip_diffim_PsfDipoleFlux_neg_centroid_y") 

214 posCenX = source.get("ip_diffim_PsfDipoleFlux_pos_centroid_x") 

215 posCenY = source.get("ip_diffim_PsfDipoleFlux_pos_centroid_y") 

216 if (np.isinf(negCenX) or np.isinf(negCenY) or np.isinf(posCenX) or np.isinf(posCenY)): 

217 return None 

218 

219 center = geom.Point2D(0.5*(negCenX+posCenX), 

220 0.5*(negCenY+posCenY)) 

221 return center 

222 

223 def getOrientation(self, source): 

224 """Calculate the orientation of dipole; vector from negative to positive lobe 

225 

226 Parameters 

227 ---------- 

228 source : `lsst.afw.table.SourceRecord` 

229 The source that will be examined""" 

230 

231 negCenX = source.get("ip_diffim_PsfDipoleFlux_neg_centroid_x") 

232 negCenY = source.get("ip_diffim_PsfDipoleFlux_neg_centroid_y") 

233 posCenX = source.get("ip_diffim_PsfDipoleFlux_pos_centroid_x") 

234 posCenY = source.get("ip_diffim_PsfDipoleFlux_pos_centroid_y") 

235 if (np.isinf(negCenX) or np.isinf(negCenY) or np.isinf(posCenX) or np.isinf(posCenY)): 

236 return None 

237 

238 dx, dy = posCenX-negCenX, posCenY-negCenY 

239 angle = geom.Angle(np.arctan2(dx, dy), geom.radians) 

240 return angle 

241 

242 def displayDipoles(self, exposure, sources): 

243 """Display debugging information on the detected dipoles 

244 

245 Parameters 

246 ---------- 

247 exposure : `lsst.afw.image.Exposure` 

248 Image the dipoles were measured on 

249 sources : `lsst.afw.table.SourceCatalog` 

250 The set of diaSources that were measured""" 

251 

252 import lsstDebug 

253 display = lsstDebug.Info(__name__).display 

254 displayDiaSources = lsstDebug.Info(__name__).displayDiaSources 

255 maskTransparency = lsstDebug.Info(__name__).maskTransparency 

256 if not maskTransparency: 

257 maskTransparency = 90 

258 disp = afwDisplay.Display(frame=lsstDebug.frame) 

259 disp.setMaskTransparency(maskTransparency) 

260 disp.mtv(exposure) 

261 

262 if display and displayDiaSources: 

263 with disp.Buffering(): 

264 for source in sources: 

265 cenX = source.get("ipdiffim_DipolePsfFlux_x") 

266 cenY = source.get("ipdiffim_DipolePsfFlux_y") 

267 if np.isinf(cenX) or np.isinf(cenY): 

268 cenX, cenY = source.getCentroid() 

269 

270 isdipole = source.get("ip_diffim_ClassificationDipole_value") 

271 if isdipole and np.isfinite(isdipole): 

272 # Dipole 

273 ctype = afwDisplay.GREEN 

274 else: 

275 # Not dipole 

276 ctype = afwDisplay.RED 

277 

278 disp.dot("o", cenX, cenY, size=2, ctype=ctype) 

279 

280 negCenX = source.get("ip_diffim_PsfDipoleFlux_neg_centroid_x") 

281 negCenY = source.get("ip_diffim_PsfDipoleFlux_neg_centroid_y") 

282 posCenX = source.get("ip_diffim_PsfDipoleFlux_pos_centroid_x") 

283 posCenY = source.get("ip_diffim_PsfDipoleFlux_pos_centroid_y") 

284 if (np.isinf(negCenX) or np.isinf(negCenY) or np.isinf(posCenX) or np.isinf(posCenY)): 

285 continue 

286 

287 disp.line([(negCenX, negCenY), (posCenX, posCenY)], ctype=afwDisplay.YELLOW) 

288 

289 lsstDebug.frame += 1 

290 

291 

292class DipoleDeblender(object): 

293 """Functor to deblend a source as a dipole, and return a new source with deblended footprints. 

294 

295 This necessarily overrides some of the functionality from 

296 meas_algorithms/python/lsst/meas/algorithms/deblend.py since we 

297 need a single source that contains the blended peaks, not 

298 multiple children sources. This directly calls the core 

299 deblending code deblendBaseline.deblend (optionally _fitPsf for 

300 debugging). 

301 

302 Not actively being used, but there is a unit test for it in 

303 dipoleAlgorithm.py. 

304 """ 

305 

306 def __init__(self): 

307 # Set up defaults to send to deblender 

308 

309 # Always deblend as Psf 

310 self.psfChisqCut1 = self.psfChisqCut2 = self.psfChisqCut2b = np.inf 

311 self.log = getLogger('lsst.ip.diffim.DipoleDeblender') 

312 self.sigma2fwhm = 2. * np.sqrt(2. * np.log(2.)) 

313 

314 def __call__(self, source, exposure): 

315 fp = source.getFootprint() 

316 peaks = fp.getPeaks() 

317 peaksF = [pk.getF() for pk in peaks] 

318 fbb = fp.getBBox() 

319 fmask = afwImage.Mask(fbb) 

320 fmask.setXY0(fbb.getMinX(), fbb.getMinY()) 

321 fp.spans.setMask(fmask, 1) 

322 

323 psf = exposure.getPsf() 

324 psfSigPix = psf.computeShape(psf.getAveragePosition()).getDeterminantRadius() 

325 psfFwhmPix = psfSigPix * self.sigma2fwhm 

326 subimage = afwImage.ExposureF(exposure, bbox=fbb, deep=True) 

327 cpsf = deblendBaseline.CachingPsf(psf) 

328 

329 # if fewer than 2 peaks, just return a copy of the source 

330 if len(peaks) < 2: 

331 return source.getTable().copyRecord(source) 

332 

333 # make sure you only deblend 2 peaks; take the brighest and faintest 

334 speaks = [(p.getPeakValue(), p) for p in peaks] 

335 speaks.sort() 

336 dpeaks = [speaks[0][1], speaks[-1][1]] 

337 

338 # and only set these peaks in the footprint (peaks is mutable) 

339 peaks.clear() 

340 for peak in dpeaks: 

341 peaks.append(peak) 

342 

343 if True: 

344 # Call top-level deblend task 

345 fpres = deblendBaseline.deblend(fp, exposure.getMaskedImage(), psf, psfFwhmPix, 

346 log=self.log, 

347 psfChisqCut1=self.psfChisqCut1, 

348 psfChisqCut2=self.psfChisqCut2, 

349 psfChisqCut2b=self.psfChisqCut2b) 

350 else: 

351 # Call lower-level _fit_psf task 

352 

353 # Prepare results structure 

354 fpres = deblendBaseline.DeblenderResult(fp, exposure.getMaskedImage(), psf, psfFwhmPix, self.log) 

355 

356 for pki, (pk, pkres, pkF) in enumerate(zip(dpeaks, fpres.deblendedParents[0].peaks, peaksF)): 

357 self.log.debug('Peak %i', pki) 

358 deblendBaseline._fitPsf(fp, fmask, pk, pkF, pkres, fbb, dpeaks, peaksF, self.log, 

359 cpsf, psfFwhmPix, 

360 subimage.image, 

361 subimage.variance, 

362 self.psfChisqCut1, self.psfChisqCut2, self.psfChisqCut2b) 

363 

364 deblendedSource = source.getTable().copyRecord(source) 

365 deblendedSource.setParent(source.getId()) 

366 peakList = deblendedSource.getFootprint().getPeaks() 

367 peakList.clear() 

368 

369 for i, peak in enumerate(fpres.deblendedParents[0].peaks): 

370 if peak.psfFitFlux > 0: 

371 suffix = "pos" 

372 else: 

373 suffix = "neg" 

374 c = peak.psfFitCenter 

375 self.log.info("deblended.centroid.dipole.psf.%s %f %f", 

376 suffix, c[0], c[1]) 

377 self.log.info("deblended.chi2dof.dipole.%s %f", 

378 suffix, peak.psfFitChisq / peak.psfFitDof) 

379 self.log.info("deblended.flux.dipole.psf.%s %f", 

380 suffix, peak.psfFitFlux * np.sum(peak.templateImage.array)) 

381 peakList.append(peak.peak) 

382 return deblendedSource