Coverage for python/lsst/cp/pipe/makeBrighterFatterKernel.py: 14%

212 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-08-03 12:39 +0000

1# This file is part of cp_pipe. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21# 

22"""Calculation of brighter-fatter effect correlations and kernels.""" 

23 

24__all__ = ['BrighterFatterKernelSolveTask', 

25 'BrighterFatterKernelSolveConfig'] 

26 

27import numpy as np 

28 

29import lsst.afw.math as afwMath 

30import lsst.pex.config as pexConfig 

31import lsst.pipe.base as pipeBase 

32import lsst.pipe.base.connectionTypes as cT 

33 

34from lsst.ip.isr import (BrighterFatterKernel) 

35from .utils import (funcPolynomial, irlsFit) 

36from ._lookupStaticCalibration import lookupStaticCalibration 

37 

38 

39class BrighterFatterKernelSolveConnections(pipeBase.PipelineTaskConnections, 

40 dimensions=("instrument", "exposure", "detector")): 

41 dummy = cT.Input( 

42 name="raw", 

43 doc="Dummy exposure.", 

44 storageClass='Exposure', 

45 dimensions=("instrument", "exposure", "detector"), 

46 multiple=True, 

47 deferLoad=True, 

48 ) 

49 camera = cT.PrerequisiteInput( 

50 name="camera", 

51 doc="Camera associated with this data.", 

52 storageClass="Camera", 

53 dimensions=("instrument", ), 

54 isCalibration=True, 

55 lookupFunction=lookupStaticCalibration, 

56 ) 

57 inputPtc = cT.PrerequisiteInput( 

58 name="ptc", 

59 doc="Photon transfer curve dataset.", 

60 storageClass="PhotonTransferCurveDataset", 

61 dimensions=("instrument", "detector"), 

62 isCalibration=True, 

63 ) 

64 

65 outputBFK = cT.Output( 

66 name="brighterFatterKernel", 

67 doc="Output measured brighter-fatter kernel.", 

68 storageClass="BrighterFatterKernel", 

69 dimensions=("instrument", "detector"), 

70 isCalibration=True, 

71 ) 

72 

73 

74class BrighterFatterKernelSolveConfig(pipeBase.PipelineTaskConfig, 

75 pipelineConnections=BrighterFatterKernelSolveConnections): 

76 level = pexConfig.ChoiceField( 

77 doc="The level at which to calculate the brighter-fatter kernels", 

78 dtype=str, 

79 default="AMP", 

80 allowed={ 

81 "AMP": "Every amplifier treated separately", 

82 "DETECTOR": "One kernel per detector", 

83 } 

84 ) 

85 ignoreAmpsForAveraging = pexConfig.ListField( 

86 dtype=str, 

87 doc="List of amp names to ignore when averaging the amplifier kernels into the detector" 

88 " kernel. Only relevant for level = DETECTOR", 

89 default=[] 

90 ) 

91 xcorrCheckRejectLevel = pexConfig.Field( 

92 dtype=float, 

93 doc="Rejection level for the sum of the input cross-correlations. Arrays which " 

94 "sum to greater than this are discarded before the clipped mean is calculated.", 

95 default=2.0 

96 ) 

97 nSigmaClip = pexConfig.Field( 

98 dtype=float, 

99 doc="Number of sigma to clip when calculating means for the cross-correlation", 

100 default=5 

101 ) 

102 forceZeroSum = pexConfig.Field( 

103 dtype=bool, 

104 doc="Force the correlation matrix to have zero sum by adjusting the (0,0) value?", 

105 default=False, 

106 ) 

107 useAmatrix = pexConfig.Field( 

108 dtype=bool, 

109 doc="Use the PTC 'a' matrix (Astier et al. 2019 equation 20) " 

110 "instead of the average of measured covariances?", 

111 default=False, 

112 ) 

113 

114 maxIterSuccessiveOverRelaxation = pexConfig.Field( 

115 dtype=int, 

116 doc="The maximum number of iterations allowed for the successive over-relaxation method", 

117 default=10000 

118 ) 

