Coverage for python/lsst/ip/isr/fringe.py: 18%

214 statements  

« prev     ^ index     » next       coverage.py v7.2.1, created at 2023-03-12 03:09 -0700

1# This file is part of ip_isr. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import numpy 

23 

24import lsst.geom 

25import lsst.afw.image as afwImage 

26import lsst.afw.math as afwMath 

27import lsst.afw.display as afwDisplay 

28 

29from lsst.pipe.base import Task, Struct, timeMethod 

30from lsst.pex.config import Config, Field, ListField, ConfigField 

31from .isrFunctions import checkFilter 

32 

33afwDisplay.setDefaultMaskTransparency(75) 

34 

35 

36def getFrame(): 

37 """Produce a new frame number each time""" 

38 getFrame.frame += 1 

39 return getFrame.frame 

40 

41 

42getFrame.frame = 0 

43 

44 

45class FringeStatisticsConfig(Config): 

46 """Options for measuring fringes on an exposure""" 

47 badMaskPlanes = ListField(dtype=str, default=["SAT"], doc="Ignore pixels with these masks") 

48 stat = Field(dtype=int, default=int(afwMath.MEDIAN), doc="Statistic to use") 

49 clip = Field(dtype=float, default=3.0, doc="Sigma clip threshold") 

50 iterations = Field(dtype=int, default=3, doc="Number of fitting iterations") 

51 rngSeedOffset = Field(dtype=int, default=0, 

52 doc="Offset to the random number generator seed (full seed includes exposure ID)") 

53 

54 

55class FringeConfig(Config): 

56 """Fringe subtraction options""" 

57 # TODO DM-28093: change the doc to specify that these are physical labels 

58 filters = ListField(dtype=str, default=[], doc="Only fringe-subtract these filters") 

59 # TODO: remove in DM-27177 

60 useFilterAliases = Field(dtype=bool, default=False, doc="Search filter aliases during check.", 

61 deprecated=("Removed with no replacement (FilterLabel has no aliases)." 

62 "Will be removed after v22.")) 

63 num = Field(dtype=int, default=30000, doc="Number of fringe measurements") 

64 small = Field(dtype=int, default=3, doc="Half-size of small (fringe) measurements (pixels)") 

65 large = Field(dtype=int, default=30, doc="Half-size of large (background) measurements (pixels)") 

66 iterations = Field(dtype=int, default=20, doc="Number of fitting iterations") 

67 clip = Field(dtype=float, default=3.0, doc="Sigma clip threshold") 

68 stats = ConfigField(dtype=FringeStatisticsConfig, doc="Statistics for measuring fringes") 

69 pedestal = Field(dtype=bool, default=False, doc="Remove fringe pedestal?") 

70 

71 

72class FringeTask(Task): 

73 """Task to remove fringes from a science exposure 

74 

75 We measure fringe amplitudes at random positions on the science exposure 

76 and at the same positions on the (potentially multiple) fringe frames 

77 and solve for the scales simultaneously. 

78 """ 

79 ConfigClass = FringeConfig 

80 _DefaultName = 'isrFringe' 

81 

82 def readFringes(self, dataRef, expId=None, assembler=None): 

83 """Read the fringe frame(s), and pack data into a Struct 

84 

85 The current implementation assumes only a single fringe frame and 

86 will have to be updated to support multi-mode fringe subtraction. 

87 

88 This implementation could be optimised by persisting the fringe 

89 positions and fluxes. 

90 

91 Parameters 

92 ---------- 

93 dataRef : `daf.butler.butlerSubset.ButlerDataRef` 

94 Butler reference for the exposure that will have fringing 

95 removed. 

96 expId : `int`, optional 

97 Exposure id to be fringe corrected, used to set RNG seed. 

98 assembler : `lsst.ip.isr.AssembleCcdTask`, optional 

99 An instance of AssembleCcdTask (for assembling fringe 

100 frames). 

101 

102 Returns 

103 ------- 

104 fringeData : `pipeBase.Struct` 

105 Struct containing fringe data: 

106 - ``fringes`` : `lsst.afw.image.Exposure` or `list` thereof 

107 Calibration fringe files containing master fringe frames. 

108 - ``seed`` : `int`, optional 

109 Seed for random number generation. 

110 """ 

111 try: 

