Coverage for python/lsst/ip/isr/brighterFatterKernel.py: 8%

202 statements  

« prev     ^ index     » next       coverage.py v6.4.1, created at 2022-06-23 11:27 +0000

1# This file is part of ip_isr. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21# 

22"""Brighter Fatter Kernel calibration definition.""" 

23 

24 

25__all__ = ['BrighterFatterKernel'] 

26 

27 

28import numpy as np 

29from astropy.table import Table 

30import lsst.afw.math as afwMath 

31from . import IsrCalib 

32 

33 

34class BrighterFatterKernel(IsrCalib): 

35 """Calibration of brighter-fatter kernels for an instrument. 

36 

37 ampKernels are the kernels for each amplifier in a detector, as 

38 generated by having level == 'AMP' 

39 

40 detectorKernel is the kernel generated for a detector as a 

41 whole, as generated by having level == 'DETECTOR' 

42 

43 makeDetectorKernelFromAmpwiseKernels is a method to generate the 

44 kernel for a detector, constructed by averaging together the 

45 ampwise kernels in the detector. The existing application code is 

46 only defined for kernels with level == 'DETECTOR', so this method 

47 is used if the supplied kernel was built with level == 'AMP'. 

48 

49 Parameters 

50 ---------- 

51 level : `str` 

52 Level the kernels will be generated for. 

53 log : `logging.Logger`, optional 

54 Log to write messages to. 

55 **kwargs : 

56 Parameters to pass to parent constructor. 

57 

58 Notes 

59 ----- 

60 TODO: DM-35260 

61 Document what is stored in the BFK calibration. 

62 

63 Version 1.1 adds the `expIdMask` property, and substitutes 

64 `means` and `variances` for `rawMeans` and `rawVariances` 

65 from the PTC dataset. 

66 """ 

67 _OBSTYPE = 'bfk' 

68 _SCHEMA = 'Brighter-fatter kernel' 

69 _VERSION = 1.1 

70 

71 def __init__(self, camera=None, level=None, **kwargs): 

72 self.level = level 

73 

74 # Things inherited from the PTC 

75 self.expIdMask = dict() 

76 self.rawMeans = dict() 

77 self.rawVariances = dict() 

78 self.rawXcorrs = dict() 

79 self.badAmps = list() 

80 self.shape = (17, 17) 

81 self.gain = dict() 

82 self.noise = dict() 

83 

84 # Things calculated from the PTC 

85 self.meanXcorrs = dict() 

86 self.valid = dict() 

87 

88 # Things that are used downstream 

89 self.ampKernels = dict() 

90 self.detKernels = dict() 

91 

92 super().__init__(**kwargs) 

93 

94 if camera: 

95 self.initFromCamera(camera, detectorId=kwargs.get('detectorId', None)) 

96 

97 self.requiredAttributes.update(['level', 'expIdMask', 'rawMeans', 'rawVariances', 'rawXcorrs', 

98 'badAmps', 'gain', 'noise', 'meanXcorrs', 'valid', 

99 'ampKernels', 'detKernels']) 

100 

101 def updateMetadata(self, setDate=False, **kwargs): 

102 """Update calibration metadata. 

103 

104 This calls the base class's method after ensuring the required 

105 calibration keywords will be saved. 

106 

107 Parameters 

108 ---------- 

109 setDate : `bool`, optional 

110 Update the CALIBDATE fields in the metadata to the current 

111 time. Defaults to False. 

112 kwargs : 

113 Other keyword parameters to set in the metadata. 

114 """ 

115 kwargs['LEVEL'] = self.level 

116 kwargs['KERNEL_DX'] = self.shape[0] 

117 kwargs['KERNEL_DY'] = self.shape[1] 

118 

119 super().updateMetadata(setDate=setDate, **kwargs) 

120 

121 def initFromCamera(self, camera, detectorId=None): 

122 """Initialize kernel structure from camera. 

123 

124 Parameters 

125 ---------- 

126 camera : `lsst.afw.cameraGeom.Camera` 

127 Camera to use to define geometry. 

128 detectorId : `int`, optional 

129 Index of the detector to generate. 

130 

131 Returns 

132 ------- 

133 calib : `lsst.ip.isr.BrighterFatterKernel` 

134 The initialized calibration. 

135 

136 Raises 

137 ------ 

138 RuntimeError : 

139 Raised if no detectorId is supplied for a calibration with 

140 level='AMP'. 

141 """ 

142 self._instrument = camera.getName() 

143 

144 if detectorId is not None: 

145 detector = camera[detectorId] 

146 self._detectorId = detectorId 

147 self._detectorName = detector.getName() 

148 self._detectorSerial = detector.getSerial() 

149 

150 if self.level == 'AMP': 

151 if detectorId is None: 