119 eLevelSuccessiveOverRelaxation = pexConfig.Field( 

120 dtype=float, 

121 doc="The target residual error for the successive over-relaxation method", 

122 default=5.0e-14 

123 ) 

124 

125 correlationQuadraticFit = pexConfig.Field( 

126 dtype=bool, 

127 doc="Use a quadratic fit to find the correlations instead of simple averaging?", 

128 default=False, 

129 ) 

130 correlationModelRadius = pexConfig.Field( 

131 dtype=int, 

132 doc="Build a model of the correlation coefficients for radii larger than this value in pixels?", 

133 default=100, 

134 ) 

135 correlationModelSlope = pexConfig.Field( 

136 dtype=float, 

137 doc="Slope of the correlation model for radii larger than correlationModelRadius", 

138 default=-1.35, 

139 ) 

140 

141 

142class BrighterFatterKernelSolveTask(pipeBase.PipelineTask, pipeBase.CmdLineTask): 

143 """Measure appropriate Brighter-Fatter Kernel from the PTC dataset. 

144 """ 

145 

146 ConfigClass = BrighterFatterKernelSolveConfig 

147 _DefaultName = 'cpBfkMeasure' 

148 

149 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

150 """Ensure that the input and output dimensions are passed along. 

151 

152 Parameters 

153 ---------- 

154 butlerQC : `lsst.daf.butler.butlerQuantumContext.ButlerQuantumContext` 

155 Butler to operate on. 

156 inputRefs : `lsst.pipe.base.connections.InputQuantizedConnection` 

157 Input data refs to load. 

158 ouptutRefs : `lsst.pipe.base.connections.OutputQuantizedConnection` 

159 Output data refs to persist. 

160 """ 

161 inputs = butlerQC.get(inputRefs) 

162 

163 # Use the dimensions to set calib/provenance information. 

164 inputs['inputDims'] = inputRefs.inputPtc.dataId.byName() 

165 

166 outputs = self.run(**inputs) 

167 butlerQC.put(outputs, outputRefs) 

168 

169 def run(self, inputPtc, dummy, camera, inputDims): 

170 """Combine covariance information from PTC into brighter-fatter 

171 kernels. 

172 

173 Parameters 

174 ---------- 

175 inputPtc : `lsst.ip.isr.PhotonTransferCurveDataset` 

176 PTC data containing per-amplifier covariance measurements. 

177 dummy : `lsst.afw.image.Exposure` 

178 The exposure used to select the appropriate PTC dataset. 

179 In almost all circumstances, one of the input exposures 

180 used to generate the PTC dataset is the best option. 

181 camera : `lsst.afw.cameraGeom.Camera` 

182 Camera to use for camera geometry information. 

183 inputDims : `lsst.daf.butler.DataCoordinate` or `dict` 

184 DataIds to use to populate the output calibration. 

185 

186 Returns 

187 ------- 

188 results : `lsst.pipe.base.Struct` 

189 The resulst struct containing: 

190 

191 ``outputBfk`` 

192 Resulting Brighter-Fatter Kernel 

193 (`lsst.ip.isr.BrighterFatterKernel`). 

194 """ 

195 if len(dummy) == 0: 

196 self.log.warning("No dummy exposure found.") 

197 

198 detector = camera[inputDims['detector']] 

199 detName = detector.getName() 

200 

201 if self.config.level == 'DETECTOR': 

202 detectorCorrList = list() 

203 detectorFluxes = list() 

204 

205 bfk = BrighterFatterKernel(camera=camera, detectorId=detector.getId(), level=self.config.level) 

206 bfk.rawMeans = inputPtc.rawMeans # ADU 

207 bfk.rawVariances = inputPtc.rawVars # ADU^2 

208 bfk.expIdMask = inputPtc.expIdMask 

209 

210 # Use the PTC covariances as the cross-correlations. These 

211 # are scaled before the kernel is generated, which performs 

212 # the conversion. 

213 bfk.rawXcorrs = inputPtc.covariances # ADU^2 

214 bfk.badAmps = inputPtc.badAmps 

