Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from builtins import zip 

2import numpy as np 

3from .baseMetric import BaseMetric 

4 

5__all__ = ['TransientMetric'] 

6 

7class TransientMetric(BaseMetric): 

8 """ 

9 Calculate what fraction of the transients would be detected. Best paired with a spatial slicer. 

10 We are assuming simple light curves with no color evolution. 

11 

12 Parameters 

13 ---------- 

14 transDuration : float, optional 

15 How long the transient lasts (days). Default 10. 

16 peakTime : float, optional 

17 How long it takes to reach the peak magnitude (days). Default 5. 

18 riseSlope : float, optional 

19 Slope of the light curve before peak time (mags/day). 

20 This should be negative since mags are backwards (magnitudes decrease towards brighter fluxes). 

21 Default 0. 

22 declineSlope : float, optional 

23 Slope of the light curve after peak time (mags/day). 

24 This should be positive since mags are backwards. Default 0. 

25 uPeak : float, optional 

26 Peak magnitude in u band. Default 20. 

27 gPeak : float, optional 

28 Peak magnitude in g band. Default 20. 

29 rPeak : float, optional 

30 Peak magnitude in r band. Default 20. 

31 iPeak : float, optional 

32 Peak magnitude in i band. Default 20. 

33 zPeak : float, optional 

34 Peak magnitude in z band. Default 20. 

35 yPeak : float, optional 

36 Peak magnitude in y band. Default 20. 

37 surveyDuration : float, optional 

38 Length of survey (years). 

39 Default 10. 

40 surveyStart : float, optional 

41 MJD for the survey start date. 

42 Default None (uses the time of the first observation). 

43 detectM5Plus : float, optional 

44 An observation will be used if the light curve magnitude is brighter than m5+detectM5Plus. 

45 Default 0. 

46 nPrePeak : int, optional 

47 Number of observations (in any filter(s)) to demand before peakTime, 

48 before saying a transient has been detected. 

49 Default 0. 

50 nPerLC : int, optional 

51 Number of sections of the light curve that must be sampled above the detectM5Plus theshold 

52 (in a single filter) for the light curve to be counted. 

53 For example, setting nPerLC = 2 means a light curve is only considered detected if there 

54 is at least 1 observation in the first half of the LC, and at least one in the second half of the LC. 

55 nPerLC = 4 means each quarter of the light curve must be detected to count. 

56 Default 1. 

57 nFilters : int, optional 

58 Number of filters that need to be observed for an object to be counted as detected. 

59 Default 1. 

60 nPhaseCheck : int, optional 

61 Sets the number of phases that should be checked. 

62 One can imagine pathological cadences where many objects pass the detection criteria, 

63 but would not if the observations were offset by a phase-shift. 

64 Default 1. 

65 countMethod : {'full' 'partialLC'}, defaults to 'full' 

66 Sets the method of counting max number of transients. if 'full', the 

67 only full light curves that fit the survey duration are counted. If 

68 'partialLC', then the max number of possible transients is taken to be 

69 the integer floor 

70 """ 

71 def __init__(self, metricName='TransientDetectMetric', mjdCol='observationStartMJD', 

72 m5Col='fiveSigmaDepth', filterCol='filter', 

73 transDuration=10., peakTime=5., riseSlope=0., declineSlope=0., 

74 surveyDuration=10., surveyStart=None, detectM5Plus=0., 

75 uPeak=20, gPeak=20, rPeak=20, iPeak=20, zPeak=20, yPeak=20, 

76 nPrePeak=0, nPerLC=1, nFilters=1, nPhaseCheck=1, countMethod='full', 

77 **kwargs): 

78 self.mjdCol = mjdCol 

79 self.m5Col = m5Col 

80 self.filterCol = filterCol 

81 super(TransientMetric, self).__init__(col=[self.mjdCol, self.m5Col, self.filterCol], 

82 units='Fraction Detected', 

83 metricName=metricName, **kwargs) 

84 self.peaks = {'u': uPeak, 'g': gPeak, 'r': rPeak, 'i': iPeak, 'z': zPeak, 'y': yPeak} 

85 self.transDuration = transDuration 

86 self.peakTime = peakTime 

87 self.riseSlope = riseSlope 

88 self.declineSlope = declineSlope 

89 self.surveyDuration = surveyDuration 

90 self.surveyStart = surveyStart 

91 self.detectM5Plus = detectM5Plus 

92 self.nPrePeak = nPrePeak 

93 self.nPerLC = nPerLC 

94 self.nFilters = nFilters 

95 self.nPhaseCheck = nPhaseCheck 

96 self.countMethod = countMethod 

97 

98 def lightCurve(self, time, filters): 

99 """ 

100 Calculate the magnitude of the object at each time, in each filter. 

101 

102 Parameters 

103 ---------- 

104 time : numpy.ndarray 

105 The times of the observations. 

106 filters : numpy.ndarray 

107 The filters of the observations. 

108 

109 Returns 

110 ------- 

111 numpy.ndarray 

112 The magnitudes of the object at each time, in each filter. 

113 """ 

114 lcMags = np.zeros(time.size, dtype=float) 

115 rise = np.where(time <= self.peakTime) 

