Coverage for python / lsst / summit / utils / spectrumExaminer.py: 10%

310 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-15 00:32 +0000

1# This file is part of summit_utils. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["SpectrumExaminer"] 

23 

24import warnings 

25from itertools import groupby 

26from typing import Any 

27 

28import matplotlib.pyplot as plt 

29import numpy as np 

30import numpy.typing as npt 

31from astropy.stats import sigma_clip 

32from matplotlib.offsetbox import AnchoredText 

33from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable 

34from scipy.optimize import curve_fit 

35 

36import lsst.afw.display as afwDisplay 

37import lsst.afw.image as afwImage 

38from lsst.atmospec.processStar import ProcessStarTask 

39from lsst.geom import Box2I 

40from lsst.obs.lsst.translators.lsst import FILTER_DELIMITER 

41from lsst.pipe.tasks.quickFrameMeasurement import QuickFrameMeasurementTask, QuickFrameMeasurementTaskConfig 

42from lsst.summit.utils.utils import getImageStats 

43from lsst.utils.plotting.figures import make_figure 

44 

45 

46class SpectrumExaminer: 

47 """Task for the QUICK spectral extraction of single-star dispersed images. 

48 

49 For a full description of how this tasks works, see the run() method. 

50 """ 

51 

52 # ConfigClass = SummarizeImageTaskConfig 

53 # _DefaultName = "summarizeImage" 

54 

55 def __init__( 

56 self, 

57 exp: afwImage.Exposure, 

58 display: afwDisplay.Display = None, 

59 debug: bool = False, 

60 savePlotAs: str | None = None, 

61 **kwargs: Any, 

62 ): 

63 super().__init__(**kwargs) 

64 self.exp = exp 

65 self.display = display 

66 self.debug = debug 

67 self.savePlotAs = savePlotAs 

68 self.fig = make_figure(figsize=(10, 10)) 

69 

70 qfmTaskConfig = QuickFrameMeasurementTaskConfig() 

71 self.qfmTask = QuickFrameMeasurementTask(config=qfmTaskConfig) 

72 

73 pstConfig = ProcessStarTask.ConfigClass() 

74 pstConfig.offsetFromMainStar = 400 

75 self.processStarTask = ProcessStarTask(config=pstConfig) 

76 

77 self.imStats = getImageStats(exp) 

78 

79 self.init() 

80 

81 @staticmethod 

82 def bboxToAwfDisplayLines(box: Box2I) -> list[list[tuple[int, int]]]: 

83 """Takes a bbox, returns a list of lines such that they can be plotted: 

84 

85 for line in lines: 

86 display.line(line, ctype='red') 

87 """ 

88 x0 = box.beginX 

89 x1 = box.endX 

90 y0 = box.beginY 

91 y1 = box.endY 

92 return [[(x0, y0), (x1, y0)], [(x0, y0), (x0, y1)], [(x1, y0), (x1, y1)], [(x0, y1), (x1, y1)]] 

93 

94 def eraseDisplay(self) -> None: 

95 if self.display: 

96 self.display.erase() 

97 

98 def displaySpectrumBbox(self) -> None: 

99 if self.display: 

100 lines = self.bboxToAwfDisplayLines(self.spectrumbbox) 

101 for line in lines: 

102 self.display.line(line, ctype="red") 

103 else: 

104 print("No display set") 

105 

106 def displayStarLocation(self) -> None: 

107 if self.display: 

108 self.display.dot("x", *self.qfmResult.brightestObjCentroid, size=50) 

109 self.display.dot("o", *self.qfmResult.brightestObjCentroid, size=50) 

110 else: 

111 print("No display set") 

112 

113 def calcGoodSpectrumSection(self, threshold: int = 5, windowSize: int = 5) -> tuple[int, int]: 

114 length = len(self.ridgeLineLocations) 

115 chunks = length // windowSize 

116 stddevs = [] 

117 for i in range(chunks + 1): 

118 stddevs.append(np.std(self.ridgeLineLocations[i * windowSize : (i + 1) * windowSize])) 

119 

120 goodPoints = np.where(np.asarray(stddevs) < threshold)[0] 

121 minPoint = (goodPoints[2] - 2) * windowSize 

122 maxPoint = (goodPoints[-3] + 3) * windowSize 

