Coverage for python/lsst/summit/utils/spectrumExaminer.py: 8%

303 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-10-12 04:04 -0700

1# This file is part of summit_utils. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ['SpectrumExaminer'] 

23 

24import numpy as np 

25import matplotlib.pyplot as plt 

26from matplotlib.offsetbox import AnchoredText 

27from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable 

28from scipy.optimize import curve_fit 

29from itertools import groupby 

30from astropy.stats import sigma_clip 

31import warnings 

32 

33from lsst.atmospec.processStar import ProcessStarTask 

34from lsst.pipe.tasks.quickFrameMeasurement import QuickFrameMeasurementTask, QuickFrameMeasurementTaskConfig 

35 

36from lsst.obs.lsst.translators.lsst import FILTER_DELIMITER 

37from lsst.summit.utils.utils import getImageStats 

38 

39 

40class SpectrumExaminer(): 

41 """Task for the QUICK spectral extraction of single-star dispersed images. 

42 

43 For a full description of how this tasks works, see the run() method. 

44 """ 

45 

46 # ConfigClass = SummarizeImageTaskConfig 

47 # _DefaultName = "summarizeImage" 

48 

49 def __init__(self, exp, display=None, debug=False, savePlotAs=None, **kwargs): 

50 super().__init__(**kwargs) 

51 self.exp = exp 

52 self.display = display 

53 self.debug = debug 

54 self.savePlotAs = savePlotAs 

55 

56 qfmTaskConfig = QuickFrameMeasurementTaskConfig() 

57 self.qfmTask = QuickFrameMeasurementTask(config=qfmTaskConfig) 

58 

59 pstConfig = ProcessStarTask.ConfigClass() 

60 pstConfig.offsetFromMainStar = 400 

61 self.processStarTask = ProcessStarTask(config=pstConfig) 

62 

63 self.imStats = getImageStats(exp) 

64 

65 self.init() 

66 

67 @staticmethod 

68 def bboxToAwfDisplayLines(box): 

69 """Takes a bbox, returns a list of lines such that they can be plotted: 

70 

71 for line in lines: 

72 display.line(line, ctype='red') 

73 """ 

74 x0 = box.beginX 

75 x1 = box.endX 

76 y0 = box.beginY 

77 y1 = box.endY 

78 return [[(x0, y0), (x1, y0)], [(x0, y0), (x0, y1)], [(x1, y0), (x1, y1)], [(x0, y1), (x1, y1)]] 

79 

80 def eraseDisplay(self): 

81 if self.display: 

82 self.display.erase() 

83 

84 def displaySpectrumBbox(self): 

85 if self.display: 

86 lines = self.bboxToAwfDisplayLines(self.spectrumbbox) 

87 for line in lines: 

88 self.display.line(line, ctype='red') 

89 else: 

90 print("No display set") 

91 

92 def displayStarLocation(self): 

93 if self.display: 

94 self.display.dot('x', *self.qfmResult.brightestObjCentroid, size=50) 

95 self.display.dot('o', *self.qfmResult.brightestObjCentroid, size=50) 

96 else: 

97 print("No display set") 

98 

99 def calcGoodSpectrumSection(self, threshold=5, windowSize=5): 

100 length = len(self.ridgeLineLocations) 

101 chunks = length // windowSize 

102 stddevs = [] 

103 for i in range(chunks+1): 

104 stddevs.append(np.std(self.ridgeLineLocations[i*windowSize:(i+1)*windowSize])) 

105 

106 goodPoints = np.where(np.asarray(stddevs) < threshold)[0] 

107 minPoint = (goodPoints[2] - 2) * windowSize 

108 maxPoint = (goodPoints[-3] + 3) * windowSize 

109 minPoint = max(minPoint, 0) 

110 maxPoint = min(maxPoint, length) 

111 if self.debug: 

112 plt.plot(range(0, length+1, windowSize), stddevs) 

113 plt.hlines(threshold, 0, length, colors='r', ls='dashed') 

114 plt.vlines(minPoint, 0, max(stddevs)+10, colors='k', ls='dashed') 

