Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of ip_isr. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import numpy 

23 

24import lsst.geom 

25import lsst.afw.image as afwImage 

26import lsst.afw.math as afwMath 

27import lsst.afw.display as afwDisplay 

28 

29from lsst.pipe.base import Task, Struct, timeMethod 

30from lsst.pex.config import Config, Field, ListField, ConfigField 

31from .isrFunctions import checkFilter 

32 

33afwDisplay.setDefaultMaskTransparency(75) 

34 

35 

36def getFrame(): 

37 """Produce a new frame number each time""" 

38 getFrame.frame += 1 

39 return getFrame.frame 

40 

41 

42getFrame.frame = 0 

43 

44 

45class FringeStatisticsConfig(Config): 

46 """Options for measuring fringes on an exposure""" 

47 badMaskPlanes = ListField(dtype=str, default=["SAT"], doc="Ignore pixels with these masks") 

48 stat = Field(dtype=int, default=int(afwMath.MEDIAN), doc="Statistic to use") 

49 clip = Field(dtype=float, default=3.0, doc="Sigma clip threshold") 

50 iterations = Field(dtype=int, default=3, doc="Number of fitting iterations") 

51 rngSeedOffset = Field(dtype=int, default=0, 

52 doc="Offset to the random number generator seed (full seed includes exposure ID)") 

53 

54 

55class FringeConfig(Config): 

56 """Fringe subtraction options""" 

57 # TODO DM-28093: change the doc to specify that these are physical labels 

58 filters = ListField(dtype=str, default=[], doc="Only fringe-subtract these filters") 

59 # TODO: remove in DM-27177 

60 useFilterAliases = Field(dtype=bool, default=False, doc="Search filter aliases during check.", 

61 deprecated=("Removed with no replacement (FilterLabel has no aliases)." 

62 "Will be removed after v22.")) 

63 num = Field(dtype=int, default=30000, doc="Number of fringe measurements") 

64 small = Field(dtype=int, default=3, doc="Half-size of small (fringe) measurements (pixels)") 

65 large = Field(dtype=int, default=30, doc="Half-size of large (background) measurements (pixels)") 

66 iterations = Field(dtype=int, default=20, doc="Number of fitting iterations") 

67 clip = Field(dtype=float, default=3.0, doc="Sigma clip threshold") 

68 stats = ConfigField(dtype=FringeStatisticsConfig, doc="Statistics for measuring fringes") 

69 pedestal = Field(dtype=bool, default=False, doc="Remove fringe pedestal?") 

70 

71 

72class FringeTask(Task): 

73 """Task to remove fringes from a science exposure 

74 

75 We measure fringe amplitudes at random positions on the science exposure 

76 and at the same positions on the (potentially multiple) fringe frames 

77 and solve for the scales simultaneously. 

78 """ 

79 ConfigClass = FringeConfig 

80 _DefaultName = 'isrFringe' 

81 

82 def readFringes(self, dataRef, assembler=None): 

83 """Read the fringe frame(s), and pack data into a Struct 

84 

85 The current implementation assumes only a single fringe frame and 

86 will have to be updated to support multi-mode fringe subtraction. 

87 

88 This implementation could be optimised by persisting the fringe 

89 positions and fluxes. 

90 

91 Parameters 

92 ---------- 

93 dataRef : `daf.butler.butlerSubset.ButlerDataRef` 

94 Butler reference for the exposure that will have fringing 

95 removed. 

96 assembler : `lsst.ip.isr.AssembleCcdTask`, optional 

97 An instance of AssembleCcdTask (for assembling fringe 

98 frames). 

99 

100 Returns 

101 ------- 

102 fringeData : `pipeBase.Struct` 

103 Struct containing fringe data: 

104 - ``fringes`` : `lsst.afw.image.Exposure` or `list` thereof 

105 Calibration fringe files containing master fringe frames. 

106 - ``seed`` : `int`, optional 

107 Seed for random number generation. 

108 """ 

109 try: 

110 fringe = dataRef.get("fringe", immediate=True) 

111 except Exception as e: 

112 raise RuntimeError("Unable to retrieve fringe for %s: %s" % (dataRef.dataId, e)) 

113 

114 return self.loadFringes(fringe, assembler) 

