Coverage for python/lsst/summit/utils/spectrumExaminer.py: 10%

305 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-17 08:53 +0000

1# This file is part of summit_utils. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["SpectrumExaminer"] 

23 

24import warnings 

25from itertools import groupby 

26from typing import Any 

27 

28import matplotlib.pyplot as plt 

29import numpy as np 

30from astropy.stats import sigma_clip 

31from matplotlib.offsetbox import AnchoredText 

32from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable 

33from scipy.optimize import curve_fit 

34 

35import lsst.afw.display as afwDisplay 

36import lsst.afw.image as afwImage 

37from lsst.atmospec.processStar import ProcessStarTask 

38from lsst.obs.lsst.translators.lsst import FILTER_DELIMITER 

39from lsst.pipe.tasks.quickFrameMeasurement import QuickFrameMeasurementTask, QuickFrameMeasurementTaskConfig 

40from lsst.summit.utils.utils import getImageStats 

41 

42 

43class SpectrumExaminer: 

44 """Task for the QUICK spectral extraction of single-star dispersed images. 

45 

46 For a full description of how this tasks works, see the run() method. 

47 """ 

48 

49 # ConfigClass = SummarizeImageTaskConfig 

50 # _DefaultName = "summarizeImage" 

51 

52 def __init__( 

53 self, 

54 exp: afwImage.Exposure, 

55 display: afwDisplay.Display = None, 

56 debug: bool | None = False, 

57 savePlotAs: str | None = None, 

58 **kwargs: Any, 

59 ): 

60 super().__init__(**kwargs) 

61 self.exp = exp 

62 self.display = display 

63 self.debug = debug 

64 self.savePlotAs = savePlotAs 

65 

66 qfmTaskConfig = QuickFrameMeasurementTaskConfig() 

67 self.qfmTask = QuickFrameMeasurementTask(config=qfmTaskConfig) 

68 

69 pstConfig = ProcessStarTask.ConfigClass() 

70 pstConfig.offsetFromMainStar = 400 

71 self.processStarTask = ProcessStarTask(config=pstConfig) 

72 

73 self.imStats = getImageStats(exp) 

74 

75 self.init() 

76 

77 @staticmethod 

78 def bboxToAwfDisplayLines(box) -> list[list[tuple[int, int]]]: 

79 """Takes a bbox, returns a list of lines such that they can be plotted: 

80 

81 for line in lines: 

82 display.line(line, ctype='red') 

83 """ 

84 x0 = box.beginX 

85 x1 = box.endX 

86 y0 = box.beginY 

87 y1 = box.endY 

88 return [[(x0, y0), (x1, y0)], [(x0, y0), (x0, y1)], [(x1, y0), (x1, y1)], [(x0, y1), (x1, y1)]] 

89 

90 def eraseDisplay(self) -> None: 

91 if self.display: 

92 self.display.erase() 

93 

94 def displaySpectrumBbox(self) -> None: 

95 if self.display: 

96 lines = self.bboxToAwfDisplayLines(self.spectrumbbox) 

97 for line in lines: 

98 self.display.line(line, ctype="red") 

99 else: 

100 print("No display set") 

101 

102 def displayStarLocation(self) -> None: 

103 if self.display: 

104 self.display.dot("x", *self.qfmResult.brightestObjCentroid, size=50) 

105 self.display.dot("o", *self.qfmResult.brightestObjCentroid, size=50) 

106 else: 

107 print("No display set") 

108 

109 def calcGoodSpectrumSection(self, threshold: int = 5, windowSize: int = 5) -> tuple[int, int]: 

110 length = len(self.ridgeLineLocations) 

111 chunks = length // windowSize 

112 stddevs = [] 

113 for i in range(chunks + 1): 

114 stddevs.append(np.std(self.ridgeLineLocations[i * windowSize : (i + 1) * windowSize])) 

115 

116 goodPoints = np.where(np.asarray(stddevs) < threshold)[0] 

117 minPoint = (goodPoints[2] - 2) * windowSize 

118 maxPoint = (goodPoints[-3] + 3) * windowSize 

119 minPoint = max(minPoint, 0) 

120 maxPoint = min(maxPoint, length) 

121 if self.debug: 

