Coverage for python/lsst/ip/isr/fringe.py: 20%

204 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-09-02 02:36 -0700

1# This file is part of ip_isr. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import numpy 

23 

24import lsst.geom 

25import lsst.afw.image as afwImage 

26import lsst.afw.math as afwMath 

27import lsst.afw.display as afwDisplay 

28 

29from lsst.pipe.base import Task, Struct 

30from lsst.pex.config import Config, Field, ListField, ConfigField 

31from lsst.utils.timer import timeMethod 

32from .isrFunctions import checkFilter 

33 

34afwDisplay.setDefaultMaskTransparency(75) 

35 

36 

37def getFrame(): 

38 """Produce a new frame number each time""" 

39 getFrame.frame += 1 

40 return getFrame.frame 

41 

42 

43getFrame.frame = 0 

44 

45 

46class FringeStatisticsConfig(Config): 

47 """Options for measuring fringes on an exposure""" 

48 badMaskPlanes = ListField(dtype=str, default=["SAT"], doc="Ignore pixels with these masks") 

49 stat = Field(dtype=int, default=int(afwMath.MEDIAN), doc="Statistic to use") 

50 clip = Field(dtype=float, default=3.0, doc="Sigma clip threshold") 

51 iterations = Field(dtype=int, default=3, doc="Number of fitting iterations") 

52 rngSeedOffset = Field(dtype=int, default=0, 

53 doc="Offset to the random number generator seed (full seed includes exposure ID)") 

54 

55 

56class FringeConfig(Config): 

57 """Fringe subtraction options""" 

58 # TODO DM-28093: change the doc to specify that these are physical labels 

59 filters = ListField(dtype=str, default=[], doc="Only fringe-subtract these filters") 

60 # TODO: remove in DM-27177 

61 useFilterAliases = Field(dtype=bool, default=False, doc="Search filter aliases during check.", 

62 deprecated=("Removed with no replacement (FilterLabel has no aliases)." 

63 "Will be removed after v22.")) 

64 num = Field(dtype=int, default=30000, doc="Number of fringe measurements") 

65 small = Field(dtype=int, default=3, doc="Half-size of small (fringe) measurements (pixels)") 

66 large = Field(dtype=int, default=30, doc="Half-size of large (background) measurements (pixels)") 

67 iterations = Field(dtype=int, default=20, doc="Number of fitting iterations") 

68 clip = Field(dtype=float, default=3.0, doc="Sigma clip threshold") 

69 stats = ConfigField(dtype=FringeStatisticsConfig, doc="Statistics for measuring fringes") 

70 pedestal = Field(dtype=bool, default=False, doc="Remove fringe pedestal?") 

71 

72 

73class FringeTask(Task): 

74 """Task to remove fringes from a science exposure 

75 

76 We measure fringe amplitudes at random positions on the science exposure 

77 and at the same positions on the (potentially multiple) fringe frames 

78 and solve for the scales simultaneously. 

79 """ 

80 ConfigClass = FringeConfig 

81 _DefaultName = 'isrFringe' 

82 

83 def loadFringes(self, fringeExp, expId=None, assembler=None): 

84 """Pack the fringe data into a Struct. 

85 

86 This method moves the struct parsing code into a butler 

87 generation agnostic handler. 

88 

89 Parameters 

90 ---------- 

91 fringeExp : `lsst.afw.exposure.Exposure` 

92 The exposure containing the fringe data. 

93 expId : `int`, optional 

94 Exposure id to be fringe corrected, used to set RNG seed. 

95 assembler : `lsst.ip.isr.AssembleCcdTask`, optional 

96 An instance of AssembleCcdTask (for assembling fringe 

97 frames). 

98 

99 Returns 

100 ------- 

101 fringeData : `pipeBase.Struct` 

102 Struct containing fringe data: 

103 - ``fringes`` : `lsst.afw.image.Exposure` or `list` thereof 

104 Calibration fringe files containing master fringe frames. 

105 - ``seed`` : `int`, optional 

106 Seed for random number generation. 

107 """ 

108 if assembler is not None: 

109 fringeExp = assembler.assembleCcd(fringeExp) 

110 

111 if expId is None: 

112 seed = self.config.stats.rngSeedOffset 

113 else: 

114 print(f"{self.config.stats.rngSeedOffset} {expId}") 

115 seed = self.config.stats.rngSeedOffset + expId 

116 

117 # Seed for numpy.random.RandomState must be convertable to a 32 bit 

118 # unsigned integer. 

119 seed %= 2**32 

120 

121 return Struct(fringes=fringeExp, 

122 seed=seed) 

123 

124 @timeMethod 

125 def run(self, exposure, fringes, seed=None): 