123 minPoint = max(minPoint, 0) 

124 maxPoint = min(maxPoint, length) 

125 if self.debug: 

126 plt.plot(range(0, length + 1, windowSize), stddevs) 

127 plt.hlines(threshold, 0, length, colors="r", ls="dashed") 

128 plt.vlines(minPoint, 0, max(stddevs) + 10, colors="k", ls="dashed") 

129 plt.vlines(maxPoint, 0, max(stddevs) + 10, colors="k", ls="dashed") 

130 plt.title(f"Ridgeline scatter, windowSize={windowSize}") 

131 

132 return (minPoint, maxPoint) 

133 

134 def fit(self) -> None: 

135 def gauss( 

136 x: float | npt.NDArray[np.float64], a: float, x0: float, sigma: float 

137 ) -> float | npt.NDArray[np.float64]: 

138 return a * np.exp(-((x - x0) ** 2) / (2 * sigma**2)) 

139 

140 data = self.spectrumData[self.goodSlice] 

141 nRows, nCols = data.shape 

142 # don't subtract the row median or even a percentile - seems bad 

143 # fitting a const also seems bad - needs some better thought 

144 

145 parameters = np.zeros((nRows, 3)) 

146 pCovs = [] 

147 xs = np.arange(nCols) 

148 for rowNum, row in enumerate(data): 

149 peakPos = self.ridgeLineLocations[rowNum] 

150 amplitude = row[peakPos] 

151 width = 7 

152 try: 

153 pars, pCov = curve_fit(gauss, xs, row, [amplitude, peakPos, width], maxfev=100) 

154 pCovs.append(pCov) 

155 except RuntimeError: 

156 pars = [np.nan] * 3 

157 if not np.all([p < 1e7 for p in pars]): 

158 pars = [np.nan] * 3 

159 parameters[rowNum] = pars 

160 

161 parameters[:, 0] = np.abs(parameters[:, 0]) 

162 parameters[:, 2] = np.abs(parameters[:, 2]) 

163 self.parameters = parameters 

164 

165 def plot(self) -> None: 

166 # spectrum 

167 gs = self.fig.add_gridspec(4, 4) 

168 ax0 = self.fig.add_subplot(gs[0, 0:3]) 

169 ax0.tick_params(axis="x", top=True, bottom=False, labeltop=True, labelbottom=False) 

170 d = self.spectrumData[self.goodSlice].T 

171 vmin = np.percentile(d, 1) 

172 vmax = np.percentile(d, 99) 

173 pos = ax0.imshow(self.spectrumData[self.goodSlice].T, vmin=vmin, vmax=vmax, origin="lower") 

174 div = make_axes_locatable(ax0) 

175 cax = div.append_axes("bottom", size="7%", pad="8%") 

176 self.fig.colorbar(pos, cax=cax, orientation="horizontal", label="Counts") 

177 

178 # spectrum histogram 

179 axHist = self.fig.add_subplot(gs[0, 3]) 

180 data = self.spectrumData 

181 histMax = np.nanpercentile(data, 99.99) 

182 histMin = np.nanpercentile(data, 0.001) 

183 axHist.hist(data[(data >= histMin) & (data <= histMax)].flatten(), bins=100) 

184 underflow = len(data[data < histMin]) 

185 overflow = len(data[data > histMax]) 

186 axHist.set_yscale("log", nonpositive="clip") 

187 axHist.set_title("Spectrum pixel histogram") 

188 text = f"Underflow = {underflow}" 

189 text += f"\nOverflow = {overflow}" 

190 anchored_text = AnchoredText(text, loc="upper right", pad=0.5) 

191 axHist.add_artist(anchored_text) 

192 

193 # peak fluxes 

194 ax1 = self.fig.add_subplot(gs[1, 0:3]) 

195 ax1.plot(self.ridgeLineValues[self.goodSlice], label="Raw peak value") 

196 ax1.plot(self.parameters[:, 0], label="Fitted amplitude") 

197 ax1.axhline(self.continuumFlux98, ls="dashed", color="g") 

198 ax1.set_ylabel("Peak amplitude (ADU)") 

199 ax1.set_xlabel("Spectrum position (pixels)") 

200 ax1.legend( 

201 title=f"Continuum flux = {self.continuumFlux98:.0f} ADU", 

202 loc="center right", 

203 framealpha=0.2, 

204 facecolor="black", 

205 ) 

