Coverage for python/lsst/ip/isr/fringe.py: 18%

205 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-11-10 02:40 -0800

1# This file is part of ip_isr. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["FringeStatisticsConfig", "FringeConfig", "FringeTask"] 

23 

24import numpy 

25 

26import lsst.geom 

27import lsst.afw.image as afwImage 

28import lsst.afw.math as afwMath 

29import lsst.afw.display as afwDisplay 

30 

31from lsst.pipe.base import Task, Struct 

32from lsst.pex.config import Config, Field, ListField, ConfigField 

33from lsst.utils.timer import timeMethod 

34from .isrFunctions import checkFilter 

35 

36afwDisplay.setDefaultMaskTransparency(75) 

37 

38 

39def getFrame(): 

40 """Produce a new frame number each time""" 

41 getFrame.frame += 1 

42 return getFrame.frame 

43 

44 

45getFrame.frame = 0 

46 

47 

48class FringeStatisticsConfig(Config): 

49 """Options for measuring fringes on an exposure""" 

50 badMaskPlanes = ListField(dtype=str, default=["SAT"], doc="Ignore pixels with these masks") 

51 stat = Field(dtype=int, default=int(afwMath.MEDIAN), doc="Statistic to use") 

52 clip = Field(dtype=float, default=3.0, doc="Sigma clip threshold") 

53 iterations = Field(dtype=int, default=3, doc="Number of fitting iterations") 

54 rngSeedOffset = Field(dtype=int, default=0, 

55 doc="Offset to the random number generator seed (full seed includes exposure ID)") 

56 

57 

58class FringeConfig(Config): 

59 """Fringe subtraction options""" 

60 # TODO DM-28093: change the doc to specify that these are physical labels 

61 filters = ListField(dtype=str, default=[], doc="Only fringe-subtract these filters") 

62 # TODO: remove in DM-27177 

63 useFilterAliases = Field(dtype=bool, default=False, doc="Search filter aliases during check.", 

64 deprecated=("Removed with no replacement (FilterLabel has no aliases)." 

65 "Will be removed after v22.")) 

66 num = Field(dtype=int, default=30000, doc="Number of fringe measurements") 

67 small = Field(dtype=int, default=3, doc="Half-size of small (fringe) measurements (pixels)") 

68 large = Field(dtype=int, default=30, doc="Half-size of large (background) measurements (pixels)") 

69 iterations = Field(dtype=int, default=20, doc="Number of fitting iterations") 

70 clip = Field(dtype=float, default=3.0, doc="Sigma clip threshold") 

71 stats = ConfigField(dtype=FringeStatisticsConfig, doc="Statistics for measuring fringes") 

72 pedestal = Field(dtype=bool, default=False, doc="Remove fringe pedestal?") 

73 

74 

75class FringeTask(Task): 

76 """Task to remove fringes from a science exposure 

77 

78 We measure fringe amplitudes at random positions on the science exposure 

79 and at the same positions on the (potentially multiple) fringe frames 

80 and solve for the scales simultaneously. 

81 """ 

82 ConfigClass = FringeConfig 

83 _DefaultName = 'isrFringe' 

84 

85 def loadFringes(self, fringeExp, expId=None, assembler=None): 

86 """Pack the fringe data into a Struct. 

87 

88 This method moves the struct parsing code into a butler 

89 generation agnostic handler. 

90 

91 Parameters 

92 ---------- 

93 fringeExp : `lsst.afw.exposure.Exposure` 

94 The exposure containing the fringe data. 

95 expId : `int`, optional 

96 Exposure id to be fringe corrected, used to set RNG seed. 

97 assembler : `lsst.ip.isr.AssembleCcdTask`, optional 

98 An instance of AssembleCcdTask (for assembling fringe 

99 frames). 

100 

101 Returns 

102 ------- 

103 fringeData : `pipeBase.Struct` 

104 Struct containing fringe data: 

105 

106 ``fringes`` 

107 Calibration fringe files containing master fringe frames. 

108 ( : `lsst.afw.image.Exposure` or `list` thereof) 

109 ``seed`` 

110 Seed for random number generation. (`int`, optional) 

111 """ 

112 if assembler is not None: 

113 fringeExp = assembler.assembleCcd(fringeExp) 

114 

115 if expId is None: 

116 seed = self.config.stats.rngSeedOffset 

117 else: 

118 print(f"{self.config.stats.rngSeedOffset} {expId}") 

119 seed = self.config.stats.rngSeedOffset + expId 

120 

121 # Seed for numpy.random.RandomState must be convertable to a 32 bit 

122 # unsigned integer. 

123 seed %= 2**32 

124 

125 return Struct(fringes=fringeExp, 

126 seed=seed) 

