Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of ip_isr. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import numpy 

23 

24import lsst.geom 

25import lsst.afw.image as afwImage 

26import lsst.afw.math as afwMath 

27import lsst.afw.display as afwDisplay 

28 

29from lsst.pipe.base import Task, Struct, timeMethod 

30from lsst.pex.config import Config, Field, ListField, ConfigField 

31 

32afwDisplay.setDefaultMaskTransparency(75) 

33 

34 

35def getFrame(): 

36 """Produce a new frame number each time""" 

37 getFrame.frame += 1 

38 return getFrame.frame 

39 

40 

41getFrame.frame = 0 

42 

43 

44class FringeStatisticsConfig(Config): 

45 """Options for measuring fringes on an exposure""" 

46 badMaskPlanes = ListField(dtype=str, default=["SAT"], doc="Ignore pixels with these masks") 

47 stat = Field(dtype=int, default=int(afwMath.MEDIAN), doc="Statistic to use") 

48 clip = Field(dtype=float, default=3.0, doc="Sigma clip threshold") 

49 iterations = Field(dtype=int, default=3, doc="Number of fitting iterations") 

50 rngSeedOffset = Field(dtype=int, default=0, 

51 doc="Offset to the random number generator seed (full seed includes exposure ID)") 

52 

53 

54class FringeConfig(Config): 

55 """Fringe subtraction options""" 

56 # TODO DM-28093: change the doc to specify that these are physical labels 

57 filters = ListField(dtype=str, default=[], doc="Only fringe-subtract these filters") 

58 # TODO: remove in DM-27177 

59 useFilterAliases = Field(dtype=bool, default=False, doc="Search filter aliases during check.", 

60 deprecated=("Removed with no replacement (FilterLabel has no aliases)." 

61 "Will be removed after v22.")) 

62 num = Field(dtype=int, default=30000, doc="Number of fringe measurements") 

63 small = Field(dtype=int, default=3, doc="Half-size of small (fringe) measurements (pixels)") 

64 large = Field(dtype=int, default=30, doc="Half-size of large (background) measurements (pixels)") 

65 iterations = Field(dtype=int, default=20, doc="Number of fitting iterations") 

66 clip = Field(dtype=float, default=3.0, doc="Sigma clip threshold") 

67 stats = ConfigField(dtype=FringeStatisticsConfig, doc="Statistics for measuring fringes") 

68 pedestal = Field(dtype=bool, default=False, doc="Remove fringe pedestal?") 

69 

70 

71class FringeTask(Task): 

72 """Task to remove fringes from a science exposure 

73 

74 We measure fringe amplitudes at random positions on the science exposure 

75 and at the same positions on the (potentially multiple) fringe frames 

76 and solve for the scales simultaneously. 

77 """ 

78 ConfigClass = FringeConfig 

79 _DefaultName = 'isrFringe' 

80 

81 def readFringes(self, dataRef, assembler=None): 

82 """Read the fringe frame(s), and pack data into a Struct 

83 

84 The current implementation assumes only a single fringe frame and 

85 will have to be updated to support multi-mode fringe subtraction. 

86 

87 This implementation could be optimised by persisting the fringe 

88 positions and fluxes. 

89 

90 Parameters 

91 ---------- 

92 dataRef : `daf.butler.butlerSubset.ButlerDataRef` 

93 Butler reference for the exposure that will have fringing 

94 removed. 

95 assembler : `lsst.ip.isr.AssembleCcdTask`, optional 

96 An instance of AssembleCcdTask (for assembling fringe 

97 frames). 

98 

99 Returns 

100 ------- 

101 fringeData : `pipeBase.Struct` 

102 Struct containing fringe data: 

103 - ``fringes`` : `lsst.afw.image.Exposure` or `list` thereof 

104 Calibration fringe files containing master fringe frames. 

105 - ``seed`` : `int`, optional 

106 Seed for random number generation. 

107 """ 

108 try: 

109 fringe = dataRef.get("fringe", immediate=True) 

110 except Exception as e: 

111 raise RuntimeError("Unable to retrieve fringe for %s: %s" % (dataRef.dataId, e)) 

112 

113 return self.loadFringes(fringe, assembler) 

114 

115 def loadFringes(self, fringeExp, expId=0, assembler=None): 

