Coverage for python/lsst/ip/isr/overscan.py: 13%

222 statements  

« prev     ^ index     » next       coverage.py v6.4.1, created at 2022-06-16 10:28 +0000

1# This file is part of ip_isr. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import numpy as np 

23import time 

24import lsst.afw.math as afwMath 

25import lsst.afw.image as afwImage 

26import lsst.pipe.base as pipeBase 

27import lsst.pex.config as pexConfig 

28 

29from .isr import fitOverscanImage 

30 

31__all__ = ["OverscanCorrectionTaskConfig", "OverscanCorrectionTask"] 

32 

33 

34class OverscanCorrectionTaskConfig(pexConfig.Config): 

35 """Overscan correction options. 

36 """ 

37 fitType = pexConfig.ChoiceField( 

38 dtype=str, 

39 doc="The method for fitting the overscan bias level.", 

40 default='MEDIAN', 

41 allowed={ 

42 "POLY": "Fit ordinary polynomial to the longest axis of the overscan region", 

43 "CHEB": "Fit Chebyshev polynomial to the longest axis of the overscan region", 

44 "LEG": "Fit Legendre polynomial to the longest axis of the overscan region", 

45 "NATURAL_SPLINE": "Fit natural spline to the longest axis of the overscan region", 

46 "CUBIC_SPLINE": "Fit cubic spline to the longest axis of the overscan region", 

47 "AKIMA_SPLINE": "Fit Akima spline to the longest axis of the overscan region", 

48 "MEAN": "Correct using the mean of the overscan region", 

49 "MEANCLIP": "Correct using a clipped mean of the overscan region", 

50 "MEDIAN": "Correct using the median of the overscan region", 

51 "MEDIAN_PER_ROW": "Correct using the median per row of the overscan region", 

52 }, 

53 ) 

54 order = pexConfig.Field( 

55 dtype=int, 

56 doc=("Order of polynomial to fit if overscan fit type is a polynomial, " 

57 "or number of spline knots if overscan fit type is a spline."), 

58 default=1, 

59 ) 

60 numSigmaClip = pexConfig.Field( 

61 dtype=float, 

62 doc="Rejection threshold (sigma) for collapsing overscan before fit", 

63 default=3.0, 

64 ) 

65 maskPlanes = pexConfig.ListField( 

66 dtype=str, 

67 doc="Mask planes to reject when measuring overscan", 

68 default=['BAD', 'SAT'], 

69 ) 

70 overscanIsInt = pexConfig.Field( 

71 dtype=bool, 

72 doc="Treat overscan as an integer image for purposes of fitType=MEDIAN" 

73 " and fitType=MEDIAN_PER_ROW.", 

74 default=True, 

75 ) 

76 

77 

78class OverscanCorrectionTask(pipeBase.Task): 

79 """Correction task for overscan. 

80 

81 This class contains a number of utilities that are easier to 

82 understand and use when they are not embedded in nested if/else 

83 loops. 

84 

85 Parameters 

86 ---------- 

87 statControl : `lsst.afw.math.StatisticsControl`, optional 

88 Statistics control object. 

89 """ 

90 ConfigClass = OverscanCorrectionTaskConfig 

91 _DefaultName = "overscan" 

92 

93 def __init__(self, statControl=None, **kwargs): 

94 super().__init__(**kwargs) 

95 self.allowDebug = True 

96 

97 if statControl: 

98 self.statControl = statControl 

99 else: 

100 self.statControl = afwMath.StatisticsControl() 

101 self.statControl.setNumSigmaClip(self.config.numSigmaClip) 

102 self.statControl.setAndMask(afwImage.Mask.getPlaneBitMask(self.config.maskPlanes)) 

103 

104 def run(self, ampImage, overscanImage, amp=None): 