115 plt.vlines(maxPoint, 0, max(stddevs)+10, colors='k', ls='dashed') 

116 plt.title(f'Ridgeline scatter, windowSize={windowSize}') 

117 

118 return (minPoint, maxPoint) 

119 

120 def fit(self): 

121 def gauss(x, a, x0, sigma): 

122 return a*np.exp(-(x-x0)**2/(2*sigma**2)) 

123 

124 data = self.spectrumData[self.goodSlice] 

125 nRows, nCols = data.shape 

126 # don't subtract the row median or even a percentile - seems bad 

127 # fitting a const also seems bad - needs some better thought 

128 

129 parameters = np.zeros((nRows, 3)) 

130 pCovs = [] 

131 xs = np.arange(nCols) 

132 for rowNum, row in enumerate(data): 

133 peakPos = self.ridgeLineLocations[rowNum] 

134 amplitude = row[peakPos] 

135 width = 7 

136 try: 

137 pars, pCov = curve_fit(gauss, xs, row, [amplitude, peakPos, width], maxfev=100) 

138 pCovs.append(pCov) 

139 except RuntimeError: 

140 pars = [np.nan] * 3 

141 if not np.all([p < 1e7 for p in pars]): 

142 pars = [np.nan] * 3 

143 parameters[rowNum] = pars 

144 

145 parameters[:, 0] = np.abs(parameters[:, 0]) 

146 parameters[:, 2] = np.abs(parameters[:, 2]) 

147 self.parameters = parameters 

148 

149 def plot(self, saveAs=None): 

150 fig = plt.figure(figsize=(10, 10)) 

151 

152 # spectrum 

153 ax0 = plt.subplot2grid((4, 4), (0, 0), colspan=3) 

154 ax0.tick_params(axis='x', top=True, bottom=False, labeltop=True, labelbottom=False) 

155 d = self.spectrumData[self.goodSlice].T 

156 vmin = np.percentile(d, 1) 

157 vmax = np.percentile(d, 99) 

158 pos = ax0.imshow(self.spectrumData[self.goodSlice].T, vmin=vmin, vmax=vmax, origin='lower') 

159 div = make_axes_locatable(ax0) 

160 cax = div.append_axes("bottom", size="7%", pad="8%") 

161 fig.colorbar(pos, cax=cax, orientation="horizontal", label="Counts") 

162 

163 # spectrum histogram 

164 axHist = plt.subplot2grid((4, 4), (0, 3)) 

165 data = self.spectrumData 

166 histMax = np.nanpercentile(data, 99.99) 

167 histMin = np.nanpercentile(data, 0.001) 

168 axHist.hist(data[(data >= histMin) & (data <= histMax)].flatten(), bins=100) 

169 underflow = len(data[data < histMin]) 

170 overflow = len(data[data > histMax]) 

171 axHist.set_yscale('log', nonpositive='clip') 

172 axHist.set_title('Spectrum pixel histogram') 

173 text = f"Underflow = {underflow}" 

174 text += f"\nOverflow = {overflow}" 

175 anchored_text = AnchoredText(text, loc=1, pad=0.5) 

176 axHist.add_artist(anchored_text) 

177 

178 # peak fluxes 

179 ax1 = plt.subplot2grid((4, 4), (1, 0), colspan=3) 

180 ax1.plot(self.ridgeLineValues[self.goodSlice], label='Raw peak value') 

181 ax1.plot(self.parameters[:, 0], label='Fitted amplitude') 

182 ax1.axhline(self.continuumFlux98, ls='dashed', color='g') 

183 ax1.set_ylabel('Peak amplitude (ADU)') 

184 ax1.set_xlabel('Spectrum position (pixels)') 

185 ax1.legend(title=f"Continuum flux = {self.continuumFlux98:.0f} ADU", 

186 loc="center right", framealpha=0.2, facecolor="black") 

187 ax1.set_title('Ridgeline plot') 

188 

189 # FWHM 

190 ax2 = plt.subplot2grid((4, 4), (2, 0), colspan=3) 

191 ax2.plot(self.parameters[:, 2]*2.355, label="FWHM (pix)") 