116 """Pack the fringe data into a Struct. 

117 

118 This method moves the struct parsing code into a butler 

119 generation agnostic handler. 

120 

121 Parameters 

122 ---------- 

123 fringeExp : `lsst.afw.exposure.Exposure` 

124 The exposure containing the fringe data. 

125 expId : `int`, optional 

126 Exposure id to be fringe corrected, used to set RNG seed. 

127 assembler : `lsst.ip.isr.AssembleCcdTask`, optional 

128 An instance of AssembleCcdTask (for assembling fringe 

129 frames). 

130 

131 Returns 

132 ------- 

133 fringeData : `pipeBase.Struct` 

134 Struct containing fringe data: 

135 - ``fringes`` : `lsst.afw.image.Exposure` or `list` thereof 

136 Calibration fringe files containing master fringe frames. 

137 - ``seed`` : `int`, optional 

138 Seed for random number generation. 

139 """ 

140 if assembler is not None: 

141 fringeExp = assembler.assembleCcd(fringeExp) 

142 

143 if expId is None: 

144 seed = self.config.stats.rngSeedOffset 

145 else: 

146 print(f"{self.config.stats.rngSeedOffset} {expId}") 

147 seed = self.config.stats.rngSeedOffset + expId 

148 

149 # Seed for numpy.random.RandomState must be convertable to a 32 bit unsigned integer 

150 seed %= 2**32 

151 

152 return Struct(fringes=fringeExp, 

153 seed=seed) 

154 

155 @timeMethod 

156 def run(self, exposure, fringes, seed=None): 

157 """Remove fringes from the provided science exposure. 

158 

159 Primary method of FringeTask. Fringes are only subtracted if the 

160 science exposure has a filter listed in the configuration. 

161 

162 Parameters 

163 ---------- 

164 exposure : `lsst.afw.image.Exposure` 

165 Science exposure from which to remove fringes. 

166 fringes : `lsst.afw.image.Exposure` or `list` thereof 

167 Calibration fringe files containing master fringe frames. 

168 seed : `int`, optional 

169 Seed for random number generation. 

170 

171 Returns 

172 ------- 

173 solution : `np.array` 

174 Fringe solution amplitudes for each input fringe frame. 

175 rms : `float` 

176 RMS error for the fit solution for this exposure. 

177 """ 

178 import lsstDebug 

179 display = lsstDebug.Info(__name__).display 

180 

181 if not self.checkFilter(exposure): 

182 self.log.info("Filter not found in FringeTaskConfig.filters. Skipping fringe correction.") 

183 return 

184 

185 if seed is None: 

186 seed = self.config.stats.rngSeedOffset 

187 rng = numpy.random.RandomState(seed=seed) 

188 

189 if not hasattr(fringes, '__iter__'): 

190 fringes = [fringes] 

191 

192 mask = exposure.getMaskedImage().getMask() 

193 for fringe in fringes: 

194 fringe.getMaskedImage().getMask().__ior__(mask) 

195 if self.config.pedestal: 

196 self.removePedestal(fringe) 

197 

198 positions = self.generatePositions(fringes[0], rng) 

199 fluxes = numpy.ndarray([self.config.num, len(fringes)]) 

200 for i, f in enumerate(fringes): 

201 fluxes[:, i] = self.measureExposure(f, positions, title="Fringe frame") 

202 

203 expFringes = self.measureExposure(exposure, positions, title="Science") 

204 solution, rms = self.solve(expFringes, fluxes) 

205 self.subtract(exposure, fringes, solution) 

206 if display: 

207 afwDisplay.Display(frame=getFrame()).mtv(exposure, title="Fringe subtracted") 

208 return solution, rms 

209 

210 @timeMethod 

211 def runDataRef(self, exposure, dataRef, assembler=None): 

212 """Remove fringes from the provided science exposure. 

213 

214 Retrieve fringes from butler dataRef provided and remove from 

215 provided science exposure. Fringes are only subtracted if the 

216 science exposure has a filter listed in the configuration. 

217 

218 Parameters 

219 ---------- 

220 exposure : `lsst.afw.image.Exposure` 

221 Science exposure from which to remove fringes. 

222 dataRef : `daf.persistence.butlerSubset.ButlerDataRef` 

223 Butler reference to the exposure. Used to find 

224 appropriate fringe data. 

225 assembler : `lsst.ip.isr.AssembleCcdTask`, optional 

226 An instance of AssembleCcdTask (for assembling fringe 

227 frames). 

228 

229 Returns 

230 ------- 

231 solution : `np.array` 

232 Fringe solution amplitudes for each input fringe frame. 

233 rms : `float` 

234 RMS error for the fit solution for this exposure. 

235 """ 

