Coverage for python / lsst / ip / isr / brighterFatterKernel.py: 7%

343 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-30 09:00 +0000

1# This file is part of ip_isr. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21# 

22"""Brighter Fatter Kernel calibration definition.""" 

23 

24 

25__all__ = [ 

26 'BrighterFatterKernel', 

27 'brighterFatterCorrection', 

28 'fluxConservingBrighterFatterCorrection'] 

29 

30 

31import scipy 

32import numpy as np 

33from astropy.table import Table 

34 

35import lsst.afw.math as afwMath 

36import lsst.afw.image as afwImage 

37 

38from . import IsrCalib 

39from .isrFunctions import gainContext 

40 

41 

42class BrighterFatterKernel(IsrCalib): 

43 """Calibration of brighter-fatter kernels for an instrument. 

44 

45 ampKernels are the kernels for each amplifier in a detector, as 

46 generated by having ``level == 'AMP'``. 

47 

48 detectorKernel is the kernel generated for a detector as a 

49 whole, as generated by having ``level == 'DETECTOR'``. 

50 

51 makeDetectorKernelFromAmpwiseKernels is a method to generate the 

52 kernel for a detector, constructed by averaging together the 

53 ampwise kernels in the detector. The existing application code is 

54 only defined for kernels with ``level == 'DETECTOR'``, so this method 

55 is used if the supplied kernel was built with ``level == 'AMP'``. 

56 

57 Parameters 

58 ---------- 

59 camera : `lsst.afw.cameraGeom.Camera` 

60 Camera describing detector geometry. 

61 level : `str` 

62 Level the kernels will be generated for. 

63 log : `logging.Logger`, optional 

64 Log to write messages to. 

65 **kwargs : 

66 Parameters to pass to parent constructor. 

67 

68 Notes 

69 ----- 

70 Version 1.1 adds the `expIdMask` property, and substitutes 

71 `means` and `variances` for `rawMeans` and `rawVariances` 

72 from the PTC dataset. 

73 

74 expIdMask : `dict`, [`str`,`numpy.ndarray`] 

75 Dictionary keyed by amp names containing the mask produced after 

76 outlier rejection. 

77 rawMeans : `dict`, [`str`, `numpy.ndarray`] 

78 Dictionary keyed by amp names containing the unmasked average of the 

79 means of the exposures in each flat pair. 

80 rawVariances : `dict`, [`str`, `numpy.ndarray`] 

81 Dictionary keyed by amp names containing the variance of the 

82 difference image of the exposures in each flat pair. 

83 Corresponds to rawVars of PTC. 

84 rawXcorrs : `dict`, [`str`, `numpy.ndarray`] 

85 Dictionary keyed by amp names containing an array of measured 

86 covariances per mean flux. 

87 Corresponds to covariances of PTC. 

88 badAmps : `list` 

89 List of bad amplifiers names. 

90 shape : `tuple` 

91 Tuple of the shape of the BFK kernels. 

92 gain : `dict`, [`str`,`float`] 

93 Dictionary keyed by amp names containing the fitted gains. 

94 noise : `dict`, [`str`,`float`] 

95 Dictionary keyed by amp names containing the fitted noise. 

96 meanXcorrs : `dict`, [`str`,`numpy.ndarray`] 

97 Dictionary keyed by amp names containing the averaged 

98 cross-correlations. 

99 valid : `dict`, [`str`,`bool`] 

100 Dictionary keyed by amp names containing validity of data. 

101 ampKernels : `dict`, [`str`, `numpy.ndarray`] 

102 Dictionary keyed by amp names containing the BF kernels. 

103 detKernels : `dict` 

104 Dictionary keyed by detector names containing the BF kernels. 

105 """ 

106 _OBSTYPE = 'bfk' 

107 _SCHEMA = 'Brighter-fatter kernel' 

108 _VERSION = 1.1 

109 

110 def __init__(self, camera=None, level=None, **kwargs): 

111 self.level = level 

112 

113 # Things inherited from the PTC 

114 self.expIdMask = dict() 

115 self.rawMeans = dict() 

116 self.rawVariances = dict() 

117 self.rawXcorrs = dict() 

118 self.badAmps = list() 

119 self.shape = (17, 17) 

120 self.gain = dict() 

121 self.noise = dict() 

122 

123 # Things calculated from the PTC 

124 self.meanXcorrs = dict() 

125 self.valid = dict() 

126 

127 # Things that are used downstream 

128 self.ampKernels = dict() 

129 self.detKernels = dict() 

130 

131 super().__init__(**kwargs) 

132 

133 if camera: 

134 self.initFromCamera(camera, detectorId=kwargs.get('detectorId', None)) 

135 

136 self.requiredAttributes.update(['level', 'expIdMask', 'rawMeans', 'rawVariances', 'rawXcorrs', 

137 'badAmps', 'gain', 'noise', 'meanXcorrs', 'valid', 

138 'ampKernels', 'detKernels']) 