215 bfk.shape = (inputPtc.covMatrixSide*2 + 1, inputPtc.covMatrixSide*2 + 1) 

216 bfk.gain = inputPtc.gain 

217 bfk.noise = inputPtc.noise 

218 bfk.meanXcorrs = dict() 

219 bfk.valid = dict() 

220 

221 for amp in detector: 

222 ampName = amp.getName() 

223 gain = bfk.gain[ampName] 

224 mask = inputPtc.expIdMask[ampName] 

225 if gain <= 0: 

226 # We've received very bad data. 

227 self.log.warning("Impossible gain recieved from PTC for %s: %f. Skipping bad amplifier.", 

228 ampName, gain) 

229 bfk.meanXcorrs[ampName] = np.zeros(bfk.shape) 

230 bfk.ampKernels[ampName] = np.zeros(bfk.shape) 

231 bfk.rawXcorrs[ampName] = np.zeros((len(mask), inputPtc.covMatrixSide, inputPtc.covMatrixSide)) 

232 bfk.valid[ampName] = False 

233 continue 

234 

235 # Use inputPtc.expIdMask to get the means, variances, 

236 # and covariances that were not masked after PTC. 

237 fluxes = np.array(bfk.rawMeans[ampName])[mask] 

238 variances = np.array(bfk.rawVariances[ampName])[mask] 

239 xCorrList = np.array([np.array(xcorr) for xcorr in bfk.rawXcorrs[ampName]])[mask] 

240 

241 fluxes = np.array([flux*gain for flux in fluxes]) # Now in e^- 

242 variances = np.array([variance*gain*gain for variance in variances]) # Now in e^2- 

243 

244 # This should duplicate Coulton et al. 2017 Equation 22-29 

245 # (arxiv:1711.06273) 

246 scaledCorrList = list() 

247 corrList = list() 

248 truncatedFluxes = list() 

249 for xcorrNum, (xcorr, flux, var) in enumerate(zip(xCorrList, fluxes, variances), 1): 

250 q = np.array(xcorr) * gain * gain # xcorr now in e^- 

251 q *= 2.0 # Remove factor of 1/2 applied in PTC. 

252 self.log.info("Amp: %s %d/%d Flux: %f Var: %f Q(0,0): %g Q(1,0): %g Q(0,1): %g", 

253 ampName, xcorrNum, len(xCorrList), flux, var, q[0][0], q[1][0], q[0][1]) 

254 

255 # Normalize by the flux, which removes the (0,0) 

256 # component attributable to Poisson noise. This 

257 # contains the two "t I delta(x - x')" terms in 

258 # Coulton et al. 2017 equation 29 

259 q[0][0] -= 2.0*(flux) 

260 

261 if q[0][0] > 0.0: 

262 self.log.warning("Amp: %s %d skipped due to value of (variance-mean)=%f", 

263 ampName, xcorrNum, q[0][0]) 

264 # If we drop an element of ``scaledCorrList`` 

265 # (which is what this does), we need to ensure we 

266 # drop the flux entry as well. 

267 continue 

268 

269 # This removes the "t (I_a^2 + I_b^2)" factor in 

270 # Coulton et al. 2017 equation 29. 

271 # The quadratic fit option needs the correlations unscaled 

272 q /= -2.0 

273 unscaled = self._tileArray(q) 

274 q /= flux**2 

275 scaled = self._tileArray(q) 

276 xcorrCheck = np.abs(np.sum(scaled))/np.sum(np.abs(scaled)) 

277 if (xcorrCheck > self.config.xcorrCheckRejectLevel) or not (np.isfinite(xcorrCheck)): 

278 self.log.warning("Amp: %s %d skipped due to value of triangle-inequality sum %f", 

279 ampName, xcorrNum, xcorrCheck) 

280 continue 

281 

282 scaledCorrList.append(scaled) 

283 corrList.append(unscaled) 

284 truncatedFluxes.append(flux) 

285 self.log.info("Amp: %s %d/%d Final: %g XcorrCheck: %f", 

286 ampName, xcorrNum, len(xCorrList), q[0][0], xcorrCheck) 

287 