126 """Remove fringes from the provided science exposure. 

127 

128 Primary method of FringeTask. Fringes are only subtracted if the 

129 science exposure has a filter listed in the configuration. 

130 

131 Parameters 

132 ---------- 

133 exposure : `lsst.afw.image.Exposure` 

134 Science exposure from which to remove fringes. 

135 fringes : `lsst.afw.image.Exposure` or `list` thereof 

136 Calibration fringe files containing master fringe frames. 

137 seed : `int`, optional 

138 Seed for random number generation. 

139 

140 Returns 

141 ------- 

142 solution : `np.array` 

143 Fringe solution amplitudes for each input fringe frame. 

144 rms : `float` 

145 RMS error for the fit solution for this exposure. 

146 """ 

147 import lsstDebug 

148 display = lsstDebug.Info(__name__).display 

149 

150 if not self.checkFilter(exposure): 

151 self.log.info("Filter not found in FringeTaskConfig.filters. Skipping fringe correction.") 

152 return 

153 

154 if seed is None: 

155 seed = self.config.stats.rngSeedOffset 

156 rng = numpy.random.RandomState(seed=seed) 

157 

158 if not hasattr(fringes, '__iter__'): 

159 fringes = [fringes] 

160 

161 mask = exposure.getMaskedImage().getMask() 

162 for fringe in fringes: 

163 fringe.getMaskedImage().getMask().__ior__(mask) 

164 if self.config.pedestal: 

165 self.removePedestal(fringe) 

166 

167 positions = self.generatePositions(fringes[0], rng) 

168 fluxes = numpy.ndarray([self.config.num, len(fringes)]) 

169 for i, f in enumerate(fringes): 

170 fluxes[:, i] = self.measureExposure(f, positions, title="Fringe frame") 

171 

172 expFringes = self.measureExposure(exposure, positions, title="Science") 

173 solution, rms = self.solve(expFringes, fluxes) 

174 self.subtract(exposure, fringes, solution) 

175 if display: 

176 afwDisplay.Display(frame=getFrame()).mtv(exposure, title="Fringe subtracted") 

177 return solution, rms 

178 

179 def checkFilter(self, exposure): 

180 """Check whether we should fringe-subtract the science exposure. 

181 

182 Parameters 

183 ---------- 

184 exposure : `lsst.afw.image.Exposure` 

185 Exposure to check the filter of. 

186 

187 Returns 

188 ------- 

189 needsFringe : `bool` 

190 If True, then the exposure has a filter listed in the 

191 configuration, and should have the fringe applied. 

192 """ 

193 return checkFilter(exposure, self.config.filters, log=self.log) 

194 

195 def removePedestal(self, fringe): 

196 """Remove pedestal from fringe exposure. 

197 

198 Parameters 

199 ---------- 

200 fringe : `lsst.afw.image.Exposure` 

201 Fringe data to subtract the pedestal value from. 

202 """ 

203 stats = afwMath.StatisticsControl() 

204 stats.setNumSigmaClip(self.config.stats.clip) 

205 stats.setNumIter(self.config.stats.iterations) 

206 mi = fringe.getMaskedImage() 

207 pedestal = afwMath.makeStatistics(mi, afwMath.MEDIAN, stats).getValue() 

208 self.log.info("Removing fringe pedestal: %f", pedestal) 

209 mi -= pedestal 

210 

211 def generatePositions(self, exposure, rng): 

212 """Generate a random distribution of positions for measuring fringe 

213 amplitudes. 

214 

215 Parameters 

216 ---------- 

217 exposure : `lsst.afw.image.Exposure` 

218 Exposure to measure the positions on. 

219 rng : `numpy.random.RandomState` 

220 Random number generator to use. 

221 

222 Returns 

223 ------- 

224 positions : `numpy.array` 

225 Two-dimensional array containing the positions to sample 

226 for fringe amplitudes. 

227 """ 

228 start = self.config.large 

229 num = self.config.num 

230 width = exposure.getWidth() - self.config.large 

231 height = exposure.getHeight() - self.config.large 

232 return numpy.array([rng.randint(start, width, size=num), 

233 rng.randint(start, height, size=num)]).swapaxes(0, 1) 

234 

235 @timeMethod 

236 def measureExposure(self, exposure, positions, title="Fringe"): 