127 

128 @timeMethod 

129 def run(self, exposure, fringes, seed=None): 

130 """Remove fringes from the provided science exposure. 

131 

132 Primary method of FringeTask. Fringes are only subtracted if the 

133 science exposure has a filter listed in the configuration. 

134 

135 Parameters 

136 ---------- 

137 exposure : `lsst.afw.image.Exposure` 

138 Science exposure from which to remove fringes. 

139 fringes : `lsst.afw.image.Exposure` or `list` thereof 

140 Calibration fringe files containing master fringe frames. 

141 seed : `int`, optional 

142 Seed for random number generation. 

143 

144 Returns 

145 ------- 

146 solution : `np.array` 

147 Fringe solution amplitudes for each input fringe frame. 

148 rms : `float` 

149 RMS error for the fit solution for this exposure. 

150 """ 

151 import lsstDebug 

152 display = lsstDebug.Info(__name__).display 

153 

154 if not self.checkFilter(exposure): 

155 self.log.info("Filter not found in FringeTaskConfig.filters. Skipping fringe correction.") 

156 return 

157 

158 if seed is None: 

159 seed = self.config.stats.rngSeedOffset 

160 rng = numpy.random.RandomState(seed=seed) 

161 

162 if not hasattr(fringes, '__iter__'): 

163 fringes = [fringes] 

164 

165 mask = exposure.getMaskedImage().getMask() 

166 for fringe in fringes: 

167 fringe.getMaskedImage().getMask().__ior__(mask) 

168 if self.config.pedestal: 

169 self.removePedestal(fringe) 

170 

171 positions = self.generatePositions(fringes[0], rng) 

172 fluxes = numpy.ndarray([self.config.num, len(fringes)]) 

173 for i, f in enumerate(fringes): 

174 fluxes[:, i] = self.measureExposure(f, positions, title="Fringe frame") 

175 

176 expFringes = self.measureExposure(exposure, positions, title="Science") 

177 solution, rms = self.solve(expFringes, fluxes) 

178 self.subtract(exposure, fringes, solution) 

179 if display: 

180 afwDisplay.Display(frame=getFrame()).mtv(exposure, title="Fringe subtracted") 

181 return solution, rms 

182 

183 def checkFilter(self, exposure): 

184 """Check whether we should fringe-subtract the science exposure. 

185 

186 Parameters 

187 ---------- 

188 exposure : `lsst.afw.image.Exposure` 

189 Exposure to check the filter of. 

190 

191 Returns 

192 ------- 

193 needsFringe : `bool` 

194 If True, then the exposure has a filter listed in the 

195 configuration, and should have the fringe applied. 

196 """ 

197 return checkFilter(exposure, self.config.filters, log=self.log) 

198 

199 def removePedestal(self, fringe): 

200 """Remove pedestal from fringe exposure. 

201 

202 Parameters 

203 ---------- 

204 fringe : `lsst.afw.image.Exposure` 

205 Fringe data to subtract the pedestal value from. 

206 """ 

207 stats = afwMath.StatisticsControl() 

208 stats.setNumSigmaClip(self.config.stats.clip) 

209 stats.setNumIter(self.config.stats.iterations) 

210 mi = fringe.getMaskedImage() 

211 pedestal = afwMath.makeStatistics(mi, afwMath.MEDIAN, stats).getValue() 

212 self.log.info("Removing fringe pedestal: %f", pedestal) 

213 mi -= pedestal 

214 

215 def generatePositions(self, exposure, rng): 

216 """Generate a random distribution of positions for measuring fringe 

217 amplitudes. 

218 

219 Parameters 

220 ---------- 

221 exposure : `lsst.afw.image.Exposure` 

222 Exposure to measure the positions on. 

223 rng : `numpy.random.RandomState` 

224 Random number generator to use. 

225 

226 Returns 

227 ------- 

228 positions : `numpy.array` 

229 Two-dimensional array containing the positions to sample 

230 for fringe amplitudes. 

231 """ 

232 start = self.config.large 

233 num = self.config.num 

234 width = exposure.getWidth() - self.config.large 

235 height = exposure.getHeight() - self.config.large 

236 return numpy.array([rng.randint(start, width, size=num), 

237 rng.randint(start, height, size=num)]).swapaxes(0, 1) 

238 

239 @timeMethod 

240 def measureExposure(self, exposure, positions, title="Fringe"): 