115 

116 def loadFringes(self, fringeExp, expId=0, assembler=None): 

117 """Pack the fringe data into a Struct. 

118 

119 This method moves the struct parsing code into a butler 

120 generation agnostic handler. 

121 

122 Parameters 

123 ---------- 

124 fringeExp : `lsst.afw.exposure.Exposure` 

125 The exposure containing the fringe data. 

126 expId : `int`, optional 

127 Exposure id to be fringe corrected, used to set RNG seed. 

128 assembler : `lsst.ip.isr.AssembleCcdTask`, optional 

129 An instance of AssembleCcdTask (for assembling fringe 

130 frames). 

131 

132 Returns 

133 ------- 

134 fringeData : `pipeBase.Struct` 

135 Struct containing fringe data: 

136 - ``fringes`` : `lsst.afw.image.Exposure` or `list` thereof 

137 Calibration fringe files containing master fringe frames. 

138 - ``seed`` : `int`, optional 

139 Seed for random number generation. 

140 """ 

141 if assembler is not None: 

142 fringeExp = assembler.assembleCcd(fringeExp) 

143 

144 if expId is None: 

145 seed = self.config.stats.rngSeedOffset 

146 else: 

147 print(f"{self.config.stats.rngSeedOffset} {expId}") 

148 seed = self.config.stats.rngSeedOffset + expId 

149 

150 # Seed for numpy.random.RandomState must be convertable to a 32 bit unsigned integer 

151 seed %= 2**32 

152 

153 return Struct(fringes=fringeExp, 

154 seed=seed) 

155 

156 @timeMethod 

157 def run(self, exposure, fringes, seed=None): 

158 """Remove fringes from the provided science exposure. 

159 

160 Primary method of FringeTask. Fringes are only subtracted if the 

161 science exposure has a filter listed in the configuration. 

162 

163 Parameters 

164 ---------- 

165 exposure : `lsst.afw.image.Exposure` 

166 Science exposure from which to remove fringes. 

167 fringes : `lsst.afw.image.Exposure` or `list` thereof 

168 Calibration fringe files containing master fringe frames. 

169 seed : `int`, optional 

170 Seed for random number generation. 

171 

172 Returns 

173 ------- 

174 solution : `np.array` 

175 Fringe solution amplitudes for each input fringe frame. 

176 rms : `float` 

177 RMS error for the fit solution for this exposure. 

178 """ 

179 import lsstDebug 

180 display = lsstDebug.Info(__name__).display 

181 

182 if not self.checkFilter(exposure): 

183 self.log.info("Filter not found in FringeTaskConfig.filters. Skipping fringe correction.") 

184 return 

185 

186 if seed is None: 

187 seed = self.config.stats.rngSeedOffset 

188 rng = numpy.random.RandomState(seed=seed) 

189 

190 if not hasattr(fringes, '__iter__'): 

191 fringes = [fringes] 

192 

193 mask = exposure.getMaskedImage().getMask() 

194 for fringe in fringes: 

195 fringe.getMaskedImage().getMask().__ior__(mask) 

196 if self.config.pedestal: 

197 self.removePedestal(fringe) 

198 

199 positions = self.generatePositions(fringes[0], rng) 

200 fluxes = numpy.ndarray([self.config.num, len(fringes)]) 

201 for i, f in enumerate(fringes): 

202 fluxes[:, i] = self.measureExposure(f, positions, title="Fringe frame") 

203 

204 expFringes = self.measureExposure(exposure, positions, title="Science") 

205 solution, rms = self.solve(expFringes, fluxes) 

206 self.subtract(exposure, fringes, solution) 

207 if display: 

208 afwDisplay.Display(frame=getFrame()).mtv(exposure, title="Fringe subtracted") 

209 return solution, rms 

210 

211 @timeMethod 

212 def runDataRef(self, exposure, dataRef, assembler=None): 