192 fwhmValues = self.parameters[:, 2]*2.355 

193 amplitudes = self.parameters[:, 0] 

194 minVal, maxVal = self.getStableFwhmRegion(fwhmValues, amplitudes) 

195 medianFwhm, bestFwhm = self.getMedianAndBestFwhm(fwhmValues, minVal, maxVal) 

196 

197 ax2.axhline(medianFwhm, ls='dashed', color='k', 

198 label=f"Median FWHM = {medianFwhm:.1f} pix") 

199 ax2.axhline(bestFwhm, ls='dashed', color='r', 

200 label=f"Best FWHM = {bestFwhm:.1f} pix") 

201 ax2.axvline(minVal, ls='dashed', color='k', alpha=0.2) 

202 ax2.axvline(maxVal, ls='dashed', color='k', alpha=0.2) 

203 ymin = max(np.nanmin(fwhmValues)-5, 0) 

204 if not np.isnan(medianFwhm): 

205 ymax = medianFwhm*2 

206 else: 

207 ymax = 5*ymin 

208 ax2.set_ylim(ymin, ymax) 

209 ax2.set_ylabel('FWHM (pixels)') 

210 ax2.set_xlabel('Spectrum position (pixels)') 

211 ax2.legend(loc="upper right", framealpha=0.2, facecolor="black") 

212 ax2.set_title('Spectrum FWHM') 

213 

214 # row fluxes 

215 ax3 = plt.subplot2grid((4, 4), (3, 0), colspan=3) 

216 ax3.plot(self.rowSums[self.goodSlice], label="Sum across row") 

217 ax3.set_ylabel('Total row flux (ADU)') 

218 ax3.set_xlabel('Spectrum position (pixels)') 

219 ax3.legend(framealpha=0.2, facecolor="black") 

220 ax3.set_title('Row sums') 

221 

222 # textbox top 

223# ax4 = plt.subplot2grid((4, 4), (1, 3)) 

224 ax4 = plt.subplot2grid((4, 4), (1, 3), rowspan=2) 

225 text = "short text" 

226 text = self.generateStatsTextboxContent(0) 

227 text += self.generateStatsTextboxContent(1) 

228 text += self.generateStatsTextboxContent(2) 

229 text += self.generateStatsTextboxContent(3) 

230 stats_text = AnchoredText(text, loc="center", pad=0.5, 

231 prop=dict(size=10.5, ma="left", backgroundcolor="white", 

232 color="black", family='monospace')) 

233 ax4.add_artist(stats_text) 

234 ax4.axis('off') 

235 

236 # textbox middle 

237 if self.debug: 

238 ax5 = plt.subplot2grid((4, 4), (2, 3)) 

239 text = self.generateStatsTextboxContent(-1) 

240 stats_text = AnchoredText(text, loc="center", pad=0.5, 

241 prop=dict(size=10.5, ma="left", backgroundcolor="white", 

242 color="black", family='monospace')) 

243 ax5.add_artist(stats_text) 

244 ax5.axis('off') 

245 

246 plt.tight_layout() 

247 plt.show() 

248 

249 if self.savePlotAs: 

250 fig.savefig(self.savePlotAs) 

251 

252 def init(self): 

253 pass 

254 

255 def generateStatsTextboxContent(self, section, doPrint=True): 

256 x, y = self.qfmResult.brightestObjCentroid 

257 exptime = self.exp.getInfo().getVisitInfo().getExposureTime() 

258 

259 info = self.exp.getInfo() 

260 vi = info.getVisitInfo() 

261 

262 fullFilterString = info.getFilterLabel().physicalLabel 

263 filt = fullFilterString.split(FILTER_DELIMITER)[0] 

264 grating = fullFilterString.split(FILTER_DELIMITER)[1] 

265 

266 airmass = vi.getBoresightAirmass() 

267 rotangle = vi.getBoresightRotAngle().asDegrees() 

268 

269 azAlt = vi.getBoresightAzAlt() 

270 az = azAlt[0].asDegrees() 

271 el = azAlt[1].asDegrees() 

272 

273 obj = self.exp.visitInfo.object 

274 