288 fluxes = np.array(truncatedFluxes) 

289 

290 if len(scaledCorrList) == 0: 

291 self.log.warning("Amp: %s All inputs rejected for amp!", ampName) 

292 bfk.meanXcorrs[ampName] = np.zeros(bfk.shape) 

293 bfk.ampKernels[ampName] = np.zeros(bfk.shape) 

294 bfk.valid[ampName] = False 

295 continue 

296 

297 if self.config.useAmatrix: 

298 # Use the aMatrix, ignoring the meanXcorr generated above. 

299 preKernel = np.pad(self._tileArray(np.array(inputPtc.aMatrix[ampName])), ((1, 1))) 

300 elif self.config.correlationQuadraticFit: 

301 # Use a quadratic fit to the correlations as a 

302 # function of flux. 

303 preKernel = self.quadraticCorrelations(corrList, fluxes, f"Amp: {ampName}") 

304 else: 

305 # Use a simple average of the measured correlations. 

306 preKernel = self.averageCorrelations(scaledCorrList, f"Amp: {ampName}") 

307 

308 center = int((bfk.shape[0] - 1) / 2) 

309 

310 if self.config.forceZeroSum: 

311 totalSum = np.sum(preKernel) 

312 

313 if self.config.correlationModelRadius < (preKernel.shape[0] - 1) / 2: 

314 # Assume a correlation model of 

315 # Corr(r) = -preFactor * r^(2 * slope) 

316 preFactor = np.sqrt(preKernel[center, center + 1] * preKernel[center + 1, center]) 

317 slopeFactor = 2.0 * np.abs(self.config.correlationModelSlope) 

318 totalSum += 2.0*np.pi*(preFactor / (slopeFactor*(center + 0.5))**slopeFactor) 

319 

320 preKernel[center, center] -= totalSum 

321 self.log.info("%s Zero-Sum Scale: %g", ampName, totalSum) 

322 

323 finalSum = np.sum(preKernel) 

324 bfk.meanXcorrs[ampName] = preKernel 

325 

326 postKernel = self.successiveOverRelax(preKernel) 

327 bfk.ampKernels[ampName] = postKernel 

328 if self.config.level == 'DETECTOR': 

329 detectorCorrList.extend(scaledCorrList) 

330 detectorFluxes.extend(fluxes) 

331 bfk.valid[ampName] = True 

332 self.log.info("Amp: %s Sum: %g Center Info Pre: %g Post: %g", 

333 ampName, finalSum, preKernel[center, center], postKernel[center, center]) 

334 

335 # Assemble a detector kernel? 

336 if self.config.level == 'DETECTOR': 

337 if self.config.correlationQuadraticFit: 

338 preKernel = self.quadraticCorrelations(detectorCorrList, detectorFluxes, f"Amp: {ampName}") 

339 else: 

340 preKernel = self.averageCorrelations(detectorCorrList, f"Det: {detName}") 

341 finalSum = np.sum(preKernel) 

342 center = int((bfk.shape[0] - 1) / 2) 

343 

344 postKernel = self.successiveOverRelax(preKernel) 

345 bfk.detKernels[detName] = postKernel 

346 self.log.info("Det: %s Sum: %g Center Info Pre: %g Post: %g", 

347 detName, finalSum, preKernel[center, center], postKernel[center, center]) 

348 

349 return pipeBase.Struct( 

350 outputBFK=bfk, 

351 ) 

352 

353 def averageCorrelations(self, xCorrList, name): 

354 """Average input correlations. 

355 

356 Parameters 

357 ---------- 

358 xCorrList : `list` [`numpy.array`] 

359 List of cross-correlations. These are expected to be 

360 square arrays. 

361 name : `str` 

362 Name for log messages. 

363 

364 Returns 

365 ------- 

366 meanXcorr : `numpy.array`, (N, N) 

367 The averaged cross-correlation. 

368 """ 

369 meanXcorr = np.zeros_like(xCorrList[0]) 

370 xCorrList = np.transpose(xCorrList) 

371 sctrl = afwMath.StatisticsControl() 

