Coverage for python / lsst / summit / utils / spectrumExaminer.py: 10%
310 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 00:32 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 00:32 +0000
1# This file is part of summit_utils.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ["SpectrumExaminer"]
24import warnings
25from itertools import groupby
26from typing import Any
28import matplotlib.pyplot as plt
29import numpy as np
30import numpy.typing as npt
31from astropy.stats import sigma_clip
32from matplotlib.offsetbox import AnchoredText
33from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
34from scipy.optimize import curve_fit
36import lsst.afw.display as afwDisplay
37import lsst.afw.image as afwImage
38from lsst.atmospec.processStar import ProcessStarTask
39from lsst.geom import Box2I
40from lsst.obs.lsst.translators.lsst import FILTER_DELIMITER
41from lsst.pipe.tasks.quickFrameMeasurement import QuickFrameMeasurementTask, QuickFrameMeasurementTaskConfig
42from lsst.summit.utils.utils import getImageStats
43from lsst.utils.plotting.figures import make_figure
46class SpectrumExaminer:
47 """Task for the QUICK spectral extraction of single-star dispersed images.
49 For a full description of how this tasks works, see the run() method.
50 """
52 # ConfigClass = SummarizeImageTaskConfig
53 # _DefaultName = "summarizeImage"
55 def __init__(
56 self,
57 exp: afwImage.Exposure,
58 display: afwDisplay.Display = None,
59 debug: bool = False,
60 savePlotAs: str | None = None,
61 **kwargs: Any,
62 ):
63 super().__init__(**kwargs)
64 self.exp = exp
65 self.display = display
66 self.debug = debug
67 self.savePlotAs = savePlotAs
68 self.fig = make_figure(figsize=(10, 10))
70 qfmTaskConfig = QuickFrameMeasurementTaskConfig()
71 self.qfmTask = QuickFrameMeasurementTask(config=qfmTaskConfig)
73 pstConfig = ProcessStarTask.ConfigClass()
74 pstConfig.offsetFromMainStar = 400
75 self.processStarTask = ProcessStarTask(config=pstConfig)
77 self.imStats = getImageStats(exp)
79 self.init()
81 @staticmethod
82 def bboxToAwfDisplayLines(box: Box2I) -> list[list[tuple[int, int]]]:
83 """Takes a bbox, returns a list of lines such that they can be plotted:
85 for line in lines:
86 display.line(line, ctype='red')
87 """
88 x0 = box.beginX
89 x1 = box.endX
90 y0 = box.beginY
91 y1 = box.endY
92 return [[(x0, y0), (x1, y0)], [(x0, y0), (x0, y1)], [(x1, y0), (x1, y1)], [(x0, y1), (x1, y1)]]
94 def eraseDisplay(self) -> None:
95 if self.display:
96 self.display.erase()
98 def displaySpectrumBbox(self) -> None:
99 if self.display:
100 lines = self.bboxToAwfDisplayLines(self.spectrumbbox)
101 for line in lines:
102 self.display.line(line, ctype="red")
103 else:
104 print("No display set")
106 def displayStarLocation(self) -> None:
107 if self.display:
108 self.display.dot("x", *self.qfmResult.brightestObjCentroid, size=50)
109 self.display.dot("o", *self.qfmResult.brightestObjCentroid, size=50)
110 else:
111 print("No display set")
113 def calcGoodSpectrumSection(self, threshold: int = 5, windowSize: int = 5) -> tuple[int, int]:
114 length = len(self.ridgeLineLocations)
115 chunks = length // windowSize
116 stddevs = []
117 for i in range(chunks + 1):
118 stddevs.append(np.std(self.ridgeLineLocations[i * windowSize : (i + 1) * windowSize]))
120 goodPoints = np.where(np.asarray(stddevs) < threshold)[0]
121 minPoint = (goodPoints[2] - 2) * windowSize
122 maxPoint = (goodPoints[-3] + 3) * windowSize
123 minPoint = max(minPoint, 0)
124 maxPoint = min(maxPoint, length)
125 if self.debug:
126 plt.plot(range(0, length + 1, windowSize), stddevs)
127 plt.hlines(threshold, 0, length, colors="r", ls="dashed")
128 plt.vlines(minPoint, 0, max(stddevs) + 10, colors="k", ls="dashed")
129 plt.vlines(maxPoint, 0, max(stddevs) + 10, colors="k", ls="dashed")
130 plt.title(f"Ridgeline scatter, windowSize={windowSize}")
132 return (minPoint, maxPoint)
134 def fit(self) -> None:
135 def gauss(
136 x: float | npt.NDArray[np.float64], a: float, x0: float, sigma: float
137 ) -> float | npt.NDArray[np.float64]:
138 return a * np.exp(-((x - x0) ** 2) / (2 * sigma**2))
140 data = self.spectrumData[self.goodSlice]
141 nRows, nCols = data.shape
142 # don't subtract the row median or even a percentile - seems bad
143 # fitting a const also seems bad - needs some better thought
145 parameters = np.zeros((nRows, 3))
146 pCovs = []
147 xs = np.arange(nCols)
148 for rowNum, row in enumerate(data):
149 peakPos = self.ridgeLineLocations[rowNum]
150 amplitude = row[peakPos]
151 width = 7
152 try:
153 pars, pCov = curve_fit(gauss, xs, row, [amplitude, peakPos, width], maxfev=100)
154 pCovs.append(pCov)
155 except RuntimeError:
156 pars = [np.nan] * 3
157 if not np.all([p < 1e7 for p in pars]):
158 pars = [np.nan] * 3
159 parameters[rowNum] = pars
161 parameters[:, 0] = np.abs(parameters[:, 0])
162 parameters[:, 2] = np.abs(parameters[:, 2])
163 self.parameters = parameters
165 def plot(self) -> None:
166 # spectrum
167 gs = self.fig.add_gridspec(4, 4)
168 ax0 = self.fig.add_subplot(gs[0, 0:3])
169 ax0.tick_params(axis="x", top=True, bottom=False, labeltop=True, labelbottom=False)
170 d = self.spectrumData[self.goodSlice].T
171 vmin = np.percentile(d, 1)
172 vmax = np.percentile(d, 99)
173 pos = ax0.imshow(self.spectrumData[self.goodSlice].T, vmin=vmin, vmax=vmax, origin="lower")
174 div = make_axes_locatable(ax0)
175 cax = div.append_axes("bottom", size="7%", pad="8%")
176 self.fig.colorbar(pos, cax=cax, orientation="horizontal", label="Counts")
178 # spectrum histogram
179 axHist = self.fig.add_subplot(gs[0, 3])
180 data = self.spectrumData
181 histMax = np.nanpercentile(data, 99.99)
182 histMin = np.nanpercentile(data, 0.001)
183 axHist.hist(data[(data >= histMin) & (data <= histMax)].flatten(), bins=100)
184 underflow = len(data[data < histMin])
185 overflow = len(data[data > histMax])
186 axHist.set_yscale("log", nonpositive="clip")
187 axHist.set_title("Spectrum pixel histogram")
188 text = f"Underflow = {underflow}"
189 text += f"\nOverflow = {overflow}"
190 anchored_text = AnchoredText(text, loc="upper right", pad=0.5)
191 axHist.add_artist(anchored_text)
193 # peak fluxes
194 ax1 = self.fig.add_subplot(gs[1, 0:3])
195 ax1.plot(self.ridgeLineValues[self.goodSlice], label="Raw peak value")
196 ax1.plot(self.parameters[:, 0], label="Fitted amplitude")
197 ax1.axhline(self.continuumFlux98, ls="dashed", color="g")
198 ax1.set_ylabel("Peak amplitude (ADU)")
199 ax1.set_xlabel("Spectrum position (pixels)")
200 ax1.legend(
201 title=f"Continuum flux = {self.continuumFlux98:.0f} ADU",
202 loc="center right",
203 framealpha=0.2,
204 facecolor="black",
205 )
206 ax1.set_title("Ridgeline plot")
208 # FWHM
209 ax2 = self.fig.add_subplot(gs[2, 0:3])
210 ax2.plot(self.parameters[:, 2] * 2.355, label="FWHM (pix)")
211 fwhmValues = self.parameters[:, 2] * 2.355
212 amplitudes = self.parameters[:, 0]
213 minVal, maxVal = self.getStableFwhmRegion(fwhmValues, amplitudes)
214 medianFwhm, bestFwhm = self.getMedianAndBestFwhm(fwhmValues, minVal, maxVal)
216 ax2.axhline(medianFwhm, ls="dashed", color="k", label=f"Median FWHM = {medianFwhm:.1f} pix")
217 ax2.axhline(bestFwhm, ls="dashed", color="r", label=f"Best FWHM = {bestFwhm:.1f} pix")
218 ax2.axvline(minVal, ls="dashed", color="k", alpha=0.2)
219 ax2.axvline(maxVal, ls="dashed", color="k", alpha=0.2)
220 ymin = max(np.nanmin(fwhmValues) - 5, 0)
221 if not np.isnan(medianFwhm):
222 ymax = medianFwhm * 2
223 else:
224 ymax = 5 * ymin
225 ax2.set_ylim(ymin, ymax)
226 ax2.set_ylabel("FWHM (pixels)")
227 ax2.set_xlabel("Spectrum position (pixels)")
228 ax2.legend(loc="upper right", framealpha=0.2, facecolor="black")
229 ax2.set_title("Spectrum FWHM")
231 # row fluxes
232 ax3 = self.fig.add_subplot(gs[3, 0:3])
233 ax3.plot(self.rowSums[self.goodSlice], label="Sum across row")
234 ax3.set_ylabel("Total row flux (ADU)")
235 ax3.set_xlabel("Spectrum position (pixels)")
236 ax3.legend(framealpha=0.2, facecolor="black")
237 ax3.set_title("Row sums")
239 # textbox top
240 # ax4 = plt.subplot2grid((4, 4), (1, 3))
241 ax4 = self.fig.add_subplot(gs[1:3, 3])
242 text = "short text"
243 text = self.generateStatsTextboxContent(0)
244 text += self.generateStatsTextboxContent(1)
245 text += self.generateStatsTextboxContent(2)
246 text += self.generateStatsTextboxContent(3)
247 stats_text = AnchoredText(
248 text,
249 loc="center",
250 pad=0.5,
251 prop=dict(size=10.5, ma="left", backgroundcolor="white", color="black", family="monospace"),
252 )
253 ax4.add_artist(stats_text)
254 ax4.axis("off")
256 # textbox middle
257 if self.debug:
258 ax5 = self.fig.add_subplot(gs[2, 3])
259 text = self.generateStatsTextboxContent(-1)
260 stats_text = AnchoredText(
261 text,
262 loc="center",
263 pad=0.5,
264 prop=dict(size=10.5, ma="left", backgroundcolor="white", color="black", family="monospace"),
265 )
266 ax5.add_artist(stats_text)
267 ax5.axis("off")
269 self.fig.tight_layout()
271 if self.savePlotAs:
272 self.fig.savefig(self.savePlotAs)
274 def init(self) -> None:
275 pass
277 def generateStatsTextboxContent(self, section: int) -> str:
278 x, y = self.qfmResult.brightestObjCentroid
280 vi = self.exp.visitInfo
281 exptime = vi.exposureTime
283 fullFilterString = self.exp.filter.physicalLabel
284 filt = fullFilterString.split(FILTER_DELIMITER)[0]
285 grating = fullFilterString.split(FILTER_DELIMITER)[1]
287 airmass = vi.getBoresightAirmass()
288 rotangle = vi.getBoresightRotAngle().asDegrees()
290 azAlt = vi.getBoresightAzAlt()
291 az = azAlt[0].asDegrees()
292 el = azAlt[1].asDegrees()
294 obj = self.exp.visitInfo.object
296 lines = []
298 if section == 0:
299 lines.append("----- Star stats -----")
300 lines.append(f"Star centroid @ {x:.0f}, {y:.0f}")
301 lines.append(f"Star max pixel = {self.starPeakFlux:,.0f} ADU")
302 lines.append(f"Star Ap25 flux = {self.qfmResult.brightestObjApFlux25:,.0f} ADU")
303 lines.extend(["", ""]) # section break
304 return "\n".join([line for line in lines])
306 if section == 1:
307 lines.append("------ Image stats ---------")
308 imageMedian = np.median(self.exp.image.array)
309 lines.append(f"Image median = {imageMedian:.2f} ADU")
310 lines.append(f"Exposure time = {exptime:.2f} s")
311 lines.extend(["", ""]) # section break
312 return "\n".join([line for line in lines])
314 if section == 2:
315 lines.append("------- Rate stats ---------")
316 lines.append(f"Star max pixel = {self.starPeakFlux / exptime:,.0f} ADU/s")
317 lines.append(f"Spectrum contiuum = {self.continuumFlux98 / exptime:,.1f} ADU/s")
318 lines.extend(["", ""]) # section break
319 return "\n".join([line for line in lines])
321 if section == 3:
322 lines.append("----- Observation info -----")
323 lines.append(f"object = {obj}")
324 lines.append(f"filter = {filt}")
325 lines.append(f"grating = {grating}")
326 lines.append(f"rotpa = {rotangle:.1f}")
328 lines.append(f"az = {az:.1f}")
329 lines.append(f"el = {el:.1f}")
330 lines.append(f"airmass = {airmass:.3f}")
331 return "\n".join([line for line in lines])
333 if section == -1: # special -1 for debug
334 lines.append("---------- Debug -----------")
335 lines.append(f"spectrum bbox: {self.spectrumbbox}")
336 lines.append(f"Good range = {self.goodSpectrumMinY},{self.goodSpectrumMaxY}")
337 return "\n".join([line for line in lines])
339 return ""
341 def run(self) -> None:
342 self.qfmResult = self.qfmTask.run(self.exp)
343 self.intCentroidX = int(np.round(self.qfmResult.brightestObjCentroid)[0])
344 self.intCentroidY = int(np.round(self.qfmResult.brightestObjCentroid)[1])
345 self.starPeakFlux = self.exp.image.array[self.intCentroidY, self.intCentroidX]
347 self.spectrumbbox = self.processStarTask.calcSpectrumBBox(
348 self.exp, self.qfmResult.brightestObjCentroid, 200
349 )
350 self.spectrumData = self.exp.image[self.spectrumbbox].array
352 self.ridgeLineLocations = np.argmax(self.spectrumData, axis=1)
353 self.ridgeLineValues = self.spectrumData[
354 range(self.spectrumbbox.getHeight()), self.ridgeLineLocations
355 ]
356 self.rowSums = np.sum(self.spectrumData, axis=1)
358 coords = self.calcGoodSpectrumSection()
359 self.goodSpectrumMinY = coords[0]
360 self.goodSpectrumMaxY = coords[1]
361 self.goodSlice = slice(coords[0], coords[1])
363 self.continuumFlux90 = np.percentile(self.ridgeLineValues, 90) # for emission stars
364 self.continuumFlux98 = np.percentile(self.ridgeLineValues, 98) # for most stars
366 self.fit()
367 self.plot()
369 return
371 @staticmethod
372 def getMedianAndBestFwhm(fwhmValues: np.ndarray, minIndex: int, maxIndex: int) -> tuple[float, float]:
373 with warnings.catch_warnings(): # to supress nan warnings, which are fine
374 warnings.simplefilter("ignore")
375 clippedValues = sigma_clip(fwhmValues[minIndex:maxIndex])
376 # cast back with asArray needed becase sigma_clip returns
377 # masked array which doesn't play nice with np.nan<med/percentile>
378 clippedValues = np.asarray(clippedValues)
379 medianFwhm = np.nanmedian(clippedValues)
380 bestFocusFwhm = np.nanpercentile(np.asarray(clippedValues), 2)
381 return medianFwhm, bestFocusFwhm
383 def getStableFwhmRegion(
384 self, fwhmValues: np.ndarray, amplitudes: np.ndarray, smoothing: int = 1, maxDifferential: int = 4
385 ) -> tuple[int, int]:
386 # smooth the fwhmValues values
387 # differentiate
388 # take the longest contiguous region of 1s
389 # check section corresponds to top 25% in ampl to exclude 2nd order
390 # if not, pick next longest run, etc
391 # walk out from ends of that list over bumps smaller than maxDiff
393 smoothFwhm = np.convolve(fwhmValues, np.ones(smoothing) / smoothing, mode="same")
394 diff = np.diff(smoothFwhm, append=smoothFwhm[-1])
396 indices = np.where(1 - np.abs(diff) < 1)[0]
397 diffIndices = np.diff(indices)
399 # [list(g) for k, g in groupby('AAAABBBCCD')] -->[['A', 'A', 'A', 'A'],
400 # ... ['B', 'B', 'B'], ['C', 'C'], ['D']]
401 indexLists = [list(g) for k, g in groupby(diffIndices)]
402 listLengths = [len(lst) for lst in indexLists]
404 amplitudeThreshold = np.nanpercentile(amplitudes, 75)
405 sortedListLengths = sorted(listLengths)
407 longestListLength = 0
408 longestListIndex = 0
409 for listLength in sortedListLengths[::-1]:
410 longestListLength = listLength
411 longestListIndex = listLengths.index(longestListLength)
412 longestListStartTruePosition = int(np.sum(listLengths[0:longestListIndex]))
413 longestListStartTruePosition += int(longestListLength / 2) # we want the mid-run value
414 if amplitudes[longestListStartTruePosition] > amplitudeThreshold:
415 break
417 startOfLongList = np.sum(listLengths[0:longestListIndex])
418 endOfLongList = startOfLongList + longestListLength
420 endValue = endOfLongList
421 for lst in indexLists[longestListIndex + 1 :]:
422 value = lst[0]
423 if value > maxDifferential:
424 break
425 endValue += len(lst)
427 startValue = startOfLongList
428 for lst in indexLists[longestListIndex - 1 :: -1]:
429 value = lst[0]
430 if value > maxDifferential:
431 break
432 startValue -= len(lst)
434 startValue = int(max(0, startValue))
435 endValue = int(min(len(fwhmValues), endValue))
437 if not self.debug:
438 return startValue, endValue
440 medianFwhm, bestFocusFwhm = self.getMedianAndBestFwhm(fwhmValues, startValue, endValue)
441 xlim = (-20, len(fwhmValues))
443 plt.figure(figsize=(10, 6))
444 plt.plot(fwhmValues)
445 plt.vlines(startValue, 0, 50, "r")
446 plt.vlines(endValue, 0, 50, "r")
447 plt.hlines(medianFwhm, xlim[0], xlim[1])
448 plt.hlines(bestFocusFwhm, xlim[0], xlim[1], "r", ls="--")
450 plt.vlines(startOfLongList, 0, 50, "g")
451 plt.vlines(endOfLongList, 0, 50, "g")
453 plt.ylim(0, 200)
454 plt.xlim(xlim)
455 plt.show()
457 plt.figure(figsize=(10, 6))
458 plt.plot(diffIndices)
459 plt.vlines(startValue, 0, 50, "r")
460 plt.vlines(endValue, 0, 50, "r")
462 plt.vlines(startOfLongList, 0, 50, "g")
463 plt.vlines(endOfLongList, 0, 50, "g")
464 plt.ylim(0, 30)
465 plt.xlim(xlim)
466 plt.show()
467 return startValue, endValue