275 lines = [] 

276 

277 if section == 0: 

278 lines.append("----- Star stats -----") 

279 lines.append(f"Star centroid @ {x:.0f}, {y:.0f}") 

280 lines.append(f"Star max pixel = {self.starPeakFlux:,.0f} ADU") 

281 lines.append(f"Star Ap25 flux = {self.qfmResult.brightestObjApFlux25:,.0f} ADU") 

282 lines.extend(["", ""]) # section break 

283 return '\n'.join([line for line in lines]) 

284 

285 if section == 1: 

286 lines.append("------ Image stats ---------") 

287 imageMedian = np.median(self.exp.image.array) 

288 lines.append(f"Image median = {imageMedian:.2f} ADU") 

289 lines.append(f"Exposure time = {exptime:.2f} s") 

290 lines.extend(["", ""]) # section break 

291 return '\n'.join([line for line in lines]) 

292 

293 if section == 2: 

294 lines.append("------- Rate stats ---------") 

295 lines.append(f"Star max pixel = {self.starPeakFlux/exptime:,.0f} ADU/s") 

296 lines.append(f"Spectrum contiuum = {self.continuumFlux98/exptime:,.1f} ADU/s") 

297 lines.extend(["", ""]) # section break 

298 return '\n'.join([line for line in lines]) 

299 

300 if section == 3: 

301 lines.append("----- Observation info -----") 

302 lines.append(f"object = {obj}") 

303 lines.append(f"filter = {filt}") 

304 lines.append(f"grating = {grating}") 

305 lines.append(f"rotpa = {rotangle:.1f}") 

306 

307 lines.append(f"az = {az:.1f}") 

308 lines.append(f"el = {el:.1f}") 

309 lines.append(f"airmass = {airmass:.3f}") 

310 return '\n'.join([line for line in lines]) 

311 

312 if section == -1: # special -1 for debug 

313 lines.append("---------- Debug -----------") 

314 lines.append(f"spectrum bbox: {self.spectrumbbox}") 

315 lines.append(f"Good range = {self.goodSpectrumMinY},{self.goodSpectrumMaxY}") 

316 return '\n'.join([line for line in lines]) 

317 

318 return 

319 

320 def run(self): 

321 self.qfmResult = self.qfmTask.run(self.exp) 

322 self.intCentroidX = int(np.round(self.qfmResult.brightestObjCentroid)[0]) 

323 self.intCentroidY = int(np.round(self.qfmResult.brightestObjCentroid)[1]) 

324 self.starPeakFlux = self.exp.image.array[self.intCentroidY, self.intCentroidX] 

325 

326 self.spectrumbbox = self.processStarTask.calcSpectrumBBox(self.exp, 

327 self.qfmResult.brightestObjCentroid, 

328 200) 

329 self.spectrumData = self.exp.image[self.spectrumbbox].array 

330 

331 self.ridgeLineLocations = np.argmax(self.spectrumData, axis=1) 

332 self.ridgeLineValues = self.spectrumData[range(self.spectrumbbox.getHeight()), 

333 self.ridgeLineLocations] 

334 self.rowSums = np.sum(self.spectrumData, axis=1) 

335 

336 coords = self.calcGoodSpectrumSection() 

337 self.goodSpectrumMinY = coords[0] 

338 self.goodSpectrumMaxY = coords[1] 

339 self.goodSlice = slice(coords[0], coords[1]) 

340 

341 self.continuumFlux90 = np.percentile(self.ridgeLineValues, 90) # for emission stars 

342 self.continuumFlux98 = np.percentile(self.ridgeLineValues, 98) # for most stars 

343 

344 self.fit() 

345 self.plot() 

346 

347 return 

348 

349 @staticmethod 

350 def getMedianAndBestFwhm(fwhmValues, minIndex, maxIndex): 

351 with warnings.catch_warnings(): # to supress nan warnings, which are fine 

352 warnings.simplefilter("ignore") 

353 clippedValues = sigma_clip(fwhmValues[minIndex:maxIndex]) 

354 # cast back with asArray needed becase sigma_clip returns 

355 # masked array which doesn't play nice with np.nan<med/percentile> 