139 

140 def updateMetadata(self, setDate=False, **kwargs): 

141 """Update calibration metadata. 

142 

143 This calls the base class's method after ensuring the required 

144 calibration keywords will be saved. 

145 

146 Parameters 

147 ---------- 

148 setDate : `bool`, optional 

149 Update the CALIBDATE fields in the metadata to the current 

150 time. Defaults to False. 

151 kwargs : 

152 Other keyword parameters to set in the metadata. 

153 """ 

154 kwargs['LEVEL'] = self.level 

155 kwargs['KERNEL_DX'] = self.shape[0] 

156 kwargs['KERNEL_DY'] = self.shape[1] 

157 

158 super().updateMetadata(setDate=setDate, **kwargs) 

159 

160 def initFromCamera(self, camera, detectorId=None): 

161 """Initialize kernel structure from camera. 

162 

163 Parameters 

164 ---------- 

165 camera : `lsst.afw.cameraGeom.Camera` 

166 Camera to use to define geometry. 

167 detectorId : `int`, optional 

168 Index of the detector to generate. 

169 

170 Returns 

171 ------- 

172 calib : `lsst.ip.isr.BrighterFatterKernel` 

173 The initialized calibration. 

174 

175 Raises 

176 ------ 

177 RuntimeError 

178 Raised if no detectorId is supplied for a calibration with 

179 ``level='AMP'``. 

180 """ 

181 self._instrument = camera.getName() 

182 

183 if detectorId is not None: 

184 detector = camera[detectorId] 

185 self._detectorId = detectorId 

186 self._detectorName = detector.getName() 

187 self._detectorSerial = detector.getSerial() 

188 

189 if self.level == 'AMP': 

190 if detectorId is None: 

191 raise RuntimeError("A detectorId must be supplied if level='AMP'.") 

192 

193 self.badAmps = [] 

194 

195 for amp in detector: 

196 ampName = amp.getName() 

197 self.expIdMask[ampName] = [] 

198 self.rawMeans[ampName] = [] 

199 self.rawVariances[ampName] = [] 

200 self.rawXcorrs[ampName] = [] 

201 self.gain[ampName] = amp.getGain() 

202 self.noise[ampName] = amp.getReadNoise() 

203 self.meanXcorrs[ampName] = [] 

204 self.ampKernels[ampName] = [] 

205 self.valid[ampName] = [] 

206 elif self.level == 'DETECTOR': 

207 if detectorId is None: 

208 for det in camera: 

209 detName = det.getName() 

210 self.detKernels[detName] = [] 

211 else: 

212 self.detKernels[self._detectorName] = [] 

213 

214 return self 

215 

216 def getLengths(self): 

217 """Return the set of lengths needed for reshaping components. 

218 

219 Returns 

220 ------- 

221 kernelLength : `int` 

222 Product of the elements of self.shape. 

223 smallLength : `int` 

224 Size of an untiled covariance. 

225 nObs : `int` 

226 Number of observation pairs used in the kernel. 

227 """ 

228 kernelLength = self.shape[0] * self.shape[1] 

229 smallLength = int((self.shape[0] - 1)*(self.shape[1] - 1)/4) 

230 if self.level == 'AMP': 

231 nObservations = set([len(self.rawMeans[amp]) for amp in self.rawMeans]) 

232 if len(nObservations) != 1: 

233 raise RuntimeError("Inconsistent number of observations found.") 

234 nObs = nObservations.pop() 

235 else: 

236 nObs = 0 

237 

238 return (kernelLength, smallLength, nObs) 

239 

240 @classmethod 

241 def fromDict(cls, dictionary): 

242 """Construct a calibration from a dictionary of properties. 

243 

244 Parameters 

245 ---------- 

246 dictionary : `dict` 

247 Dictionary of properties. 

248 

249 Returns 

250 ------- 

251 calib : `lsst.ip.isr.BrighterFatterKernel` 

252 Constructed calibration. 

253 

254 Raises 

255 ------ 

256 RuntimeError 

257 Raised if the supplied dictionary is for a different 

258 calibration. 

259 Raised if the version of the supplied dictionary is 1.0. 

260 """ 

261 calib = cls() 

262 

263 if calib._OBSTYPE != (found := dictionary['metadata']['OBSTYPE']): 

264 raise RuntimeError(f"Incorrect brighter-fatter kernel supplied. Expected {calib._OBSTYPE}, " 

265 f"found {found}") 

266 calib.setMetadata(dictionary['metadata']) 

267 calib.calibInfoFromDict(dictionary) 

268 

269 calib.level = dictionary['metadata'].get('LEVEL', 'AMP') 

270 calib.shape = (dictionary['metadata'].get('KERNEL_DX', 0), 

271 dictionary['metadata'].get('KERNEL_DY', 0)) 

272 

273 calibVersion = dictionary['metadata']['bfk_VERSION'] 

274 if calibVersion == 1.0: 