236 if not self.checkFilter(exposure): 

237 self.log.info("Filter not found in FringeTaskConfig.filters. Skipping fringe correction.") 

238 return 

239 fringeStruct = self.readFringes(dataRef, assembler=assembler) 

240 return self.run(exposure, **fringeStruct.getDict()) 

241 

242 def checkFilter(self, exposure): 

243 """Check whether we should fringe-subtract the science exposure. 

244 

245 Parameters 

246 ---------- 

247 exposure : `lsst.afw.image.Exposure` 

248 Exposure to check the filter of. 

249 

250 Returns 

251 ------- 

252 needsFringe : `bool` 

253 If True, then the exposure has a filter listed in the 

254 configuration, and should have the fringe applied. 

255 """ 

256 filterObj = afwImage.Filter(exposure.getFilter().getId()) 

257 # TODO: remove this check along with the config option in DM-27177 

258 if self.config.useFilterAliases: 

259 filterNameSet = set(filterObj.getAliases() + [filterObj.getName()]) 

260 else: 

261 filterNameSet = set([filterObj.getName(), ]) 

262 return bool(len(filterNameSet.intersection(self.config.filters))) 

263 

264 def removePedestal(self, fringe): 

265 """Remove pedestal from fringe exposure. 

266 

267 Parameters 

268 ---------- 

269 fringe : `lsst.afw.image.Exposure` 

270 Fringe data to subtract the pedestal value from. 

271 """ 

272 stats = afwMath.StatisticsControl() 

273 stats.setNumSigmaClip(self.config.stats.clip) 

274 stats.setNumIter(self.config.stats.iterations) 

275 mi = fringe.getMaskedImage() 

276 pedestal = afwMath.makeStatistics(mi, afwMath.MEDIAN, stats).getValue() 

277 self.log.info("Removing fringe pedestal: %f", pedestal) 

278 mi -= pedestal 

279 

280 def generatePositions(self, exposure, rng): 

281 """Generate a random distribution of positions for measuring fringe amplitudes. 

282 

283 Parameters 

284 ---------- 

285 exposure : `lsst.afw.image.Exposure` 

286 Exposure to measure the positions on. 

287 rng : `numpy.random.RandomState` 

288 Random number generator to use. 

289 

290 Returns 

291 ------- 

292 positions : `numpy.array` 

293 Two-dimensional array containing the positions to sample 

294 for fringe amplitudes. 

295 """ 

296 start = self.config.large 

297 num = self.config.num 

298 width = exposure.getWidth() - self.config.large 

299 height = exposure.getHeight() - self.config.large 

300 return numpy.array([rng.randint(start, width, size=num), 

301 rng.randint(start, height, size=num)]).swapaxes(0, 1) 

302 

303 @timeMethod 

304 def measureExposure(self, exposure, positions, title="Fringe"): 

305 """Measure fringe amplitudes for an exposure 

306 

307 The fringe amplitudes are measured as the statistic within a square 

308 aperture. The statistic within a larger aperture are subtracted so 

309 as to remove the background. 

310 

311 Parameters 

312 ---------- 

313 exposure : `lsst.afw.image.Exposure` 

314 Exposure to measure the positions on. 

315 positions : `numpy.array` 

316 Two-dimensional array containing the positions to sample 

317 for fringe amplitudes. 

318 title : `str`, optional 

319 Title used for debug out plots. 

320 

321 Returns 

322 ------- 

323 fringes : `numpy.array` 

324 Array of measured exposure values at each of the positions 

325 supplied. 

326 """ 

327 stats = afwMath.StatisticsControl() 

328 stats.setNumSigmaClip(self.config.stats.clip) 

329 stats.setNumIter(self.config.stats.iterations) 

330 stats.setAndMask(exposure.getMaskedImage().getMask().getPlaneBitMask(self.config.stats.badMaskPlanes)) 