112 fringe = dataRef.get("fringe", immediate=True) 

113 except Exception as e: 

114 raise RuntimeError("Unable to retrieve fringe for %s: %s" % (dataRef.dataId, e)) 

115 

116 return self.loadFringes(fringe, expId=expId, assembler=assembler) 

117 

118 def loadFringes(self, fringeExp, expId=None, assembler=None): 

119 """Pack the fringe data into a Struct. 

120 

121 This method moves the struct parsing code into a butler 

122 generation agnostic handler. 

123 

124 Parameters 

125 ---------- 

126 fringeExp : `lsst.afw.exposure.Exposure` 

127 The exposure containing the fringe data. 

128 expId : `int`, optional 

129 Exposure id to be fringe corrected, used to set RNG seed. 

130 assembler : `lsst.ip.isr.AssembleCcdTask`, optional 

131 An instance of AssembleCcdTask (for assembling fringe 

132 frames). 

133 

134 Returns 

135 ------- 

136 fringeData : `pipeBase.Struct` 

137 Struct containing fringe data: 

138 - ``fringes`` : `lsst.afw.image.Exposure` or `list` thereof 

139 Calibration fringe files containing master fringe frames. 

140 - ``seed`` : `int`, optional 

141 Seed for random number generation. 

142 """ 

143 if assembler is not None: 

144 fringeExp = assembler.assembleCcd(fringeExp) 

145 

146 if expId is None: 

147 seed = self.config.stats.rngSeedOffset 

148 else: 

149 print(f"{self.config.stats.rngSeedOffset} {expId}") 

150 seed = self.config.stats.rngSeedOffset + expId 

151 

152 # Seed for numpy.random.RandomState must be convertable to a 32 bit unsigned integer 

153 seed %= 2**32 

154 

155 return Struct(fringes=fringeExp, 

156 seed=seed) 

157 

158 @timeMethod 

159 def run(self, exposure, fringes, seed=None): 

160 """Remove fringes from the provided science exposure. 

161 

162 Primary method of FringeTask. Fringes are only subtracted if the 

163 science exposure has a filter listed in the configuration. 

164 

165 Parameters 

166 ---------- 

167 exposure : `lsst.afw.image.Exposure` 

168 Science exposure from which to remove fringes. 

169 fringes : `lsst.afw.image.Exposure` or `list` thereof 

170 Calibration fringe files containing master fringe frames. 

171 seed : `int`, optional 

172 Seed for random number generation. 

173 

174 Returns 

175 ------- 

176 solution : `np.array` 

177 Fringe solution amplitudes for each input fringe frame. 

178 rms : `float` 

179 RMS error for the fit solution for this exposure. 

180 """ 

181 import lsstDebug 

182 display = lsstDebug.Info(__name__).display 

183 

184 if not self.checkFilter(exposure): 

185 self.log.info("Filter not found in FringeTaskConfig.filters. Skipping fringe correction.") 

186 return 

187 

188 if seed is None: 

189 seed = self.config.stats.rngSeedOffset 

190 rng = numpy.random.RandomState(seed=seed) 

191 

192 if not hasattr(fringes, '__iter__'): 

193 fringes = [fringes] 

194 

195 mask = exposure.getMaskedImage().getMask() 

196 for fringe in fringes: 

197 fringe.getMaskedImage().getMask().__ior__(mask) 

198 if self.config.pedestal: 

199 self.removePedestal(fringe) 

200 

201 positions = self.generatePositions(fringes[0], rng) 

202 fluxes = numpy.ndarray([self.config.num, len(fringes)]) 

203 for i, f in enumerate(fringes): 

204 fluxes[:, i] = self.measureExposure(f, positions, title="Fringe frame") 

205 

206 expFringes = self.measureExposure(exposure, positions, title="Science") 

207 solution, rms = self.solve(expFringes, fluxes) 

208 self.subtract(exposure, fringes, solution) 

209 if display: 

210 afwDisplay.Display(frame=getFrame()).mtv(exposure, title="Fringe subtracted") 

211 return solution, rms 

212 

213 @timeMethod 

214 def runDataRef(self, exposure, dataRef, assembler=None): 