275 calib.log.debug("Old Version of brighter-fatter kernel found. Current version: " 

276 f"{calib._VERSION}. The new attribute 'expIdMask' will be " 

277 "populated with 'True' values, and the new attributes 'rawMeans' " 

278 "and 'rawVariances' will be populated with the masked 'means' " 

279 "and 'variances' values." 

280 ) 

281 # use 'means', because 'expIdMask' does not exist. 

282 calib.expIdMask = {amp: np.repeat(True, len(dictionary['means'][amp])) for amp in 

283 dictionary['means']} 

284 calib.rawMeans = {amp: np.array(dictionary['means'][amp]) for amp in dictionary['means']} 

285 calib.rawVariances = {amp: np.array(dictionary['variances'][amp]) for amp in 

286 dictionary['variances']} 

287 elif calibVersion == 1.1: 

288 calib.expIdMask = {amp: np.array(dictionary['expIdMask'][amp]) for amp in dictionary['expIdMask']} 

289 calib.rawMeans = {amp: np.array(dictionary['rawMeans'][amp]) for amp in dictionary['rawMeans']} 

290 calib.rawVariances = {amp: np.array(dictionary['rawVariances'][amp]) for amp in 

291 dictionary['rawVariances']} 

292 else: 

293 raise RuntimeError(f"Unknown version for brighter-fatter kernel: {calibVersion}") 

294 

295 # Lengths for reshape: 

296 _, smallLength, nObs = calib.getLengths() 

297 smallShapeSide = int(np.sqrt(smallLength)) 

298 

299 calib.rawXcorrs = {amp: np.array(dictionary['rawXcorrs'][amp]).reshape((nObs, 

300 smallShapeSide, 

301 smallShapeSide)) 

302 for amp in dictionary['rawXcorrs']} 

303 

304 calib.gain = dictionary['gain'] 

305 calib.noise = dictionary['noise'] 

306 

307 calib.meanXcorrs = {amp: np.array(dictionary['meanXcorrs'][amp]).reshape(calib.shape) 

308 for amp in dictionary['rawXcorrs']} 

309 calib.ampKernels = {amp: np.array(dictionary['ampKernels'][amp]).reshape(calib.shape) 

310 for amp in dictionary['ampKernels']} 

311 calib.valid = {amp: bool(value) for amp, value in dictionary['valid'].items()} 

312 calib.badAmps = [amp for amp, valid in dictionary['valid'].items() if valid is False] 

313 

314 calib.detKernels = {det: np.array(dictionary['detKernels'][det]).reshape(calib.shape) 

315 for det in dictionary['detKernels']} 

316 

317 calib.updateMetadata() 

318 return calib 

319 

320 def toDict(self): 

321 """Return a dictionary containing the calibration properties. 

322 

323 The dictionary should be able to be round-tripped through 

324 `fromDict`. 

325 

326 Returns 

327 ------- 

328 dictionary : `dict` 

329 Dictionary of properties. 

330 """ 

331 self.updateMetadata() 

332 

333 outDict = {} 

334 metadata = self.getMetadata() 

335 outDict['metadata'] = metadata 

336 

337 # Lengths for ravel: 

338 kernelLength, smallLength, nObs = self.getLengths() 

339 

340 outDict['expIdMask'] = {amp: np.array(self.expIdMask[amp]).tolist() for amp in self.expIdMask} 

341 outDict['rawMeans'] = {amp: np.array(self.rawMeans[amp]).tolist() for amp in self.rawMeans} 

342 outDict['rawVariances'] = {amp: np.array(self.rawVariances[amp]).tolist() for amp in 

343 self.rawVariances} 

344 

345 for amp in self.rawXcorrs.keys(): 

346 # Check to see if we need to repack the data. 

347 correlationShape = np.array(self.rawXcorrs[amp]).shape 

348 if nObs != correlationShape[0]: 

349 if correlationShape[0] == np.sum(self.expIdMask[amp]): 

350 # Repack data. 

351 self.repackCorrelations(amp, correlationShape) 

352 else: 

353 raise ValueError("Could not coerce rawXcorrs into appropriate shape " 

354 "(have %d correlations, but expect to see %d.", 

355 correlationShape[0], np.sum(self.expIdMask[amp])) 

356 

357 outDict['rawXcorrs'] = {amp: np.array(self.rawXcorrs[amp]).reshape(nObs*smallLength).tolist() 

358 for amp in self.rawXcorrs} 

359 outDict['badAmps'] = self.badAmps 

360 outDict['gain'] = self.gain 

361 outDict['noise'] = self.noise 

362 

363 outDict['meanXcorrs'] = {amp: self.meanXcorrs[amp].reshape(kernelLength).tolist() 

364 for amp in self.meanXcorrs} 

365 outDict['ampKernels'] = {amp: self.ampKernels[amp].reshape(kernelLength).tolist() 

366 for amp in self.ampKernels} 

367 outDict['valid'] = self.valid 

368 