152 raise RuntimeError("A detectorId must be supplied if level='AMP'.") 

153 

154 self.badAmps = [] 

155 

156 for amp in detector: 

157 ampName = amp.getName() 

158 self.expIdMask[ampName] = [] 

159 self.rawMeans[ampName] = [] 

160 self.rawVariances[ampName] = [] 

161 self.rawXcorrs[ampName] = [] 

162 self.gain[ampName] = amp.getGain() 

163 self.noise[ampName] = amp.getReadNoise() 

164 self.meanXcorrs[ampName] = [] 

165 self.ampKernels[ampName] = [] 

166 self.valid[ampName] = [] 

167 elif self.level == 'DETECTOR': 

168 if detectorId is None: 

169 for det in camera: 

170 detName = det.getName() 

171 self.detKernels[detName] = [] 

172 else: 

173 self.detKernels[self._detectorName] = [] 

174 

175 return self 

176 

177 def getLengths(self): 

178 """Return the set of lengths needed for reshaping components. 

179 

180 Returns 

181 ------- 

182 kernelLength : `int` 

183 Product of the elements of self.shape. 

184 smallLength : `int` 

185 Size of an untiled covariance. 

186 nObs : `int` 

187 Number of observation pairs used in the kernel. 

188 """ 

189 kernelLength = self.shape[0] * self.shape[1] 

190 smallLength = int((self.shape[0] - 1)*(self.shape[1] - 1)/4) 

191 if self.level == 'AMP': 

192 nObservations = set([len(self.rawMeans[amp]) for amp in self.rawMeans]) 

193 if len(nObservations) != 1: 

194 raise RuntimeError("Inconsistent number of observations found.") 

195 nObs = nObservations.pop() 

196 else: 

197 nObs = 0 

198 

199 return (kernelLength, smallLength, nObs) 

200 

201 @classmethod 

202 def fromDict(cls, dictionary): 

203 """Construct a calibration from a dictionary of properties. 

204 

205 Parameters 

206 ---------- 

207 dictionary : `dict` 

208 Dictionary of properties. 

209 

210 Returns 

211 ------- 

212 calib : `lsst.ip.isr.BrighterFatterKernel 

213 Constructed calibration. 

214 

215 Raises 

216 ------ 

217 RuntimeError : 

218 Raised if the supplied dictionary is for a different 

219 calibration. 

220 Raised if the version of the supplied dictionary is 1.0. 

221 """ 

222 calib = cls() 

223 

224 if calib._OBSTYPE != (found := dictionary['metadata']['OBSTYPE']): 

225 raise RuntimeError(f"Incorrect brighter-fatter kernel supplied. Expected {calib._OBSTYPE}, " 

226 f"found {found}") 

227 calib.setMetadata(dictionary['metadata']) 

228 calib.calibInfoFromDict(dictionary) 

229 

230 calib.level = dictionary['metadata'].get('LEVEL', 'AMP') 

231 calib.shape = (dictionary['metadata'].get('KERNEL_DX', 0), 

232 dictionary['metadata'].get('KERNEL_DY', 0)) 

233 

234 calibVersion = dictionary['metadata']['bfk_VERSION'] 

235 if calibVersion == 1.0: 

236 cls().log.warning("Old Version of brightter-fatter kernel found. Current version: " 

237 f"{calib._VERSION}. The new attribute 'expIdMask' will be " 

238 "populated with 'True' values, and the new attributes 'rawMeans'" 

239 "and 'rawVariances' will be populated with the masked 'means'." 

240 "and 'variances' values." 

241 ) 

242 # use 'means', because 'expIdMask' does not exist. 

243 calib.expIdMask = {amp: np.repeat(True, len(dictionary['means'][amp])) for amp in 

244 dictionary['means']} 

245 calib.rawMeans = {amp: np.array(dictionary['means'][amp]) for amp in dictionary['means']} 

246 calib.rawVariances = {amp: np.array(dictionary['variances'][amp]) for amp in 

247 dictionary['variances']} 

248 else: 

249 calib.expIdMask = {amp: np.array(dictionary['expIdMask'][amp]) for amp in dictionary['expIdMask']} 

250 calib.rawMeans = {amp: np.array(dictionary['rawMeans'][amp]) for amp in dictionary['rawMeans']} 

251 calib.rawVariances = {amp: np.array(dictionary['rawVariances'][amp]) for amp in 

252 dictionary['rawVariances']} 

253 

254 # Lengths for reshape: 

255 _, smallLength, nObs = calib.getLengths() 

256 smallShapeSide = int(np.sqrt(smallLength)) 

257 

258 calib.rawXcorrs = {amp: np.array(dictionary['rawXcorrs'][amp]).reshape((nObs, 

259 smallShapeSide, 

260 smallShapeSide)) 

261 for amp in dictionary['rawXcorrs']} 