105 """Measure and remove an overscan from an amplifier image. 

106 

107 Parameters 

108 ---------- 

109 ampImage : `lsst.afw.image.Image` 

110 Image data that will have the overscan removed. 

111 overscanImage : `lsst.afw.image.Image` 

112 Overscan data that the overscan is measured from. 

113 amp : `lsst.afw.cameraGeom.Amplifier`, optional 

114 Amplifier to use for debugging purposes. 

115 

116 Returns 

117 ------- 

118 overscanResults : `lsst.pipe.base.Struct` 

119 Result struct with components: 

120 

121 ``imageFit`` 

122 Value or fit subtracted from the amplifier image data 

123 (scalar or `lsst.afw.image.Image`). 

124 ``overscanFit`` 

125 Value or fit subtracted from the overscan image data 

126 (scalar or `lsst.afw.image.Image`). 

127 ``overscanImage`` 

128 Image of the overscan region with the overscan 

129 correction applied (`lsst.afw.image.Image`). This 

130 quantity is used to estimate the amplifier read noise 

131 empirically. 

132 

133 Raises 

134 ------ 

135 RuntimeError 

136 Raised if an invalid overscan type is set. 

137 

138 """ 

139 if self.config.fitType in ('MEAN', 'MEANCLIP', 'MEDIAN'): 

140 overscanResult = self.measureConstantOverscan(overscanImage) 

141 overscanValue = overscanResult.overscanValue 

142 offImage = overscanValue 

143 overscanModel = overscanValue 

144 maskSuspect = None 

145 elif self.config.fitType in ('MEDIAN_PER_ROW', 'POLY', 'CHEB', 'LEG', 

146 'NATURAL_SPLINE', 'CUBIC_SPLINE', 'AKIMA_SPLINE'): 

147 overscanResult = self.measureVectorOverscan(overscanImage) 

148 overscanValue = overscanResult.overscanValue 

149 maskArray = overscanResult.maskArray 

150 isTransposed = overscanResult.isTransposed 

151 

152 offImage = afwImage.ImageF(ampImage.getDimensions()) 

153 offArray = offImage.getArray() 

154 overscanModel = afwImage.ImageF(overscanImage.getDimensions()) 

155 overscanArray = overscanModel.getArray() 

156 

157 if hasattr(ampImage, 'getMask'): 

158 maskSuspect = afwImage.Mask(ampImage.getDimensions()) 

159 else: 

160 maskSuspect = None 

161 

162 if isTransposed: 

163 offArray[:, :] = overscanValue[np.newaxis, :] 

164 overscanArray[:, :] = overscanValue[np.newaxis, :] 

165 if maskSuspect: 

166 maskSuspect.getArray()[:, maskArray] |= ampImage.getMask().getPlaneBitMask("SUSPECT") 

167 else: 

168 offArray[:, :] = overscanValue[:, np.newaxis] 

169 overscanArray[:, :] = overscanValue[:, np.newaxis] 

170 if maskSuspect: 

171 maskSuspect.getArray()[maskArray, :] |= ampImage.getMask().getPlaneBitMask("SUSPECT") 

172 else: 

173 raise RuntimeError('%s : %s an invalid overscan type' % 

174 ("overscanCorrection", self.config.fitType)) 

175 

176 self.debugView(overscanImage, overscanValue, amp) 

177 

178 ampImage -= offImage 

179 if maskSuspect: 

180 ampImage.getMask().getArray()[:, :] |= maskSuspect.getArray()[:, :] 

181 overscanImage -= overscanModel 

182 return pipeBase.Struct(imageFit=offImage, 

183 overscanFit=overscanModel, 

184 overscanImage=overscanImage, 

185 edgeMask=maskSuspect) 

186 

187 @staticmethod 

188 def integerConvert(image): 

189 """Return an integer version of the input image. 

190 

191 Parameters 

192 ---------- 

193 image : `numpy.ndarray`, `lsst.afw.image.Image` or `MaskedImage` 

194 Image to convert to integers. 

195 

196 Returns 

197 ------- 

198 outI : `numpy.ndarray`, `lsst.afw.image.Image` or `MaskedImage` 

199 The integer converted image. 

200 

201 Raises 

202 ------ 

203 RuntimeError 

204 Raised if the input image could not be converted. 

205 """ 

