Coverage for python/lsst/summit/utils/spectrumExaminer.py: 9%

302 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-19 06:11 -0700

1# This file is part of summit_utils. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["SpectrumExaminer"] 

23 

24import warnings 

25from itertools import groupby 

26 

27import matplotlib.pyplot as plt 

28import numpy as np 

29from astropy.stats import sigma_clip 

30from matplotlib.offsetbox import AnchoredText 

31from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable 

32from scipy.optimize import curve_fit 

33 

34from lsst.atmospec.processStar import ProcessStarTask 

35from lsst.obs.lsst.translators.lsst import FILTER_DELIMITER 

36from lsst.pipe.tasks.quickFrameMeasurement import QuickFrameMeasurementTask, QuickFrameMeasurementTaskConfig 

37from lsst.summit.utils.utils import getImageStats 

38 

39 

40class SpectrumExaminer: 

41 """Task for the QUICK spectral extraction of single-star dispersed images. 

42 

43 For a full description of how this tasks works, see the run() method. 

44 """ 

45 

46 # ConfigClass = SummarizeImageTaskConfig 

47 # _DefaultName = "summarizeImage" 

48 

49 def __init__(self, exp, display=None, debug=False, savePlotAs=None, **kwargs): 

50 super().__init__(**kwargs) 

51 self.exp = exp 

52 self.display = display 

53 self.debug = debug 

54 self.savePlotAs = savePlotAs 

55 

56 qfmTaskConfig = QuickFrameMeasurementTaskConfig() 

57 self.qfmTask = QuickFrameMeasurementTask(config=qfmTaskConfig) 

58 

59 pstConfig = ProcessStarTask.ConfigClass() 

60 pstConfig.offsetFromMainStar = 400 

61 self.processStarTask = ProcessStarTask(config=pstConfig) 

62 

63 self.imStats = getImageStats(exp) 

64 

65 self.init() 

66 

67 @staticmethod 

68 def bboxToAwfDisplayLines(box): 

69 """Takes a bbox, returns a list of lines such that they can be plotted: 

70 

71 for line in lines: 

72 display.line(line, ctype='red') 

73 """ 

74 x0 = box.beginX 

75 x1 = box.endX 

76 y0 = box.beginY 

77 y1 = box.endY 

78 return [[(x0, y0), (x1, y0)], [(x0, y0), (x0, y1)], [(x1, y0), (x1, y1)], [(x0, y1), (x1, y1)]] 

79 

80 def eraseDisplay(self): 

81 if self.display: 

82 self.display.erase() 

83 

84 def displaySpectrumBbox(self): 

85 if self.display: 

86 lines = self.bboxToAwfDisplayLines(self.spectrumbbox) 

87 for line in lines: 

88 self.display.line(line, ctype="red") 

89 else: 

90 print("No display set") 

91 

92 def displayStarLocation(self): 

93 if self.display: 

94 self.display.dot("x", *self.qfmResult.brightestObjCentroid, size=50) 

95 self.display.dot("o", *self.qfmResult.brightestObjCentroid, size=50) 

96 else: 

97 print("No display set") 

98 

99 def calcGoodSpectrumSection(self, threshold=5, windowSize=5): 

100 length = len(self.ridgeLineLocations) 

101 chunks = length // windowSize 

102 stddevs = [] 

103 for i in range(chunks + 1): 

104 stddevs.append(np.std(self.ridgeLineLocations[i * windowSize : (i + 1) * windowSize])) 

105 

106 goodPoints = np.where(np.asarray(stddevs) < threshold)[0] 

107 minPoint = (goodPoints[2] - 2) * windowSize 

108 maxPoint = (goodPoints[-3] + 3) * windowSize 

109 minPoint = max(minPoint, 0) 

110 maxPoint = min(maxPoint, length) 

111 if self.debug: 

112 plt.plot(range(0, length + 1, windowSize), stddevs) 

113 plt.hlines(threshold, 0, length, colors="r", ls="dashed") 

114 plt.vlines(minPoint, 0, max(stddevs) + 10, colors="k", ls="dashed") 

