Coverage for python/lsst/summit/utils/spectrumExaminer.py: 8%
302 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-18 03:34 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-18 03:34 -0800
1# This file is part of summit_utils.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ['SpectrumExaminer']
24import numpy as np
25import matplotlib.pyplot as plt
26from matplotlib.offsetbox import AnchoredText
27from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
28from scipy.optimize import curve_fit
29from itertools import groupby
30from astropy.stats import sigma_clip
31import warnings
33from lsst.atmospec.processStar import ProcessStarTask
34from lsst.pipe.tasks.quickFrameMeasurement import QuickFrameMeasurementTask, QuickFrameMeasurementTaskConfig
36from lsst.obs.lsst.translators.lsst import FILTER_DELIMITER
37from lsst.summit.utils.utils import getImageStats
40class SpectrumExaminer():
41 """Task for the QUICK spectral extraction of single-star dispersed images.
43 For a full description of how this tasks works, see the run() method.
44 """
46 # ConfigClass = SummarizeImageTaskConfig
47 # _DefaultName = "summarizeImage"
49 def __init__(self, exp, display=None, debug=False, savePlotAs=None, **kwargs):
50 super().__init__(**kwargs)
51 self.exp = exp
52 self.display = display
53 self.debug = debug
54 self.savePlotAs = savePlotAs
56 qfmTaskConfig = QuickFrameMeasurementTaskConfig()
57 self.qfmTask = QuickFrameMeasurementTask(config=qfmTaskConfig)
59 pstConfig = ProcessStarTask.ConfigClass()
60 pstConfig.offsetFromMainStar = 400
61 self.processStarTask = ProcessStarTask(config=pstConfig)
63 self.imStats = getImageStats(exp)
65 self.init()
67 @staticmethod
68 def bboxToAwfDisplayLines(box):
69 """Takes a bbox, returns a list of lines such that they can be plotted:
71 for line in lines:
72 display.line(line, ctype='red')
73 """
74 x0 = box.beginX
75 x1 = box.endX
76 y0 = box.beginY
77 y1 = box.endY
78 return [[(x0, y0), (x1, y0)], [(x0, y0), (x0, y1)], [(x1, y0), (x1, y1)], [(x0, y1), (x1, y1)]]
80 def eraseDisplay(self):
81 if self.display:
82 self.display.erase()
84 def displaySpectrumBbox(self):
85 if self.display:
86 lines = self.bboxToAwfDisplayLines(self.spectrumbbox)
87 for line in lines:
88 self.display.line(line, ctype='red')
89 else:
90 print("No display set")
92 def displayStarLocation(self):
93 if self.display:
94 self.display.dot('x', *self.qfmResult.brightestObjCentroid, size=50)
95 self.display.dot('o', *self.qfmResult.brightestObjCentroid, size=50)
96 else:
97 print("No display set")
99 def calcGoodSpectrumSection(self, threshold=5, windowSize=5):
100 length = len(self.ridgeLineLocations)
101 chunks = length // windowSize
102 stddevs = []
103 for i in range(chunks+1):
104 stddevs.append(np.std(self.ridgeLineLocations[i*windowSize:(i+1)*windowSize]))
106 goodPoints = np.where(np.asarray(stddevs) < threshold)[0]
107 minPoint = (goodPoints[2] - 2) * windowSize
108 maxPoint = (goodPoints[-3] + 3) * windowSize
109 minPoint = max(minPoint, 0)
110 maxPoint = min(maxPoint, length)
111 if self.debug:
112 plt.plot(range(0, length+1, windowSize), stddevs)
113 plt.hlines(threshold, 0, length, colors='r', ls='dashed')
114 plt.vlines(minPoint, 0, max(stddevs)+10, colors='k', ls='dashed')
115 plt.vlines(maxPoint, 0, max(stddevs)+10, colors='k', ls='dashed')
116 plt.title(f'Ridgeline scatter, windowSize={windowSize}')
118 return (minPoint, maxPoint)
120 def fit(self):
121 def gauss(x, a, x0, sigma):
122 return a*np.exp(-(x-x0)**2/(2*sigma**2))
124 data = self.spectrumData[self.goodSlice]
125 nRows, nCols = data.shape
126 # don't subtract the row median or even a percentile - seems bad
127 # fitting a const also seems bad - needs some better thought
129 parameters = np.zeros((nRows, 3))
130 pCovs = []
131 xs = np.arange(nCols)
132 for rowNum, row in enumerate(data):
133 peakPos = self.ridgeLineLocations[rowNum]
134 amplitude = row[peakPos]
135 width = 7
136 try:
137 pars, pCov = curve_fit(gauss, xs, row, [amplitude, peakPos, width], maxfev=100)
138 pCovs.append(pCov)
139 except RuntimeError:
140 pars = [np.nan] * 3
141 if not np.all([p < 1e7 for p in pars]):
142 pars = [np.nan] * 3
143 parameters[rowNum] = pars
145 parameters[:, 0] = np.abs(parameters[:, 0])
146 parameters[:, 2] = np.abs(parameters[:, 2])
147 self.parameters = parameters
149 def plot(self):
150 fig = plt.figure(figsize=(10, 10))
152 # spectrum
153 ax0 = plt.subplot2grid((4, 4), (0, 0), colspan=3)
154 ax0.tick_params(axis='x', top=True, bottom=False, labeltop=True, labelbottom=False)
155 d = self.spectrumData[self.goodSlice].T
156 vmin = np.percentile(d, 1)
157 vmax = np.percentile(d, 99)
158 pos = ax0.imshow(self.spectrumData[self.goodSlice].T, vmin=vmin, vmax=vmax, origin='lower')
159 div = make_axes_locatable(ax0)
160 cax = div.append_axes("bottom", size="7%", pad="8%")
161 fig.colorbar(pos, cax=cax, orientation="horizontal", label="Counts")
163 # spectrum histogram
164 axHist = plt.subplot2grid((4, 4), (0, 3))
165 data = self.spectrumData
166 histMax = np.nanpercentile(data, 99.99)
167 histMin = np.nanpercentile(data, 0.001)
168 axHist.hist(data[(data >= histMin) & (data <= histMax)].flatten(), bins=100)
169 underflow = len(data[data < histMin])
170 overflow = len(data[data > histMax])
171 axHist.set_yscale('log', nonpositive='clip')
172 axHist.set_title('Spectrum pixel histogram')
173 text = f"Underflow = {underflow}"
174 text += f"\nOverflow = {overflow}"
175 anchored_text = AnchoredText(text, loc=1, pad=0.5)
176 axHist.add_artist(anchored_text)
178 # peak fluxes
179 ax1 = plt.subplot2grid((4, 4), (1, 0), colspan=3)
180 ax1.plot(self.ridgeLineValues[self.goodSlice], label='Raw peak value')
181 ax1.plot(self.parameters[:, 0], label='Fitted amplitude')
182 ax1.axhline(self.continuumFlux98, ls='dashed', color='g')
183 ax1.set_ylabel('Peak amplitude (ADU)')
184 ax1.set_xlabel('Spectrum position (pixels)')
185 ax1.legend(title=f"Continuum flux = {self.continuumFlux98:.0f} ADU",
186 loc="center right", framealpha=0.2, facecolor="black")
187 ax1.set_title('Ridgeline plot')
189 # FWHM
190 ax2 = plt.subplot2grid((4, 4), (2, 0), colspan=3)
191 ax2.plot(self.parameters[:, 2]*2.355, label="FWHM (pix)")
192 fwhmValues = self.parameters[:, 2]*2.355
193 amplitudes = self.parameters[:, 0]
194 minVal, maxVal = self.getStableFwhmRegion(fwhmValues, amplitudes)
195 medianFwhm, bestFwhm = self.getMedianAndBestFwhm(fwhmValues, minVal, maxVal)
197 ax2.axhline(medianFwhm, ls='dashed', color='k',
198 label=f"Median FWHM = {medianFwhm:.1f} pix")
199 ax2.axhline(bestFwhm, ls='dashed', color='r',
200 label=f"Best FWHM = {bestFwhm:.1f} pix")
201 ax2.axvline(minVal, ls='dashed', color='k', alpha=0.2)
202 ax2.axvline(maxVal, ls='dashed', color='k', alpha=0.2)
203 ymin = max(np.nanmin(fwhmValues)-5, 0)
204 if not np.isnan(medianFwhm):
205 ymax = medianFwhm*2
206 else:
207 ymax = 5*ymin
208 ax2.set_ylim(ymin, ymax)
209 ax2.set_ylabel('FWHM (pixels)')
210 ax2.set_xlabel('Spectrum position (pixels)')
211 ax2.legend(loc="upper right", framealpha=0.2, facecolor="black")
212 ax2.set_title('Spectrum FWHM')
214 # row fluxes
215 ax3 = plt.subplot2grid((4, 4), (3, 0), colspan=3)
216 ax3.plot(self.rowSums[self.goodSlice], label="Sum across row")
217 ax3.set_ylabel('Total row flux (ADU)')
218 ax3.set_xlabel('Spectrum position (pixels)')
219 ax3.legend(framealpha=0.2, facecolor="black")
220 ax3.set_title('Row sums')
222 # textbox top
223# ax4 = plt.subplot2grid((4, 4), (1, 3))
224 ax4 = plt.subplot2grid((4, 4), (1, 3), rowspan=2)
225 text = "short text"
226 text = self.generateStatsTextboxContent(0)
227 text += self.generateStatsTextboxContent(1)
228 text += self.generateStatsTextboxContent(2)
229 text += self.generateStatsTextboxContent(3)
230 stats_text = AnchoredText(text, loc="center", pad=0.5,
231 prop=dict(size=10.5, ma="left", backgroundcolor="white",
232 color="black", family='monospace'))
233 ax4.add_artist(stats_text)
234 ax4.axis('off')
236 # textbox middle
237 if self.debug:
238 ax5 = plt.subplot2grid((4, 4), (2, 3))
239 text = self.generateStatsTextboxContent(-1)
240 stats_text = AnchoredText(text, loc="center", pad=0.5,
241 prop=dict(size=10.5, ma="left", backgroundcolor="white",
242 color="black", family='monospace'))
243 ax5.add_artist(stats_text)
244 ax5.axis('off')
246 plt.tight_layout()
247 plt.show()
249 if self.savePlotAs:
250 fig.savefig(self.savePlotAs)
252 def init(self):
253 pass
255 def generateStatsTextboxContent(self, section):
256 x, y = self.qfmResult.brightestObjCentroid
258 vi = self.exp.visitInfo
259 exptime = vi.exposureTime
261 fullFilterString = self.exp.filter.physicalLabel
262 filt = fullFilterString.split(FILTER_DELIMITER)[0]
263 grating = fullFilterString.split(FILTER_DELIMITER)[1]
265 airmass = vi.getBoresightAirmass()
266 rotangle = vi.getBoresightRotAngle().asDegrees()
268 azAlt = vi.getBoresightAzAlt()
269 az = azAlt[0].asDegrees()
270 el = azAlt[1].asDegrees()
272 obj = self.exp.visitInfo.object
274 lines = []
276 if section == 0:
277 lines.append("----- Star stats -----")
278 lines.append(f"Star centroid @ {x:.0f}, {y:.0f}")
279 lines.append(f"Star max pixel = {self.starPeakFlux:,.0f} ADU")
280 lines.append(f"Star Ap25 flux = {self.qfmResult.brightestObjApFlux25:,.0f} ADU")
281 lines.extend(["", ""]) # section break
282 return '\n'.join([line for line in lines])
284 if section == 1:
285 lines.append("------ Image stats ---------")
286 imageMedian = np.median(self.exp.image.array)
287 lines.append(f"Image median = {imageMedian:.2f} ADU")
288 lines.append(f"Exposure time = {exptime:.2f} s")
289 lines.extend(["", ""]) # section break
290 return '\n'.join([line for line in lines])
292 if section == 2:
293 lines.append("------- Rate stats ---------")
294 lines.append(f"Star max pixel = {self.starPeakFlux/exptime:,.0f} ADU/s")
295 lines.append(f"Spectrum contiuum = {self.continuumFlux98/exptime:,.1f} ADU/s")
296 lines.extend(["", ""]) # section break
297 return '\n'.join([line for line in lines])
299 if section == 3:
300 lines.append("----- Observation info -----")
301 lines.append(f"object = {obj}")
302 lines.append(f"filter = {filt}")
303 lines.append(f"grating = {grating}")
304 lines.append(f"rotpa = {rotangle:.1f}")
306 lines.append(f"az = {az:.1f}")
307 lines.append(f"el = {el:.1f}")
308 lines.append(f"airmass = {airmass:.3f}")
309 return '\n'.join([line for line in lines])
311 if section == -1: # special -1 for debug
312 lines.append("---------- Debug -----------")
313 lines.append(f"spectrum bbox: {self.spectrumbbox}")
314 lines.append(f"Good range = {self.goodSpectrumMinY},{self.goodSpectrumMaxY}")
315 return '\n'.join([line for line in lines])
317 return
319 def run(self):
320 self.qfmResult = self.qfmTask.run(self.exp)
321 self.intCentroidX = int(np.round(self.qfmResult.brightestObjCentroid)[0])
322 self.intCentroidY = int(np.round(self.qfmResult.brightestObjCentroid)[1])
323 self.starPeakFlux = self.exp.image.array[self.intCentroidY, self.intCentroidX]
325 self.spectrumbbox = self.processStarTask.calcSpectrumBBox(self.exp,
326 self.qfmResult.brightestObjCentroid,
327 200)
328 self.spectrumData = self.exp.image[self.spectrumbbox].array
330 self.ridgeLineLocations = np.argmax(self.spectrumData, axis=1)
331 self.ridgeLineValues = self.spectrumData[range(self.spectrumbbox.getHeight()),
332 self.ridgeLineLocations]
333 self.rowSums = np.sum(self.spectrumData, axis=1)
335 coords = self.calcGoodSpectrumSection()
336 self.goodSpectrumMinY = coords[0]
337 self.goodSpectrumMaxY = coords[1]
338 self.goodSlice = slice(coords[0], coords[1])
340 self.continuumFlux90 = np.percentile(self.ridgeLineValues, 90) # for emission stars
341 self.continuumFlux98 = np.percentile(self.ridgeLineValues, 98) # for most stars
343 self.fit()
344 self.plot()
346 return
348 @staticmethod
349 def getMedianAndBestFwhm(fwhmValues, minIndex, maxIndex):
350 with warnings.catch_warnings(): # to supress nan warnings, which are fine
351 warnings.simplefilter("ignore")
352 clippedValues = sigma_clip(fwhmValues[minIndex:maxIndex])
353 # cast back with asArray needed becase sigma_clip returns
354 # masked array which doesn't play nice with np.nan<med/percentile>
355 clippedValues = np.asarray(clippedValues)
356 medianFwhm = np.nanmedian(clippedValues)
357 bestFocusFwhm = np.nanpercentile(np.asarray(clippedValues), 2)
358 return medianFwhm, bestFocusFwhm
360 def getStableFwhmRegion(self, fwhmValues, amplitudes, smoothing=1, maxDifferential=4):
361 # smooth the fwhmValues values
362 # differentiate
363 # take the longest contiguous region of 1s
364 # check section corresponds to top 25% in ampl to exclude 2nd order
365 # if not, pick next longest run, etc
366 # walk out from ends of that list over bumps smaller than maxDiff
368 smoothFwhm = np.convolve(fwhmValues, np.ones(smoothing)/smoothing, mode='same')
369 diff = np.diff(smoothFwhm, append=smoothFwhm[-1])
371 indices = np.where(1-np.abs(diff) < 1)[0]
372 diffIndices = np.diff(indices)
374 # [list(g) for k, g in groupby('AAAABBBCCD')] -->[['A', 'A', 'A', 'A'],
375 # ... ['B', 'B', 'B'], ['C', 'C'], ['D']]
376 indexLists = [list(g) for k, g in groupby(diffIndices)]
377 listLengths = [len(lst) for lst in indexLists]
379 amplitudeThreshold = np.nanpercentile(amplitudes, 75)
380 sortedListLengths = sorted(listLengths)
382 for listLength in sortedListLengths[::-1]:
383 longestListLength = listLength
384 longestListIndex = listLengths.index(longestListLength)
385 longestListStartTruePosition = int(np.sum(listLengths[0:longestListIndex]))
386 longestListStartTruePosition += int(longestListLength/2) # we want the mid-run value
387 if amplitudes[longestListStartTruePosition] > amplitudeThreshold:
388 break
390 startOfLongList = np.sum(listLengths[0:longestListIndex])
391 endOfLongList = startOfLongList + longestListLength
393 endValue = endOfLongList
394 for lst in indexLists[longestListIndex+1:]:
395 value = lst[0]
396 if value > maxDifferential:
397 break
398 endValue += len(lst)
400 startValue = startOfLongList
401 for lst in indexLists[longestListIndex-1::-1]:
402 value = lst[0]
403 if value > maxDifferential:
404 break
405 startValue -= len(lst)
407 startValue = int(max(0, startValue))
408 endValue = int(min(len(fwhmValues), endValue))
410 if not self.debug:
411 return startValue, endValue
413 medianFwhm, bestFocusFwhm = self.getMedianAndBestFwhm(fwhmValues, startValue, endValue)
414 xlim = (-20, len(fwhmValues))
416 plt.figure(figsize=(10, 6))
417 plt.plot(fwhmValues)
418 plt.vlines(startValue, 0, 50, 'r')
419 plt.vlines(endValue, 0, 50, 'r')
420 plt.hlines(medianFwhm, xlim[0], xlim[1])
421 plt.hlines(bestFocusFwhm, xlim[0], xlim[1], 'r', ls='--')
423 plt.vlines(startOfLongList, 0, 50, 'g')
424 plt.vlines(endOfLongList, 0, 50, 'g')
426 plt.ylim(0, 200)
427 plt.xlim(xlim)
428 plt.show()
430 plt.figure(figsize=(10, 6))
431 plt.plot(diffIndices)
432 plt.vlines(startValue, 0, 50, 'r')
433 plt.vlines(endValue, 0, 50, 'r')
435 plt.vlines(startOfLongList, 0, 50, 'g')
436 plt.vlines(endOfLongList, 0, 50, 'g')
437 plt.ylim(0, 30)
438 plt.xlim(xlim)
439 plt.show()
440 return startValue, endValue