369 outDict['detKernels'] = {det: self.detKernels[det].reshape(kernelLength).tolist() 

370 for det in self.detKernels} 

371 return outDict 

372 

373 @classmethod 

374 def fromTable(cls, tableList): 

375 """Construct calibration from a list of tables. 

376 

377 This method uses the `fromDict` method to create the 

378 calibration, after constructing an appropriate dictionary from 

379 the input tables. 

380 

381 Parameters 

382 ---------- 

383 tableList : `list` [`astropy.table.Table`] 

384 List of tables to use to construct the brighter-fatter 

385 calibration. 

386 

387 Returns 

388 ------- 

389 calib : `lsst.ip.isr.BrighterFatterKernel` 

390 The calibration defined in the tables. 

391 """ 

392 ampTable = tableList[0] 

393 

394 metadata = ampTable.meta 

395 inDict = dict() 

396 inDict['metadata'] = metadata 

397 

398 amps = ampTable['AMPLIFIER'] 

399 

400 # Determine version for expected values. The ``fromDict`` 

401 # method can unpack either, but the appropriate fields need to 

402 # be supplied. 

403 calibVersion = metadata['bfk_VERSION'] 

404 

405 if calibVersion == 1.0: 

406 # We expect to find ``means`` and ``variances`` for this 

407 # case, and will construct an ``expIdMask`` from these 

408 # parameters in the ``fromDict`` method. 

409 rawMeanList = ampTable['MEANS'] 

410 rawVarianceList = ampTable['VARIANCES'] 

411 

412 inDict['means'] = {amp: mean for amp, mean in zip(amps, rawMeanList)} 

413 inDict['variances'] = {amp: var for amp, var in zip(amps, rawVarianceList)} 

414 elif calibVersion == 1.1: 

415 # This will have ``rawMeans`` and ``rawVariances``, which 

416 # are filtered via the ``expIdMask`` fields. 

417 expIdMaskList = ampTable['EXP_ID_MASK'] 

418 rawMeanList = ampTable['RAW_MEANS'] 

419 rawVarianceList = ampTable['RAW_VARIANCES'] 

420 

421 inDict['expIdMask'] = {amp: mask for amp, mask in zip(amps, expIdMaskList)} 

422 inDict['rawMeans'] = {amp: mean for amp, mean in zip(amps, rawMeanList)} 

423 inDict['rawVariances'] = {amp: var for amp, var in zip(amps, rawVarianceList)} 

424 else: 

425 raise RuntimeError(f"Unknown version for brighter-fatter kernel: {calibVersion}") 

426 

427 rawXcorrs = ampTable['RAW_XCORRS'] 

428 gainList = ampTable['GAIN'] 

429 noiseList = ampTable['NOISE'] 

430 

431 meanXcorrs = ampTable['MEAN_XCORRS'] 

432 ampKernels = ampTable['KERNEL'] 

433 validList = ampTable['VALID'] 

434 

435 inDict['rawXcorrs'] = {amp: kernel for amp, kernel in zip(amps, rawXcorrs)} 

436 inDict['gain'] = {amp: gain for amp, gain in zip(amps, gainList)} 

437 inDict['noise'] = {amp: noise for amp, noise in zip(amps, noiseList)} 

438 inDict['meanXcorrs'] = {amp: kernel for amp, kernel in zip(amps, meanXcorrs)} 

439 inDict['ampKernels'] = {amp: kernel for amp, kernel in zip(amps, ampKernels)} 

440 inDict['valid'] = {amp: bool(valid) for amp, valid in zip(amps, validList)} 

441 

442 inDict['badAmps'] = [amp for amp, valid in inDict['valid'].items() if valid is False] 

443 

444 if len(tableList) > 1: 

445 detTable = tableList[1] 

446 inDict['detKernels'] = {det: kernel for det, kernel 

447 in zip(detTable['DETECTOR'], detTable['KERNEL'])} 

448 else: 

449 inDict['detKernels'] = {} 

450 

451 return cls.fromDict(inDict) 

452 

453 def toTable(self): 

454 """Construct a list of tables containing the information in this 

455 calibration. 

456 

457 The list of tables should create an identical calibration 

458 after being passed to this class's fromTable method. 

459 

460 Returns 

461 ------- 

462 tableList : `list` [`lsst.afw.table.Table`] 

463 List of tables containing the crosstalk calibration 

464 information. 

465 

466 """ 

467 tableList = [] 

468 self.updateMetadata() 

469 

470 # Lengths 

471 kernelLength, smallLength, nObs = self.getLengths() 

472 

473 ampList = [] 

474 expIdMaskList = [] 

475 rawMeanList = [] 

476 rawVarianceList = [] 

477 rawXcorrs = [] 

478 gainList = [] 

479 noiseList = [] 

480 

481 meanXcorrsList = [] 

482 kernelList = [] 

483 validList = [] 

484 

485 if self.level == 'AMP': 

486 for amp in self.rawMeans.keys(): 

