Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from __future__ import print_function 

2from builtins import map 

3from builtins import object 

4# Base class for metrics - defines methods which must be implemented. 

5# If a metric calculates a vector or list at each gridpoint, then there 

6# should be additional 'reduce_*' functions defined, to convert the vector 

7# into scalar (and thus plottable) values at each gridpoint. 

8# The philosophy behind keeping the vector instead of the scalar at each gridpoint 

9# is that these vectors may be expensive to compute; by keeping/writing the full 

10# vector we permit multiple 'reduce' functions to be executed on the same data. 

11 

12import numpy as np 

13import inspect 

14from lsst.sims.maf.stackers.getColInfo import ColInfo 

15from future.utils import with_metaclass 

16import warnings 

17 

18__all__ = ['MetricRegistry', 'BaseMetric'] 

19 

20 

21class MetricRegistry(type): 

22 """ 

23 Meta class for metrics, to build a registry of metric classes. 

24 """ 

25 def __init__(cls, name, bases, dict): 

26 super(MetricRegistry, cls).__init__(name, bases, dict) 

27 if not hasattr(cls, 'registry'): 

28 cls.registry = {} 

29 modname = inspect.getmodule(cls).__name__ 

30 if modname.startswith('lsst.sims.maf.metrics'): 30 ↛ 33line 30 didn't jump to line 33, because the condition on line 30 was never false

31 modname = '' 

32 else: 

33 if len(modname.split('.')) > 1: 

34 modname = '.'.join(modname.split('.')[:-1]) + '.' 

35 else: 

36 modname = modname + '.' 

37 metricname = modname + name 

38 if metricname in cls.registry: 38 ↛ 39line 38 didn't jump to line 39, because the condition on line 38 was never true

39 warnings.warn('Redefining metric %s! (there are >1 metrics with the same name)' % (metricname)) 

40 if metricname not in ['BaseMetric', 'SimpleScalarMetric']: 

41 cls.registry[metricname] = cls 

42 

43 def getClass(cls, metricname): 

44 return cls.registry[metricname] 

45 

46 def help(cls, doc=False): 

47 for metricname in sorted(cls.registry): 

48 if not doc: 

49 print(metricname) 

50 if doc: 

51 print('---- ', metricname, ' ----') 

52 print(inspect.getdoc(cls.registry[metricname])) 

53 

54 def help_metric(cls, metricname): 

55 print(metricname) 

56 print(inspect.getdoc(cls.registry[metricname])) 

57 k = inspect.signature(cls.registry[metricname]) 

58 print(' Metric __init__ keyword args and defaults: ') 

59 print(k) 

60 

61 

62class ColRegistry(object): 

63 """ 

64 ColRegistry tracks the columns needed for all metric objects (kept internally in a set). 

65 

66 ColRegistry.colSet : a set of all unique columns required for metrics. 

67 ColRegistry.dbCols : the subset of these which come from the database. 

68 ColRegistry.stackerCols : the dictionary of [columns: stacker class]. 

69 """ 

70 colInfo = ColInfo() 

71 

72 def __init__(self): 

73 self.colSet = set() 

74 self.dbSet = set() 

75 self.stackerDict = {} 

76 

77 def addCols(self, colArray): 

78 """Add the columns in ColArray into the ColRegistry. 

79 

80 Add the columns in colArray into the ColRegistry set (self.colSet) and identifies their source, 

81 using ColInfo (lsst.sims.maf.stackers.getColInfo). 

82 

83 Parameters 

84 ---------- 

85 colArray : list 

86 list of columns used in a metric. 

87 """ 

88 for col in colArray: 

89 if col is not None: 

90 self.colSet.add(col) 

91 source = self.colInfo.getDataSource(col) 

92 if source == self.colInfo.defaultDataSource: 

93 self.dbSet.add(col) 

94 else: 

95 if col not in self.stackerDict: 

96 self.stackerDict[col] = source 

97 

98 

99class BaseMetric(with_metaclass(MetricRegistry, object)): 