122 plt.plot(range(0, length + 1, windowSize), stddevs) 

123 plt.hlines(threshold, 0, length, colors="r", ls="dashed") 

124 plt.vlines(minPoint, 0, max(stddevs) + 10, colors="k", ls="dashed") 

125 plt.vlines(maxPoint, 0, max(stddevs) + 10, colors="k", ls="dashed") 

126 plt.title(f"Ridgeline scatter, windowSize={windowSize}") 

127 

128 return (minPoint, maxPoint) 

129 

130 def fit(self) -> None: 

131 def gauss(x, a, x0, sigma): 

132 return a * np.exp(-((x - x0) ** 2) / (2 * sigma**2)) 

133 

134 data = self.spectrumData[self.goodSlice] 

135 nRows, nCols = data.shape 

136 # don't subtract the row median or even a percentile - seems bad 

137 # fitting a const also seems bad - needs some better thought 

138 

139 parameters = np.zeros((nRows, 3)) 

140 pCovs = [] 

141 xs = np.arange(nCols) 

142 for rowNum, row in enumerate(data): 

143 peakPos = self.ridgeLineLocations[rowNum] 

144 amplitude = row[peakPos] 

145 width = 7 

146 try: 

147 pars, pCov = curve_fit(gauss, xs, row, [amplitude, peakPos, width], maxfev=100) 

148 pCovs.append(pCov) 

149 except RuntimeError: 

150 pars = [np.nan] * 3 

151 if not np.all([p < 1e7 for p in pars]): 

152 pars = [np.nan] * 3 

153 parameters[rowNum] = pars 

154 

155 parameters[:, 0] = np.abs(parameters[:, 0]) 

156 parameters[:, 2] = np.abs(parameters[:, 2]) 

157 self.parameters = parameters 

158 

159 def plot(self) -> None: 

160 fig = plt.figure(figsize=(10, 10)) 

161 

162 # spectrum 

163 ax0 = plt.subplot2grid((4, 4), (0, 0), colspan=3) 

164 ax0.tick_params(axis="x", top=True, bottom=False, labeltop=True, labelbottom=False) 

165 d = self.spectrumData[self.goodSlice].T 

166 vmin = np.percentile(d, 1) 

167 vmax = np.percentile(d, 99) 

168 pos = ax0.imshow(self.spectrumData[self.goodSlice].T, vmin=vmin, vmax=vmax, origin="lower") 

169 div = make_axes_locatable(ax0) 

170 cax = div.append_axes("bottom", size="7%", pad="8%") 

171 fig.colorbar(pos, cax=cax, orientation="horizontal", label="Counts") 

172 

173 # spectrum histogram 

174 axHist = plt.subplot2grid((4, 4), (0, 3)) 

175 data = self.spectrumData 

176 histMax = np.nanpercentile(data, 99.99) 

177 histMin = np.nanpercentile(data, 0.001) 

178 axHist.hist(data[(data >= histMin) & (data <= histMax)].flatten(), bins=100) 

179 underflow = len(data[data < histMin]) 

180 overflow = len(data[data > histMax]) 

181 axHist.set_yscale("log", nonpositive="clip") 

182 axHist.set_title("Spectrum pixel histogram") 

183 text = f"Underflow = {underflow}" 

184 text += f"\nOverflow = {overflow}" 

185 anchored_text = AnchoredText(text, loc=1, pad=0.5) 

186 axHist.add_artist(anchored_text) 

187 

188 # peak fluxes 

189 ax1 = plt.subplot2grid((4, 4), (1, 0), colspan=3) 

190 ax1.plot(self.ridgeLineValues[self.goodSlice], label="Raw peak value") 

191 ax1.plot(self.parameters[:, 0], label="Fitted amplitude") 

192 ax1.axhline(self.continuumFlux98, ls="dashed", color="g") 

193 ax1.set_ylabel("Peak amplitude (ADU)") 

194 ax1.set_xlabel("Spectrum position (pixels)") 

195 ax1.legend( 

196 title=f"Continuum flux = {self.continuumFlux98:.0f} ADU", 

197 loc="center right", 

198 framealpha=0.2, 

199 facecolor="black", 

200 ) 

201 ax1.set_title("Ridgeline plot") 

202 

203 # FWHM 