487 ampList.append(amp) 

488 expIdMaskList.append(self.expIdMask[amp]) 

489 rawMeanList.append(self.rawMeans[amp]) 

490 rawVarianceList.append(self.rawVariances[amp]) 

491 

492 correlationShape = np.array(self.rawXcorrs[amp]).shape 

493 if nObs != correlationShape[0]: 

494 if correlationShape[0] == np.sum(self.expIdMask[amp]): 

495 # Repack data. 

496 self.repackCorrelations(amp, correlationShape) 

497 else: 

498 raise ValueError("Could not coerce rawXcorrs into appropriate shape " 

499 "(have %d correlations, but expect to see %d.", 

500 correlationShape[0], np.sum(self.expIdMask[amp])) 

501 

502 rawXcorrs.append(np.array(self.rawXcorrs[amp]).reshape(nObs*smallLength).tolist()) 

503 gainList.append(self.gain[amp]) 

504 noiseList.append(self.noise[amp]) 

505 

506 meanXcorrsList.append(self.meanXcorrs[amp].reshape(kernelLength).tolist()) 

507 kernelList.append(self.ampKernels[amp].reshape(kernelLength).tolist()) 

508 validList.append(int(self.valid[amp] and not (amp in self.badAmps))) 

509 

510 ampTable = Table({'AMPLIFIER': ampList, 

511 'EXP_ID_MASK': expIdMaskList, 

512 'RAW_MEANS': rawMeanList, 

513 'RAW_VARIANCES': rawVarianceList, 

514 'RAW_XCORRS': rawXcorrs, 

515 'GAIN': gainList, 

516 'NOISE': noiseList, 

517 'MEAN_XCORRS': meanXcorrsList, 

518 'KERNEL': kernelList, 

519 'VALID': validList, 

520 }) 

521 

522 ampTable.meta = self.getMetadata().toDict() 

523 tableList.append(ampTable) 

524 

525 if len(self.detKernels): 

526 detList = [] 

527 kernelList = [] 

528 for det in self.detKernels.keys(): 

529 detList.append(det) 

530 kernelList.append(self.detKernels[det].reshape(kernelLength).tolist()) 

531 

532 detTable = Table({'DETECTOR': detList, 

533 'KERNEL': kernelList}) 

534 detTable.meta = self.getMetadata().toDict() 

535 tableList.append(detTable) 

536 

537 return tableList 

538 

539 def repackCorrelations(self, amp, correlationShape): 

540 """If the correlations were masked, they need to be repacked into the 

541 correct shape. 

542 

543 Parameters 

544 ---------- 

545 amp : `str` 

546 Amplifier needing repacked. 

547 correlationShape : `tuple` [`int`], (3, ) 

548 Shape the correlations are expected to take. 

549 """ 

550 repackedCorrelations = [] 

551 idx = 0 

552 for maskValue in self.expIdMask[amp]: 

553 if maskValue: 

554 repackedCorrelations.append(self.rawXcorrs[amp][idx]) 

555 idx += 1 

556 else: 

557 repackedCorrelations.append(np.full((correlationShape[1], correlationShape[2]), np.nan)) 

558 self.rawXcorrs[amp] = repackedCorrelations 

559 

560 # Implementation methods 

561 def makeDetectorKernelFromAmpwiseKernels(self, detectorName, ampsToExclude=[]): 

562 """Average the amplifier level kernels to create a detector level 

563 kernel. There is no change in index ordering/orientation from 

564 this averaging. 

565 

566 Parameters 

567 ---------- 

568 detectorName : `str` 

569 Detector for which the averaged kernel will be used. 

570 ampsToExclude : `list` [`str`], optional 

571 Amps that should not be included in the average. 

572 """ 

573 inKernels = np.array([self.ampKernels[amp] for amp in 

574 self.ampKernels if (self.valid[amp] and amp not in ampsToExclude)]) 

575 avgKernel = np.zeros_like(inKernels[0]) 

576 sctrl = afwMath.StatisticsControl() 

577 sctrl.setNumSigmaClip(5.0) 

578 for i in range(np.shape(avgKernel)[0]): 

579 for j in range(np.shape(avgKernel)[1]): 

580 avgKernel[i, j] = afwMath.makeStatistics(inKernels[:, i, j], 

581 afwMath.MEANCLIP, sctrl).getValue() 

582 

583 self.detKernels[detectorName] = avgKernel 

584 

585 def replaceDetectorKernelWithAmpKernel(self, ampName, detectorName): 

586 self.detKernel[detectorName] = self.ampKernel[ampName] 

587 

588 

589def brighterFatterCorrection(exposure, kernel, maxIter, threshold, applyGain, gains=None): 

