Coverage for python/lsst/ip/isr/fringe.py: 20%

217 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-07-16 02:30 -0700

1# This file is part of ip_isr. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import numpy 

23 

24import lsst.geom 

25import lsst.afw.image as afwImage 

26import lsst.afw.math as afwMath 

27import lsst.afw.display as afwDisplay 

28 

29from lsst.pipe.base import Task, Struct 

30from lsst.pex.config import Config, Field, ListField, ConfigField 

31from lsst.utils.timer import timeMethod 

32from .isrFunctions import checkFilter 

33 

34afwDisplay.setDefaultMaskTransparency(75) 

35 

36 

37def getFrame(): 

38 """Produce a new frame number each time""" 

39 getFrame.frame += 1 

40 return getFrame.frame 

41 

42 

43getFrame.frame = 0 

44 

45 

46class FringeStatisticsConfig(Config): 

47 """Options for measuring fringes on an exposure""" 

48 badMaskPlanes = ListField(dtype=str, default=["SAT"], doc="Ignore pixels with these masks") 

49 stat = Field(dtype=int, default=int(afwMath.MEDIAN), doc="Statistic to use") 

50 clip = Field(dtype=float, default=3.0, doc="Sigma clip threshold") 

51 iterations = Field(dtype=int, default=3, doc="Number of fitting iterations") 

52 rngSeedOffset = Field(dtype=int, default=0, 

53 doc="Offset to the random number generator seed (full seed includes exposure ID)") 

54 

55 

56class FringeConfig(Config): 

57 """Fringe subtraction options""" 

58 # TODO DM-28093: change the doc to specify that these are physical labels 

59 filters = ListField(dtype=str, default=[], doc="Only fringe-subtract these filters") 

60 # TODO: remove in DM-27177 

61 useFilterAliases = Field(dtype=bool, default=False, doc="Search filter aliases during check.", 

62 deprecated=("Removed with no replacement (FilterLabel has no aliases)." 

63 "Will be removed after v22.")) 

64 num = Field(dtype=int, default=30000, doc="Number of fringe measurements") 

65 small = Field(dtype=int, default=3, doc="Half-size of small (fringe) measurements (pixels)") 

66 large = Field(dtype=int, default=30, doc="Half-size of large (background) measurements (pixels)") 

67 iterations = Field(dtype=int, default=20, doc="Number of fitting iterations") 

68 clip = Field(dtype=float, default=3.0, doc="Sigma clip threshold") 

69 stats = ConfigField(dtype=FringeStatisticsConfig, doc="Statistics for measuring fringes") 

70 pedestal = Field(dtype=bool, default=False, doc="Remove fringe pedestal?") 

71 

72 

73class FringeTask(Task): 

74 """Task to remove fringes from a science exposure 

75 

76 We measure fringe amplitudes at random positions on the science exposure 

77 and at the same positions on the (potentially multiple) fringe frames 

78 and solve for the scales simultaneously. 

79 """ 

80 ConfigClass = FringeConfig 

81 _DefaultName = 'isrFringe' 

82 

83 def readFringes(self, dataRef, expId=None, assembler=None): 

84 """Read the fringe frame(s), and pack data into a Struct 

85 

86 The current implementation assumes only a single fringe frame and 

87 will have to be updated to support multi-mode fringe subtraction. 

88 

89 This implementation could be optimised by persisting the fringe 

90 positions and fluxes. 

91 

92 Parameters 

93 ---------- 

94 dataRef : `daf.butler.butlerSubset.ButlerDataRef` 

95 Butler reference for the exposure that will have fringing 

96 removed. 

97 expId : `int`, optional 

98 Exposure id to be fringe corrected, used to set RNG seed. 

99 assembler : `lsst.ip.isr.AssembleCcdTask`, optional 

100 An instance of AssembleCcdTask (for assembling fringe 

101 frames). 

102 

103 Returns 

104 ------- 

105 fringeData : `pipeBase.Struct` 

106 Struct containing fringe data: 

107 - ``fringes`` : `lsst.afw.image.Exposure` or `list` thereof 

108 Calibration fringe files containing master fringe frames. 

109 - ``seed`` : `int`, optional 

110 Seed for random number generation. 

111 """ 

112 try: 

113 fringe = dataRef.get("fringe", immediate=True) 

