Coverage for python/lsst/ip/isr/brighterFatterKernel.py: 8%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

188 statements  

1# This file is part of ip_isr. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21# 

22"""Brighter Fatter Kernel calibration definition.""" 

23 

24 

25__all__ = ['BrighterFatterKernel'] 

26 

27 

28import numpy as np 

29from astropy.table import Table 

30import lsst.afw.math as afwMath 

31from . import IsrCalib 

32 

33 

34class BrighterFatterKernel(IsrCalib): 

35 """Calibration of brighter-fatter kernels for an instrument. 

36 

37 ampKernels are the kernels for each amplifier in a detector, as 

38 generated by having level == 'AMP' 

39 

40 detectorKernel is the kernel generated for a detector as a 

41 whole, as generated by having level == 'DETECTOR' 

42 

43 makeDetectorKernelFromAmpwiseKernels is a method to generate the 

44 kernel for a detector, constructed by averaging together the 

45 ampwise kernels in the detector. The existing application code is 

46 only defined for kernels with level == 'DETECTOR', so this method 

47 is used if the supplied kernel was built with level == 'AMP'. 

48 

49 Parameters 

50 ---------- 

51 level : `str` 

52 Level the kernels will be generated for. 

53 log : `logging.Logger`, optional 

54 Log to write messages to. 

55 **kwargs : 

56 Parameters to pass to parent constructor. 

57 

58 """ 

59 _OBSTYPE = 'bfk' 

60 _SCHEMA = 'Brighter-fatter kernel' 

61 _VERSION = 1.0 

62 

63 def __init__(self, camera=None, level=None, **kwargs): 

64 self.level = level 

65 

66 # Things inherited from the PTC 

67 self.means = dict() 

68 self.variances = dict() 

69 self.rawXcorrs = dict() 

70 self.badAmps = list() 

71 self.shape = (17, 17) 

72 self.gain = dict() 

73 self.noise = dict() 

74 

75 # Things calculated from the PTC 

76 self.meanXcorrs = dict() 

77 self.valid = dict() 

78 

79 # Things that are used downstream 

80 self.ampKernels = dict() 

81 self.detKernels = dict() 

82 

83 super().__init__(**kwargs) 

84 

85 if camera: 

86 self.initFromCamera(camera, detectorId=kwargs.get('detectorId', None)) 

87 

88 self.requiredAttributes.update(['level', 'means', 'variances', 'rawXcorrs', 

89 'badAmps', 'gain', 'noise', 'meanXcorrs', 'valid', 

90 'ampKernels', 'detKernels']) 

91 

92 def updateMetadata(self, setDate=False, **kwargs): 

93 """Update calibration metadata. 

94 

95 This calls the base class's method after ensuring the required 

96 calibration keywords will be saved. 

97 

98 Parameters 

99 ---------- 

100 setDate : `bool`, optional 

101 Update the CALIBDATE fields in the metadata to the current 

102 time. Defaults to False. 

103 kwargs : 

104 Other keyword parameters to set in the metadata. 

105 """ 

106 kwargs['LEVEL'] = self.level 

107 kwargs['KERNEL_DX'] = self.shape[0] 

108 kwargs['KERNEL_DY'] = self.shape[1] 

109 

110 super().updateMetadata(setDate=setDate, **kwargs) 

111 

112 def initFromCamera(self, camera, detectorId=None): 

113 """Initialize kernel structure from camera. 

114 

115 Parameters 

116 ---------- 

117 camera : `lsst.afw.cameraGeom.Camera` 

118 Camera to use to define geometry. 

119 detectorId : `int`, optional 

120 Index of the detector to generate. 

121 

122 Returns 

123 ------- 

124 calib : `lsst.ip.isr.BrighterFatterKernel` 

125 The initialized calibration. 

126 

127 Raises 

128 ------ 

129 RuntimeError : 

130 Raised if no detectorId is supplied for a calibration with 

131 level='AMP'. 

132 """ 

133 self._instrument = camera.getName() 

134 

135 if detectorId is not None: 

136 detector = camera[detectorId] 

137 self._detectorId = detectorId 

138 self._detectorName = detector.getName() 