590 """Apply brighter fatter correction in place for the image. 

591 

592 Parameters 

593 ---------- 

594 exposure : `lsst.afw.image.Exposure` 

595 Exposure to have brighter-fatter correction applied. Modified 

596 by this method. 

597 kernel : `numpy.ndarray` 

598 Brighter-fatter kernel to apply. 

599 maxIter : scalar 

600 Number of correction iterations to run. 

601 threshold : scalar 

602 Convergence threshold in terms of the sum of absolute 

603 deviations between an iteration and the previous one. 

604 applyGain : `Bool` 

605 If True, then the exposure values are scaled by the gain prior 

606 to correction. 

607 gains : `dict` [`str`, `float`] 

608 A dictionary, keyed by amplifier name, of the gains to use. 

609 If gains is None, the nominal gains in the amplifier object are used. 

610 

611 Returns 

612 ------- 

613 diff : `float` 

614 Final difference between iterations achieved in correction. 

615 iteration : `int` 

616 Number of iterations used to calculate correction. 

617 

618 Notes 

619 ----- 

620 This correction takes a kernel that has been derived from flat 

621 field images to redistribute the charge. The gradient of the 

622 kernel is the deflection field due to the accumulated charge. 

623 

624 Given the original image I(x) and the kernel K(x) we can compute 

625 the corrected image Ic(x) using the following equation: 

626 

627 Ic(x) = I(x) + 0.5*d/dx(I(x)*d/dx(int( dy*K(x-y)*I(y)))) 

628 

629 To evaluate the derivative term we expand it as follows: 

630 

631 0.5 * ( d/dx(I(x))*d/dx(int(dy*K(x-y)*I(y))) 

632 + I(x)*d^2/dx^2(int(dy* K(x-y)*I(y))) ) 

633 

634 Because we use the measured counts instead of the incident counts 

635 we apply the correction iteratively to reconstruct the original 

636 counts and the correction. We stop iterating when the summed 

637 difference between the current corrected image and the one from 

638 the previous iteration is below the threshold. We do not require 

639 convergence because the number of iterations is too large a 

640 computational cost. How we define the threshold still needs to be 

641 evaluated, the current default was shown to work reasonably well 

642 on a small set of images. For more information on the method see 

643 DocuShare Document-19407. 

644 

645 The edges as defined by the kernel are not corrected because they 

646 have spurious values due to the convolution. 

647 """ 

648 image = exposure.getMaskedImage().getImage() 

649 

650 # The image needs to be units of electrons/holes 

651 with gainContext(exposure, image, applyGain, gains): 

652 

653 kLx = np.shape(kernel)[0] 

654 kLy = np.shape(kernel)[1] 

655 kernelImage = afwImage.ImageD(kLx, kLy) 

656 kernelImage.getArray()[:, :] = kernel 

657 tempImage = afwImage.ImageD(image, deep=True) 

658 

659 nanIndex = np.isnan(tempImage.getArray()) 

660 tempImage.getArray()[nanIndex] = 0. 

661 

662 corr = np.zeros(image.array.shape, dtype=np.float64) 

663 prev_image = np.zeros(image.array.shape, dtype=np.float64) 

664 

665 # Define boundary by convolution region. The region that the 

666 # correction will be calculated for is one fewer in each dimension 

667 # because of the second derivative terms. 

668 # NOTE: these need to use integer math, as we're using start:end as 

669 # numpy index ranges. 

670 startX = kLx//2 

671 endX = -kLx//2 

672 startY = kLy//2 

673 endY = -kLy//2 

674 

675 for iteration in range(maxIter): 

676 

677 outArray = scipy.signal.convolve( 

678 tempImage.array, 

679 kernelImage.array, 

680 mode="same", 

681 method="fft", 

682 ) 

683 tmpArray = tempImage.getArray() 

684 

685 with np.errstate(invalid="ignore", over="ignore"): 

686 # First derivative term 

687 gradTmp = np.gradient(tmpArray[startY:endY, startX:endX]) 

688 gradOut = np.gradient(outArray[startY:endY, startX:endX]) 

689 first = (gradTmp[0]*gradOut[0] + gradTmp[1]*gradOut[1])[1:-1, 1:-1] 

690 

691 # Second derivative term 

692 diffOut20 = np.diff(outArray, 2, 0)[startY:endY, startX + 1:endX - 1] 

693 diffOut21 = np.diff(outArray, 2, 1)[startY + 1:endY - 1, startX:endX] 

694 second = tmpArray[startY + 1:endY - 1, startX + 1:endX - 1]*(diffOut20 + diffOut21) 

695 

696 corr[startY + 1:endY - 1, startX + 1:endX - 1] = 0.5*(first + second) 

697 

698 tmpArray[:, :] = image.getArray()[:, :] 

699 tmpArray[nanIndex] = 0. 

700 tmpArray[startY:endY, startX:endX] += corr[startY:endY, startX:endX] 

701 

702 if iteration > 0: 

703 diff = np.sum(np.abs(prev_image - tmpArray), dtype=np.float64) 

704 

705 if diff < threshold: 

706 break 

707 prev_image[:, :] = tmpArray[:, :] 

708 