204 ax2 = plt.subplot2grid((4, 4), (2, 0), colspan=3) 

205 ax2.plot(self.parameters[:, 2] * 2.355, label="FWHM (pix)") 

206 fwhmValues = self.parameters[:, 2] * 2.355 

207 amplitudes = self.parameters[:, 0] 

208 minVal, maxVal = self.getStableFwhmRegion(fwhmValues, amplitudes) 

209 medianFwhm, bestFwhm = self.getMedianAndBestFwhm(fwhmValues, minVal, maxVal) 

210 

211 ax2.axhline(medianFwhm, ls="dashed", color="k", label=f"Median FWHM = {medianFwhm:.1f} pix") 

212 ax2.axhline(bestFwhm, ls="dashed", color="r", label=f"Best FWHM = {bestFwhm:.1f} pix") 

213 ax2.axvline(minVal, ls="dashed", color="k", alpha=0.2) 

214 ax2.axvline(maxVal, ls="dashed", color="k", alpha=0.2) 

215 ymin = max(np.nanmin(fwhmValues) - 5, 0) 

216 if not np.isnan(medianFwhm): 

217 ymax = medianFwhm * 2 

218 else: 

219 ymax = 5 * ymin 

220 ax2.set_ylim(ymin, ymax) 

221 ax2.set_ylabel("FWHM (pixels)") 

222 ax2.set_xlabel("Spectrum position (pixels)") 

223 ax2.legend(loc="upper right", framealpha=0.2, facecolor="black") 

224 ax2.set_title("Spectrum FWHM") 

225 

226 # row fluxes 

227 ax3 = plt.subplot2grid((4, 4), (3, 0), colspan=3) 

228 ax3.plot(self.rowSums[self.goodSlice], label="Sum across row") 

229 ax3.set_ylabel("Total row flux (ADU)") 

230 ax3.set_xlabel("Spectrum position (pixels)") 

231 ax3.legend(framealpha=0.2, facecolor="black") 

232 ax3.set_title("Row sums") 

233 

234 # textbox top 

235 # ax4 = plt.subplot2grid((4, 4), (1, 3)) 

236 ax4 = plt.subplot2grid((4, 4), (1, 3), rowspan=2) 

237 text = "short text" 

238 text = self.generateStatsTextboxContent(0) 

239 text += self.generateStatsTextboxContent(1) 

240 text += self.generateStatsTextboxContent(2) 

241 text += self.generateStatsTextboxContent(3) 

242 stats_text = AnchoredText( 

243 text, 

244 loc="center", 

245 pad=0.5, 

246 prop=dict(size=10.5, ma="left", backgroundcolor="white", color="black", family="monospace"), 

247 ) 

248 ax4.add_artist(stats_text) 

249 ax4.axis("off") 

250 

251 # textbox middle 

252 if self.debug: 

253 ax5 = plt.subplot2grid((4, 4), (2, 3)) 

254 text = self.generateStatsTextboxContent(-1) 

255 stats_text = AnchoredText( 

256 text, 

257 loc="center", 

258 pad=0.5, 

259 prop=dict(size=10.5, ma="left", backgroundcolor="white", color="black", family="monospace"), 

260 ) 

261 ax5.add_artist(stats_text) 

262 ax5.axis("off") 

263 

264 plt.tight_layout() 

265 plt.show() 

266 

267 if self.savePlotAs: 

268 fig.savefig(self.savePlotAs) 

269 

270 def init(self): 

271 pass 

272 

273 def generateStatsTextboxContent(self, section: int) -> str: 

274 x, y = self.qfmResult.brightestObjCentroid 

275 

276 vi = self.exp.visitInfo 

277 exptime = vi.exposureTime 

278 

279 fullFilterString = self.exp.filter.physicalLabel 

280 filt = fullFilterString.split(FILTER_DELIMITER)[0] 

281 grating = fullFilterString.split(FILTER_DELIMITER)[1] 

282 

283 airmass = vi.getBoresightAirmass() 

284 rotangle = vi.getBoresightRotAngle().asDegrees() 

285 

286 azAlt = vi.getBoresightAzAlt() 

287 az = azAlt[0].asDegrees() 

288 el = azAlt[1].asDegrees() 

289 