139 self._detectorSerial = detector.getSerial() 

140 

141 if self.level == 'AMP': 

142 if detectorId is None: 

143 raise RuntimeError("A detectorId must be supplied if level='AMP'.") 

144 

145 self.badAmps = [] 

146 

147 for amp in detector: 

148 ampName = amp.getName() 

149 self.means[ampName] = [] 

150 self.variances[ampName] = [] 

151 self.rawXcorrs[ampName] = [] 

152 self.gain[ampName] = amp.getGain() 

153 self.noise[ampName] = amp.getReadNoise() 

154 self.meanXcorrs[ampName] = [] 

155 self.ampKernels[ampName] = [] 

156 self.valid[ampName] = [] 

157 elif self.level == 'DETECTOR': 

158 if detectorId is None: 

159 for det in camera: 

160 detName = det.getName() 

161 self.detKernels[detName] = [] 

162 else: 

163 self.detKernels[self._detectorName] = [] 

164 

165 return self 

166 

167 def getLengths(self): 

168 """Return the set of lengths needed for reshaping components. 

169 

170 Returns 

171 ------- 

172 kernelLength : `int` 

173 Product of the elements of self.shape. 

174 smallLength : `int` 

175 Size of an untiled covariance. 

176 nObs : `int` 

177 Number of observation pairs used in the kernel. 

178 """ 

179 kernelLength = self.shape[0] * self.shape[1] 

180 smallLength = int((self.shape[0] - 1)*(self.shape[1] - 1)/4) 

181 if self.level == 'AMP': 

182 nObservations = set([len(self.means[amp]) for amp in self.means]) 

183 if len(nObservations) != 1: 

184 raise RuntimeError("Inconsistent number of observations found.") 

185 nObs = nObservations.pop() 

186 else: 

187 nObs = 0 

188 

189 return (kernelLength, smallLength, nObs) 

190 

191 @classmethod 

192 def fromDict(cls, dictionary): 

193 """Construct a calibration from a dictionary of properties. 

194 

195 Parameters 

196 ---------- 

197 dictionary : `dict` 

198 Dictionary of properties. 

199 

200 Returns 

201 ------- 

202 calib : `lsst.ip.isr.BrighterFatterKernel 

203 Constructed calibration. 

204 

205 Raises 

206 ------ 

207 RuntimeError : 

208 Raised if the supplied dictionary is for a different 

209 calibration. 

210 """ 

211 calib = cls() 

212 

213 if calib._OBSTYPE != (found := dictionary['metadata']['OBSTYPE']): 

214 raise RuntimeError(f"Incorrect brighter-fatter kernel supplied. Expected {calib._OBSTYPE}, " 

215 f"found {found}") 

216 

217 calib.setMetadata(dictionary['metadata']) 

218 calib.calibInfoFromDict(dictionary) 

219 

220 calib.level = dictionary['metadata'].get('LEVEL', 'AMP') 

221 calib.shape = (dictionary['metadata'].get('KERNEL_DX', 0), 

222 dictionary['metadata'].get('KERNEL_DY', 0)) 

223 

224 calib.means = {amp: np.array(dictionary['means'][amp]) for amp in dictionary['means']} 

225 calib.variances = {amp: np.array(dictionary['variances'][amp]) for amp in dictionary['variances']} 

226 

227 # Lengths for reshape: 

228 _, smallLength, nObs = calib.getLengths() 

229 smallShapeSide = int(np.sqrt(smallLength)) 

230 

231 calib.rawXcorrs = {amp: np.array(dictionary['rawXcorrs'][amp]).reshape((nObs, 

232 smallShapeSide, 

233 smallShapeSide)) 

234 for amp in dictionary['rawXcorrs']} 

235 

236 calib.gain = dictionary['gain'] 

237 calib.noise = dictionary['noise'] 

238 

239 calib.meanXcorrs = {amp: np.array(dictionary['meanXcorrs'][amp]).reshape(calib.shape) 

240 for amp in dictionary['rawXcorrs']} 

241 calib.ampKernels = {amp: np.array(dictionary['ampKernels'][amp]).reshape(calib.shape) 

242 for amp in dictionary['ampKernels']} 