115 plt.vlines(maxPoint, 0, max(stddevs) + 10, colors="k", ls="dashed") 

116 plt.title(f"Ridgeline scatter, windowSize={windowSize}") 

117 

118 return (minPoint, maxPoint) 

119 

120 def fit(self): 

121 def gauss(x, a, x0, sigma): 

122 return a * np.exp(-((x - x0) ** 2) / (2 * sigma**2)) 

123 

124 data = self.spectrumData[self.goodSlice] 

125 nRows, nCols = data.shape 

126 # don't subtract the row median or even a percentile - seems bad 

127 # fitting a const also seems bad - needs some better thought 

128 

129 parameters = np.zeros((nRows, 3)) 

130 pCovs = [] 

131 xs = np.arange(nCols) 

132 for rowNum, row in enumerate(data): 

133 peakPos = self.ridgeLineLocations[rowNum] 

134 amplitude = row[peakPos] 

135 width = 7 

136 try: 

137 pars, pCov = curve_fit(gauss, xs, row, [amplitude, peakPos, width], maxfev=100) 

138 pCovs.append(pCov) 

139 except RuntimeError: 

140 pars = [np.nan] * 3 

141 if not np.all([p < 1e7 for p in pars]): 

142 pars = [np.nan] * 3 

143 parameters[rowNum] = pars 

144 

145 parameters[:, 0] = np.abs(parameters[:, 0]) 

146 parameters[:, 2] = np.abs(parameters[:, 2]) 

147 self.parameters = parameters 

148 

149 def plot(self): 

150 fig = plt.figure(figsize=(10, 10)) 

151 

152 # spectrum 

153 ax0 = plt.subplot2grid((4, 4), (0, 0), colspan=3) 

154 ax0.tick_params(axis="x", top=True, bottom=False, labeltop=True, labelbottom=False) 

155 d = self.spectrumData[self.goodSlice].T 

156 vmin = np.percentile(d, 1) 

157 vmax = np.percentile(d, 99) 

158 pos = ax0.imshow(self.spectrumData[self.goodSlice].T, vmin=vmin, vmax=vmax, origin="lower") 

159 div = make_axes_locatable(ax0) 

160 cax = div.append_axes("bottom", size="7%", pad="8%") 

161 fig.colorbar(pos, cax=cax, orientation="horizontal", label="Counts") 

162 

163 # spectrum histogram 

164 axHist = plt.subplot2grid((4, 4), (0, 3)) 

165 data = self.spectrumData 

166 histMax = np.nanpercentile(data, 99.99) 

167 histMin = np.nanpercentile(data, 0.001) 

168 axHist.hist(data[(data >= histMin) & (data <= histMax)].flatten(), bins=100) 

169 underflow = len(data[data < histMin]) 

170 overflow = len(data[data > histMax]) 

171 axHist.set_yscale("log", nonpositive="clip") 

172 axHist.set_title("Spectrum pixel histogram") 

173 text = f"Underflow = {underflow}" 

174 text += f"\nOverflow = {overflow}" 

175 anchored_text = AnchoredText(text, loc=1, pad=0.5) 

176 axHist.add_artist(anchored_text) 

177 

178 # peak fluxes 

179 ax1 = plt.subplot2grid((4, 4), (1, 0), colspan=3) 

180 ax1.plot(self.ridgeLineValues[self.goodSlice], label="Raw peak value") 

181 ax1.plot(self.parameters[:, 0], label="Fitted amplitude") 

182 ax1.axhline(self.continuumFlux98, ls="dashed", color="g") 

183 ax1.set_ylabel("Peak amplitude (ADU)") 

184 ax1.set_xlabel("Spectrum position (pixels)") 

185 ax1.legend( 

186 title=f"Continuum flux = {self.continuumFlux98:.0f} ADU", 

187 loc="center right", 

188 framealpha=0.2, 

189 facecolor="black", 

190 ) 

191 ax1.set_title("Ridgeline plot") 

192 

193 # FWHM 

194 ax2 = plt.subplot2grid((4, 4), (2, 0), colspan=3) 

195 ax2.plot(self.parameters[:, 2] * 2.355, label="FWHM (pix)") 