206 if hasattr(image, "image"): 

207 # Is a maskedImage: 

208 imageI = image.image.convertI() 

209 outI = afwImage.MaskedImageI(imageI, image.mask, image.variance) 

210 elif hasattr(image, "convertI"): 

211 # Is an Image: 

212 outI = image.convertI() 

213 elif hasattr(image, "astype"): 

214 # Is a numpy array: 

215 outI = image.astype(int) 

216 else: 

217 raise RuntimeError("Could not convert this to integers: %s %s %s", 

218 image, type(image), dir(image)) 

219 return outI 

220 

221 # Constant methods 

222 def measureConstantOverscan(self, image): 

223 """Measure a constant overscan value. 

224 

225 Parameters 

226 ---------- 

227 image : `lsst.afw.image.Image` or `lsst.afw.image.MaskedImage` 

228 Image data to measure the overscan from. 

229 

230 Returns 

231 ------- 

232 results : `lsst.pipe.base.Struct` 

233 Overscan result with entries: 

234 - ``overscanValue``: Overscan value to subtract (`float`) 

235 - ``maskArray``: Placeholder for a mask array (`list`) 

236 - ``isTransposed``: Orientation of the overscan (`bool`) 

237 """ 

238 if self.config.fitType == 'MEDIAN': 

239 calcImage = self.integerConvert(image) 

240 else: 

241 calcImage = image 

242 

243 fitType = afwMath.stringToStatisticsProperty(self.config.fitType) 

244 overscanValue = afwMath.makeStatistics(calcImage, fitType, self.statControl).getValue() 

245 

246 return pipeBase.Struct(overscanValue=overscanValue, 

247 maskArray=None, 

248 isTransposed=False) 

249 

250 # Vector correction utilities 

251 def getImageArray(self, image): 

252 """Extract the numpy array from the input image. 

253 

254 Parameters 

255 ---------- 

256 image : `lsst.afw.image.Image` or `lsst.afw.image.MaskedImage` 

257 Image data to pull array from. 

258 

259 calcImage : `numpy.ndarray` 

260 Image data array for numpy operating. 

261 """ 

262 if hasattr(image, "getImage"): 

263 calcImage = image.getImage().getArray() 

264 calcImage = np.ma.masked_where(image.getMask().getArray() & self.statControl.getAndMask(), 

265 calcImage) 

266 else: 

267 calcImage = image.getArray() 

268 return calcImage 

269 

270 @staticmethod 

271 def transpose(imageArray): 

272 """Transpose input numpy array if necessary. 

273 

274 Parameters 

275 ---------- 

276 imageArray : `numpy.ndarray` 

277 Image data to transpose. 

278 

279 Returns 

280 ------- 

281 imageArray : `numpy.ndarray` 

282 Transposed image data. 

283 isTransposed : `bool` 

284 Indicates whether the input data was transposed. 

285 """ 

286 if np.argmin(imageArray.shape) == 0: 

287 return np.transpose(imageArray), True 

288 else: 

289 return imageArray, False 

290 

291 def maskOutliers(self, imageArray): 

292 """Mask outliers in a row of overscan data from a robust sigma 

293 clipping procedure. 

294 

295 Parameters 

296 ---------- 

297 imageArray : `numpy.ndarray` 

298 Image to filter along numpy axis=1. 

299 

300 Returns 

301 ------- 

302 maskedArray : `numpy.ma.masked_array` 

303 Masked image marking outliers. 

304 """ 

305 lq, median, uq = np.percentile(imageArray, [25.0, 50.0, 75.0], axis=1) 

306 axisMedians = median 

307 axisStdev = 0.74*(uq - lq) # robust stdev 

308 