116 lcMags[rise] += self.riseSlope * time[rise] - self.riseSlope * self.peakTime 

117 decline = np.where(time > self.peakTime) 

118 lcMags[decline] += self.declineSlope * (time[decline] - self.peakTime) 

119 for key in self.peaks: 

120 fMatch = np.where(filters == key) 

121 lcMags[fMatch] += self.peaks[key] 

122 return lcMags 

123 

124 def run(self, dataSlice, slicePoint=None): 

125 """" 

126 Calculate the detectability of a transient with the specified lightcurve. 

127 

128 Parameters 

129 ---------- 

130 dataSlice : numpy.array 

131 Numpy structured array containing the data related to the visits provided by the slicer. 

132 slicePoint : dict, optional 

133 Dictionary containing information about the slicepoint currently active in the slicer. 

134 

135 Returns 

136 ------- 

137 float 

138 The total number of transients that could be detected. 

139 """ 

140 # Total number of transients that could go off back-to-back 

141 if self.countMethod == 'partialLC': 

142 _nTransMax = np.ceil(self.surveyDuration / (self.transDuration / 365.25)) 

143 else: 

144 _nTransMax = np.floor(self.surveyDuration / (self.transDuration / 365.25)) 

145 tshifts = np.arange(self.nPhaseCheck) * self.transDuration / float(self.nPhaseCheck) 

146 nDetected = 0 

147 nTransMax = 0 

148 for tshift in tshifts: 

149 # Compute the total number of back-to-back transients are possible to detect 

150 # given the survey duration and the transient duration. 

151 nTransMax += _nTransMax 

152 if tshift != 0: 

153 nTransMax -= 1 

154 if self.surveyStart is None: 

155 surveyStart = dataSlice[self.mjdCol].min() 

156 time = (dataSlice[self.mjdCol] - surveyStart + tshift) % self.transDuration 

157 

158 # Which lightcurve does each point belong to 

159 lcNumber = np.floor((dataSlice[self.mjdCol] - surveyStart) / self.transDuration) 

160 

161 lcMags = self.lightCurve(time, dataSlice[self.filterCol]) 

162 

163 # How many criteria needs to be passed 

164 detectThresh = 0 

165 

166 # Flag points that are above the SNR limit 

167 detected = np.zeros(dataSlice.size, dtype=int) 

168 detected[np.where(lcMags < dataSlice[self.m5Col] + self.detectM5Plus)] += 1 

169 detectThresh += 1 

170 

171 # If we demand points on the rise 

172 if self.nPrePeak > 0: 

173 detectThresh += 1 

174 ord = np.argsort(dataSlice[self.mjdCol]) 

175 dataSlice = dataSlice[ord] 

176 detected = detected[ord] 

177 lcNumber = lcNumber[ord] 

178 time = time[ord] 

179 ulcNumber = np.unique(lcNumber) 

180 left = np.searchsorted(lcNumber, ulcNumber) 

181 right = np.searchsorted(lcNumber, ulcNumber, side='right') 

182 # Note here I'm using np.searchsorted to basically do a 'group by' 

183 # might be clearer to use scipy.ndimage.measurements.find_objects or pandas, but 

184 # this numpy function is known for being efficient. 

185 for le, ri in zip(left, right): 

186 # Number of points where there are a detection 

187 good = np.where(time[le:ri] < self.peakTime) 

188 nd = np.sum(detected[le:ri][good]) 

189 if nd >= self.nPrePeak: 

190 detected[le:ri] += 1 

191 

192 # Check if we need multiple points per light curve or multiple filters 

193 if (self.nPerLC > 1) | (self.nFilters > 1): 

194 # make sure things are sorted by time 

195 ord = np.argsort(dataSlice[self.mjdCol]) 

196 dataSlice = dataSlice[ord] 

197 detected = detected[ord] 

198 lcNumber = lcNumber[ord] 

199 time = time[ord] 

200 ulcNumber = np.unique(lcNumber) 

201 left = np.searchsorted(lcNumber, ulcNumber) 

202 right = np.searchsorted(lcNumber, ulcNumber, side='right') 

203 detectThresh += self.nFilters 

204 

205 for le, ri in zip(left, right): 

206 points = np.where(detected[le:ri] > 0) 

207 ufilters = np.unique(dataSlice[self.filterCol][le:ri][points]) 

208 phaseSections = np.floor(time[le:ri][points] / self.transDuration * self.nPerLC) 

209 for filtName in ufilters: 

210 good = np.where(dataSlice[self.filterCol][le:ri][points] == filtName) 

211 if np.size(np.unique(phaseSections[good])) >= self.nPerLC: 

212 detected[le:ri] += 1 

213 

214 # Find the unique number of light curves that passed the required number of conditions 

215 nDetected += np.size(np.unique(lcNumber[np.where(detected >= detectThresh)])) 

216 

217 # Rather than keeping a single "detected" variable, maybe make a mask for each criteria, then 

218 # reduce functions like: reduce_singleDetect, reduce_NDetect, reduce_PerLC, reduce_perFilter. 

219 # The way I'm running now it would speed things up. 

220 

221 return float(nDetected) / nTransMax