213 """Remove fringes from the provided science exposure. 

214 

215 Retrieve fringes from butler dataRef provided and remove from 

216 provided science exposure. Fringes are only subtracted if the 

217 science exposure has a filter listed in the configuration. 

218 

219 Parameters 

220 ---------- 

221 exposure : `lsst.afw.image.Exposure` 

222 Science exposure from which to remove fringes. 

223 dataRef : `daf.persistence.butlerSubset.ButlerDataRef` 

224 Butler reference to the exposure. Used to find 

225 appropriate fringe data. 

226 assembler : `lsst.ip.isr.AssembleCcdTask`, optional 

227 An instance of AssembleCcdTask (for assembling fringe 

228 frames). 

229 

230 Returns 

231 ------- 

232 solution : `np.array` 

233 Fringe solution amplitudes for each input fringe frame. 

234 rms : `float` 

235 RMS error for the fit solution for this exposure. 

236 """ 

237 if not self.checkFilter(exposure): 

238 self.log.info("Filter not found in FringeTaskConfig.filters. Skipping fringe correction.") 

239 return 

240 fringeStruct = self.readFringes(dataRef, assembler=assembler) 

241 return self.run(exposure, **fringeStruct.getDict()) 

242 

243 def checkFilter(self, exposure): 

244 """Check whether we should fringe-subtract the science exposure. 

245 

246 Parameters 

247 ---------- 

248 exposure : `lsst.afw.image.Exposure` 

249 Exposure to check the filter of. 

250 

251 Returns 

252 ------- 

253 needsFringe : `bool` 

254 If True, then the exposure has a filter listed in the 

255 configuration, and should have the fringe applied. 

256 """ 

257 return checkFilter(exposure, self.config.filters, log=self.log) 

258 

259 def removePedestal(self, fringe): 

260 """Remove pedestal from fringe exposure. 

261 

262 Parameters 

263 ---------- 

264 fringe : `lsst.afw.image.Exposure` 

265 Fringe data to subtract the pedestal value from. 

266 """ 

267 stats = afwMath.StatisticsControl() 

268 stats.setNumSigmaClip(self.config.stats.clip) 

269 stats.setNumIter(self.config.stats.iterations) 

270 mi = fringe.getMaskedImage() 

271 pedestal = afwMath.makeStatistics(mi, afwMath.MEDIAN, stats).getValue() 

272 self.log.info("Removing fringe pedestal: %f", pedestal) 

273 mi -= pedestal 

274 

275 def generatePositions(self, exposure, rng): 

276 """Generate a random distribution of positions for measuring fringe amplitudes. 

277 

278 Parameters 

279 ---------- 

280 exposure : `lsst.afw.image.Exposure` 

281 Exposure to measure the positions on. 

282 rng : `numpy.random.RandomState` 

283 Random number generator to use. 

284 

285 Returns 

286 ------- 

287 positions : `numpy.array` 

288 Two-dimensional array containing the positions to sample 

289 for fringe amplitudes. 

290 """ 

291 start = self.config.large 

292 num = self.config.num 

293 width = exposure.getWidth() - self.config.large 

294 height = exposure.getHeight() - self.config.large 

295 return numpy.array([rng.randint(start, width, size=num), 

296 rng.randint(start, height, size=num)]).swapaxes(0, 1) 

297 

298 @timeMethod 

299 def measureExposure(self, exposure, positions, title="Fringe"): 

300 """Measure fringe amplitudes for an exposure 

301 

302 The fringe amplitudes are measured as the statistic within a square 

303 aperture. The statistic within a larger aperture are subtracted so 

304 as to remove the background. 

305 

306 Parameters 

307 ---------- 

308 exposure : `lsst.afw.image.Exposure` 

309 Exposure to measure the positions on. 

310 positions : `numpy.array` 

311 Two-dimensional array containing the positions to sample 

312 for fringe amplitudes. 

313 title : `str`, optional 

314 Title used for debug out plots. 

315 

316 Returns 

317 ------- 

318 fringes : `numpy.array` 

319 Array of measured exposure values at each of the positions 

320 supplied. 

321 """ 

322 stats = afwMath.StatisticsControl() 

323 stats.setNumSigmaClip(self.config.stats.clip) 

324 stats.setNumIter(self.config.stats.iterations) 

325 stats.setAndMask(exposure.getMaskedImage().getMask().getPlaneBitMask(self.config.stats.badMaskPlanes)) 

326 

327 num = self.config.num 

328 fringes = numpy.ndarray(num) 

329 

330 for i in range(num): 