709 image.getArray()[startY + 1:endY - 1, startX + 1:endX - 1] += \ 

710 corr[startY + 1:endY - 1, startX + 1:endX - 1] 

711 

712 return diff, iteration 

713 

714 

715def transferFlux(cFunc, fStep, correctionMode=True): 

716 """Take the input convolved deflection potential and the flux array 

717 to compute and apply the flux transfer into the correction array. 

718 

719 Parameters 

720 ---------- 

721 cFunc: `np.array` 

722 Deflection potential, being the convolution of the flux F with the 

723 kernel K. 

724 fStep: `np.array` 

725 The array of flux values which act as the source of the flux transfer. 

726 correctionMode: `bool` 

727 Defines if applying correction (True) or generating sims (False). 

728 

729 Returns 

730 ------- 

731 corr: 

732 BFE correction array 

733 """ 

734 

735 if cFunc.shape != fStep.shape: 

736 raise RuntimeError(f'transferFlux: array shapes do not match: {cFunc.shape}, {fStep.shape}') 

737 

738 # set the sign of the correction and set its value for the 

739 # time averaged solution 

740 if correctionMode: 

741 # negative sign if applying BFE correction 

742 factor = -0.5 

743 else: 

744 # positive sign if generating BFE simulations 

745 factor = 0.5 

746 

747 # initialise the BFE correction image to zero 

748 corr = np.zeros(cFunc.shape, dtype=np.float64) 

749 

750 # Generate a 2D mesh of x,y coordinates 

751 yDim, xDim = cFunc.shape 

752 y = np.arange(yDim, dtype=int) 

753 x = np.arange(xDim, dtype=int) 

754 xc, yc = np.meshgrid(x, y) 

755 

756 # process each axis in turn 

757 for ax in [0, 1]: 

758 

759 # gradient of phi on right/upper edge of pixel 

760 diff = np.diff(cFunc, axis=ax) 

761 

762 # expand array back to full size with zero gradient at the end 

763 gx = np.zeros(cFunc.shape, dtype=np.float64) 

764 yDiff, xDiff = diff.shape 

765 gx[:yDiff, :xDiff] += diff 

766 

767 # select pixels with either positive gradients on the right edge, 

768 # flux flowing to the right/up 

769 # or negative gradients, flux flowing to the left/down 

770 for i, sel in enumerate([gx > 0, gx < 0]): 

771 xSelPixels = xc[sel] 

772 ySelPixels = yc[sel] 

773 # and add the flux into the pixel to the right or top 

774 # depending on which axis we are handling 

775 if ax == 0: 

776 xPix = xSelPixels 

777 yPix = ySelPixels+1 

778 else: 

779 xPix = xSelPixels+1 

780 yPix = ySelPixels 

781 # define flux as the either current pixel value or pixel 

782 # above/right 

783 # depending on whether positive or negative gradient 

784 if i == 0: 

785 # positive gradients, flux flowing to higher coordinate values 

786 flux = factor * fStep[sel]*gx[sel] 

787 else: 

788 # negative gradients, flux flowing to lower coordinate values 

789 flux = factor * fStep[yPix, xPix]*gx[sel] 

790 # change the fluxes of the donor and receiving pixels 

791 # such that flux is conserved 

792 corr[sel] -= flux 

793 corr[yPix, xPix] += flux 

794 

795 # return correction array 

796 return corr 

797 

798 

799def fluxConservingBrighterFatterCorrection(exposure, kernel, maxIter, threshold, applyGain, 

800 gains=None, correctionMode=True): 