290 obj = self.exp.visitInfo.object 

291 

292 lines = [] 

293 

294 if section == 0: 

295 lines.append("----- Star stats -----") 

296 lines.append(f"Star centroid @ {x:.0f}, {y:.0f}") 

297 lines.append(f"Star max pixel = {self.starPeakFlux:,.0f} ADU") 

298 lines.append(f"Star Ap25 flux = {self.qfmResult.brightestObjApFlux25:,.0f} ADU") 

299 lines.extend(["", ""]) # section break 

300 return "\n".join([line for line in lines]) 

301 

302 if section == 1: 

303 lines.append("------ Image stats ---------") 

304 imageMedian = np.median(self.exp.image.array) 

305 lines.append(f"Image median = {imageMedian:.2f} ADU") 

306 lines.append(f"Exposure time = {exptime:.2f} s") 

307 lines.extend(["", ""]) # section break 

308 return "\n".join([line for line in lines]) 

309 

310 if section == 2: 

311 lines.append("------- Rate stats ---------") 

312 lines.append(f"Star max pixel = {self.starPeakFlux/exptime:,.0f} ADU/s") 

313 lines.append(f"Spectrum contiuum = {self.continuumFlux98/exptime:,.1f} ADU/s") 

314 lines.extend(["", ""]) # section break 

315 return "\n".join([line for line in lines]) 

316 

317 if section == 3: 

318 lines.append("----- Observation info -----") 

319 lines.append(f"object = {obj}") 

320 lines.append(f"filter = {filt}") 

321 lines.append(f"grating = {grating}") 

322 lines.append(f"rotpa = {rotangle:.1f}") 

323 

324 lines.append(f"az = {az:.1f}") 

325 lines.append(f"el = {el:.1f}") 

326 lines.append(f"airmass = {airmass:.3f}") 

327 return "\n".join([line for line in lines]) 

328 

329 if section == -1: # special -1 for debug 

330 lines.append("---------- Debug -----------") 

331 lines.append(f"spectrum bbox: {self.spectrumbbox}") 

332 lines.append(f"Good range = {self.goodSpectrumMinY},{self.goodSpectrumMaxY}") 

333 return "\n".join([line for line in lines]) 

334 

335 return "" 

336 

337 def run(self) -> None: 

338 self.qfmResult = self.qfmTask.run(self.exp) 

339 self.intCentroidX = int(np.round(self.qfmResult.brightestObjCentroid)[0]) 

340 self.intCentroidY = int(np.round(self.qfmResult.brightestObjCentroid)[1]) 

341 self.starPeakFlux = self.exp.image.array[self.intCentroidY, self.intCentroidX] 

342 

343 self.spectrumbbox = self.processStarTask.calcSpectrumBBox( 

344 self.exp, self.qfmResult.brightestObjCentroid, 200 

345 ) 

346 self.spectrumData = self.exp.image[self.spectrumbbox].array 

347 

348 self.ridgeLineLocations = np.argmax(self.spectrumData, axis=1) 

349 self.ridgeLineValues = self.spectrumData[ 

350 range(self.spectrumbbox.getHeight()), self.ridgeLineLocations 

351 ] 

352 self.rowSums = np.sum(self.spectrumData, axis=1) 

353 

354 coords = self.calcGoodSpectrumSection() 

355 self.goodSpectrumMinY = coords[0] 

356 self.goodSpectrumMaxY = coords[1] 

357 self.goodSlice = slice(coords[0], coords[1]) 

358 

359 self.continuumFlux90 = np.percentile(self.ridgeLineValues, 90) # for emission stars 

360 self.continuumFlux98 = np.percentile(self.ridgeLineValues, 98) # for most stars 

361 

362 self.fit() 

363 self.plot() 

364 

365 return 

366 

367 @staticmethod 

368 def getMedianAndBestFwhm(fwhmValues: np.ndarray, minIndex: int, maxIndex: int) -> tuple[float, float]: 

369 with warnings.catch_warnings(): # to supress nan warnings, which are fine 

370 warnings.simplefilter("ignore") 

371 clippedValues = sigma_clip(fwhmValues[minIndex:maxIndex]) 

372 # cast back with asArray needed becase sigma_clip returns 