114 except Exception as e: 

115 raise RuntimeError("Unable to retrieve fringe for %s: %s" % (dataRef.dataId, e)) 

116 

117 return self.loadFringes(fringe, expId=expId, assembler=assembler) 

118 

119 def loadFringes(self, fringeExp, expId=None, assembler=None): 

120 """Pack the fringe data into a Struct. 

121 

122 This method moves the struct parsing code into a butler 

123 generation agnostic handler. 

124 

125 Parameters 

126 ---------- 

127 fringeExp : `lsst.afw.exposure.Exposure` 

128 The exposure containing the fringe data. 

129 expId : `int`, optional 

130 Exposure id to be fringe corrected, used to set RNG seed. 

131 assembler : `lsst.ip.isr.AssembleCcdTask`, optional 

132 An instance of AssembleCcdTask (for assembling fringe 

133 frames). 

134 

135 Returns 

136 ------- 

137 fringeData : `pipeBase.Struct` 

138 Struct containing fringe data: 

139 - ``fringes`` : `lsst.afw.image.Exposure` or `list` thereof 

140 Calibration fringe files containing master fringe frames. 

141 - ``seed`` : `int`, optional 

142 Seed for random number generation. 

143 """ 

144 if assembler is not None: 

145 fringeExp = assembler.assembleCcd(fringeExp) 

146 

147 if expId is None: 

148 seed = self.config.stats.rngSeedOffset 

149 else: 

150 print(f"{self.config.stats.rngSeedOffset} {expId}") 

151 seed = self.config.stats.rngSeedOffset + expId 

152 

153 # Seed for numpy.random.RandomState must be convertable to a 32 bit 

154 # unsigned integer. 

155 seed %= 2**32 

156 

157 return Struct(fringes=fringeExp, 

158 seed=seed) 

159 

160 @timeMethod 

161 def run(self, exposure, fringes, seed=None): 

162 """Remove fringes from the provided science exposure. 

163 

164 Primary method of FringeTask. Fringes are only subtracted if the 

165 science exposure has a filter listed in the configuration. 

166 

167 Parameters 

168 ---------- 

169 exposure : `lsst.afw.image.Exposure` 

170 Science exposure from which to remove fringes. 

171 fringes : `lsst.afw.image.Exposure` or `list` thereof 

172 Calibration fringe files containing master fringe frames. 

173 seed : `int`, optional 

174 Seed for random number generation. 

175 

176 Returns 

177 ------- 

178 solution : `np.array` 

179 Fringe solution amplitudes for each input fringe frame. 

180 rms : `float` 

181 RMS error for the fit solution for this exposure. 

182 """ 

183 import lsstDebug 

184 display = lsstDebug.Info(__name__).display 

185 

186 if not self.checkFilter(exposure): 

187 self.log.info("Filter not found in FringeTaskConfig.filters. Skipping fringe correction.") 

188 return 

189 

190 if seed is None: 

191 seed = self.config.stats.rngSeedOffset 

192 rng = numpy.random.RandomState(seed=seed) 

193 

194 if not hasattr(fringes, '__iter__'): 

195 fringes = [fringes] 

196 

197 mask = exposure.getMaskedImage().getMask() 

198 for fringe in fringes: 

199 fringe.getMaskedImage().getMask().__ior__(mask) 

200 if self.config.pedestal: 

201 self.removePedestal(fringe) 

202 

203 positions = self.generatePositions(fringes[0], rng) 

204 fluxes = numpy.ndarray([self.config.num, len(fringes)]) 

205 for i, f in enumerate(fringes): 

206 fluxes[:, i] = self.measureExposure(f, positions, title="Fringe frame") 

207 

208 expFringes = self.measureExposure(exposure, positions, title="Science") 

209 solution, rms = self.solve(expFringes, fluxes) 

210 self.subtract(exposure, fringes, solution) 

211 if display: 

212 afwDisplay.Display(frame=getFrame()).mtv(exposure, title="Fringe subtracted") 

213 return solution, rms 

214 

215 @timeMethod 

216 def runDataRef(self, exposure, dataRef, assembler=None): 