196 fwhmValues = self.parameters[:, 2] * 2.355 

197 amplitudes = self.parameters[:, 0] 

198 minVal, maxVal = self.getStableFwhmRegion(fwhmValues, amplitudes) 

199 medianFwhm, bestFwhm = self.getMedianAndBestFwhm(fwhmValues, minVal, maxVal) 

200 

201 ax2.axhline(medianFwhm, ls="dashed", color="k", label=f"Median FWHM = {medianFwhm:.1f} pix") 

202 ax2.axhline(bestFwhm, ls="dashed", color="r", label=f"Best FWHM = {bestFwhm:.1f} pix") 

203 ax2.axvline(minVal, ls="dashed", color="k", alpha=0.2) 

204 ax2.axvline(maxVal, ls="dashed", color="k", alpha=0.2) 

205 ymin = max(np.nanmin(fwhmValues) - 5, 0) 

206 if not np.isnan(medianFwhm): 

207 ymax = medianFwhm * 2 

208 else: 

209 ymax = 5 * ymin 

210 ax2.set_ylim(ymin, ymax) 

211 ax2.set_ylabel("FWHM (pixels)") 

212 ax2.set_xlabel("Spectrum position (pixels)") 

213 ax2.legend(loc="upper right", framealpha=0.2, facecolor="black") 

214 ax2.set_title("Spectrum FWHM") 

215 

216 # row fluxes 

217 ax3 = plt.subplot2grid((4, 4), (3, 0), colspan=3) 

218 ax3.plot(self.rowSums[self.goodSlice], label="Sum across row") 

219 ax3.set_ylabel("Total row flux (ADU)") 

220 ax3.set_xlabel("Spectrum position (pixels)") 

221 ax3.legend(framealpha=0.2, facecolor="black") 

222 ax3.set_title("Row sums") 

223 

224 # textbox top 

225 # ax4 = plt.subplot2grid((4, 4), (1, 3)) 

226 ax4 = plt.subplot2grid((4, 4), (1, 3), rowspan=2) 

227 text = "short text" 

228 text = self.generateStatsTextboxContent(0) 

229 text += self.generateStatsTextboxContent(1) 

230 text += self.generateStatsTextboxContent(2) 

231 text += self.generateStatsTextboxContent(3) 

232 stats_text = AnchoredText( 

233 text, 

234 loc="center", 

235 pad=0.5, 

236 prop=dict(size=10.5, ma="left", backgroundcolor="white", color="black", family="monospace"), 

237 ) 

238 ax4.add_artist(stats_text) 

239 ax4.axis("off") 

240 

241 # textbox middle 

242 if self.debug: 

243 ax5 = plt.subplot2grid((4, 4), (2, 3)) 

244 text = self.generateStatsTextboxContent(-1) 

245 stats_text = AnchoredText( 

246 text, 

247 loc="center", 

248 pad=0.5, 

249 prop=dict(size=10.5, ma="left", backgroundcolor="white", color="black", family="monospace"), 

250 ) 

251 ax5.add_artist(stats_text) 

252 ax5.axis("off") 

253 

254 plt.tight_layout() 

255 plt.show() 

256 

257 if self.savePlotAs: 

258 fig.savefig(self.savePlotAs) 

259 

260 def init(self): 

261 pass 

262 

263 def generateStatsTextboxContent(self, section): 

264 x, y = self.qfmResult.brightestObjCentroid 

265 

266 vi = self.exp.visitInfo 

267 exptime = vi.exposureTime 

268 

269 fullFilterString = self.exp.filter.physicalLabel 

270 filt = fullFilterString.split(FILTER_DELIMITER)[0] 

271 grating = fullFilterString.split(FILTER_DELIMITER)[1] 

272 

273 airmass = vi.getBoresightAirmass() 

274 rotangle = vi.getBoresightRotAngle().asDegrees() 

275 

276 azAlt = vi.getBoresightAzAlt() 

277 az = azAlt[0].asDegrees() 

278 el = azAlt[1].asDegrees() 

279 

280 obj = self.exp.visitInfo.object 

281 