309 diff = np.abs(imageArray - axisMedians[:, np.newaxis]) 

310 return np.ma.masked_where(diff > self.statControl.getNumSigmaClip() 

311 * axisStdev[:, np.newaxis], imageArray) 

312 

313 @staticmethod 

314 def collapseArray(maskedArray): 

315 """Collapse overscan array (and mask) to a 1-D vector of values. 

316 

317 Parameters 

318 ---------- 

319 maskedArray : `numpy.ma.masked_array` 

320 Masked array of input overscan data. 

321 

322 Returns 

323 ------- 

324 collapsed : `numpy.ma.masked_array` 

325 Single dimensional overscan data, combined with the mean. 

326 """ 

327 collapsed = np.mean(maskedArray, axis=1) 

328 if collapsed.mask.sum() > 0: 

329 collapsed.data[collapsed.mask] = np.mean(maskedArray.data[collapsed.mask], axis=1) 

330 return collapsed 

331 

332 def collapseArrayMedian(self, maskedArray): 

333 """Collapse overscan array (and mask) to a 1-D vector of using the 

334 correct integer median of row-values. 

335 

336 Parameters 

337 ---------- 

338 maskedArray : `numpy.ma.masked_array` 

339 Masked array of input overscan data. 

340 

341 Returns 

342 ------- 

343 collapsed : `numpy.ma.masked_array` 

344 Single dimensional overscan data, combined with the afwMath median. 

345 """ 

346 integerMI = self.integerConvert(maskedArray) 

347 

348 collapsed = [] 

349 fitType = afwMath.stringToStatisticsProperty('MEDIAN') 

350 for row in integerMI: 

351 newRow = row.compressed() 

352 if len(newRow) > 0: 

353 rowMedian = afwMath.makeStatistics(newRow, fitType, self.statControl).getValue() 

354 else: 

355 rowMedian = np.nan 

356 collapsed.append(rowMedian) 

357 

358 return np.array(collapsed) 

359 

360 def splineFit(self, indices, collapsed, numBins): 

361 """Wrapper function to match spline fit API to polynomial fit API. 

362 

363 Parameters 

364 ---------- 

365 indices : `numpy.ndarray` 

366 Locations to evaluate the spline. 

367 collapsed : `numpy.ndarray` 

368 Collapsed overscan values corresponding to the spline 

369 evaluation points. 

370 numBins : `int` 

371 Number of bins to use in constructing the spline. 

372 

373 Returns 

374 ------- 

375 interp : `lsst.afw.math.Interpolate` 

376 Interpolation object for later evaluation. 

377 """ 

378 if not np.ma.is_masked(collapsed): 

379 collapsed.mask = np.array(len(collapsed)*[np.ma.nomask]) 

380 

381 numPerBin, binEdges = np.histogram(indices, bins=numBins, 

382 weights=1 - collapsed.mask.astype(int)) 

383 with np.errstate(invalid="ignore"): 

384 values = np.histogram(indices, bins=numBins, 

385 weights=collapsed.data*~collapsed.mask)[0]/numPerBin 

386 binCenters = np.histogram(indices, bins=numBins, 

387 weights=indices*~collapsed.mask)[0]/numPerBin 

388 

389 if len(binCenters[numPerBin > 0]) < 5: 

390 self.log.warn("Cannot do spline fitting for overscan: %s valid points.", 

391 len(binCenters[numPerBin > 0])) 

392 # Return a scalar value if we have one, otherwise 

393 # return zero. This amplifier is hopefully already 

394 # masked. 

395 if len(values[numPerBin > 0]) != 0: 

396 return float(values[numPerBin > 0][0]) 

397 else: 

398 return 0.0 

399 

400 interp = afwMath.makeInterpolate(binCenters.astype(float)[numPerBin > 0], 

401 values.astype(float)[numPerBin > 0], 

402 afwMath.stringToInterpStyle(self.config.fitType)) 