217 """Remove fringes from the provided science exposure. 

218 

219 Retrieve fringes from butler dataRef provided and remove from 

220 provided science exposure. Fringes are only subtracted if the 

221 science exposure has a filter listed in the configuration. 

222 

223 Parameters 

224 ---------- 

225 exposure : `lsst.afw.image.Exposure` 

226 Science exposure from which to remove fringes. 

227 dataRef : `daf.persistence.butlerSubset.ButlerDataRef` 

228 Butler reference to the exposure. Used to find 

229 appropriate fringe data. 

230 assembler : `lsst.ip.isr.AssembleCcdTask`, optional 

231 An instance of AssembleCcdTask (for assembling fringe 

232 frames). 

233 

234 Returns 

235 ------- 

236 solution : `np.array` 

237 Fringe solution amplitudes for each input fringe frame. 

238 rms : `float` 

239 RMS error for the fit solution for this exposure. 

240 """ 

241 if not self.checkFilter(exposure): 

242 self.log.info("Filter not found in FringeTaskConfig.filters. Skipping fringe correction.") 

243 return 

244 fringeStruct = self.readFringes(dataRef, assembler=assembler) 

245 return self.run(exposure, **fringeStruct.getDict()) 

246 

247 def checkFilter(self, exposure): 

248 """Check whether we should fringe-subtract the science exposure. 

249 

250 Parameters 

251 ---------- 

252 exposure : `lsst.afw.image.Exposure` 

253 Exposure to check the filter of. 

254 

255 Returns 

256 ------- 

257 needsFringe : `bool` 

258 If True, then the exposure has a filter listed in the 

259 configuration, and should have the fringe applied. 

260 """ 

261 return checkFilter(exposure, self.config.filters, log=self.log) 

262 

263 def removePedestal(self, fringe): 

264 """Remove pedestal from fringe exposure. 

265 

266 Parameters 

267 ---------- 

268 fringe : `lsst.afw.image.Exposure` 

269 Fringe data to subtract the pedestal value from. 

270 """ 

271 stats = afwMath.StatisticsControl() 

272 stats.setNumSigmaClip(self.config.stats.clip) 

273 stats.setNumIter(self.config.stats.iterations) 

274 mi = fringe.getMaskedImage() 

275 pedestal = afwMath.makeStatistics(mi, afwMath.MEDIAN, stats).getValue() 

276 self.log.info("Removing fringe pedestal: %f", pedestal) 

277 mi -= pedestal 

278 

279 def generatePositions(self, exposure, rng): 

280 """Generate a random distribution of positions for measuring fringe 

281 amplitudes. 

282 

283 Parameters 

284 ---------- 

285 exposure : `lsst.afw.image.Exposure` 

286 Exposure to measure the positions on. 

287 rng : `numpy.random.RandomState` 

288 Random number generator to use. 

289 

290 Returns 

291 ------- 

292 positions : `numpy.array` 

293 Two-dimensional array containing the positions to sample 

294 for fringe amplitudes. 

295 """ 

296 start = self.config.large 

297 num = self.config.num 

298 width = exposure.getWidth() - self.config.large 

299 height = exposure.getHeight() - self.config.large 

300 return numpy.array([rng.randint(start, width, size=num), 

301 rng.randint(start, height, size=num)]).swapaxes(0, 1) 

302 

303 @timeMethod 

304 def measureExposure(self, exposure, positions, title="Fringe"): 

305 """Measure fringe amplitudes for an exposure 

306 

307 The fringe amplitudes are measured as the statistic within a square 

308 aperture. The statistic within a larger aperture are subtracted so 

309 as to remove the background. 

310 

311 Parameters 

312 ---------- 

313 exposure : `lsst.afw.image.Exposure` 

314 Exposure to measure the positions on. 

315 positions : `numpy.array` 

316 Two-dimensional array containing the positions to sample 

317 for fringe amplitudes. 

318 title : `str`, optional 

319 Title used for debug out plots. 

320 

321 Returns 

322 ------- 

323 fringes : `numpy.array` 

324 Array of measured exposure values at each of the positions 

325 supplied. 

326 """ 

327 stats = afwMath.StatisticsControl() 

328 stats.setNumSigmaClip(self.config.stats.clip) 

329 stats.setNumIter(self.config.stats.iterations) 

330 stats.setAndMask(exposure.getMaskedImage().getMask().getPlaneBitMask(self.config.stats.badMaskPlanes)) 

331 

332 num = self.config.num 