215 """Remove fringes from the provided science exposure. 

216 

217 Retrieve fringes from butler dataRef provided and remove from 

218 provided science exposure. Fringes are only subtracted if the 

219 science exposure has a filter listed in the configuration. 

220 

221 Parameters 

222 ---------- 

223 exposure : `lsst.afw.image.Exposure` 

224 Science exposure from which to remove fringes. 

225 dataRef : `daf.persistence.butlerSubset.ButlerDataRef` 

226 Butler reference to the exposure. Used to find 

227 appropriate fringe data. 

228 assembler : `lsst.ip.isr.AssembleCcdTask`, optional 

229 An instance of AssembleCcdTask (for assembling fringe 

230 frames). 

231 

232 Returns 

233 ------- 

234 solution : `np.array` 

235 Fringe solution amplitudes for each input fringe frame. 

236 rms : `float` 

237 RMS error for the fit solution for this exposure. 

238 """ 

239 if not self.checkFilter(exposure): 

240 self.log.info("Filter not found in FringeTaskConfig.filters. Skipping fringe correction.") 

241 return 

242 fringeStruct = self.readFringes(dataRef, assembler=assembler) 

243 return self.run(exposure, **fringeStruct.getDict()) 

244 

245 def checkFilter(self, exposure): 

246 """Check whether we should fringe-subtract the science exposure. 

247 

248 Parameters 

249 ---------- 

250 exposure : `lsst.afw.image.Exposure` 

251 Exposure to check the filter of. 

252 

253 Returns 

254 ------- 

255 needsFringe : `bool` 

256 If True, then the exposure has a filter listed in the 

257 configuration, and should have the fringe applied. 

258 """ 

259 return checkFilter(exposure, self.config.filters, log=self.log) 

260 

261 def removePedestal(self, fringe): 

262 """Remove pedestal from fringe exposure. 

263 

264 Parameters 

265 ---------- 

266 fringe : `lsst.afw.image.Exposure` 

267 Fringe data to subtract the pedestal value from. 

268 """ 

269 stats = afwMath.StatisticsControl() 

270 stats.setNumSigmaClip(self.config.stats.clip) 

271 stats.setNumIter(self.config.stats.iterations) 

272 mi = fringe.getMaskedImage() 

273 pedestal = afwMath.makeStatistics(mi, afwMath.MEDIAN, stats).getValue() 

274 self.log.info("Removing fringe pedestal: %f", pedestal) 

275 mi -= pedestal 

276 

277 def generatePositions(self, exposure, rng): 

278 """Generate a random distribution of positions for measuring fringe amplitudes. 

279 

280 Parameters 

281 ---------- 

282 exposure : `lsst.afw.image.Exposure` 

283 Exposure to measure the positions on. 

284 rng : `numpy.random.RandomState` 

285 Random number generator to use. 

286 

287 Returns 

288 ------- 

289 positions : `numpy.array` 

290 Two-dimensional array containing the positions to sample 

291 for fringe amplitudes. 

292 """ 

293 start = self.config.large 

294 num = self.config.num 

295 width = exposure.getWidth() - self.config.large 

296 height = exposure.getHeight() - self.config.large 

297 return numpy.array([rng.randint(start, width, size=num), 

298 rng.randint(start, height, size=num)]).swapaxes(0, 1) 

299 

300 @timeMethod 

301 def measureExposure(self, exposure, positions, title="Fringe"): 

302 """Measure fringe amplitudes for an exposure 

303 

304 The fringe amplitudes are measured as the statistic within a square 

305 aperture. The statistic within a larger aperture are subtracted so 

306 as to remove the background. 

307 

308 Parameters 

309 ---------- 

310 exposure : `lsst.afw.image.Exposure` 

311 Exposure to measure the positions on. 

312 positions : `numpy.array` 

313 Two-dimensional array containing the positions to sample 

314 for fringe amplitudes. 

315 title : `str`, optional 

316 Title used for debug out plots. 

317 

318 Returns 

319 ------- 

320 fringes : `numpy.array` 

321 Array of measured exposure values at each of the positions 

322 supplied. 

323 """ 

324 stats = afwMath.StatisticsControl() 

325 stats.setNumSigmaClip(self.config.stats.clip) 

326 stats.setNumIter(self.config.stats.iterations) 

327 stats.setAndMask(exposure.getMaskedImage().getMask().getPlaneBitMask(self.config.stats.badMaskPlanes)) 

328 

329 num = self.config.num 