100 """ 

101 Base class for the metrics. 

102 Sets up some basic functionality for the MAF framework: after __init__ every metric will 

103 record the columns (and stackers) it requires into the column registry, and the metricName, 

104 metricDtype, and units for the metric will be set. 

105 

106 Parameters 

107 ---------- 

108 col : str or list 

109 Names of the data columns that the metric will use. 

110 The columns required for each metric is tracked in the ColRegistry, and used to retrieve data 

111 from the opsim database. Can be a single string or a list. 

112 metricName : str 

113 Name to use for the metric (optional - if not set, will be derived). 

114 maps : list of lsst.sims.maf.maps objects 

115 The maps that the metric will need (passed from the slicer). 

116 units : str 

117 The units for the value returned by the metric (optional - if not set, 

118 will be derived from the ColInfo). 

119 metricDtype : str 

120 The type of value returned by the metric - 'int', 'float', 'object'. 

121 If not set, will be derived by introspection. 

122 badval : float 

123 The value indicating "bad" values calculated by the metric. 

124 """ 

125 colRegistry = ColRegistry() 

126 colInfo = ColInfo() 

127 

128 def __init__(self, col=None, metricName=None, maps=None, units=None, 

129 metricDtype=None, badval=-666, maskVal=None): 

130 # Turn cols into numpy array so we know we can iterate over the columns. 

131 self.colNameArr = np.array(col, copy=False, ndmin=1) 

132 # To support simple metrics operating on a single column, set self.colname 

133 if len(self.colNameArr) == 1: 

134 self.colname = self.colNameArr[0] 

135 # Add the columns to the colRegistry. 

136 self.colRegistry.addCols(self.colNameArr) 

137 # Set the maps that are needed: 

138 if maps is None: 

139 maps = [] 

140 self.maps = maps 

141 # Value to return if the metric can't be computed 

142 self.badval = badval 

143 if maskVal is not None: 

144 self.maskVal = maskVal 

145 # Save a unique name for the metric. 

146 self.name = metricName 

147 if self.name is None: 

148 # If none provided, construct our own from the class name and the data columns. 

149 self.name = (self.__class__.__name__.replace('Metric', '', 1) + ' ' + 

150 ', '.join(map(str, self.colNameArr))) 

151 # Set up dictionary of reduce functions (may be empty). 

152 self.reduceFuncs = {} 

153 self.reduceOrder = {} 

154 for i, r in enumerate(inspect.getmembers(self, predicate=inspect.ismethod)): 

155 if r[0].startswith('reduce'): 

156 reducename = r[0].replace('reduce', '', 1) 

157 self.reduceFuncs[reducename] = r[1] 

158 self.reduceOrder[reducename] = i 

159 # Identify type of metric return value. 

160 if metricDtype is not None: 

161 self.metricDtype = metricDtype 

162 elif len(list(self.reduceFuncs.keys())) > 0: 

163 self.metricDtype = 'object' 

164 else: 

165 self.metricDtype = 'float' 

166 # Set physical units, for plotting purposes. 

167 if units is None: 

168 units = ' '.join([self.colInfo.getUnits(colName) for colName in self.colNameArr]) 

169 if len(units.replace(' ', '')) == 0: 

170 units = '' 

171 self.units = units 

172 # Add the ability to set a comment 

173 # (that could be propagated automatically to a benchmark's display caption). 

174 self.comment = None 

175 

176 # Default to only return one metric value per slice 

177 self.shape = 1 

178 

179 def run(self, dataSlice, slicePoint=None): 

180 """Calculate metric values. 

181 

182 Parameters 

183 ---------- 

184 dataSlice : numpy.NDarray 

185 Values passed to metric by the slicer, which the metric will use to calculate 

186 metric values at each slicePoint. 

187 slicePoint : Dict 

188 Dictionary of slicePoint metadata passed to each metric. 

189 E.g. the ra/dec of the healpix pixel or opsim fieldId. 

190 

191 Returns 

192 ------- 

193 int, float or object 

194 The metric value at each slicePoint. 

195 """ 

196 raise NotImplementedError('Please implement your metric calculation.')