333 fringes = numpy.ndarray(num) 

334 

335 for i in range(num): 

336 x, y = positions[i] 

337 small = measure(exposure.getMaskedImage(), x, y, self.config.small, self.config.stats.stat, stats) 

338 large = measure(exposure.getMaskedImage(), x, y, self.config.large, self.config.stats.stat, stats) 

339 fringes[i] = small - large 

340 

341 import lsstDebug 

342 display = lsstDebug.Info(__name__).display 

343 if display: 

344 disp = afwDisplay.Display(frame=getFrame()) 

345 disp.mtv(exposure, title=title) 

346 if False: 

347 with disp.Buffering(): 

348 for x, y in positions: 

349 corners = numpy.array([[-1, -1], [1, -1], [1, 1], [-1, 1], [-1, -1]]) + [[x, y]] 

350 disp.line(corners*self.config.small, ctype=afwDisplay.GREEN) 

351 disp.line(corners*self.config.large, ctype=afwDisplay.BLUE) 

352 

353 return fringes 

354 

355 @timeMethod 

356 def solve(self, science, fringes): 

357 """Solve for the scale factors with iterative clipping. 

358 

359 Parameters 

360 ---------- 

361 science : `numpy.array` 

362 Array of measured science image values at each of the 

363 positions supplied. 

364 fringes : `numpy.array` 

365 Array of measured fringe values at each of the positions 

366 supplied. 

367 

368 Returns 

369 ------- 

370 solution : `np.array` 

371 Fringe solution amplitudes for each input fringe frame. 

372 rms : `float` 

373 RMS error for the fit solution for this exposure. 

374 """ 

375 import lsstDebug 

376 doPlot = lsstDebug.Info(__name__).plot 

377 

378 origNum = len(science) 

379 

380 def emptyResult(msg=""): 

381 """Generate an empty result for return to the user 

382 

383 There are no good pixels; doesn't matter what we return. 

384 """ 

385 self.log.warning("Unable to solve for fringes: no good pixels%s", msg) 

386 out = [0] 

387 if len(fringes) > 1: 

388 out = out*len(fringes) 

389 return numpy.array(out), numpy.nan 

390 

391 good = numpy.where(numpy.logical_and(numpy.isfinite(science), numpy.any(numpy.isfinite(fringes), 1))) 

392 science = science[good] 

393 fringes = fringes[good] 

394 oldNum = len(science) 

395 if oldNum == 0: 

396 return emptyResult() 

397 

398 # Up-front rejection to get rid of extreme, potentially troublesome 

399 # values (e.g., fringe apertures that fall on objects). 

400 good = select(science, self.config.clip) 

401 for ff in range(fringes.shape[1]): 

402 good &= select(fringes[:, ff], self.config.clip) 

403 science = science[good] 

404 fringes = fringes[good] 

405 oldNum = len(science) 

406 if oldNum == 0: 

407 return emptyResult(" after initial rejection") 

408 

409 for i in range(self.config.iterations): 

410 solution = self._solve(science, fringes) 

411 resid = science - numpy.sum(solution*fringes, 1) 

412 rms = stdev(resid) 

413 good = numpy.logical_not(abs(resid) > self.config.clip*rms) 

414 self.log.debug("Iteration %d: RMS=%f numGood=%d", i, rms, good.sum()) 

415 self.log.debug("Solution %d: %s", i, solution) 

416 newNum = good.sum() 

417 if newNum == 0: 

418 return emptyResult(" after %d rejection iterations" % i) 

419 

420 if doPlot: 

421 import matplotlib.pyplot as plot 

422 for j in range(fringes.shape[1]): 

423 fig = plot.figure(j) 

424 fig.clf() 

425 try: 

426 fig.canvas._tkcanvas._root().lift() # == Tk's raise 

427 except Exception: 

428 pass 

429 ax = fig.add_subplot(1, 1, 1) 

430 adjust = science.copy() 

431 others = set(range(fringes.shape[1])) 

432 others.discard(j) 

433 for k in others: 

434 adjust -= solution[k]*fringes[:, k] 

435 ax.plot(fringes[:, j], adjust, 'r.') 

436 xmin = fringes[:, j].min() 

437 xmax = fringes[:, j].max() 

438 ymin = solution[j]*xmin 

439 ymax = solution[j]*xmax 