373 # masked array which doesn't play nice with np.nan<med/percentile> 

374 clippedValues = np.asarray(clippedValues) 

375 medianFwhm = np.nanmedian(clippedValues) 

376 bestFocusFwhm = np.nanpercentile(np.asarray(clippedValues), 2) 

377 return medianFwhm, bestFocusFwhm 

378 

379 def getStableFwhmRegion( 

380 self, fwhmValues: np.ndarray, amplitudes: np.ndarray, smoothing: int = 1, maxDifferential: int = 4 

381 ) -> tuple[int, int]: 

382 # smooth the fwhmValues values 

383 # differentiate 

384 # take the longest contiguous region of 1s 

385 # check section corresponds to top 25% in ampl to exclude 2nd order 

386 # if not, pick next longest run, etc 

387 # walk out from ends of that list over bumps smaller than maxDiff 

388 

389 smoothFwhm = np.convolve(fwhmValues, np.ones(smoothing) / smoothing, mode="same") 

390 diff = np.diff(smoothFwhm, append=smoothFwhm[-1]) 

391 

392 indices = np.where(1 - np.abs(diff) < 1)[0] 

393 diffIndices = np.diff(indices) 

394 

395 # [list(g) for k, g in groupby('AAAABBBCCD')] -->[['A', 'A', 'A', 'A'], 

396 # ... ['B', 'B', 'B'], ['C', 'C'], ['D']] 

397 indexLists = [list(g) for k, g in groupby(diffIndices)] 

398 listLengths = [len(lst) for lst in indexLists] 

399 

400 amplitudeThreshold = np.nanpercentile(amplitudes, 75) 

401 sortedListLengths = sorted(listLengths) 

402 

403 for listLength in sortedListLengths[::-1]: 

404 longestListLength = listLength 

405 longestListIndex = listLengths.index(longestListLength) 

406 longestListStartTruePosition = int(np.sum(listLengths[0:longestListIndex])) 

407 longestListStartTruePosition += int(longestListLength / 2) # we want the mid-run value 

408 if amplitudes[longestListStartTruePosition] > amplitudeThreshold: 

409 break 

410 

411 startOfLongList = np.sum(listLengths[0:longestListIndex]) 

412 endOfLongList = startOfLongList + longestListLength 

413 

414 endValue = endOfLongList 

415 for lst in indexLists[longestListIndex + 1 :]: 

416 value = lst[0] 

417 if value > maxDifferential: 

418 break 

419 endValue += len(lst) 

420 

421 startValue = startOfLongList 

422 for lst in indexLists[longestListIndex - 1 :: -1]: 

423 value = lst[0] 

424 if value > maxDifferential: 

425 break 

426 startValue -= len(lst) 

427 

428 startValue = int(max(0, startValue)) 

429 endValue = int(min(len(fwhmValues), endValue)) 

430 

431 if not self.debug: 

432 return startValue, endValue 

433 

434 medianFwhm, bestFocusFwhm = self.getMedianAndBestFwhm(fwhmValues, startValue, endValue) 

435 xlim = (-20, len(fwhmValues)) 

436 

437 plt.figure(figsize=(10, 6)) 

438 plt.plot(fwhmValues) 

439 plt.vlines(startValue, 0, 50, "r") 

440 plt.vlines(endValue, 0, 50, "r") 

441 plt.hlines(medianFwhm, xlim[0], xlim[1]) 

442 plt.hlines(bestFocusFwhm, xlim[0], xlim[1], "r", ls="--") 

443 

444 plt.vlines(startOfLongList, 0, 50, "g") 

445 plt.vlines(endOfLongList, 0, 50, "g") 

446 

447 plt.ylim(0, 200) 

448 plt.xlim(xlim) 

449 plt.show() 

450 

451 plt.figure(figsize=(10, 6)) 

452 plt.plot(diffIndices) 

453 plt.vlines(startValue, 0, 50, "r") 

454 plt.vlines(endValue, 0, 50, "r") 

455 

456 plt.vlines(startOfLongList, 0, 50, "g") 

457 plt.vlines(endOfLongList, 0, 50, "g") 

458 plt.ylim(0, 30) 

459 plt.xlim(xlim) 

460 plt.show() 

461 return startValue, endValue