356 clippedValues = np.asarray(clippedValues) 

357 medianFwhm = np.nanmedian(clippedValues) 

358 bestFocusFwhm = np.nanpercentile(np.asarray(clippedValues), 2) 

359 return medianFwhm, bestFocusFwhm 

360 

361 def getStableFwhmRegion(self, fwhmValues, amplitudes, smoothing=1, maxDifferential=4): 

362 # smooth the fwhmValues values 

363 # differentiate 

364 # take the longest contiguous region of 1s 

365 # check section corresponds to top 25% in ampl to exclude 2nd order 

366 # if not, pick next longest run, etc 

367 # walk out from ends of that list over bumps smaller than maxDiff 

368 

369 smoothFwhm = np.convolve(fwhmValues, np.ones(smoothing)/smoothing, mode='same') 

370 diff = np.diff(smoothFwhm, append=smoothFwhm[-1]) 

371 

372 indices = np.where(1-np.abs(diff) < 1)[0] 

373 diffIndices = np.diff(indices) 

374 

375 # [list(g) for k, g in groupby('AAAABBBCCD')] -->[['A', 'A', 'A', 'A'], 

376 # ... ['B', 'B', 'B'], ['C', 'C'], ['D']] 

377 indexLists = [list(g) for k, g in groupby(diffIndices)] 

378 listLengths = [len(lst) for lst in indexLists] 

379 

380 amplitudeThreshold = np.nanpercentile(amplitudes, 75) 

381 sortedListLengths = sorted(listLengths) 

382 

383 for listLength in sortedListLengths[::-1]: 

384 longestListLength = listLength 

385 longestListIndex = listLengths.index(longestListLength) 

386 longestListStartTruePosition = int(np.sum(listLengths[0:longestListIndex])) 

387 longestListStartTruePosition += int(longestListLength/2) # we want the mid-run value 

388 if amplitudes[longestListStartTruePosition] > amplitudeThreshold: 

389 break 

390 

391 startOfLongList = np.sum(listLengths[0:longestListIndex]) 

392 endOfLongList = startOfLongList + longestListLength 

393 

394 endValue = endOfLongList 

395 for lst in indexLists[longestListIndex+1:]: 

396 value = lst[0] 

397 if value > maxDifferential: 

398 break 

399 endValue += len(lst) 

400 

401 startValue = startOfLongList 

402 for lst in indexLists[longestListIndex-1::-1]: 

403 value = lst[0] 

404 if value > maxDifferential: 

405 break 

406 startValue -= len(lst) 

407 

408 startValue = int(max(0, startValue)) 

409 endValue = int(min(len(fwhmValues), endValue)) 

410 

411 if not self.debug: 

412 return startValue, endValue 

413 

414 medianFwhm, bestFocusFwhm = self.getMedianAndBestFwhm(fwhmValues, startValue, endValue) 

415 xlim = (-20, len(fwhmValues)) 

416 

417 plt.figure(figsize=(10, 6)) 

418 plt.plot(fwhmValues) 

419 plt.vlines(startValue, 0, 50, 'r') 

420 plt.vlines(endValue, 0, 50, 'r') 

421 plt.hlines(medianFwhm, xlim[0], xlim[1]) 

422 plt.hlines(bestFocusFwhm, xlim[0], xlim[1], 'r', ls='--') 

423 

424 plt.vlines(startOfLongList, 0, 50, 'g') 

425 plt.vlines(endOfLongList, 0, 50, 'g') 

426 

427 plt.ylim(0, 200) 

428 plt.xlim(xlim) 

429 plt.show() 

430 

431 plt.figure(figsize=(10, 6)) 

432 plt.plot(diffIndices) 

433 plt.vlines(startValue, 0, 50, 'r') 

434 plt.vlines(endValue, 0, 50, 'r') 

435 

436 plt.vlines(startOfLongList, 0, 50, 'g') 

437 plt.vlines(endOfLongList, 0, 50, 'g') 

438 plt.ylim(0, 30) 

439 plt.xlim(xlim) 

440 plt.show() 

441 return startValue, endValue