Coverage for python/lsst/sims/maf/metrics/transientMetrics.py : 7%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from builtins import zip
2import numpy as np
3from .baseMetric import BaseMetric
5__all__ = ['TransientMetric']
7class TransientMetric(BaseMetric):
8 """
9 Calculate what fraction of the transients would be detected. Best paired with a spatial slicer.
10 We are assuming simple light curves with no color evolution.
12 Parameters
13 ----------
14 transDuration : float, optional
15 How long the transient lasts (days). Default 10.
16 peakTime : float, optional
17 How long it takes to reach the peak magnitude (days). Default 5.
18 riseSlope : float, optional
19 Slope of the light curve before peak time (mags/day).
20 This should be negative since mags are backwards (magnitudes decrease towards brighter fluxes).
21 Default 0.
22 declineSlope : float, optional
23 Slope of the light curve after peak time (mags/day).
24 This should be positive since mags are backwards. Default 0.
25 uPeak : float, optional
26 Peak magnitude in u band. Default 20.
27 gPeak : float, optional
28 Peak magnitude in g band. Default 20.
29 rPeak : float, optional
30 Peak magnitude in r band. Default 20.
31 iPeak : float, optional
32 Peak magnitude in i band. Default 20.
33 zPeak : float, optional
34 Peak magnitude in z band. Default 20.
35 yPeak : float, optional
36 Peak magnitude in y band. Default 20.
37 surveyDuration : float, optional
38 Length of survey (years).
39 Default 10.
40 surveyStart : float, optional
41 MJD for the survey start date.
42 Default None (uses the time of the first observation).
43 detectM5Plus : float, optional
44 An observation will be used if the light curve magnitude is brighter than m5+detectM5Plus.
45 Default 0.
46 nPrePeak : int, optional
47 Number of observations (in any filter(s)) to demand before peakTime,
48 before saying a transient has been detected.
49 Default 0.
50 nPerLC : int, optional
51 Number of sections of the light curve that must be sampled above the detectM5Plus theshold
52 (in a single filter) for the light curve to be counted.
53 For example, setting nPerLC = 2 means a light curve is only considered detected if there
54 is at least 1 observation in the first half of the LC, and at least one in the second half of the LC.
55 nPerLC = 4 means each quarter of the light curve must be detected to count.
56 Default 1.
57 nFilters : int, optional
58 Number of filters that need to be observed for an object to be counted as detected.
59 Default 1.
60 nPhaseCheck : int, optional
61 Sets the number of phases that should be checked.
62 One can imagine pathological cadences where many objects pass the detection criteria,
63 but would not if the observations were offset by a phase-shift.
64 Default 1.
65 countMethod : {'full' 'partialLC'}, defaults to 'full'
66 Sets the method of counting max number of transients. if 'full', the
67 only full light curves that fit the survey duration are counted. If
68 'partialLC', then the max number of possible transients is taken to be
69 the integer floor
70 """
71 def __init__(self, metricName='TransientDetectMetric', mjdCol='observationStartMJD',
72 m5Col='fiveSigmaDepth', filterCol='filter',
73 transDuration=10., peakTime=5., riseSlope=0., declineSlope=0.,
74 surveyDuration=10., surveyStart=None, detectM5Plus=0.,
75 uPeak=20, gPeak=20, rPeak=20, iPeak=20, zPeak=20, yPeak=20,
76 nPrePeak=0, nPerLC=1, nFilters=1, nPhaseCheck=1, countMethod='full',
77 **kwargs):
78 self.mjdCol = mjdCol
79 self.m5Col = m5Col
80 self.filterCol = filterCol
81 super(TransientMetric, self).__init__(col=[self.mjdCol, self.m5Col, self.filterCol],
82 units='Fraction Detected',
83 metricName=metricName, **kwargs)
84 self.peaks = {'u': uPeak, 'g': gPeak, 'r': rPeak, 'i': iPeak, 'z': zPeak, 'y': yPeak}
85 self.transDuration = transDuration
86 self.peakTime = peakTime
87 self.riseSlope = riseSlope
88 self.declineSlope = declineSlope
89 self.surveyDuration = surveyDuration
90 self.surveyStart = surveyStart
91 self.detectM5Plus = detectM5Plus
92 self.nPrePeak = nPrePeak
93 self.nPerLC = nPerLC
94 self.nFilters = nFilters
95 self.nPhaseCheck = nPhaseCheck
96 self.countMethod = countMethod
98 def lightCurve(self, time, filters):
99 """
100 Calculate the magnitude of the object at each time, in each filter.
102 Parameters
103 ----------
104 time : numpy.ndarray
105 The times of the observations.
106 filters : numpy.ndarray
107 The filters of the observations.
109 Returns
110 -------
111 numpy.ndarray
112 The magnitudes of the object at each time, in each filter.
113 """
114 lcMags = np.zeros(time.size, dtype=float)
115 rise = np.where(time <= self.peakTime)
116 lcMags[rise] += self.riseSlope * time[rise] - self.riseSlope * self.peakTime
117 decline = np.where(time > self.peakTime)
118 lcMags[decline] += self.declineSlope * (time[decline] - self.peakTime)
119 for key in self.peaks:
120 fMatch = np.where(filters == key)
121 lcMags[fMatch] += self.peaks[key]
122 return lcMags
124 def run(self, dataSlice, slicePoint=None):
125 """"
126 Calculate the detectability of a transient with the specified lightcurve.
128 Parameters
129 ----------
130 dataSlice : numpy.array
131 Numpy structured array containing the data related to the visits provided by the slicer.
132 slicePoint : dict, optional
133 Dictionary containing information about the slicepoint currently active in the slicer.
135 Returns
136 -------
137 float
138 The total number of transients that could be detected.
139 """
140 # Total number of transients that could go off back-to-back
141 if self.countMethod == 'partialLC':
142 _nTransMax = np.ceil(self.surveyDuration / (self.transDuration / 365.25))
143 else:
144 _nTransMax = np.floor(self.surveyDuration / (self.transDuration / 365.25))
145 tshifts = np.arange(self.nPhaseCheck) * self.transDuration / float(self.nPhaseCheck)
146 nDetected = 0
147 nTransMax = 0
148 for tshift in tshifts:
149 # Compute the total number of back-to-back transients are possible to detect
150 # given the survey duration and the transient duration.
151 nTransMax += _nTransMax
152 if tshift != 0:
153 nTransMax -= 1
154 if self.surveyStart is None:
155 surveyStart = dataSlice[self.mjdCol].min()
156 time = (dataSlice[self.mjdCol] - surveyStart + tshift) % self.transDuration
158 # Which lightcurve does each point belong to
159 lcNumber = np.floor((dataSlice[self.mjdCol] - surveyStart) / self.transDuration)
161 lcMags = self.lightCurve(time, dataSlice[self.filterCol])
163 # How many criteria needs to be passed
164 detectThresh = 0
166 # Flag points that are above the SNR limit
167 detected = np.zeros(dataSlice.size, dtype=int)
168 detected[np.where(lcMags < dataSlice[self.m5Col] + self.detectM5Plus)] += 1
169 detectThresh += 1
171 # If we demand points on the rise
172 if self.nPrePeak > 0:
173 detectThresh += 1
174 ord = np.argsort(dataSlice[self.mjdCol])
175 dataSlice = dataSlice[ord]
176 detected = detected[ord]
177 lcNumber = lcNumber[ord]
178 time = time[ord]
179 ulcNumber = np.unique(lcNumber)
180 left = np.searchsorted(lcNumber, ulcNumber)
181 right = np.searchsorted(lcNumber, ulcNumber, side='right')
182 # Note here I'm using np.searchsorted to basically do a 'group by'
183 # might be clearer to use scipy.ndimage.measurements.find_objects or pandas, but
184 # this numpy function is known for being efficient.
185 for le, ri in zip(left, right):
186 # Number of points where there are a detection
187 good = np.where(time[le:ri] < self.peakTime)
188 nd = np.sum(detected[le:ri][good])
189 if nd >= self.nPrePeak:
190 detected[le:ri] += 1
192 # Check if we need multiple points per light curve or multiple filters
193 if (self.nPerLC > 1) | (self.nFilters > 1):
194 # make sure things are sorted by time
195 ord = np.argsort(dataSlice[self.mjdCol])
196 dataSlice = dataSlice[ord]
197 detected = detected[ord]
198 lcNumber = lcNumber[ord]
199 time = time[ord]
200 ulcNumber = np.unique(lcNumber)
201 left = np.searchsorted(lcNumber, ulcNumber)
202 right = np.searchsorted(lcNumber, ulcNumber, side='right')
203 detectThresh += self.nFilters
205 for le, ri in zip(left, right):
206 points = np.where(detected[le:ri] > 0)
207 ufilters = np.unique(dataSlice[self.filterCol][le:ri][points])
208 phaseSections = np.floor(time[le:ri][points] / self.transDuration * self.nPerLC)
209 for filtName in ufilters:
210 good = np.where(dataSlice[self.filterCol][le:ri][points] == filtName)
211 if np.size(np.unique(phaseSections[good])) >= self.nPerLC:
212 detected[le:ri] += 1
214 # Find the unique number of light curves that passed the required number of conditions
215 nDetected += np.size(np.unique(lcNumber[np.where(detected >= detectThresh)]))
217 # Rather than keeping a single "detected" variable, maybe make a mask for each criteria, then
218 # reduce functions like: reduce_singleDetect, reduce_NDetect, reduce_PerLC, reduce_perFilter.
219 # The way I'm running now it would speed things up.
221 return float(nDetected) / nTransMax