372 sctrl.setNumSigmaClip(self.config.nSigmaClip) 

373 for i in range(np.shape(meanXcorr)[0]): 

374 for j in range(np.shape(meanXcorr)[1]): 

375 meanXcorr[i, j] = afwMath.makeStatistics(xCorrList[i, j], 

376 afwMath.MEANCLIP, sctrl).getValue() 

377 

378 # To match previous definitions, pad by one element. 

379 meanXcorr = np.pad(meanXcorr, ((1, 1))) 

380 

381 return meanXcorr 

382 

383 def quadraticCorrelations(self, xCorrList, fluxList, name): 

384 """Measure a quadratic correlation model. 

385 

386 Parameters 

387 ---------- 

388 xCorrList : `list` [`numpy.array`] 

389 List of cross-correlations. These are expected to be 

390 square arrays. 

391 fluxList : `numpy.array`, (Nflux,) 

392 Associated list of fluxes. 

393 name : `str` 

394 Name for log messages. 

395 

396 Returns 

397 ------- 

398 meanXcorr : `numpy.array`, (N, N) 

399 The averaged cross-correlation. 

400 """ 

401 meanXcorr = np.zeros_like(xCorrList[0]) 

402 fluxList = np.square(fluxList) 

403 xCorrList = np.array(xCorrList) 

404 

405 for i in range(np.shape(meanXcorr)[0]): 

406 for j in range(np.shape(meanXcorr)[1]): 

407 # Fit corrlation_i(x, y) = a0 + a1 * (flux_i)^2 The 

408 # i,j indices are inverted to apply the transposition, 

409 # as is done in the averaging case. 

410 linearFit, linearFitErr, chiSq, weights = irlsFit([0.0, 1e-4], fluxList, 

411 xCorrList[:, j, i], funcPolynomial, 

412 scaleResidual=False) 

413 meanXcorr[i, j] = linearFit[1] # Discard the intercept. 

414 self.log.info("Quad fit meanXcorr[%d,%d] = %g", i, j, linearFit[1]) 

415 

416 # To match previous definitions, pad by one element. 

417 meanXcorr = np.pad(meanXcorr, ((1, 1))) 

418 

419 return meanXcorr 

420 

421 @staticmethod 

422 def _tileArray(in_array): 

423 """Given an input quarter-image, tile/mirror it and return full image. 

424 

425 Given a square input of side-length n, of the form 

426 

427 input = array([[1, 2, 3], 

428 [4, 5, 6], 

429 [7, 8, 9]]) 

430 

431 return an array of size 2n-1 as 

432 

433 output = array([[ 9, 8, 7, 8, 9], 

434 [ 6, 5, 4, 5, 6], 

435 [ 3, 2, 1, 2, 3], 

436 [ 6, 5, 4, 5, 6], 

437 [ 9, 8, 7, 8, 9]]) 

438 

439 Parameters 

440 ---------- 

441 input : `np.array`, (N, N) 

442 The square input quarter-array 

443 

444 Returns 

445 ------- 

446 output : `np.array`, (2*N + 1, 2*N + 1) 

447 The full, tiled array 

448 """ 

449 assert(in_array.shape[0] == in_array.shape[1]) 

450 length = in_array.shape[0] - 1 

451 output = np.zeros((2*length + 1, 2*length + 1)) 

452 

453 for i in range(length + 1): 

454 for j in range(length + 1): 

455 output[i + length, j + length] = in_array[i, j] 

456 output[-i + length, j + length] = in_array[i, j] 

457 output[i + length, -j + length] = in_array[i, j] 

458 output[-i + length, -j + length] = in_array[i, j] 

459 return output 

460 

461 def successiveOverRelax(self, source, maxIter=None, eLevel=None): 

462 """An implementation of the successive over relaxation (SOR) method. 

463 

464 A numerical method for solving a system of linear equations 

465 with faster convergence than the Gauss-Seidel method. 

466 

467 Parameters 

468 ---------- 

469 source : `numpy.ndarray`, (N, N) 

470 The input array. 

471 maxIter : `int`, optional 

472 Maximum number of iterations to attempt before aborting. 

473 eLevel : `float`, optional 

474 The target error level at which we deem convergence to have 

475 occurred. 

476 

477 Returns 

478 ------- 

479 output : `numpy.ndarray`, (N, N) 

480 The solution. 

481 """ 