237 """Measure fringe amplitudes for an exposure 

238 

239 The fringe amplitudes are measured as the statistic within a square 

240 aperture. The statistic within a larger aperture are subtracted so 

241 as to remove the background. 

242 

243 Parameters 

244 ---------- 

245 exposure : `lsst.afw.image.Exposure` 

246 Exposure to measure the positions on. 

247 positions : `numpy.array` 

248 Two-dimensional array containing the positions to sample 

249 for fringe amplitudes. 

250 title : `str`, optional 

251 Title used for debug out plots. 

252 

253 Returns 

254 ------- 

255 fringes : `numpy.array` 

256 Array of measured exposure values at each of the positions 

257 supplied. 

258 """ 

259 stats = afwMath.StatisticsControl() 

260 stats.setNumSigmaClip(self.config.stats.clip) 

261 stats.setNumIter(self.config.stats.iterations) 

262 stats.setAndMask(exposure.getMaskedImage().getMask().getPlaneBitMask(self.config.stats.badMaskPlanes)) 

263 

264 num = self.config.num 

265 fringes = numpy.ndarray(num) 

266 

267 for i in range(num): 

268 x, y = positions[i] 

269 small = measure(exposure.getMaskedImage(), x, y, self.config.small, self.config.stats.stat, stats) 

270 large = measure(exposure.getMaskedImage(), x, y, self.config.large, self.config.stats.stat, stats) 

271 fringes[i] = small - large 

272 

273 import lsstDebug 

274 display = lsstDebug.Info(__name__).display 

275 if display: 

276 disp = afwDisplay.Display(frame=getFrame()) 

277 disp.mtv(exposure, title=title) 

278 if False: 

279 with disp.Buffering(): 

280 for x, y in positions: 

281 corners = numpy.array([[-1, -1], [1, -1], [1, 1], [-1, 1], [-1, -1]]) + [[x, y]] 

282 disp.line(corners*self.config.small, ctype=afwDisplay.GREEN) 

283 disp.line(corners*self.config.large, ctype=afwDisplay.BLUE) 

284 

285 return fringes 

286 

287 @timeMethod 

288 def solve(self, science, fringes): 

289 """Solve for the scale factors with iterative clipping. 

290 

291 Parameters 

292 ---------- 

293 science : `numpy.array` 

294 Array of measured science image values at each of the 

295 positions supplied. 

296 fringes : `numpy.array` 

297 Array of measured fringe values at each of the positions 

298 supplied. 

299 

300 Returns 

301 ------- 

302 solution : `np.array` 

303 Fringe solution amplitudes for each input fringe frame. 

304 rms : `float` 

305 RMS error for the fit solution for this exposure. 

306 """ 

307 import lsstDebug 

308 doPlot = lsstDebug.Info(__name__).plot 

309 

310 origNum = len(science) 

311 

312 def emptyResult(msg=""): 

313 """Generate an empty result for return to the user 

314 

315 There are no good pixels; doesn't matter what we return. 

316 """ 

317 self.log.warning("Unable to solve for fringes: no good pixels%s", msg) 

318 out = [0] 

319 if len(fringes) > 1: 

320 out = out*len(fringes) 

321 return numpy.array(out), numpy.nan 

322 

323 good = numpy.where(numpy.logical_and(numpy.isfinite(science), numpy.any(numpy.isfinite(fringes), 1))) 

324 science = science[good] 

325 fringes = fringes[good] 

326 oldNum = len(science) 

327 if oldNum == 0: 

328 return emptyResult() 

329 

330 # Up-front rejection to get rid of extreme, potentially troublesome 

331 # values (e.g., fringe apertures that fall on objects). 

332 good = select(science, self.config.clip) 

333 for ff in range(fringes.shape[1]): 

334 good &= select(fringes[:, ff], self.config.clip) 

335 science = science[good] 

336 fringes = fringes[good] 

337 oldNum = len(science) 

338 if oldNum == 0: 

339 return emptyResult(" after initial rejection") 

340 

341 for i in range(self.config.iterations): 

342 solution = self._solve(science, fringes) 

343 resid = science - numpy.sum(solution*fringes, 1) 

344 rms = stdev(resid) 

345 good = numpy.logical_not(abs(resid) > self.config.clip*rms) 

346 self.log.debug("Iteration %d: RMS=%f numGood=%d", i, rms, good.sum()) 

347 self.log.debug("Solution %d: %s", i, solution) 

348 newNum = good.sum() 

349 if newNum == 0: 

350 return emptyResult(" after %d rejection iterations" % i) 

351 

352 if doPlot: 

353 import matplotlib.pyplot as plot 

354 for j in range(fringes.shape[1]): 

355 fig = plot.figure(j) 

356 fig.clf() 

357 try: 

358 fig.canvas._tkcanvas._root().lift() # == Tk's raise 

359 except Exception: 

360 pass 

361 ax = fig.add_subplot(1, 1, 1) 

362 adjust = science.copy() 

363 others = set(range(fringes.shape[1])) 