801 """Apply brighter fatter correction in place for the image. 

802 

803 This version presents a modified version of the algorithm 

804 found in ``lsst.ip.isr.isrFunctions.brighterFatterCorrection`` 

805 which conserves the image flux, resulting in improved 

806 correction of the cores of stars. The convolution has also been 

807 modified to mitigate edge effects. 

808 

809 Parameters 

810 ---------- 

811 exposure : `lsst.afw.image.Exposure` 

812 Exposure to have brighter-fatter correction applied. Modified 

813 by this method. 

814 kernel : `np.ndarray` 

815 Brighter-fatter kernel to apply. 

816 maxIter : scalar 

817 Number of correction iterations to run. 

818 threshold : scalar 

819 Convergence threshold in terms of the sum of absolute 

820 deviations between an iteration and the previous one. 

821 applyGain : `Bool` 

822 If True, then the exposure values are scaled by the gain prior 

823 to correction. 

824 gains : `dict` [`str`, `float`] 

825 A dictionary, keyed by amplifier name, of the gains to use. 

826 If gains is None, the nominal gains in the amplifier object are used. 

827 correctionMode : `Bool` 

828 If True (default) the function applies correction for BFE. If False, 

829 the code can instead be used to generate a simulation of BFE (sign 

830 change in the direction of the effect) 

831 

832 Returns 

833 ------- 

834 diff : `float` 

835 Final difference between iterations achieved in correction. 

836 iteration : `int` 

837 Number of iterations used to calculate correction. 

838 

839 Notes 

840 ----- 

841 Modified version of ``lsst.ip.isr.isrFunctions.brighterFatterCorrection``. 

842 

843 This correction takes a kernel that has been derived from flat 

844 field images to redistribute the charge. The gradient of the 

845 kernel is the deflection field due to the accumulated charge. 

846 

847 Given the original image I(x) and the kernel K(x) we can compute 

848 the corrected image Ic(x) using the following equation: 

849 

850 Ic(x) = I(x) + 0.5*d/dx(I(x)*d/dx(int( dy*K(x-y)*I(y)))) 

851 

852 Improved algorithm at this step applies the divergence theorem to 

853 obtain a pixelised correction. 

854 

855 Because we use the measured counts instead of the incident counts 

856 we apply the correction iteratively to reconstruct the original 

857 counts and the correction. We stop iterating when the summed 

858 difference between the current corrected image and the one from 

859 the previous iteration is below the threshold. We do not require 

860 convergence because the number of iterations is too large a 

861 computational cost. How we define the threshold still needs to be 

862 evaluated, the current default was shown to work reasonably well 

863 on a small set of images. 

864 

865 Edges are handled in the convolution by padding. This is still not 

866 a physical model for the edge, but avoids discontinuity in the correction. 

867 

868 Author of modified version: Lance.Miller@physics.ox.ac.uk 

869 (see DM-38555). 

870 """ 

871 image = exposure.getMaskedImage().getImage() 

872 

873 # The image needs to be units of electrons/holes 

874 with gainContext(exposure, image, applyGain, gains): 

875 

876 # get kernel and its shape 

877 kLy, kLx = kernel.shape 

878 kernelImage = afwImage.ImageD(kLx, kLy) 

879 kernelImage.getArray()[:, :] = kernel 

880 tempImage = afwImage.ImageD(image, deep=True) 

881 

882 nanIndex = np.isnan(tempImage.getArray()) 

883 tempImage.getArray()[nanIndex] = 0. 

884 

885 outImage = afwImage.ImageD(image.getDimensions()) 

886 corr = np.zeros(image.array.shape, dtype=np.float64) 

887 prevImage = np.zeros(image.array.shape, dtype=np.float64) 

888 convCntrl = afwMath.ConvolutionControl(False, False, 1) 

889 fixedKernel = afwMath.FixedKernel(kernelImage) 

890 

891 # set the padding amount 

892 # ensure we pad by an even amount larger than the kernel 

893 kLy = 2 * ((1+kLy)//2) 

894 kLx = 2 * ((1+kLx)//2) 

895 

896 # The deflection potential only depends on the gradient of 

897 # the convolution, so we can subtract the mean, which then 

898 # allows us to pad the image with zeros and avoid wrap-around effects 

899 # (although still not handling the image edges with a physical model) 

900 # This wouldn't be great if there were a strong image gradient. 

901 imYdimension, imXdimension = tempImage.array.shape 

902 imean = np.mean(tempImage.getArray()[~nanIndex], dtype=np.float64) 

903 # subtract mean from image 

904 tempImage -= imean 

905 tempImage.array[nanIndex] = 0.0 

906 padArray = np.pad(tempImage.getArray(), ((0, kLy), (0, kLx))) 

907 outImage = afwImage.ImageD(np.pad(outImage.getArray(), ((0, kLy), (0, kLx)))) 

908 # Convert array to afw image so afwMath.convolve works 

909 padImage = afwImage.ImageD(padArray.shape[1], padArray.shape[0]) 

910 padImage.array[:] = padArray 

911 

912 for iteration in range(maxIter): 

913 

914 # create deflection potential, convolution of flux with kernel 

915 # using padded counts array 

916 afwMath.convolve(outImage, padImage, fixedKernel, convCntrl) 

917 tmpArray = tempImage.getArray() 

918 outArray = outImage.getArray() 

919 

920 # trim convolution output back to original shape 

921 outArray = outArray[:imYdimension, :imXdimension] 

922 

923 # generate the correction array, with correctionMode set as input 

924 corr[...] = transferFlux(outArray, tmpArray, correctionMode=correctionMode) 

925 

926 # update the arrays for the next iteration 

927 tmpArray[:, :] = image.getArray()[:, :] 

928 tmpArray += corr 

929 tmpArray[nanIndex] = 0. 

930 # update padded array 

931 # subtract mean 

932 tmpArray -= imean 

933 tempImage.array[nanIndex] = 0. 

934 padArray = np.pad(tempImage.getArray(), ((0, kLy), (0, kLx))) 

935 

936 if iteration > 0: 

937 diff = np.sum(np.abs(prevImage - tmpArray), dtype=np.float64) 

938 

939 if diff < threshold: 

940 break 

941 prevImage[:, :] = tmpArray[:, :] 

942 

943 image.getArray()[:] += corr[:] 

944 

945 return diff, iteration