241 """Measure fringe amplitudes for an exposure 

242 

243 The fringe amplitudes are measured as the statistic within a square 

244 aperture. The statistic within a larger aperture are subtracted so 

245 as to remove the background. 

246 

247 Parameters 

248 ---------- 

249 exposure : `lsst.afw.image.Exposure` 

250 Exposure to measure the positions on. 

251 positions : `numpy.array` 

252 Two-dimensional array containing the positions to sample 

253 for fringe amplitudes. 

254 title : `str`, optional 

255 Title used for debug out plots. 

256 

257 Returns 

258 ------- 

259 fringes : `numpy.array` 

260 Array of measured exposure values at each of the positions 

261 supplied. 

262 """ 

263 stats = afwMath.StatisticsControl() 

264 stats.setNumSigmaClip(self.config.stats.clip) 

265 stats.setNumIter(self.config.stats.iterations) 

266 stats.setAndMask(exposure.getMaskedImage().getMask().getPlaneBitMask(self.config.stats.badMaskPlanes)) 

267 

268 num = self.config.num 

269 fringes = numpy.ndarray(num) 

270 

271 for i in range(num): 

272 x, y = positions[i] 

273 small = measure(exposure.getMaskedImage(), x, y, self.config.small, self.config.stats.stat, stats) 

274 large = measure(exposure.getMaskedImage(), x, y, self.config.large, self.config.stats.stat, stats) 

275 fringes[i] = small - large 

276 

277 import lsstDebug 

278 display = lsstDebug.Info(__name__).display 

279 if display: 

280 disp = afwDisplay.Display(frame=getFrame()) 

281 disp.mtv(exposure, title=title) 

282 if False: 

283 with disp.Buffering(): 

284 for x, y in positions: 

285 corners = numpy.array([[-1, -1], [1, -1], [1, 1], [-1, 1], [-1, -1]]) + [[x, y]] 

286 disp.line(corners*self.config.small, ctype=afwDisplay.GREEN) 

287 disp.line(corners*self.config.large, ctype=afwDisplay.BLUE) 

288 

289 return fringes 

290 

291 @timeMethod 

292 def solve(self, science, fringes): 

293 """Solve for the scale factors with iterative clipping. 

294 

295 Parameters 

296 ---------- 

297 science : `numpy.array` 

298 Array of measured science image values at each of the 

299 positions supplied. 

300 fringes : `numpy.array` 

301 Array of measured fringe values at each of the positions 

302 supplied. 

303 

304 Returns 

305 ------- 

306 solution : `np.array` 

307 Fringe solution amplitudes for each input fringe frame. 

308 rms : `float` 

309 RMS error for the fit solution for this exposure. 

310 """ 

311 import lsstDebug 

312 doPlot = lsstDebug.Info(__name__).plot 

313 

314 origNum = len(science) 

315 

316 def emptyResult(msg=""): 

317 """Generate an empty result for return to the user 

318 

319 There are no good pixels; doesn't matter what we return. 

320 """ 

321 self.log.warning("Unable to solve for fringes: no good pixels%s", msg) 

322 out = [0] 

323 if len(fringes) > 1: 

324 out = out*len(fringes) 

325 return numpy.array(out), numpy.nan 

326 

327 good = numpy.where(numpy.logical_and(numpy.isfinite(science), numpy.any(numpy.isfinite(fringes), 1))) 

328 science = science[good] 

329 fringes = fringes[good] 

330 oldNum = len(science) 

331 if oldNum == 0: 

332 return emptyResult() 

333 

334 # Up-front rejection to get rid of extreme, potentially troublesome 

335 # values (e.g., fringe apertures that fall on objects). 

336 good = select(science, self.config.clip) 

337 for ff in range(fringes.shape[1]): 

338 good &= select(fringes[:, ff], self.config.clip) 

339 science = science[good] 

340 fringes = fringes[good] 

341 oldNum = len(science) 

342 if oldNum == 0: 

343 return emptyResult(" after initial rejection") 

344 

345 for i in range(self.config.iterations): 

346 solution = self._solve(science, fringes) 

347 resid = science - numpy.sum(solution*fringes, 1) 

348 rms = stdev(resid) 

349 good = numpy.logical_not(abs(resid) > self.config.clip*rms) 

350 self.log.debug("Iteration %d: RMS=%f numGood=%d", i, rms, good.sum()) 

351 self.log.debug("Solution %d: %s", i, solution) 

352 newNum = good.sum() 

353 if newNum == 0: 

354 return emptyResult(" after %d rejection iterations" % i) 

355 

356 if doPlot: 

357 import matplotlib.pyplot as plot 

358 for j in range(fringes.shape[1]): 

359 fig = plot.figure(j) 

360 fig.clf() 

361 try: 

362 fig.canvas._tkcanvas._root().lift() # == Tk's raise 

363 except Exception: 

364 pass 

365 ax = fig.add_subplot(1, 1, 1) 

366 adjust = science.copy() 