262 

263 calib.gain = dictionary['gain'] 

264 calib.noise = dictionary['noise'] 

265 

266 calib.meanXcorrs = {amp: np.array(dictionary['meanXcorrs'][amp]).reshape(calib.shape) 

267 for amp in dictionary['rawXcorrs']} 

268 calib.ampKernels = {amp: np.array(dictionary['ampKernels'][amp]).reshape(calib.shape) 

269 for amp in dictionary['ampKernels']} 

270 calib.valid = {amp: bool(value) for amp, value in dictionary['valid'].items()} 

271 calib.badAmps = [amp for amp, valid in dictionary['valid'].items() if valid is False] 

272 

273 calib.detKernels = {det: np.array(dictionary['detKernels'][det]).reshape(calib.shape) 

274 for det in dictionary['detKernels']} 

275 

276 calib.updateMetadata() 

277 return calib 

278 

279 def toDict(self): 

280 """Return a dictionary containing the calibration properties. 

281 

282 The dictionary should be able to be round-tripped through 

283 `fromDict`. 

284 

285 Returns 

286 ------- 

287 dictionary : `dict` 

288 Dictionary of properties. 

289 """ 

290 self.updateMetadata() 

291 

292 outDict = {} 

293 metadata = self.getMetadata() 

294 outDict['metadata'] = metadata 

295 

296 # Lengths for ravel: 

297 kernelLength, smallLength, nObs = self.getLengths() 

298 

299 outDict['expIdMask'] = {amp: np.array(self.expIdMask[amp]).tolist() for amp in self.expIdMask} 

300 outDict['rawMeans'] = {amp: np.array(self.rawMeans[amp]).tolist() for amp in self.rawMeans} 

301 outDict['rawVariances'] = {amp: np.array(self.rawVariances[amp]).tolist() for amp in 

302 self.rawVariances} 

303 outDict['rawXcorrs'] = {amp: np.array(self.rawXcorrs[amp]).reshape(nObs*smallLength).tolist() 

304 for amp in self.rawXcorrs} 

305 outDict['badAmps'] = self.badAmps 

306 outDict['gain'] = self.gain 

307 outDict['noise'] = self.noise 

308 

309 outDict['meanXcorrs'] = {amp: self.meanXcorrs[amp].reshape(kernelLength).tolist() 

310 for amp in self.meanXcorrs} 

311 outDict['ampKernels'] = {amp: self.ampKernels[amp].reshape(kernelLength).tolist() 

312 for amp in self.ampKernels} 

313 outDict['valid'] = self.valid 

314 

315 outDict['detKernels'] = {det: self.detKernels[det].reshape(kernelLength).tolist() 

316 for det in self.detKernels} 

317 return outDict 

318 

319 @classmethod 

320 def fromTable(cls, tableList): 

321 """Construct calibration from a list of tables. 

322 

323 This method uses the `fromDict` method to create the 

324 calibration, after constructing an appropriate dictionary from 

325 the input tables. 

326 

327 Parameters 

328 ---------- 

329 tableList : `list` [`astropy.table.Table`] 

330 List of tables to use to construct the brighter-fatter 

331 calibration. 

332 

333 Returns 

334 ------- 

335 calib : `lsst.ip.isr.BrighterFatterKernel` 

336 The calibration defined in the tables. 

337 """ 

338 ampTable = tableList[0] 

339 

340 metadata = ampTable.meta 

341 inDict = dict() 

342 inDict['metadata'] = metadata 

343 

344 amps = ampTable['AMPLIFIER'] 

345 

346 expIdMaskList = ampTable['EXP_ID_MASK'] 

347 rawMeanList = ampTable['RAW_MEANS'] 

348 rawVarianceList = ampTable['RAW_VARIANCES'] 

349 

350 rawXcorrs = ampTable['RAW_XCORRS'] 

351 gainList = ampTable['GAIN'] 

352 noiseList = ampTable['NOISE'] 

353 

354 meanXcorrs = ampTable['MEAN_XCORRS'] 

355 ampKernels = ampTable['KERNEL'] 

356 validList = ampTable['VALID'] 

357 

358 inDict['expIdMask'] = {amp: mask for amp, mask in zip(amps, expIdMaskList)} 

359 inDict['rawMeans'] = {amp: mean for amp, mean in zip(amps, rawMeanList)} 

360 inDict['rawVariances'] = {amp: var for amp, var in zip(amps, rawVarianceList)} 

361 inDict['rawXcorrs'] = {amp: kernel for amp, kernel in zip(amps, rawXcorrs)} 

362 inDict['gain'] = {amp: gain for amp, gain in zip(amps, gainList)} 