331 

332 num = self.config.num 

333 fringes = numpy.ndarray(num) 

334 

335 for i in range(num): 

336 x, y = positions[i] 

337 small = measure(exposure.getMaskedImage(), x, y, self.config.small, self.config.stats.stat, stats) 

338 large = measure(exposure.getMaskedImage(), x, y, self.config.large, self.config.stats.stat, stats) 

339 fringes[i] = small - large 

340 

341 import lsstDebug 

342 display = lsstDebug.Info(__name__).display 

343 if display: 

344 disp = afwDisplay.Display(frame=getFrame()) 

345 disp.mtv(exposure, title=title) 

346 if False: 

347 with disp.Buffering(): 

348 for x, y in positions: 

349 corners = numpy.array([[-1, -1], [1, -1], [1, 1], [-1, 1], [-1, -1]]) + [[x, y]] 

350 disp.line(corners*self.config.small, ctype=afwDisplay.GREEN) 

351 disp.line(corners*self.config.large, ctype=afwDisplay.BLUE) 

352 

353 return fringes 

354 

355 @timeMethod 

356 def solve(self, science, fringes): 

357 """Solve for the scale factors with iterative clipping. 

358 

359 Parameters 

360 ---------- 

361 science : `numpy.array` 

362 Array of measured science image values at each of the 

363 positions supplied. 

364 fringes : `numpy.array` 

365 Array of measured fringe values at each of the positions 

366 supplied. 

367 

368 Returns 

369 ------- 

370 solution : `np.array` 

371 Fringe solution amplitudes for each input fringe frame. 

372 rms : `float` 

373 RMS error for the fit solution for this exposure. 

374 """ 

375 import lsstDebug 

376 doPlot = lsstDebug.Info(__name__).plot 

377 

378 origNum = len(science) 

379 

380 def emptyResult(msg=""): 

381 """Generate an empty result for return to the user 

382 

383 There are no good pixels; doesn't matter what we return. 

384 """ 

385 self.log.warn("Unable to solve for fringes: no good pixels%s", msg) 

386 out = [0] 

387 if len(fringes) > 1: 

388 out = out*len(fringes) 

389 return numpy.array(out), numpy.nan 

390 

391 good = numpy.where(numpy.logical_and(numpy.isfinite(science), numpy.any(numpy.isfinite(fringes), 1))) 

392 science = science[good] 

393 fringes = fringes[good] 

394 oldNum = len(science) 

395 if oldNum == 0: 

396 return emptyResult() 

397 

398 # Up-front rejection to get rid of extreme, potentially troublesome values 

399 # (e.g., fringe apertures that fall on objects). 

400 good = select(science, self.config.clip) 

401 for ff in range(fringes.shape[1]): 

402 good &= select(fringes[:, ff], self.config.clip) 

403 science = science[good] 

404 fringes = fringes[good] 

405 oldNum = len(science) 

406 if oldNum == 0: 

407 return emptyResult(" after initial rejection") 

408 

409 for i in range(self.config.iterations): 

410 solution = self._solve(science, fringes) 

411 resid = science - numpy.sum(solution*fringes, 1) 

412 rms = stdev(resid) 

413 good = numpy.logical_not(abs(resid) > self.config.clip*rms) 

414 self.log.debug("Iteration %d: RMS=%f numGood=%d", i, rms, good.sum()) 

415 self.log.debug("Solution %d: %s", i, solution) 

416 newNum = good.sum() 

417 if newNum == 0: 

418 return emptyResult(" after %d rejection iterations" % i) 

419 

420 if doPlot: 

421 import matplotlib.pyplot as plot 

422 for j in range(fringes.shape[1]): 

423 fig = plot.figure(j) 

424 fig.clf() 

425 try: 

426 fig.canvas._tkcanvas._root().lift() # == Tk's raise 

427 except Exception: 

428 pass 

429 ax = fig.add_subplot(1, 1, 1) 

430 adjust = science.copy() 

431 others = set(range(fringes.shape[1])) 

432 others.discard(j) 

433 for k in others: 

434 adjust -= solution[k]*fringes[:, k] 

435 ax.plot(fringes[:, j], adjust, 'r.') 

436 xmin = fringes[:, j].min() 

437 xmax = fringes[:, j].max() 