403 return interp 

404 

405 @staticmethod 

406 def splineEval(indices, interp): 

407 """Wrapper function to match spline evaluation API to polynomial fit 

408 API. 

409 

410 Parameters 

411 ---------- 

412 indices : `numpy.ndarray` 

413 Locations to evaluate the spline. 

414 interp : `lsst.afw.math.interpolate` 

415 Interpolation object to use. 

416 

417 Returns 

418 ------- 

419 values : `numpy.ndarray` 

420 Evaluated spline values at each index. 

421 """ 

422 

423 return interp.interpolate(indices.astype(float)) 

424 

425 @staticmethod 

426 def maskExtrapolated(collapsed): 

427 """Create mask if edges are extrapolated. 

428 

429 Parameters 

430 ---------- 

431 collapsed : `numpy.ma.masked_array` 

432 Masked array to check the edges of. 

433 

434 Returns 

435 ------- 

436 maskArray : `numpy.ndarray` 

437 Boolean numpy array of pixels to mask. 

438 """ 

439 maskArray = np.full_like(collapsed, False, dtype=bool) 

440 if np.ma.is_masked(collapsed): 

441 num = len(collapsed) 

442 for low in range(num): 

443 if not collapsed.mask[low]: 

444 break 

445 if low > 0: 

446 maskArray[:low] = True 

447 for high in range(1, num): 

448 if not collapsed.mask[-high]: 

449 break 

450 if high > 1: 

451 maskArray[-high:] = True 

452 return maskArray 

453 

454 def measureVectorOverscan(self, image): 

455 """Calculate the 1-d vector overscan from the input overscan image. 

456 

457 Parameters 

458 ---------- 

459 image : `lsst.afw.image.MaskedImage` 

460 Image containing the overscan data. 

461 

462 Returns 

463 ------- 

464 results : `lsst.pipe.base.Struct` 

465 Overscan result with entries: 

466 - ``overscanValue``: Overscan value to subtract (`float`) 

467 - ``maskArray`` : `list` [ `bool` ] 

468 List of rows that should be masked as ``SUSPECT`` when the 

469 overscan solution is applied. 

470 - ``isTransposed`` : `bool` 

471 Indicates if the overscan data was transposed during 

472 calcuation, noting along which axis the overscan should be 

473 subtracted. 

474 """ 

475 calcImage = self.getImageArray(image) 

476 

477 # operate on numpy-arrays from here 

478 calcImage, isTransposed = self.transpose(calcImage) 

479 masked = self.maskOutliers(calcImage) 

480 

481 startTime = time.perf_counter() 

482 

483 if self.config.fitType == 'MEDIAN_PER_ROW': 

484 mi = afwImage.MaskedImageI(image.getBBox()) 

485 masked = masked.astype(int) 

486 if isTransposed: 

487 masked = masked.transpose() 

488 

489 mi.image.array[:, :] = masked.data[:, :] 

490 if bool(masked.mask.shape): 

491 mi.mask.array[:, :] = masked.mask[:, :] 

492 

493 overscanVector = fitOverscanImage(mi, self.config.maskPlanes, isTransposed) 

494 maskArray = self.maskExtrapolated(overscanVector) 

495 else: 

496 collapsed = self.collapseArray(masked) 

497 

498 num = len(collapsed) 

499 indices = 2.0*np.arange(num)/float(num) - 1.0 

500 

501 poly = np.polynomial 

502 fitter, evaler = { 

503 'POLY': (poly.polynomial.polyfit, poly.polynomial.polyval), 

504 'CHEB': (poly.chebyshev.chebfit, poly.chebyshev.chebval), 

505 'LEG': (poly.legendre.legfit, poly.legendre.legval), 

506 'NATURAL_SPLINE': (self.splineFit, self.splineEval), 

507 'CUBIC_SPLINE': (self.splineFit, self.splineEval), 

508 'AKIMA_SPLINE': (self.splineFit, self.splineEval) 

509 }[self.config.fitType] 