206 ax1.set_title("Ridgeline plot") 

207 

208 # FWHM 

209 ax2 = self.fig.add_subplot(gs[2, 0:3]) 

210 ax2.plot(self.parameters[:, 2] * 2.355, label="FWHM (pix)") 

211 fwhmValues = self.parameters[:, 2] * 2.355 

212 amplitudes = self.parameters[:, 0] 

213 minVal, maxVal = self.getStableFwhmRegion(fwhmValues, amplitudes) 

214 medianFwhm, bestFwhm = self.getMedianAndBestFwhm(fwhmValues, minVal, maxVal) 

215 

216 ax2.axhline(medianFwhm, ls="dashed", color="k", label=f"Median FWHM = {medianFwhm:.1f} pix") 

217 ax2.axhline(bestFwhm, ls="dashed", color="r", label=f"Best FWHM = {bestFwhm:.1f} pix") 

218 ax2.axvline(minVal, ls="dashed", color="k", alpha=0.2) 

219 ax2.axvline(maxVal, ls="dashed", color="k", alpha=0.2) 

220 ymin = max(np.nanmin(fwhmValues) - 5, 0) 

221 if not np.isnan(medianFwhm): 

222 ymax = medianFwhm * 2 

223 else: 

224 ymax = 5 * ymin 

225 ax2.set_ylim(ymin, ymax) 

226 ax2.set_ylabel("FWHM (pixels)") 

227 ax2.set_xlabel("Spectrum position (pixels)") 

228 ax2.legend(loc="upper right", framealpha=0.2, facecolor="black") 

229 ax2.set_title("Spectrum FWHM") 

230 

231 # row fluxes 

232 ax3 = self.fig.add_subplot(gs[3, 0:3]) 

233 ax3.plot(self.rowSums[self.goodSlice], label="Sum across row") 

234 ax3.set_ylabel("Total row flux (ADU)") 

235 ax3.set_xlabel("Spectrum position (pixels)") 

236 ax3.legend(framealpha=0.2, facecolor="black") 

237 ax3.set_title("Row sums") 

238 

239 # textbox top 

240 # ax4 = plt.subplot2grid((4, 4), (1, 3)) 

241 ax4 = self.fig.add_subplot(gs[1:3, 3]) 

242 text = "short text" 

243 text = self.generateStatsTextboxContent(0) 

244 text += self.generateStatsTextboxContent(1) 

245 text += self.generateStatsTextboxContent(2) 

246 text += self.generateStatsTextboxContent(3) 

247 stats_text = AnchoredText( 

248 text, 

249 loc="center", 

250 pad=0.5, 

251 prop=dict(size=10.5, ma="left", backgroundcolor="white", color="black", family="monospace"), 

252 ) 

253 ax4.add_artist(stats_text) 

254 ax4.axis("off") 

255 

256 # textbox middle 

257 if self.debug: 

258 ax5 = self.fig.add_subplot(gs[2, 3]) 

259 text = self.generateStatsTextboxContent(-1) 

260 stats_text = AnchoredText( 

261 text, 

262 loc="center", 

263 pad=0.5, 

264 prop=dict(size=10.5, ma="left", backgroundcolor="white", color="black", family="monospace"), 

265 ) 

266 ax5.add_artist(stats_text) 

267 ax5.axis("off") 

268 

269 self.fig.tight_layout() 

270 

271 if self.savePlotAs: 

272 self.fig.savefig(self.savePlotAs) 

273 

274 def init(self) -> None: 

275 pass 

276 

277 def generateStatsTextboxContent(self, section: int) -> str: 

278 x, y = self.qfmResult.brightestObjCentroid 

279 

280 vi = self.exp.visitInfo 

281 exptime = vi.exposureTime 

282 

283 fullFilterString = self.exp.filter.physicalLabel 

284 filt = fullFilterString.split(FILTER_DELIMITER)[0] 

285 grating = fullFilterString.split(FILTER_DELIMITER)[1] 

286 

287 airmass = vi.getBoresightAirmass() 

288 rotangle = vi.getBoresightRotAngle().asDegrees() 

289 

290 azAlt = vi.getBoresightAzAlt() 

291 az = azAlt[0].asDegrees() 