438 ymin = solution[j]*xmin 

439 ymax = solution[j]*xmax 

440 ax.plot([xmin, xmax], [ymin, ymax], 'b-') 

441 ax.set_title("Fringe %d: %f" % (j, solution[j])) 

442 ax.set_xlabel("Fringe amplitude") 

443 ax.set_ylabel("Science amplitude") 

444 ax.set_autoscale_on(False) 

445 ax.set_xbound(lower=xmin, upper=xmax) 

446 ax.set_ybound(lower=ymin, upper=ymax) 

447 fig.show() 

448 while True: 

449 ans = input("Enter or c to continue [chp]").lower() 

450 if ans in ("", "c",): 

451 break 

452 if ans in ("p",): 

453 import pdb 

454 pdb.set_trace() 

455 elif ans in ("h", ): 

456 print("h[elp] c[ontinue] p[db]") 

457 

458 if newNum == oldNum: 

459 # Not gaining 

460 break 

461 oldNum = newNum 

462 good = numpy.where(good) 

463 science = science[good] 

464 fringes = fringes[good] 

465 

466 # Final solution without rejection 

467 solution = self._solve(science, fringes) 

468 self.log.info("Fringe solution: %s RMS: %f Good: %d/%d", solution, rms, len(science), origNum) 

469 return solution, rms 

470 

471 def _solve(self, science, fringes): 

472 """Solve for the scale factors. 

473 

474 Parameters 

475 ---------- 

476 science : `numpy.array` 

477 Array of measured science image values at each of the 

478 positions supplied. 

479 fringes : `numpy.array` 

480 Array of measured fringe values at each of the positions 

481 supplied. 

482 

483 Returns 

484 ------- 

485 solution : `np.array` 

486 Fringe solution amplitudes for each input fringe frame. 

487 """ 

488 return afwMath.LeastSquares.fromDesignMatrix(fringes, science, 

489 afwMath.LeastSquares.DIRECT_SVD).getSolution() 

490 

491 def subtract(self, science, fringes, solution): 

492 """Subtract the fringes. 

493 

494 Parameters 

495 ---------- 

496 science : `lsst.afw.image.Exposure` 

497 Science exposure from which to remove fringes. 

498 fringes : `lsst.afw.image.Exposure` or `list` thereof 

499 Calibration fringe files containing master fringe frames. 

500 solution : `np.array` 

501 Fringe solution amplitudes for each input fringe frame. 

502 

503 Raises 

504 ------ 

505 RuntimeError : 

506 Raised if the number of fringe frames does not match the 

507 number of measured amplitudes. 

508 """ 

509 if len(solution) != len(fringes): 

510 raise RuntimeError("Number of fringe frames (%s) != number of scale factors (%s)" % 

511 (len(fringes), len(solution))) 

512 

513 for s, f in zip(solution, fringes): 

514 science.getMaskedImage().scaledMinus(s, f.getMaskedImage()) 

515 

516 

517def measure(mi, x, y, size, statistic, stats): 

518 """Measure a statistic within an aperture 

519 

520 @param mi MaskedImage to measure 

521 @param x, y Center for aperture 

522 @param size Size of aperture 

523 @param statistic Statistic to measure 

524 @param stats StatisticsControl object 

525 @return Value of statistic within aperture 

526 """ 

527 bbox = lsst.geom.Box2I(lsst.geom.Point2I(int(x) - size, int(y - size)), 

528 lsst.geom.Extent2I(2*size, 2*size)) 

529 subImage = mi.Factory(mi, bbox, afwImage.LOCAL) 

530 return afwMath.makeStatistics(subImage, statistic, stats).getValue() 

531 

532 

533def stdev(vector): 

534 """Calculate a robust standard deviation of an array of values 

535 

536 @param vector Array of values 

537 @return Standard deviation 

538 """ 

539 q1, q3 = numpy.percentile(vector, (25, 75)) 

540 return 0.74*(q3 - q1) 

541 

542 

543def select(vector, clip): 

544 """Select values within 'clip' standard deviations of the median 

545 

546 Returns a boolean array. 

547 """ 

548 q1, q2, q3 = numpy.percentile(vector, (25, 50, 75)) 

549 return numpy.abs(vector - q2) < clip*0.74*(q3 - q1)