330 fringes = numpy.ndarray(num) 

331 

332 for i in range(num): 

333 x, y = positions[i] 

334 small = measure(exposure.getMaskedImage(), x, y, self.config.small, self.config.stats.stat, stats) 

335 large = measure(exposure.getMaskedImage(), x, y, self.config.large, self.config.stats.stat, stats) 

336 fringes[i] = small - large 

337 

338 import lsstDebug 

339 display = lsstDebug.Info(__name__).display 

340 if display: 

341 disp = afwDisplay.Display(frame=getFrame()) 

342 disp.mtv(exposure, title=title) 

343 if False: 

344 with disp.Buffering(): 

345 for x, y in positions: 

346 corners = numpy.array([[-1, -1], [1, -1], [1, 1], [-1, 1], [-1, -1]]) + [[x, y]] 

347 disp.line(corners*self.config.small, ctype=afwDisplay.GREEN) 

348 disp.line(corners*self.config.large, ctype=afwDisplay.BLUE) 

349 

350 return fringes 

351 

352 @timeMethod 

353 def solve(self, science, fringes): 

354 """Solve for the scale factors with iterative clipping. 

355 

356 Parameters 

357 ---------- 

358 science : `numpy.array` 

359 Array of measured science image values at each of the 

360 positions supplied. 

361 fringes : `numpy.array` 

362 Array of measured fringe values at each of the positions 

363 supplied. 

364 

365 Returns 

366 ------- 

367 solution : `np.array` 

368 Fringe solution amplitudes for each input fringe frame. 

369 rms : `float` 

370 RMS error for the fit solution for this exposure. 

371 """ 

372 import lsstDebug 

373 doPlot = lsstDebug.Info(__name__).plot 

374 

375 origNum = len(science) 

376 

377 def emptyResult(msg=""): 

378 """Generate an empty result for return to the user 

379 

380 There are no good pixels; doesn't matter what we return. 

381 """ 

382 self.log.warning("Unable to solve for fringes: no good pixels%s", msg) 

383 out = [0] 

384 if len(fringes) > 1: 

385 out = out*len(fringes) 

386 return numpy.array(out), numpy.nan 

387 

388 good = numpy.where(numpy.logical_and(numpy.isfinite(science), numpy.any(numpy.isfinite(fringes), 1))) 

389 science = science[good] 

390 fringes = fringes[good] 

391 oldNum = len(science) 

392 if oldNum == 0: 

393 return emptyResult() 

394 

395 # Up-front rejection to get rid of extreme, potentially troublesome values 

396 # (e.g., fringe apertures that fall on objects). 

397 good = select(science, self.config.clip) 

398 for ff in range(fringes.shape[1]): 

399 good &= select(fringes[:, ff], self.config.clip) 

400 science = science[good] 

401 fringes = fringes[good] 

402 oldNum = len(science) 

403 if oldNum == 0: 

404 return emptyResult(" after initial rejection") 

405 

406 for i in range(self.config.iterations): 

407 solution = self._solve(science, fringes) 

408 resid = science - numpy.sum(solution*fringes, 1) 

409 rms = stdev(resid) 

410 good = numpy.logical_not(abs(resid) > self.config.clip*rms) 

411 self.log.debug("Iteration %d: RMS=%f numGood=%d", i, rms, good.sum()) 

412 self.log.debug("Solution %d: %s", i, solution) 

413 newNum = good.sum() 

414 if newNum == 0: 

415 return emptyResult(" after %d rejection iterations" % i) 

416 

417 if doPlot: 

418 import matplotlib.pyplot as plot 

419 for j in range(fringes.shape[1]): 

420 fig = plot.figure(j) 

421 fig.clf() 

422 try: 

423 fig.canvas._tkcanvas._root().lift() # == Tk's raise 

424 except Exception: 

425 pass 

426 ax = fig.add_subplot(1, 1, 1) 

427 adjust = science.copy() 

428 others = set(range(fringes.shape[1])) 

429 others.discard(j) 

430 for k in others: 

431 adjust -= solution[k]*fringes[:, k] 

432 ax.plot(fringes[:, j], adjust, 'r.') 

433 xmin = fringes[:, j].min() 

434 xmax = fringes[:, j].max() 

435 ymin = solution[j]*xmin 

436 ymax = solution[j]*xmax 