331 x, y = positions[i] 

332 small = measure(exposure.getMaskedImage(), x, y, self.config.small, self.config.stats.stat, stats) 

333 large = measure(exposure.getMaskedImage(), x, y, self.config.large, self.config.stats.stat, stats) 

334 fringes[i] = small - large 

335 

336 import lsstDebug 

337 display = lsstDebug.Info(__name__).display 

338 if display: 

339 disp = afwDisplay.Display(frame=getFrame()) 

340 disp.mtv(exposure, title=title) 

341 if False: 

342 with disp.Buffering(): 

343 for x, y in positions: 

344 corners = numpy.array([[-1, -1], [1, -1], [1, 1], [-1, 1], [-1, -1]]) + [[x, y]] 

345 disp.line(corners*self.config.small, ctype=afwDisplay.GREEN) 

346 disp.line(corners*self.config.large, ctype=afwDisplay.BLUE) 

347 

348 return fringes 

349 

350 @timeMethod 

351 def solve(self, science, fringes): 

352 """Solve for the scale factors with iterative clipping. 

353 

354 Parameters 

355 ---------- 

356 science : `numpy.array` 

357 Array of measured science image values at each of the 

358 positions supplied. 

359 fringes : `numpy.array` 

360 Array of measured fringe values at each of the positions 

361 supplied. 

362 

363 Returns 

364 ------- 

365 solution : `np.array` 

366 Fringe solution amplitudes for each input fringe frame. 

367 rms : `float` 

368 RMS error for the fit solution for this exposure. 

369 """ 

370 import lsstDebug 

371 doPlot = lsstDebug.Info(__name__).plot 

372 

373 origNum = len(science) 

374 

375 def emptyResult(msg=""): 

376 """Generate an empty result for return to the user 

377 

378 There are no good pixels; doesn't matter what we return. 

379 """ 

380 self.log.warn("Unable to solve for fringes: no good pixels%s", msg) 

381 out = [0] 

382 if len(fringes) > 1: 

383 out = out*len(fringes) 

384 return numpy.array(out), numpy.nan 

385 

386 good = numpy.where(numpy.logical_and(numpy.isfinite(science), numpy.any(numpy.isfinite(fringes), 1))) 

387 science = science[good] 

388 fringes = fringes[good] 

389 oldNum = len(science) 

390 if oldNum == 0: 

391 return emptyResult() 

392 

393 # Up-front rejection to get rid of extreme, potentially troublesome values 

394 # (e.g., fringe apertures that fall on objects). 

395 good = select(science, self.config.clip) 

396 for ff in range(fringes.shape[1]): 

397 good &= select(fringes[:, ff], self.config.clip) 

398 science = science[good] 

399 fringes = fringes[good] 

400 oldNum = len(science) 

401 if oldNum == 0: 

402 return emptyResult(" after initial rejection") 

403 

404 for i in range(self.config.iterations): 

405 solution = self._solve(science, fringes) 

406 resid = science - numpy.sum(solution*fringes, 1) 

407 rms = stdev(resid) 

408 good = numpy.logical_not(abs(resid) > self.config.clip*rms) 

409 self.log.debug("Iteration %d: RMS=%f numGood=%d", i, rms, good.sum()) 

410 self.log.debug("Solution %d: %s", i, solution) 

411 newNum = good.sum() 

412 if newNum == 0: 

413 return emptyResult(" after %d rejection iterations" % i) 

414 

415 if doPlot: 

416 import matplotlib.pyplot as plot 

417 for j in range(fringes.shape[1]): 

418 fig = plot.figure(j) 

419 fig.clf() 

420 try: 

421 fig.canvas._tkcanvas._root().lift() # == Tk's raise 

422 except Exception: 

423 pass 

424 ax = fig.add_subplot(1, 1, 1) 

425 adjust = science.copy() 

426 others = set(range(fringes.shape[1])) 

427 others.discard(j) 

428 for k in others: 

429 adjust -= solution[k]*fringes[:, k] 

430 ax.plot(fringes[:, j], adjust, 'r.') 

431 xmin = fringes[:, j].min() 

432 xmax = fringes[:, j].max() 

433 ymin = solution[j]*xmin 