292 el = azAlt[1].asDegrees() 

293 

294 obj = self.exp.visitInfo.object 

295 

296 lines = [] 

297 

298 if section == 0: 

299 lines.append("----- Star stats -----") 

300 lines.append(f"Star centroid @ {x:.0f}, {y:.0f}") 

301 lines.append(f"Star max pixel = {self.starPeakFlux:,.0f} ADU") 

302 lines.append(f"Star Ap25 flux = {self.qfmResult.brightestObjApFlux25:,.0f} ADU") 

303 lines.extend(["", ""]) # section break 

304 return "\n".join([line for line in lines]) 

305 

306 if section == 1: 

307 lines.append("------ Image stats ---------") 

308 imageMedian = np.median(self.exp.image.array) 

309 lines.append(f"Image median = {imageMedian:.2f} ADU") 

310 lines.append(f"Exposure time = {exptime:.2f} s") 

311 lines.extend(["", ""]) # section break 

312 return "\n".join([line for line in lines]) 

313 

314 if section == 2: 

315 lines.append("------- Rate stats ---------") 

316 lines.append(f"Star max pixel = {self.starPeakFlux / exptime:,.0f} ADU/s") 

317 lines.append(f"Spectrum contiuum = {self.continuumFlux98 / exptime:,.1f} ADU/s") 

318 lines.extend(["", ""]) # section break 

319 return "\n".join([line for line in lines]) 

320 

321 if section == 3: 

322 lines.append("----- Observation info -----") 

323 lines.append(f"object = {obj}") 

324 lines.append(f"filter = {filt}") 

325 lines.append(f"grating = {grating}") 

326 lines.append(f"rotpa = {rotangle:.1f}") 

327 

328 lines.append(f"az = {az:.1f}") 

329 lines.append(f"el = {el:.1f}") 

330 lines.append(f"airmass = {airmass:.3f}") 

331 return "\n".join([line for line in lines]) 

332 

333 if section == -1: # special -1 for debug 

334 lines.append("---------- Debug -----------") 

335 lines.append(f"spectrum bbox: {self.spectrumbbox}") 

336 lines.append(f"Good range = {self.goodSpectrumMinY},{self.goodSpectrumMaxY}") 

337 return "\n".join([line for line in lines]) 

338 

339 return "" 

340 

341 def run(self) -> None: 

342 self.qfmResult = self.qfmTask.run(self.exp) 

343 self.intCentroidX = int(np.round(self.qfmResult.brightestObjCentroid)[0]) 

344 self.intCentroidY = int(np.round(self.qfmResult.brightestObjCentroid)[1]) 

345 self.starPeakFlux = self.exp.image.array[self.intCentroidY, self.intCentroidX] 

346 

347 self.spectrumbbox = self.processStarTask.calcSpectrumBBox( 

348 self.exp, self.qfmResult.brightestObjCentroid, 200 

349 ) 

350 self.spectrumData = self.exp.image[self.spectrumbbox].array 

351 

352 self.ridgeLineLocations = np.argmax(self.spectrumData, axis=1) 

353 self.ridgeLineValues = self.spectrumData[ 

354 range(self.spectrumbbox.getHeight()), self.ridgeLineLocations 

355 ] 

356 self.rowSums = np.sum(self.spectrumData, axis=1) 

357 

358 coords = self.calcGoodSpectrumSection() 

359 self.goodSpectrumMinY = coords[0] 

360 self.goodSpectrumMaxY = coords[1] 

361 self.goodSlice = slice(coords[0], coords[1]) 

362 

363 self.continuumFlux90 = np.percentile(self.ridgeLineValues, 90) # for emission stars 

364 self.continuumFlux98 = np.percentile(self.ridgeLineValues, 98) # for most stars 

365 

366 self.fit() 

367 self.plot() 

368 

369 return 

370 

371 @staticmethod 

372 def getMedianAndBestFwhm(fwhmValues: np.ndarray, minIndex: int, maxIndex: int) -> tuple[float, float]: 

373 with warnings.catch_warnings(): # to supress nan warnings, which are fine 

374 warnings.simplefilter("ignore") 

375 clippedValues = sigma_clip(fwhmValues[minIndex:maxIndex]) 

376 # cast back with asArray needed becase sigma_clip returns 