482 if not maxIter: 

483 maxIter = self.config.maxIterSuccessiveOverRelaxation 

484 if not eLevel: 

485 eLevel = self.config.eLevelSuccessiveOverRelaxation 

486 

487 assert source.shape[0] == source.shape[1], "Input array must be square" 

488 # initialize, and set boundary conditions 

489 func = np.zeros([source.shape[0] + 2, source.shape[1] + 2]) 

490 resid = np.zeros([source.shape[0] + 2, source.shape[1] + 2]) 

491 rhoSpe = np.cos(np.pi/source.shape[0]) # Here a square grid is assumed 

492 

493 # Calculate the initial error 

494 for i in range(1, func.shape[0] - 1): 

495 for j in range(1, func.shape[1] - 1): 

496 resid[i, j] = (func[i, j - 1] + func[i, j + 1] + func[i - 1, j] 

497 + func[i + 1, j] - 4*func[i, j] - source[i - 1, j - 1]) 

498 inError = np.sum(np.abs(resid)) 

499 

500 # Iterate until convergence 

501 # We perform two sweeps per cycle, 

502 # updating 'odd' and 'even' points separately 

503 nIter = 0 

504 omega = 1.0 

505 dx = 1.0 

506 while nIter < maxIter*2: 

507 outError = 0 

508 if nIter%2 == 0: 

509 for i in range(1, func.shape[0] - 1, 2): 

510 for j in range(1, func.shape[1] - 1, 2): 

511 resid[i, j] = float(func[i, j-1] + func[i, j + 1] + func[i - 1, j] 

512 + func[i + 1, j] - 4.0*func[i, j] - dx*dx*source[i - 1, j - 1]) 

513 func[i, j] += omega*resid[i, j]*.25 

514 for i in range(2, func.shape[0] - 1, 2): 

515 for j in range(2, func.shape[1] - 1, 2): 

516 resid[i, j] = float(func[i, j - 1] + func[i, j + 1] + func[i - 1, j] 

517 + func[i + 1, j] - 4.0*func[i, j] - dx*dx*source[i - 1, j - 1]) 

518 func[i, j] += omega*resid[i, j]*.25 

519 else: 

520 for i in range(1, func.shape[0] - 1, 2): 

521 for j in range(2, func.shape[1] - 1, 2): 

522 resid[i, j] = float(func[i, j - 1] + func[i, j + 1] + func[i - 1, j] 

523 + func[i + 1, j] - 4.0*func[i, j] - dx*dx*source[i - 1, j - 1]) 

524 func[i, j] += omega*resid[i, j]*.25 

525 for i in range(2, func.shape[0] - 1, 2): 

526 for j in range(1, func.shape[1] - 1, 2): 

527 resid[i, j] = float(func[i, j - 1] + func[i, j + 1] + func[i - 1, j] 

528 + func[i + 1, j] - 4.0*func[i, j] - dx*dx*source[i - 1, j - 1]) 

529 func[i, j] += omega*resid[i, j]*.25 

530 outError = np.sum(np.abs(resid)) 

531 if outError < inError*eLevel: 

532 break 

533 if nIter == 0: 

534 omega = 1.0/(1 - rhoSpe*rhoSpe/2.0) 

535 else: 

536 omega = 1.0/(1 - rhoSpe*rhoSpe*omega/4.0) 

537 nIter += 1 

538 

539 if nIter >= maxIter*2: 

540 self.log.warning("Failure: SuccessiveOverRelaxation did not converge in %s iterations." 

541 "\noutError: %s, inError: %s,", nIter//2, outError, inError*eLevel) 

542 else: 

543 self.log.info("Success: SuccessiveOverRelaxation converged in %s iterations." 

544 "\noutError: %s, inError: %s", nIter//2, outError, inError*eLevel) 

545 return func[1: -1, 1: -1]