367 others = set(range(fringes.shape[1])) 

368 others.discard(j) 

369 for k in others: 

370 adjust -= solution[k]*fringes[:, k] 

371 ax.plot(fringes[:, j], adjust, 'r.') 

372 xmin = fringes[:, j].min() 

373 xmax = fringes[:, j].max() 

374 ymin = solution[j]*xmin 

375 ymax = solution[j]*xmax 

376 ax.plot([xmin, xmax], [ymin, ymax], 'b-') 

377 ax.set_title("Fringe %d: %f" % (j, solution[j])) 

378 ax.set_xlabel("Fringe amplitude") 

379 ax.set_ylabel("Science amplitude") 

380 ax.set_autoscale_on(False) 

381 ax.set_xbound(lower=xmin, upper=xmax) 

382 ax.set_ybound(lower=ymin, upper=ymax) 

383 fig.show() 

384 while True: 

385 ans = input("Enter or c to continue [chp]").lower() 

386 if ans in ("", "c",): 

387 break 

388 if ans in ("p",): 

389 import pdb 

390 pdb.set_trace() 

391 elif ans in ("h", ): 

392 print("h[elp] c[ontinue] p[db]") 

393 

394 if newNum == oldNum: 

395 # Not gaining 

396 break 

397 oldNum = newNum 

398 good = numpy.where(good) 

399 science = science[good] 

400 fringes = fringes[good] 

401 

402 # Final solution without rejection 

403 solution = self._solve(science, fringes) 

404 self.log.info("Fringe solution: %s RMS: %f Good: %d/%d", solution, rms, len(science), origNum) 

405 return solution, rms 

406 

407 def _solve(self, science, fringes): 

408 """Solve for the scale factors. 

409 

410 Parameters 

411 ---------- 

412 science : `numpy.array` 

413 Array of measured science image values at each of the 

414 positions supplied. 

415 fringes : `numpy.array` 

416 Array of measured fringe values at each of the positions 

417 supplied. 

418 

419 Returns 

420 ------- 

421 solution : `np.array` 

422 Fringe solution amplitudes for each input fringe frame. 

423 """ 

424 return afwMath.LeastSquares.fromDesignMatrix(fringes, science, 

425 afwMath.LeastSquares.DIRECT_SVD).getSolution() 

426 

427 def subtract(self, science, fringes, solution): 

428 """Subtract the fringes. 

429 

430 Parameters 

431 ---------- 

432 science : `lsst.afw.image.Exposure` 

433 Science exposure from which to remove fringes. 

434 fringes : `lsst.afw.image.Exposure` or `list` thereof 

435 Calibration fringe files containing master fringe frames. 

436 solution : `np.array` 

437 Fringe solution amplitudes for each input fringe frame. 

438 

439 Raises 

440 ------ 

441 RuntimeError 

442 Raised if the number of fringe frames does not match the 

443 number of measured amplitudes. 

444 """ 

445 if len(solution) != len(fringes): 

446 raise RuntimeError("Number of fringe frames (%s) != number of scale factors (%s)" % 

447 (len(fringes), len(solution))) 

448 

449 for s, f in zip(solution, fringes): 

450 # We do not want to add the mask from the fringe to the image. 

451 f.getMaskedImage().getMask().getArray()[:] = 0 

452 science.getMaskedImage().scaledMinus(s, f.getMaskedImage()) 

453 

454 

455def measure(mi, x, y, size, statistic, stats): 

456 """Measure a statistic within an aperture 

457 

458 @param mi MaskedImage to measure 

459 @param x, y Center for aperture 

460 @param size Size of aperture 

461 @param statistic Statistic to measure 

462 @param stats StatisticsControl object 

463 @return Value of statistic within aperture 

464 """ 

465 bbox = lsst.geom.Box2I(lsst.geom.Point2I(int(x) - size, int(y - size)), 

466 lsst.geom.Extent2I(2*size, 2*size)) 

467 subImage = mi.Factory(mi, bbox, afwImage.LOCAL) 

468 return afwMath.makeStatistics(subImage, statistic, stats).getValue() 

469 

470 

471def stdev(vector): 

472 """Calculate a robust standard deviation of an array of values 

473 

474 @param vector Array of values 

475 @return Standard deviation 

476 """ 

477 q1, q3 = numpy.percentile(vector, (25, 75)) 

478 return 0.74*(q3 - q1) 

479 

480 

481def select(vector, clip): 

482 """Select values within 'clip' standard deviations of the median 

483 

484 Returns a boolean array. 

485 """ 

486 q1, q2, q3 = numpy.percentile(vector, (25, 50, 75)) 

487 return numpy.abs(vector - q2) < clip*0.74*(q3 - q1)