243 calib.valid = {amp: bool(value) for amp, value in dictionary['valid'].items()} 

244 calib.badAmps = [amp for amp, valid in dictionary['valid'].items() if valid is False] 

245 

246 calib.detKernels = {det: np.array(dictionary['detKernels'][det]).reshape(calib.shape) 

247 for det in dictionary['detKernels']} 

248 

249 calib.updateMetadata() 

250 return calib 

251 

252 def toDict(self): 

253 """Return a dictionary containing the calibration properties. 

254 

255 The dictionary should be able to be round-tripped through 

256 `fromDict`. 

257 

258 Returns 

259 ------- 

260 dictionary : `dict` 

261 Dictionary of properties. 

262 """ 

263 self.updateMetadata() 

264 

265 outDict = {} 

266 metadata = self.getMetadata() 

267 outDict['metadata'] = metadata 

268 

269 # Lengths for ravel: 

270 kernelLength, smallLength, nObs = self.getLengths() 

271 

272 outDict['means'] = {amp: np.array(self.means[amp]).tolist() for amp in self.means} 

273 outDict['variances'] = {amp: np.array(self.variances[amp]).tolist() for amp in self.variances} 

274 outDict['rawXcorrs'] = {amp: np.array(self.rawXcorrs[amp]).reshape(nObs*smallLength).tolist() 

275 for amp in self.rawXcorrs} 

276 outDict['badAmps'] = self.badAmps 

277 outDict['gain'] = self.gain 

278 outDict['noise'] = self.noise 

279 

280 outDict['meanXcorrs'] = {amp: self.meanXcorrs[amp].reshape(kernelLength).tolist() 

281 for amp in self.meanXcorrs} 

282 outDict['ampKernels'] = {amp: self.ampKernels[amp].reshape(kernelLength).tolist() 

283 for amp in self.ampKernels} 

284 outDict['valid'] = self.valid 

285 

286 outDict['detKernels'] = {det: self.detKernels[det].reshape(kernelLength).tolist() 

287 for det in self.detKernels} 

288 return outDict 

289 

290 @classmethod 

291 def fromTable(cls, tableList): 

292 """Construct calibration from a list of tables. 

293 

294 This method uses the `fromDict` method to create the 

295 calibration, after constructing an appropriate dictionary from 

296 the input tables. 

297 

298 Parameters 

299 ---------- 

300 tableList : `list` [`astropy.table.Table`] 

301 List of tables to use to construct the brighter-fatter 

302 calibration. 

303 

304 Returns 

305 ------- 

306 calib : `lsst.ip.isr.BrighterFatterKernel` 

307 The calibration defined in the tables. 

308 """ 

309 ampTable = tableList[0] 

310 

311 metadata = ampTable.meta 

312 inDict = dict() 

313 inDict['metadata'] = metadata 

314 

315 amps = ampTable['AMPLIFIER'] 

316 

317 meanList = ampTable['MEANS'] 

318 varianceList = ampTable['VARIANCES'] 

319 

320 rawXcorrs = ampTable['RAW_XCORRS'] 

321 gainList = ampTable['GAIN'] 

322 noiseList = ampTable['NOISE'] 

323 

324 meanXcorrs = ampTable['MEAN_XCORRS'] 

325 ampKernels = ampTable['KERNEL'] 

326 validList = ampTable['VALID'] 

327 

328 inDict['means'] = {amp: mean for amp, mean in zip(amps, meanList)} 

329 inDict['variances'] = {amp: var for amp, var in zip(amps, varianceList)} 

330 inDict['rawXcorrs'] = {amp: kernel for amp, kernel in zip(amps, rawXcorrs)} 

331 inDict['gain'] = {amp: gain for amp, gain in zip(amps, gainList)} 

332 inDict['noise'] = {amp: noise for amp, noise in zip(amps, noiseList)} 

333 inDict['meanXcorrs'] = {amp: kernel for amp, kernel in zip(amps, meanXcorrs)} 

334 inDict['ampKernels'] = {amp: kernel for amp, kernel in zip(amps, ampKernels)} 

335 inDict['valid'] = {amp: bool(valid) for amp, valid in zip(amps, validList)} 

