Coverage for python/lsst/summit/utils/spectrumExaminer.py: 10%
305 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-09 05:00 -0700
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-09 05:00 -0700
1# This file is part of summit_utils.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ["SpectrumExaminer"]
24import warnings
25from itertools import groupby
26from typing import Any
28import matplotlib.pyplot as plt
29import numpy as np
30from astropy.stats import sigma_clip
31from matplotlib.offsetbox import AnchoredText
32from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
33from scipy.optimize import curve_fit
35import lsst.afw.display as afwDisplay
36import lsst.afw.image as afwImage
37from lsst.atmospec.processStar import ProcessStarTask
38from lsst.obs.lsst.translators.lsst import FILTER_DELIMITER
39from lsst.pipe.tasks.quickFrameMeasurement import QuickFrameMeasurementTask, QuickFrameMeasurementTaskConfig
40from lsst.summit.utils.utils import getImageStats
43class SpectrumExaminer:
44 """Task for the QUICK spectral extraction of single-star dispersed images.
46 For a full description of how this tasks works, see the run() method.
47 """
49 # ConfigClass = SummarizeImageTaskConfig
50 # _DefaultName = "summarizeImage"
52 def __init__(
53 self,
54 exp: afwImage.Exposure,
55 display: afwDisplay.Display = None,
56 debug: bool | None = False,
57 savePlotAs: str | None = None,
58 **kwargs: Any,
59 ):
60 super().__init__(**kwargs)
61 self.exp = exp
62 self.display = display
63 self.debug = debug
64 self.savePlotAs = savePlotAs
66 qfmTaskConfig = QuickFrameMeasurementTaskConfig()
67 self.qfmTask = QuickFrameMeasurementTask(config=qfmTaskConfig)
69 pstConfig = ProcessStarTask.ConfigClass()
70 pstConfig.offsetFromMainStar = 400
71 self.processStarTask = ProcessStarTask(config=pstConfig)
73 self.imStats = getImageStats(exp)
75 self.init()
77 @staticmethod
78 def bboxToAwfDisplayLines(box) -> list[list[tuple[int, int]]]:
79 """Takes a bbox, returns a list of lines such that they can be plotted:
81 for line in lines:
82 display.line(line, ctype='red')
83 """
84 x0 = box.beginX
85 x1 = box.endX
86 y0 = box.beginY
87 y1 = box.endY
88 return [[(x0, y0), (x1, y0)], [(x0, y0), (x0, y1)], [(x1, y0), (x1, y1)], [(x0, y1), (x1, y1)]]
90 def eraseDisplay(self) -> None:
91 if self.display:
92 self.display.erase()
94 def displaySpectrumBbox(self) -> None:
95 if self.display:
96 lines = self.bboxToAwfDisplayLines(self.spectrumbbox)
97 for line in lines:
98 self.display.line(line, ctype="red")
99 else:
100 print("No display set")
102 def displayStarLocation(self) -> None:
103 if self.display:
104 self.display.dot("x", *self.qfmResult.brightestObjCentroid, size=50)
105 self.display.dot("o", *self.qfmResult.brightestObjCentroid, size=50)
106 else:
107 print("No display set")
109 def calcGoodSpectrumSection(self, threshold: int = 5, windowSize: int = 5) -> tuple[int, int]:
110 length = len(self.ridgeLineLocations)
111 chunks = length // windowSize
112 stddevs = []
113 for i in range(chunks + 1):
114 stddevs.append(np.std(self.ridgeLineLocations[i * windowSize : (i + 1) * windowSize]))
116 goodPoints = np.where(np.asarray(stddevs) < threshold)[0]
117 minPoint = (goodPoints[2] - 2) * windowSize
118 maxPoint = (goodPoints[-3] + 3) * windowSize
119 minPoint = max(minPoint, 0)
120 maxPoint = min(maxPoint, length)
121 if self.debug:
122 plt.plot(range(0, length + 1, windowSize), stddevs)
123 plt.hlines(threshold, 0, length, colors="r", ls="dashed")
124 plt.vlines(minPoint, 0, max(stddevs) + 10, colors="k", ls="dashed")
125 plt.vlines(maxPoint, 0, max(stddevs) + 10, colors="k", ls="dashed")
126 plt.title(f"Ridgeline scatter, windowSize={windowSize}")
128 return (minPoint, maxPoint)
130 def fit(self) -> None:
131 def gauss(x, a, x0, sigma):
132 return a * np.exp(-((x - x0) ** 2) / (2 * sigma**2))
134 data = self.spectrumData[self.goodSlice]
135 nRows, nCols = data.shape
136 # don't subtract the row median or even a percentile - seems bad
137 # fitting a const also seems bad - needs some better thought
139 parameters = np.zeros((nRows, 3))
140 pCovs = []
141 xs = np.arange(nCols)
142 for rowNum, row in enumerate(data):
143 peakPos = self.ridgeLineLocations[rowNum]
144 amplitude = row[peakPos]
145 width = 7
146 try:
147 pars, pCov = curve_fit(gauss, xs, row, [amplitude, peakPos, width], maxfev=100)
148 pCovs.append(pCov)
149 except RuntimeError:
150 pars = [np.nan] * 3
151 if not np.all([p < 1e7 for p in pars]):
152 pars = [np.nan] * 3
153 parameters[rowNum] = pars
155 parameters[:, 0] = np.abs(parameters[:, 0])
156 parameters[:, 2] = np.abs(parameters[:, 2])
157 self.parameters = parameters
159 def plot(self) -> None:
160 fig = plt.figure(figsize=(10, 10))
162 # spectrum
163 ax0 = plt.subplot2grid((4, 4), (0, 0), colspan=3)
164 ax0.tick_params(axis="x", top=True, bottom=False, labeltop=True, labelbottom=False)
165 d = self.spectrumData[self.goodSlice].T
166 vmin = np.percentile(d, 1)
167 vmax = np.percentile(d, 99)
168 pos = ax0.imshow(self.spectrumData[self.goodSlice].T, vmin=vmin, vmax=vmax, origin="lower")
169 div = make_axes_locatable(ax0)
170 cax = div.append_axes("bottom", size="7%", pad="8%")
171 fig.colorbar(pos, cax=cax, orientation="horizontal", label="Counts")
173 # spectrum histogram
174 axHist = plt.subplot2grid((4, 4), (0, 3))
175 data = self.spectrumData
176 histMax = np.nanpercentile(data, 99.99)
177 histMin = np.nanpercentile(data, 0.001)
178 axHist.hist(data[(data >= histMin) & (data <= histMax)].flatten(), bins=100)
179 underflow = len(data[data < histMin])
180 overflow = len(data[data > histMax])
181 axHist.set_yscale("log", nonpositive="clip")
182 axHist.set_title("Spectrum pixel histogram")
183 text = f"Underflow = {underflow}"
184 text += f"\nOverflow = {overflow}"
185 anchored_text = AnchoredText(text, loc=1, pad=0.5)
186 axHist.add_artist(anchored_text)
188 # peak fluxes
189 ax1 = plt.subplot2grid((4, 4), (1, 0), colspan=3)
190 ax1.plot(self.ridgeLineValues[self.goodSlice], label="Raw peak value")
191 ax1.plot(self.parameters[:, 0], label="Fitted amplitude")
192 ax1.axhline(self.continuumFlux98, ls="dashed", color="g")
193 ax1.set_ylabel("Peak amplitude (ADU)")
194 ax1.set_xlabel("Spectrum position (pixels)")
195 ax1.legend(
196 title=f"Continuum flux = {self.continuumFlux98:.0f} ADU",
197 loc="center right",
198 framealpha=0.2,
199 facecolor="black",
200 )
201 ax1.set_title("Ridgeline plot")
203 # FWHM
204 ax2 = plt.subplot2grid((4, 4), (2, 0), colspan=3)
205 ax2.plot(self.parameters[:, 2] * 2.355, label="FWHM (pix)")
206 fwhmValues = self.parameters[:, 2] * 2.355
207 amplitudes = self.parameters[:, 0]
208 minVal, maxVal = self.getStableFwhmRegion(fwhmValues, amplitudes)
209 medianFwhm, bestFwhm = self.getMedianAndBestFwhm(fwhmValues, minVal, maxVal)
211 ax2.axhline(medianFwhm, ls="dashed", color="k", label=f"Median FWHM = {medianFwhm:.1f} pix")
212 ax2.axhline(bestFwhm, ls="dashed", color="r", label=f"Best FWHM = {bestFwhm:.1f} pix")
213 ax2.axvline(minVal, ls="dashed", color="k", alpha=0.2)
214 ax2.axvline(maxVal, ls="dashed", color="k", alpha=0.2)
215 ymin = max(np.nanmin(fwhmValues) - 5, 0)
216 if not np.isnan(medianFwhm):
217 ymax = medianFwhm * 2
218 else:
219 ymax = 5 * ymin
220 ax2.set_ylim(ymin, ymax)
221 ax2.set_ylabel("FWHM (pixels)")
222 ax2.set_xlabel("Spectrum position (pixels)")
223 ax2.legend(loc="upper right", framealpha=0.2, facecolor="black")
224 ax2.set_title("Spectrum FWHM")
226 # row fluxes
227 ax3 = plt.subplot2grid((4, 4), (3, 0), colspan=3)
228 ax3.plot(self.rowSums[self.goodSlice], label="Sum across row")
229 ax3.set_ylabel("Total row flux (ADU)")
230 ax3.set_xlabel("Spectrum position (pixels)")
231 ax3.legend(framealpha=0.2, facecolor="black")
232 ax3.set_title("Row sums")
234 # textbox top
235 # ax4 = plt.subplot2grid((4, 4), (1, 3))
236 ax4 = plt.subplot2grid((4, 4), (1, 3), rowspan=2)
237 text = "short text"
238 text = self.generateStatsTextboxContent(0)
239 text += self.generateStatsTextboxContent(1)
240 text += self.generateStatsTextboxContent(2)
241 text += self.generateStatsTextboxContent(3)
242 stats_text = AnchoredText(
243 text,
244 loc="center",
245 pad=0.5,
246 prop=dict(size=10.5, ma="left", backgroundcolor="white", color="black", family="monospace"),
247 )
248 ax4.add_artist(stats_text)
249 ax4.axis("off")
251 # textbox middle
252 if self.debug:
253 ax5 = plt.subplot2grid((4, 4), (2, 3))
254 text = self.generateStatsTextboxContent(-1)
255 stats_text = AnchoredText(
256 text,
257 loc="center",
258 pad=0.5,
259 prop=dict(size=10.5, ma="left", backgroundcolor="white", color="black", family="monospace"),
260 )
261 ax5.add_artist(stats_text)
262 ax5.axis("off")
264 plt.tight_layout()
265 plt.show()
267 if self.savePlotAs:
268 fig.savefig(self.savePlotAs)
270 def init(self):
271 pass
273 def generateStatsTextboxContent(self, section: int) -> str:
274 x, y = self.qfmResult.brightestObjCentroid
276 vi = self.exp.visitInfo
277 exptime = vi.exposureTime
279 fullFilterString = self.exp.filter.physicalLabel
280 filt = fullFilterString.split(FILTER_DELIMITER)[0]
281 grating = fullFilterString.split(FILTER_DELIMITER)[1]
283 airmass = vi.getBoresightAirmass()
284 rotangle = vi.getBoresightRotAngle().asDegrees()
286 azAlt = vi.getBoresightAzAlt()
287 az = azAlt[0].asDegrees()
288 el = azAlt[1].asDegrees()
290 obj = self.exp.visitInfo.object
292 lines = []
294 if section == 0:
295 lines.append("----- Star stats -----")
296 lines.append(f"Star centroid @ {x:.0f}, {y:.0f}")
297 lines.append(f"Star max pixel = {self.starPeakFlux:,.0f} ADU")
298 lines.append(f"Star Ap25 flux = {self.qfmResult.brightestObjApFlux25:,.0f} ADU")
299 lines.extend(["", ""]) # section break
300 return "\n".join([line for line in lines])
302 if section == 1:
303 lines.append("------ Image stats ---------")
304 imageMedian = np.median(self.exp.image.array)
305 lines.append(f"Image median = {imageMedian:.2f} ADU")
306 lines.append(f"Exposure time = {exptime:.2f} s")
307 lines.extend(["", ""]) # section break
308 return "\n".join([line for line in lines])
310 if section == 2:
311 lines.append("------- Rate stats ---------")
312 lines.append(f"Star max pixel = {self.starPeakFlux/exptime:,.0f} ADU/s")
313 lines.append(f"Spectrum contiuum = {self.continuumFlux98/exptime:,.1f} ADU/s")
314 lines.extend(["", ""]) # section break
315 return "\n".join([line for line in lines])
317 if section == 3:
318 lines.append("----- Observation info -----")
319 lines.append(f"object = {obj}")
320 lines.append(f"filter = {filt}")
321 lines.append(f"grating = {grating}")
322 lines.append(f"rotpa = {rotangle:.1f}")
324 lines.append(f"az = {az:.1f}")
325 lines.append(f"el = {el:.1f}")
326 lines.append(f"airmass = {airmass:.3f}")
327 return "\n".join([line for line in lines])
329 if section == -1: # special -1 for debug
330 lines.append("---------- Debug -----------")
331 lines.append(f"spectrum bbox: {self.spectrumbbox}")
332 lines.append(f"Good range = {self.goodSpectrumMinY},{self.goodSpectrumMaxY}")
333 return "\n".join([line for line in lines])
335 return ""
337 def run(self) -> None:
338 self.qfmResult = self.qfmTask.run(self.exp)
339 self.intCentroidX = int(np.round(self.qfmResult.brightestObjCentroid)[0])
340 self.intCentroidY = int(np.round(self.qfmResult.brightestObjCentroid)[1])
341 self.starPeakFlux = self.exp.image.array[self.intCentroidY, self.intCentroidX]
343 self.spectrumbbox = self.processStarTask.calcSpectrumBBox(
344 self.exp, self.qfmResult.brightestObjCentroid, 200
345 )
346 self.spectrumData = self.exp.image[self.spectrumbbox].array
348 self.ridgeLineLocations = np.argmax(self.spectrumData, axis=1)
349 self.ridgeLineValues = self.spectrumData[
350 range(self.spectrumbbox.getHeight()), self.ridgeLineLocations
351 ]
352 self.rowSums = np.sum(self.spectrumData, axis=1)
354 coords = self.calcGoodSpectrumSection()
355 self.goodSpectrumMinY = coords[0]
356 self.goodSpectrumMaxY = coords[1]
357 self.goodSlice = slice(coords[0], coords[1])
359 self.continuumFlux90 = np.percentile(self.ridgeLineValues, 90) # for emission stars
360 self.continuumFlux98 = np.percentile(self.ridgeLineValues, 98) # for most stars
362 self.fit()
363 self.plot()
365 return
367 @staticmethod
368 def getMedianAndBestFwhm(fwhmValues: np.ndarray, minIndex: int, maxIndex: int) -> tuple[float, float]:
369 with warnings.catch_warnings(): # to supress nan warnings, which are fine
370 warnings.simplefilter("ignore")
371 clippedValues = sigma_clip(fwhmValues[minIndex:maxIndex])
372 # cast back with asArray needed becase sigma_clip returns
373 # masked array which doesn't play nice with np.nan<med/percentile>
374 clippedValues = np.asarray(clippedValues)
375 medianFwhm = np.nanmedian(clippedValues)
376 bestFocusFwhm = np.nanpercentile(np.asarray(clippedValues), 2)
377 return medianFwhm, bestFocusFwhm
379 def getStableFwhmRegion(
380 self, fwhmValues: np.ndarray, amplitudes: np.ndarray, smoothing: int = 1, maxDifferential: int = 4
381 ) -> tuple[int, int]:
382 # smooth the fwhmValues values
383 # differentiate
384 # take the longest contiguous region of 1s
385 # check section corresponds to top 25% in ampl to exclude 2nd order
386 # if not, pick next longest run, etc
387 # walk out from ends of that list over bumps smaller than maxDiff
389 smoothFwhm = np.convolve(fwhmValues, np.ones(smoothing) / smoothing, mode="same")
390 diff = np.diff(smoothFwhm, append=smoothFwhm[-1])
392 indices = np.where(1 - np.abs(diff) < 1)[0]
393 diffIndices = np.diff(indices)
395 # [list(g) for k, g in groupby('AAAABBBCCD')] -->[['A', 'A', 'A', 'A'],
396 # ... ['B', 'B', 'B'], ['C', 'C'], ['D']]
397 indexLists = [list(g) for k, g in groupby(diffIndices)]
398 listLengths = [len(lst) for lst in indexLists]
400 amplitudeThreshold = np.nanpercentile(amplitudes, 75)
401 sortedListLengths = sorted(listLengths)
403 for listLength in sortedListLengths[::-1]:
404 longestListLength = listLength
405 longestListIndex = listLengths.index(longestListLength)
406 longestListStartTruePosition = int(np.sum(listLengths[0:longestListIndex]))
407 longestListStartTruePosition += int(longestListLength / 2) # we want the mid-run value
408 if amplitudes[longestListStartTruePosition] > amplitudeThreshold:
409 break
411 startOfLongList = np.sum(listLengths[0:longestListIndex])
412 endOfLongList = startOfLongList + longestListLength
414 endValue = endOfLongList
415 for lst in indexLists[longestListIndex + 1 :]:
416 value = lst[0]
417 if value > maxDifferential:
418 break
419 endValue += len(lst)
421 startValue = startOfLongList
422 for lst in indexLists[longestListIndex - 1 :: -1]:
423 value = lst[0]
424 if value > maxDifferential:
425 break
426 startValue -= len(lst)
428 startValue = int(max(0, startValue))
429 endValue = int(min(len(fwhmValues), endValue))
431 if not self.debug:
432 return startValue, endValue
434 medianFwhm, bestFocusFwhm = self.getMedianAndBestFwhm(fwhmValues, startValue, endValue)
435 xlim = (-20, len(fwhmValues))
437 plt.figure(figsize=(10, 6))
438 plt.plot(fwhmValues)
439 plt.vlines(startValue, 0, 50, "r")
440 plt.vlines(endValue, 0, 50, "r")
441 plt.hlines(medianFwhm, xlim[0], xlim[1])
442 plt.hlines(bestFocusFwhm, xlim[0], xlim[1], "r", ls="--")
444 plt.vlines(startOfLongList, 0, 50, "g")
445 plt.vlines(endOfLongList, 0, 50, "g")
447 plt.ylim(0, 200)
448 plt.xlim(xlim)
449 plt.show()
451 plt.figure(figsize=(10, 6))
452 plt.plot(diffIndices)
453 plt.vlines(startValue, 0, 50, "r")
454 plt.vlines(endValue, 0, 50, "r")
456 plt.vlines(startOfLongList, 0, 50, "g")
457 plt.vlines(endOfLongList, 0, 50, "g")
458 plt.ylim(0, 30)
459 plt.xlim(xlim)
460 plt.show()
461 return startValue, endValue