363 inDict['noise'] = {amp: noise for amp, noise in zip(amps, noiseList)} 

364 inDict['meanXcorrs'] = {amp: kernel for amp, kernel in zip(amps, meanXcorrs)} 

365 inDict['ampKernels'] = {amp: kernel for amp, kernel in zip(amps, ampKernels)} 

366 inDict['valid'] = {amp: bool(valid) for amp, valid in zip(amps, validList)} 

367 

368 inDict['badAmps'] = [amp for amp, valid in inDict['valid'].items() if valid is False] 

369 

370 if len(tableList) > 1: 

371 detTable = tableList[1] 

372 inDict['detKernels'] = {det: kernel for det, kernel 

373 in zip(detTable['DETECTOR'], detTable['KERNEL'])} 

374 else: 

375 inDict['detKernels'] = {} 

376 

377 return cls.fromDict(inDict) 

378 

379 def toTable(self): 

380 """Construct a list of tables containing the information in this 

381 calibration. 

382 

383 The list of tables should create an identical calibration 

384 after being passed to this class's fromTable method. 

385 

386 Returns 

387 ------- 

388 tableList : `list` [`lsst.afw.table.Table`] 

389 List of tables containing the crosstalk calibration 

390 information. 

391 

392 """ 

393 tableList = [] 

394 self.updateMetadata() 

395 

396 # Lengths 

397 kernelLength, smallLength, nObs = self.getLengths() 

398 

399 ampList = [] 

400 expIdMaskList = [] 

401 rawMeanList = [] 

402 rawVarianceList = [] 

403 rawXcorrs = [] 

404 gainList = [] 

405 noiseList = [] 

406 

407 meanXcorrsList = [] 

408 kernelList = [] 

409 validList = [] 

410 

411 if self.level == 'AMP': 

412 for amp in self.rawMeans.keys(): 

413 ampList.append(amp) 

414 expIdMaskList.append(self.expIdMask[amp]) 

415 rawMeanList.append(self.rawMeans[amp]) 

416 rawVarianceList.append(self.rawVariances[amp]) 

417 rawXcorrs.append(np.array(self.rawXcorrs[amp]).reshape(nObs*smallLength).tolist()) 

418 gainList.append(self.gain[amp]) 

419 noiseList.append(self.noise[amp]) 

420 

421 meanXcorrsList.append(self.meanXcorrs[amp].reshape(kernelLength).tolist()) 

422 kernelList.append(self.ampKernels[amp].reshape(kernelLength).tolist()) 

423 validList.append(int(self.valid[amp] and not (amp in self.badAmps))) 

424 

425 ampTable = Table({'AMPLIFIER': ampList, 

426 'EXP_ID_MASK': expIdMaskList, 

427 'RAW_MEANS': rawMeanList, 

428 'RAW_VARIANCES': rawVarianceList, 

429 'RAW_XCORRS': rawXcorrs, 

430 'GAIN': gainList, 

431 'NOISE': noiseList, 

432 'MEAN_XCORRS': meanXcorrsList, 

433 'KERNEL': kernelList, 

434 'VALID': validList, 

435 }) 

436 

437 ampTable.meta = self.getMetadata().toDict() 

438 tableList.append(ampTable) 

439 

440 if len(self.detKernels): 

441 detList = [] 

442 kernelList = [] 

443 for det in self.detKernels.keys(): 

444 detList.append(det) 

445 kernelList.append(self.detKernels[det].reshape(kernelLength).tolist()) 

446 

447 detTable = Table({'DETECTOR': detList, 

448 'KERNEL': kernelList}) 

449 detTable.meta = self.getMetadata().toDict() 

450 tableList.append(detTable) 

451 

452 return tableList 

453 

454 # Implementation methods 

455 def makeDetectorKernelFromAmpwiseKernels(self, detectorName, ampsToExclude=[]): 

456 """Average the amplifier level kernels to create a detector level 

457 kernel. 

458 """ 

459 inKernels = np.array([self.ampKernels[amp] for amp in 

460 self.ampKernels if amp not in ampsToExclude]) 

461 averagingList = np.transpose(inKernels) 

462 avgKernel = np.zeros_like(inKernels[0]) 

463 sctrl = afwMath.StatisticsControl() 

464 sctrl.setNumSigmaClip(5.0) 

465 for i in range(np.shape(avgKernel)[0]): 

466 for j in range(np.shape(avgKernel)[1]): 

467 avgKernel[i, j] = afwMath.makeStatistics(averagingList[i, j], 

468 afwMath.MEANCLIP, sctrl).getValue() 

469 

470 self.detKernels[detectorName] = avgKernel 

471 

472 def replaceDetectorKernelWithAmpKernel(self, ampName, detectorName): 

473 self.detKernel[detectorName] = self.ampKernel[ampName]