336 

337 inDict['badAmps'] = [amp for amp, valid in inDict['valid'].items() if valid is False] 

338 

339 if len(tableList) > 1: 

340 detTable = tableList[1] 

341 inDict['detKernels'] = {det: kernel for det, kernel 

342 in zip(detTable['DETECTOR'], detTable['KERNEL'])} 

343 else: 

344 inDict['detKernels'] = {} 

345 

346 return cls.fromDict(inDict) 

347 

348 def toTable(self): 

349 """Construct a list of tables containing the information in this calibration. 

350 

351 The list of tables should create an identical calibration 

352 after being passed to this class's fromTable method. 

353 

354 Returns 

355 ------- 

356 tableList : `list` [`lsst.afw.table.Table`] 

357 List of tables containing the crosstalk calibration 

358 information. 

359 

360 """ 

361 tableList = [] 

362 self.updateMetadata() 

363 

364 # Lengths 

365 kernelLength, smallLength, nObs = self.getLengths() 

366 

367 ampList = [] 

368 meanList = [] 

369 varianceList = [] 

370 rawXcorrs = [] 

371 gainList = [] 

372 noiseList = [] 

373 

374 meanXcorrsList = [] 

375 kernelList = [] 

376 validList = [] 

377 

378 if self.level == 'AMP': 

379 for amp in self.means.keys(): 

380 ampList.append(amp) 

381 meanList.append(self.means[amp]) 

382 varianceList.append(self.variances[amp]) 

383 rawXcorrs.append(np.array(self.rawXcorrs[amp]).reshape(nObs*smallLength).tolist()) 

384 gainList.append(self.gain[amp]) 

385 noiseList.append(self.noise[amp]) 

386 

387 meanXcorrsList.append(self.meanXcorrs[amp].reshape(kernelLength).tolist()) 

388 kernelList.append(self.ampKernels[amp].reshape(kernelLength).tolist()) 

389 validList.append(int(self.valid[amp] and not (amp in self.badAmps))) 

390 

391 ampTable = Table({'AMPLIFIER': ampList, 

392 'MEANS': meanList, 

393 'VARIANCES': varianceList, 

394 'RAW_XCORRS': rawXcorrs, 

395 'GAIN': gainList, 

396 'NOISE': noiseList, 

397 'MEAN_XCORRS': meanXcorrsList, 

398 'KERNEL': kernelList, 

399 'VALID': validList, 

400 }) 

401 

402 ampTable.meta = self.getMetadata().toDict() 

403 tableList.append(ampTable) 

404 

405 if len(self.detKernels): 

406 detList = [] 

407 kernelList = [] 

408 for det in self.detKernels.keys(): 

409 detList.append(det) 

410 kernelList.append(self.detKernels[det].reshape(kernelLength).tolist()) 

411 

412 detTable = Table({'DETECTOR': detList, 

413 'KERNEL': kernelList}) 

414 detTable.meta = self.getMetadata().toDict() 

415 tableList.append(detTable) 

416 

417 return tableList 

418 

419 # Implementation methods 

420 def makeDetectorKernelFromAmpwiseKernels(self, detectorName, ampsToExclude=[]): 

421 """Average the amplifier level kernels to create a detector level kernel. 

422 """ 

423 inKernels = np.array([self.ampKernels[amp] for amp in 

424 self.ampKernels if amp not in ampsToExclude]) 

425 averagingList = np.transpose(inKernels) 

426 avgKernel = np.zeros_like(inKernels[0]) 

427 sctrl = afwMath.StatisticsControl() 

428 sctrl.setNumSigmaClip(5.0) 

429 for i in range(np.shape(avgKernel)[0]): 

430 for j in range(np.shape(avgKernel)[1]): 

431 avgKernel[i, j] = afwMath.makeStatistics(averagingList[i, j], 

432 afwMath.MEANCLIP, sctrl).getValue() 

433 

434 self.detKernels[detectorName] = avgKernel 

435 

436 def replaceDetectorKernelWithAmpKernel(self, ampName, detectorName): 

437 self.detKernel[detectorName] = self.ampKernel[ampName]