282 lines = [] 

283 

284 if section == 0: 

285 lines.append("----- Star stats -----") 

286 lines.append(f"Star centroid @ {x:.0f}, {y:.0f}") 

287 lines.append(f"Star max pixel = {self.starPeakFlux:,.0f} ADU") 

288 lines.append(f"Star Ap25 flux = {self.qfmResult.brightestObjApFlux25:,.0f} ADU") 

289 lines.extend(["", ""]) # section break 

290 return "\n".join([line for line in lines]) 

291 

292 if section == 1: 

293 lines.append("------ Image stats ---------") 

294 imageMedian = np.median(self.exp.image.array) 

295 lines.append(f"Image median = {imageMedian:.2f} ADU") 

296 lines.append(f"Exposure time = {exptime:.2f} s") 

297 lines.extend(["", ""]) # section break 

298 return "\n".join([line for line in lines]) 

299 

300 if section == 2: 

301 lines.append("------- Rate stats ---------") 

302 lines.append(f"Star max pixel = {self.starPeakFlux/exptime:,.0f} ADU/s") 

303 lines.append(f"Spectrum contiuum = {self.continuumFlux98/exptime:,.1f} ADU/s") 

304 lines.extend(["", ""]) # section break 

305 return "\n".join([line for line in lines]) 

306 

307 if section == 3: 

308 lines.append("----- Observation info -----") 

309 lines.append(f"object = {obj}") 

310 lines.append(f"filter = {filt}") 

311 lines.append(f"grating = {grating}") 

312 lines.append(f"rotpa = {rotangle:.1f}") 

313 

314 lines.append(f"az = {az:.1f}") 

315 lines.append(f"el = {el:.1f}") 

316 lines.append(f"airmass = {airmass:.3f}") 

317 return "\n".join([line for line in lines]) 

318 

319 if section == -1: # special -1 for debug 

320 lines.append("---------- Debug -----------") 

321 lines.append(f"spectrum bbox: {self.spectrumbbox}") 

322 lines.append(f"Good range = {self.goodSpectrumMinY},{self.goodSpectrumMaxY}") 

323 return "\n".join([line for line in lines]) 

324 

325 return 

326 

327 def run(self): 

328 self.qfmResult = self.qfmTask.run(self.exp) 

329 self.intCentroidX = int(np.round(self.qfmResult.brightestObjCentroid)[0]) 

330 self.intCentroidY = int(np.round(self.qfmResult.brightestObjCentroid)[1]) 

331 self.starPeakFlux = self.exp.image.array[self.intCentroidY, self.intCentroidX] 

332 

333 self.spectrumbbox = self.processStarTask.calcSpectrumBBox( 

334 self.exp, self.qfmResult.brightestObjCentroid, 200 

335 ) 

336 self.spectrumData = self.exp.image[self.spectrumbbox].array 

337 

338 self.ridgeLineLocations = np.argmax(self.spectrumData, axis=1) 

339 self.ridgeLineValues = self.spectrumData[ 

340 range(self.spectrumbbox.getHeight()), self.ridgeLineLocations 

341 ] 

342 self.rowSums = np.sum(self.spectrumData, axis=1) 

343 

344 coords = self.calcGoodSpectrumSection() 

345 self.goodSpectrumMinY = coords[0] 

346 self.goodSpectrumMaxY = coords[1] 

347 self.goodSlice = slice(coords[0], coords[1]) 

348 

349 self.continuumFlux90 = np.percentile(self.ridgeLineValues, 90) # for emission stars 

350 self.continuumFlux98 = np.percentile(self.ridgeLineValues, 98) # for most stars 

351 

352 self.fit() 

353 self.plot() 

354 

355 return 

356 

357 @staticmethod 

358 def getMedianAndBestFwhm(fwhmValues, minIndex, maxIndex): 

359 with warnings.catch_warnings(): # to supress nan warnings, which are fine 

360 warnings.simplefilter("ignore") 

361 clippedValues = sigma_clip(fwhmValues[minIndex:maxIndex]) 

362 # cast back with asArray needed becase sigma_clip returns 