510 

511 # These are the polynomial coefficients, or an 

512 # interpolation object. 

513 coeffs = fitter(indices, collapsed, self.config.order) 

514 

515 if isinstance(coeffs, float): 

516 self.log.warn("Using fallback value %f due to fitter failure. Amplifier will be masked.", 

517 coeffs) 

518 overscanVector = np.full_like(indices, coeffs) 

519 maskArray = np.full_like(collapsed, True, dtype=bool) 

520 else: 

521 # Otherwise we can just use things as normal. 

522 overscanVector = evaler(indices, coeffs) 

523 maskArray = self.maskExtrapolated(collapsed) 

524 endTime = time.perf_counter() 

525 self.log.info(f"Overscan measurement took {endTime - startTime}s for {self.config.fitType}") 

526 return pipeBase.Struct(overscanValue=np.array(overscanVector), 

527 maskArray=maskArray, 

528 isTransposed=isTransposed) 

529 

530 def debugView(self, image, model, amp=None): 

531 """Debug display for the final overscan solution. 

532 

533 Parameters 

534 ---------- 

535 image : `lsst.afw.image.Image` 

536 Input image the overscan solution was determined from. 

537 model : `numpy.ndarray` or `float` 

538 Overscan model determined for the image. 

539 amp : `lsst.afw.cameraGeom.Amplifier`, optional 

540 Amplifier to extract diagnostic information. 

541 """ 

542 import lsstDebug 

543 if not lsstDebug.Info(__name__).display: 

544 return 

545 if not self.allowDebug: 

546 return 

547 

548 calcImage = self.getImageArray(image) 

549 calcImage, isTransposed = self.transpose(calcImage) 

550 masked = self.maskOutliers(calcImage) 

551 collapsed = self.collapseArray(masked) 

552 

553 num = len(collapsed) 

554 indices = 2.0 * np.arange(num)/float(num) - 1.0 

555 

556 if np.ma.is_masked(collapsed): 

557 collapsedMask = collapsed.mask 

558 else: 

559 collapsedMask = np.array(num*[np.ma.nomask]) 

560 

561 import matplotlib.pyplot as plot 

562 figure = plot.figure(1) 

563 figure.clear() 

564 axes = figure.add_axes((0.1, 0.1, 0.8, 0.8)) 

565 axes.plot(indices[~collapsedMask], collapsed[~collapsedMask], 'k+') 

566 if collapsedMask.sum() > 0: 

567 axes.plot(indices[collapsedMask], collapsed.data[collapsedMask], 'b+') 

568 if isinstance(model, np.ndarray): 

569 plotModel = model 

570 else: 

571 plotModel = np.zeros_like(indices) 

572 plotModel += model 

573 axes.plot(indices, plotModel, 'r-') 

574 plot.xlabel("centered/scaled position along overscan region") 

575 plot.ylabel("pixel value/fit value") 

576 if amp: 

577 plot.title(f"{amp.getName()} DataX: " 

578 f"[{amp.getRawDataBBox().getBeginX()}:{amp.getRawBBox().getEndX()}]" 

579 f"OscanX: [{amp.getRawHorizontalOverscanBBox().getBeginX()}:" 

580 f"{amp.getRawHorizontalOverscanBBox().getEndX()}] {self.config.fitType}") 

581 else: 

582 plot.title("No amp supplied.") 

583 figure.show() 

584 prompt = "Press Enter or c to continue [chp]..." 

585 while True: 

586 ans = input(prompt).lower() 

587 if ans in ("", " ", "c",): 

588 break 

589 elif ans in ("p", ): 

590 import pdb 

591 pdb.set_trace() 

592 elif ans in ('x', ): 

593 self.allowDebug = False 

594 break 

595 elif ans in ("h", ): 

596 print("[h]elp [c]ontinue [p]db e[x]itDebug") 

597 plot.close()