437 ax.plot([xmin, xmax], [ymin, ymax], 'b-') 

438 ax.set_title("Fringe %d: %f" % (j, solution[j])) 

439 ax.set_xlabel("Fringe amplitude") 

440 ax.set_ylabel("Science amplitude") 

441 ax.set_autoscale_on(False) 

442 ax.set_xbound(lower=xmin, upper=xmax) 

443 ax.set_ybound(lower=ymin, upper=ymax) 

444 fig.show() 

445 while True: 

446 ans = input("Enter or c to continue [chp]").lower() 

447 if ans in ("", "c",): 

448 break 

449 if ans in ("p",): 

450 import pdb 

451 pdb.set_trace() 

452 elif ans in ("h", ): 

453 print("h[elp] c[ontinue] p[db]") 

454 

455 if newNum == oldNum: 

456 # Not gaining 

457 break 

458 oldNum = newNum 

459 good = numpy.where(good) 

460 science = science[good] 

461 fringes = fringes[good] 

462 

463 # Final solution without rejection 

464 solution = self._solve(science, fringes) 

465 self.log.info("Fringe solution: %s RMS: %f Good: %d/%d", solution, rms, len(science), origNum) 

466 return solution, rms 

467 

468 def _solve(self, science, fringes): 

469 """Solve for the scale factors. 

470 

471 Parameters 

472 ---------- 

473 science : `numpy.array` 

474 Array of measured science image values at each of the 

475 positions supplied. 

476 fringes : `numpy.array` 

477 Array of measured fringe values at each of the positions 

478 supplied. 

479 

480 Returns 

481 ------- 

482 solution : `np.array` 

483 Fringe solution amplitudes for each input fringe frame. 

484 """ 

485 return afwMath.LeastSquares.fromDesignMatrix(fringes, science, 

486 afwMath.LeastSquares.DIRECT_SVD).getSolution() 

487 

488 def subtract(self, science, fringes, solution): 

489 """Subtract the fringes. 

490 

491 Parameters 

492 ---------- 

493 science : `lsst.afw.image.Exposure` 

494 Science exposure from which to remove fringes. 

495 fringes : `lsst.afw.image.Exposure` or `list` thereof 

496 Calibration fringe files containing master fringe frames. 

497 solution : `np.array` 

498 Fringe solution amplitudes for each input fringe frame. 

499 

500 Raises 

501 ------ 

502 RuntimeError : 

503 Raised if the number of fringe frames does not match the 

504 number of measured amplitudes. 

505 """ 

506 if len(solution) != len(fringes): 

507 raise RuntimeError("Number of fringe frames (%s) != number of scale factors (%s)" % 

508 (len(fringes), len(solution))) 

509 

510 for s, f in zip(solution, fringes): 

511 # We do not want to add the mask from the fringe to the image. 

512 f.getMaskedImage().getMask().getArray()[:] = 0 

513 science.getMaskedImage().scaledMinus(s, f.getMaskedImage()) 

514 

515 

516def measure(mi, x, y, size, statistic, stats): 

517 """Measure a statistic within an aperture 

518 

519 @param mi MaskedImage to measure 

520 @param x, y Center for aperture 

521 @param size Size of aperture 

522 @param statistic Statistic to measure 

523 @param stats StatisticsControl object 

524 @return Value of statistic within aperture 

525 """ 

526 bbox = lsst.geom.Box2I(lsst.geom.Point2I(int(x) - size, int(y - size)), 

527 lsst.geom.Extent2I(2*size, 2*size)) 

528 subImage = mi.Factory(mi, bbox, afwImage.LOCAL) 

529 return afwMath.makeStatistics(subImage, statistic, stats).getValue() 

530 

531 

532def stdev(vector): 

533 """Calculate a robust standard deviation of an array of values 

534 

535 @param vector Array of values 

536 @return Standard deviation 

537 """ 

538 q1, q3 = numpy.percentile(vector, (25, 75)) 

539 return 0.74*(q3 - q1) 

540 

541 

542def select(vector, clip): 

543 """Select values within 'clip' standard deviations of the median 

544 

545 Returns a boolean array. 

546 """ 

547 q1, q2, q3 = numpy.percentile(vector, (25, 50, 75)) 

548 return numpy.abs(vector - q2) < clip*0.74*(q3 - q1)