364 others.discard(j) 

365 for k in others: 

366 adjust -= solution[k]*fringes[:, k] 

367 ax.plot(fringes[:, j], adjust, 'r.') 

368 xmin = fringes[:, j].min() 

369 xmax = fringes[:, j].max() 

370 ymin = solution[j]*xmin 

371 ymax = solution[j]*xmax 

372 ax.plot([xmin, xmax], [ymin, ymax], 'b-') 

373 ax.set_title("Fringe %d: %f" % (j, solution[j])) 

374 ax.set_xlabel("Fringe amplitude") 

375 ax.set_ylabel("Science amplitude") 

376 ax.set_autoscale_on(False) 

377 ax.set_xbound(lower=xmin, upper=xmax) 

378 ax.set_ybound(lower=ymin, upper=ymax) 

379 fig.show() 

380 while True: 

381 ans = input("Enter or c to continue [chp]").lower() 

382 if ans in ("", "c",): 

383 break 

384 if ans in ("p",): 

385 import pdb 

386 pdb.set_trace() 

387 elif ans in ("h", ): 

388 print("h[elp] c[ontinue] p[db]") 

389 

390 if newNum == oldNum: 

391 # Not gaining 

392 break 

393 oldNum = newNum 

394 good = numpy.where(good) 

395 science = science[good] 

396 fringes = fringes[good] 

397 

398 # Final solution without rejection 

399 solution = self._solve(science, fringes) 

400 self.log.info("Fringe solution: %s RMS: %f Good: %d/%d", solution, rms, len(science), origNum) 

401 return solution, rms 

402 

403 def _solve(self, science, fringes): 

404 """Solve for the scale factors. 

405 

406 Parameters 

407 ---------- 

408 science : `numpy.array` 

409 Array of measured science image values at each of the 

410 positions supplied. 

411 fringes : `numpy.array` 

412 Array of measured fringe values at each of the positions 

413 supplied. 

414 

415 Returns 

416 ------- 

417 solution : `np.array` 

418 Fringe solution amplitudes for each input fringe frame. 

419 """ 

420 return afwMath.LeastSquares.fromDesignMatrix(fringes, science, 

421 afwMath.LeastSquares.DIRECT_SVD).getSolution() 

422 

423 def subtract(self, science, fringes, solution): 

424 """Subtract the fringes. 

425 

426 Parameters 

427 ---------- 

428 science : `lsst.afw.image.Exposure` 

429 Science exposure from which to remove fringes. 

430 fringes : `lsst.afw.image.Exposure` or `list` thereof 

431 Calibration fringe files containing master fringe frames. 

432 solution : `np.array` 

433 Fringe solution amplitudes for each input fringe frame. 

434 

435 Raises 

436 ------ 

437 RuntimeError : 

438 Raised if the number of fringe frames does not match the 

439 number of measured amplitudes. 

440 """ 

441 if len(solution) != len(fringes): 

442 raise RuntimeError("Number of fringe frames (%s) != number of scale factors (%s)" % 

443 (len(fringes), len(solution))) 

444 

445 for s, f in zip(solution, fringes): 

446 # We do not want to add the mask from the fringe to the image. 

447 f.getMaskedImage().getMask().getArray()[:] = 0 

448 science.getMaskedImage().scaledMinus(s, f.getMaskedImage()) 

449 

450 

451def measure(mi, x, y, size, statistic, stats): 

452 """Measure a statistic within an aperture 

453 

454 @param mi MaskedImage to measure 

455 @param x, y Center for aperture 

456 @param size Size of aperture 

457 @param statistic Statistic to measure 

458 @param stats StatisticsControl object 

459 @return Value of statistic within aperture 

460 """ 

461 bbox = lsst.geom.Box2I(lsst.geom.Point2I(int(x) - size, int(y - size)), 

462 lsst.geom.Extent2I(2*size, 2*size)) 

463 subImage = mi.Factory(mi, bbox, afwImage.LOCAL) 

464 return afwMath.makeStatistics(subImage, statistic, stats).getValue() 

465 

466 

467def stdev(vector): 

468 """Calculate a robust standard deviation of an array of values 

469 

470 @param vector Array of values 

471 @return Standard deviation 

472 """ 

473 q1, q3 = numpy.percentile(vector, (25, 75)) 

474 return 0.74*(q3 - q1) 

475 

476 

477def select(vector, clip): 

478 """Select values within 'clip' standard deviations of the median 

479 

480 Returns a boolean array. 

481 """ 

482 q1, q2, q3 = numpy.percentile(vector, (25, 50, 75)) 

483 return numpy.abs(vector - q2) < clip*0.74*(q3 - q1)