434 ymax = solution[j]*xmax 

435 ax.plot([xmin, xmax], [ymin, ymax], 'b-') 

436 ax.set_title("Fringe %d: %f" % (j, solution[j])) 

437 ax.set_xlabel("Fringe amplitude") 

438 ax.set_ylabel("Science amplitude") 

439 ax.set_autoscale_on(False) 

440 ax.set_xbound(lower=xmin, upper=xmax) 

441 ax.set_ybound(lower=ymin, upper=ymax) 

442 fig.show() 

443 while True: 

444 ans = input("Enter or c to continue [chp]").lower() 

445 if ans in ("", "c",): 

446 break 

447 if ans in ("p",): 

448 import pdb 

449 pdb.set_trace() 

450 elif ans in ("h", ): 

451 print("h[elp] c[ontinue] p[db]") 

452 

453 if newNum == oldNum: 

454 # Not gaining 

455 break 

456 oldNum = newNum 

457 good = numpy.where(good) 

458 science = science[good] 

459 fringes = fringes[good] 

460 

461 # Final solution without rejection 

462 solution = self._solve(science, fringes) 

463 self.log.info("Fringe solution: %s RMS: %f Good: %d/%d", solution, rms, len(science), origNum) 

464 return solution, rms 

465 

466 def _solve(self, science, fringes): 

467 """Solve for the scale factors. 

468 

469 Parameters 

470 ---------- 

471 science : `numpy.array` 

472 Array of measured science image values at each of the 

473 positions supplied. 

474 fringes : `numpy.array` 

475 Array of measured fringe values at each of the positions 

476 supplied. 

477 

478 Returns 

479 ------- 

480 solution : `np.array` 

481 Fringe solution amplitudes for each input fringe frame. 

482 """ 

483 return afwMath.LeastSquares.fromDesignMatrix(fringes, science, 

484 afwMath.LeastSquares.DIRECT_SVD).getSolution() 

485 

486 def subtract(self, science, fringes, solution): 

487 """Subtract the fringes. 

488 

489 Parameters 

490 ---------- 

491 science : `lsst.afw.image.Exposure` 

492 Science exposure from which to remove fringes. 

493 fringes : `lsst.afw.image.Exposure` or `list` thereof 

494 Calibration fringe files containing master fringe frames. 

495 solution : `np.array` 

496 Fringe solution amplitudes for each input fringe frame. 

497 

498 Raises 

499 ------ 

500 RuntimeError : 

501 Raised if the number of fringe frames does not match the 

502 number of measured amplitudes. 

503 """ 

504 if len(solution) != len(fringes): 

505 raise RuntimeError("Number of fringe frames (%s) != number of scale factors (%s)" % 

506 (len(fringes), len(solution))) 

507 

508 for s, f in zip(solution, fringes): 

509 science.getMaskedImage().scaledMinus(s, f.getMaskedImage()) 

510 

511 

512def measure(mi, x, y, size, statistic, stats): 

513 """Measure a statistic within an aperture 

514 

515 @param mi MaskedImage to measure 

516 @param x, y Center for aperture 

517 @param size Size of aperture 

518 @param statistic Statistic to measure 

519 @param stats StatisticsControl object 

520 @return Value of statistic within aperture 

521 """ 

522 bbox = lsst.geom.Box2I(lsst.geom.Point2I(int(x) - size, int(y - size)), 

523 lsst.geom.Extent2I(2*size, 2*size)) 

524 subImage = mi.Factory(mi, bbox, afwImage.LOCAL) 

525 return afwMath.makeStatistics(subImage, statistic, stats).getValue() 

526 

527 

528def stdev(vector): 

529 """Calculate a robust standard deviation of an array of values 

530 

531 @param vector Array of values 

532 @return Standard deviation 

533 """ 

534 q1, q3 = numpy.percentile(vector, (25, 75)) 

535 return 0.74*(q3 - q1) 

536 

537 

538def select(vector, clip): 

539 """Select values within 'clip' standard deviations of the median 

540 

541 Returns a boolean array. 

542 """ 

543 q1, q2, q3 = numpy.percentile(vector, (25, 50, 75)) 

544 return numpy.abs(vector - q2) < clip*0.74*(q3 - q1)