377 # masked array which doesn't play nice with np.nan<med/percentile> 

378 clippedValues = np.asarray(clippedValues) 

379 medianFwhm = np.nanmedian(clippedValues) 

380 bestFocusFwhm = np.nanpercentile(np.asarray(clippedValues), 2) 

381 return medianFwhm, bestFocusFwhm 

382 

383 def getStableFwhmRegion( 

384 self, fwhmValues: np.ndarray, amplitudes: np.ndarray, smoothing: int = 1, maxDifferential: int = 4 

385 ) -> tuple[int, int]: 

386 # smooth the fwhmValues values 

387 # differentiate 

388 # take the longest contiguous region of 1s 

389 # check section corresponds to top 25% in ampl to exclude 2nd order 

390 # if not, pick next longest run, etc 

391 # walk out from ends of that list over bumps smaller than maxDiff 

392 

393 smoothFwhm = np.convolve(fwhmValues, np.ones(smoothing) / smoothing, mode="same") 

394 diff = np.diff(smoothFwhm, append=smoothFwhm[-1]) 

395 

396 indices = np.where(1 - np.abs(diff) < 1)[0] 

397 diffIndices = np.diff(indices) 

398 

399 # [list(g) for k, g in groupby('AAAABBBCCD')] -->[['A', 'A', 'A', 'A'], 

400 # ... ['B', 'B', 'B'], ['C', 'C'], ['D']] 

401 indexLists = [list(g) for k, g in groupby(diffIndices)] 

402 listLengths = [len(lst) for lst in indexLists] 

403 

404 amplitudeThreshold = np.nanpercentile(amplitudes, 75) 

405 sortedListLengths = sorted(listLengths) 

406 

407 longestListLength = 0 

408 longestListIndex = 0 

409 for listLength in sortedListLengths[::-1]: 

410 longestListLength = listLength 

411 longestListIndex = listLengths.index(longestListLength) 

412 longestListStartTruePosition = int(np.sum(listLengths[0:longestListIndex])) 

413 longestListStartTruePosition += int(longestListLength / 2) # we want the mid-run value 

414 if amplitudes[longestListStartTruePosition] > amplitudeThreshold: 

415 break 

416 

417 startOfLongList = np.sum(listLengths[0:longestListIndex]) 

418 endOfLongList = startOfLongList + longestListLength 

419 

420 endValue = endOfLongList 

421 for lst in indexLists[longestListIndex + 1 :]: 

422 value = lst[0] 

423 if value > maxDifferential: 

424 break 

425 endValue += len(lst) 

426 

427 startValue = startOfLongList 

428 for lst in indexLists[longestListIndex - 1 :: -1]: 

429 value = lst[0] 

430 if value > maxDifferential: 

431 break 

432 startValue -= len(lst) 

433 

434 startValue = int(max(0, startValue)) 

435 endValue = int(min(len(fwhmValues), endValue)) 

436 

437 if not self.debug: 

438 return startValue, endValue 

439 

440 medianFwhm, bestFocusFwhm = self.getMedianAndBestFwhm(fwhmValues, startValue, endValue) 

441 xlim = (-20, len(fwhmValues)) 

442 

443 plt.figure(figsize=(10, 6)) 

444 plt.plot(fwhmValues) 

445 plt.vlines(startValue, 0, 50, "r") 

446 plt.vlines(endValue, 0, 50, "r") 

447 plt.hlines(medianFwhm, xlim[0], xlim[1]) 

448 plt.hlines(bestFocusFwhm, xlim[0], xlim[1], "r", ls="--") 

449 

450 plt.vlines(startOfLongList, 0, 50, "g") 

451 plt.vlines(endOfLongList, 0, 50, "g") 

452 

453 plt.ylim(0, 200) 

454 plt.xlim(xlim) 

455 plt.show() 

456 

457 plt.figure(figsize=(10, 6)) 

458 plt.plot(diffIndices) 

459 plt.vlines(startValue, 0, 50, "r") 

460 plt.vlines(endValue, 0, 50, "r") 

461 

462 plt.vlines(startOfLongList, 0, 50, "g") 

463 plt.vlines(endOfLongList, 0, 50, "g") 

464 plt.ylim(0, 30) 

465 plt.xlim(xlim) 

466 plt.show() 

467 return startValue, endValue