440 ax.plot([xmin, xmax], [ymin, ymax], 'b-') 

441 ax.set_title("Fringe %d: %f" % (j, solution[j])) 

442 ax.set_xlabel("Fringe amplitude") 

443 ax.set_ylabel("Science amplitude") 

444 ax.set_autoscale_on(False) 

445 ax.set_xbound(lower=xmin, upper=xmax) 

446 ax.set_ybound(lower=ymin, upper=ymax) 

447 fig.show() 

448 while True: 

449 ans = input("Enter or c to continue [chp]").lower() 

450 if ans in ("", "c",): 

451 break 

452 if ans in ("p",): 

453 import pdb 

454 pdb.set_trace() 

455 elif ans in ("h", ): 

456 print("h[elp] c[ontinue] p[db]") 

457 

458 if newNum == oldNum: 

459 # Not gaining 

460 break 

461 oldNum = newNum 

462 good = numpy.where(good) 

463 science = science[good] 

464 fringes = fringes[good] 

465 

466 # Final solution without rejection 

467 solution = self._solve(science, fringes) 

468 self.log.info("Fringe solution: %s RMS: %f Good: %d/%d", solution, rms, len(science), origNum) 

469 return solution, rms 

470 

471 def _solve(self, science, fringes): 

472 """Solve for the scale factors. 

473 

474 Parameters 

475 ---------- 

476 science : `numpy.array` 

477 Array of measured science image values at each of the 

478 positions supplied. 

479 fringes : `numpy.array` 

480 Array of measured fringe values at each of the positions 

481 supplied. 

482 

483 Returns 

484 ------- 

485 solution : `np.array` 

486 Fringe solution amplitudes for each input fringe frame. 

487 """ 

488 return afwMath.LeastSquares.fromDesignMatrix(fringes, science, 

489 afwMath.LeastSquares.DIRECT_SVD).getSolution() 

490 

491 def subtract(self, science, fringes, solution): 

492 """Subtract the fringes. 

493 

494 Parameters 

495 ---------- 

496 science : `lsst.afw.image.Exposure` 

497 Science exposure from which to remove fringes. 

498 fringes : `lsst.afw.image.Exposure` or `list` thereof 

499 Calibration fringe files containing master fringe frames. 

500 solution : `np.array` 

501 Fringe solution amplitudes for each input fringe frame. 

502 

503 Raises 

504 ------ 

505 RuntimeError : 

506 Raised if the number of fringe frames does not match the 

507 number of measured amplitudes. 

508 """ 

509 if len(solution) != len(fringes): 

510 raise RuntimeError("Number of fringe frames (%s) != number of scale factors (%s)" % 

511 (len(fringes), len(solution))) 

512 

513 for s, f in zip(solution, fringes): 

514 # We do not want to add the mask from the fringe to the image. 

515 f.getMaskedImage().getMask().getArray()[:] = 0 

516 science.getMaskedImage().scaledMinus(s, f.getMaskedImage()) 

517 

518 

519def measure(mi, x, y, size, statistic, stats): 

520 """Measure a statistic within an aperture 

521 

522 @param mi MaskedImage to measure 

523 @param x, y Center for aperture 

524 @param size Size of aperture 

525 @param statistic Statistic to measure 

526 @param stats StatisticsControl object 

527 @return Value of statistic within aperture 

528 """ 

529 bbox = lsst.geom.Box2I(lsst.geom.Point2I(int(x) - size, int(y - size)), 

530 lsst.geom.Extent2I(2*size, 2*size)) 

531 subImage = mi.Factory(mi, bbox, afwImage.LOCAL) 

532 return afwMath.makeStatistics(subImage, statistic, stats).getValue() 

533 

534 

535def stdev(vector): 

536 """Calculate a robust standard deviation of an array of values 

537 

538 @param vector Array of values 

539 @return Standard deviation 

540 """ 

541 q1, q3 = numpy.percentile(vector, (25, 75)) 

542 return 0.74*(q3 - q1) 

543 

544 

545def select(vector, clip): 

546 """Select values within 'clip' standard deviations of the median 

547 

548 Returns a boolean array. 

549 """ 

550 q1, q2, q3 = numpy.percentile(vector, (25, 50, 75)) 

551 return numpy.abs(vector - q2) < clip*0.74*(q3 - q1)