363 # masked array which doesn't play nice with np.nan<med/percentile> 

364 clippedValues = np.asarray(clippedValues) 

365 medianFwhm = np.nanmedian(clippedValues) 

366 bestFocusFwhm = np.nanpercentile(np.asarray(clippedValues), 2) 

367 return medianFwhm, bestFocusFwhm 

368 

369 def getStableFwhmRegion(self, fwhmValues, amplitudes, smoothing=1, maxDifferential=4): 

370 # smooth the fwhmValues values 

371 # differentiate 

372 # take the longest contiguous region of 1s 

373 # check section corresponds to top 25% in ampl to exclude 2nd order 

374 # if not, pick next longest run, etc 

375 # walk out from ends of that list over bumps smaller than maxDiff 

376 

377 smoothFwhm = np.convolve(fwhmValues, np.ones(smoothing) / smoothing, mode="same") 

378 diff = np.diff(smoothFwhm, append=smoothFwhm[-1]) 

379 

380 indices = np.where(1 - np.abs(diff) < 1)[0] 

381 diffIndices = np.diff(indices) 

382 

383 # [list(g) for k, g in groupby('AAAABBBCCD')] -->[['A', 'A', 'A', 'A'], 

384 # ... ['B', 'B', 'B'], ['C', 'C'], ['D']] 

385 indexLists = [list(g) for k, g in groupby(diffIndices)] 

386 listLengths = [len(lst) for lst in indexLists] 

387 

388 amplitudeThreshold = np.nanpercentile(amplitudes, 75) 

389 sortedListLengths = sorted(listLengths) 

390 

391 for listLength in sortedListLengths[::-1]: 

392 longestListLength = listLength 

393 longestListIndex = listLengths.index(longestListLength) 

394 longestListStartTruePosition = int(np.sum(listLengths[0:longestListIndex])) 

395 longestListStartTruePosition += int(longestListLength / 2) # we want the mid-run value 

396 if amplitudes[longestListStartTruePosition] > amplitudeThreshold: 

397 break 

398 

399 startOfLongList = np.sum(listLengths[0:longestListIndex]) 

400 endOfLongList = startOfLongList + longestListLength 

401 

402 endValue = endOfLongList 

403 for lst in indexLists[longestListIndex + 1 :]: 

404 value = lst[0] 

405 if value > maxDifferential: 

406 break 

407 endValue += len(lst) 

408 

409 startValue = startOfLongList 

410 for lst in indexLists[longestListIndex - 1 :: -1]: 

411 value = lst[0] 

412 if value > maxDifferential: 

413 break 

414 startValue -= len(lst) 

415 

416 startValue = int(max(0, startValue)) 

417 endValue = int(min(len(fwhmValues), endValue)) 

418 

419 if not self.debug: 

420 return startValue, endValue 

421 

422 medianFwhm, bestFocusFwhm = self.getMedianAndBestFwhm(fwhmValues, startValue, endValue) 

423 xlim = (-20, len(fwhmValues)) 

424 

425 plt.figure(figsize=(10, 6)) 

426 plt.plot(fwhmValues) 

427 plt.vlines(startValue, 0, 50, "r") 

428 plt.vlines(endValue, 0, 50, "r") 

429 plt.hlines(medianFwhm, xlim[0], xlim[1]) 

430 plt.hlines(bestFocusFwhm, xlim[0], xlim[1], "r", ls="--") 

431 

432 plt.vlines(startOfLongList, 0, 50, "g") 

433 plt.vlines(endOfLongList, 0, 50, "g") 

434 

435 plt.ylim(0, 200) 

436 plt.xlim(xlim) 

437 plt.show() 

438 

439 plt.figure(figsize=(10, 6)) 

440 plt.plot(diffIndices) 

441 plt.vlines(startValue, 0, 50, "r") 

442 plt.vlines(endValue, 0, 50, "r") 

443 

444 plt.vlines(startOfLongList, 0, 50, "g") 

445 plt.vlines(endOfLongList, 0, 50, "g") 

446 plt.ylim(0, 30) 

447 plt.xlim(xlim) 

448 plt.show() 

449 return startValue, endValue