Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1#!/Users/square/j/ws/release/tarball/59a01628fc/build/conda/miniconda3-4.7.12/envs/lsst-scipipe-1a1d771/bin/python # noqa 

2from __future__ import division 

3from __future__ import print_function 

4 

5import argparse 

6import os 

7import sys 

8import numpy as np 

9 

10# Data columns in the output 

11# This needs to be kept in sync with the columns in export-results, until 

12# this script and the expected results files are modified to run directly 

13# on the binary table outputs. 

14# Note that we load the flags as 'str', since that's what export-results 

15# does, but that doesn't matter since we're just comparing for equality 

16# for all non-float fields. 

17DTYPE = np.dtype( 

18 [("id", np.int64), 

19 ("coord_ra", float), 

20 ("coord_dec", float), 

21 ("flags_negative", str, 5), 

22 ("base_SdssCentroid_flag", str, 5), 

23 ("base_PixelFlags_flag_edge", str, 5), 

24 ("base_PixelFlags_flag_interpolated", str, 5), 

25 ("base_PixelFlags_flag_interpolatedCenter", str, 5), 

26 ("base_PixelFlags_flag_saturated", str, 5), 

27 ("base_PixelFlags_flag_saturatedCenter", str, 5), 

28 ("base_SdssCentroid_x", float), 

29 ("base_SdssCentroid_y", float), 

30 ("base_SdssCentroid_xErr", float), 

31 ("base_SdssCentroid_yErr", float), 

32 ("base_SdssShape_xx", float), 

33 ("base_SdssShape_xy", float), 

34 ("base_SdssShape_yy", float), 

35 ("base_SdssShape_xxErr", float), 

36 ("base_SdssShape_xyErr", float), 

37 ("base_SdssShape_yyErr", float), 

38 ("base_SdssShape_flag", str, 5), 

39 ("base_GaussianFlux_instFlux", float), 

40 ("base_GaussianFlux_instFluxErr", float), 

41 ("base_PsfFlux_instFlux", float), 

42 ("base_PsfFlux_instFluxErr", float), 

43 ("base_CircularApertureFlux_2_instFlux", float), 

44 ("base_CircularApertureFlux_2_instFluxErr", float), 

45 ("base_ClassificationExtendedness_value", float), 

46 ]) 

47 

48 

49def get_array(filename): 

50 with open(filename, 'r') as f: 

51 array = np.loadtxt(f, dtype=DTYPE) 

52 return array 

53 

54 

55def difference(arr1, arr2): 

56 """ 

57 Compute the relative and absolute differences of numpy arrays arr1 & arr2. 

58 

59 The relative difference R between numbers n1 and n2 is defined as per 

60 numdiff (http://www.nongnu.org/numdiff/numdiff.html): 

61 * R = 0 if n1 and n2 are equal, 

62 * R = Inf if n2 differs from n1 and at least one of them is zero, 

63 * R = A/ min(|n1|, |n2|) if n1 and n2 are both non zero and n2 differs from n1. 

64 """ 

65 absDiff = np.abs(arr1 - arr2) 

66 

67 # If there is a difference between 0 and something else, the result is 

68 # infinite. 

69 absDiff = np.where((absDiff != 0) & ((arr1 == 0) | (arr2 == 0)), np.inf, absDiff) 

70 

71 # If both inputs are nan, the result is 0. 

72 absDiff = np.where(np.isnan(arr1) & np.isnan(arr2), 0, absDiff) 

73 

74 # If one input is nan, the result is infinite. 

75 absDiff = np.where(np.logical_xor(np.isnan(arr1), np.isnan(arr2)), np.inf, absDiff) 

76 

77 # Divide by the minimum of the inputs, unless 0 or nan. 

78 # If the minimum is 0 or nan, then either both inputs are 0/nan (so there's no 

79 # difference) or one is 0/nan (in which case the absolute difference is 

80 # already inf). 

81 divisor = np.where(np.minimum(arr1, arr2) == 0, 1, np.minimum(arr1, arr2)) 

82 divisor = np.where(np.isnan(divisor), 1, divisor) 

83 

84 return absDiff, absDiff/np.abs(divisor) 

85 

86 

87def compareWithNumPy(filename, reference, tolerance): 

88 """ 

89 Compare a generated data file to a reference using NumPy. 

90 

91 The comparison is successful if: 

92 * The files contain the same data columns (both by name and by data type); 

93 * Each numeric value is in the input is either: 

94 a) Within ``tolerance`` of the corresponding value in the reference, or 

95 b) The relative difference with the reference (defined as above) is 

96 within ``tolerance``; 

97 * Flags recorded in the input and the reference are identical. 

98 

99 @param filename Path to input data file. 

100 @param reference Path to reference file. 

101 @param tolerance Tolerance. 

102 """ 

103 with open(filename, 'r') as data, open(reference, 'r') as ref: 

104 data_columns = data.readline().strip('#').split() 

105 ref_columns = ref.readline().strip('#').split() 

106 table1, table2 = get_array(filename), get_array(reference) 

107 if (table1.dtype != table2.dtype) or (data_columns != ref_columns): 107 ↛ 108line 107 didn't jump to line 108, because the condition on line 107 was never true

108 print("Files do not contain the same columns.") 

109 return False 

110 valid = True 

111 for name in table1.dtype.names: 

112 dtype, count = table1.dtype.fields[name] 

113 if dtype == np.dtype(float): 

114 absDiff, relDiff = difference(table1[name], table2[name]) 

115 for pos in np.where((relDiff > tolerance) & (absDiff > tolerance))[0]: 115 ↛ 116line 115 didn't jump to line 116, because the loop on line 115 never started

116 valid = False 

117 print("Failed (absolute difference %g, relative difference %g over tolerance %g) " 

118 "in column %s." % (absDiff[pos], relDiff[pos], tolerance, name)) 

119 else: 

120 if not np.all(table1[name] == table2[name]): 120 ↛ 121line 120 didn't jump to line 121, because the condition on line 120 was never true

121 nTotal = len(table1[name]) 

122 nDiff = len(np.where(table1[name] != table2[name])[0]) 

123 print("Failed (%s of %s flags do not match) in column %s." % (str(nDiff), str(nTotal), name)) 

124 valid = False 

125 return valid 

126 

127 

128def determineFlavor(): 

129 """ 

130 Return a string representing the 'flavor' of the local system. 

131 

132 Based on the equivalent logic in EUPS, but without introducing an EUPS 

133 dependency. 

134 """ 

135 uname, machine = os.uname()[0:5:4] 

136 if uname == "Linux": 136 ↛ 137line 136 didn't jump to line 137, because the condition on line 136 was never true

137 if machine[-2:] == "64": 

138 return "Linux64" 

139 else: 

140 return "Linux" 

141 elif uname == "Darwin": 141 ↛ 147line 141 didn't jump to line 147, because the condition on line 141 was never false

142 if machine in ("x86_64", "i686"): 142 ↛ 145line 142 didn't jump to line 145, because the condition on line 142 was never false

143 return "DarwinX86" 

144 else: 

145 return "Darwin" 

146 else: 

147 raise RuntimeError("Unknown flavor: (%s, %s)" % (uname, machine)) 

148 

149 

150def extantFile(filename): 

151 """ 

152 Raise ArgumentTypeError if ``filename`` does not exist. 

153 """ 

154 if not os.path.isfile(filename): 154 ↛ 155line 154 didn't jump to line 155, because the condition on line 154 was never true

155 raise argparse.ArgumentTypeError(filename + " is not a file.") 

156 return filename 

157 

158 

159def referenceFilename(checkFilename): 

160 """ 

161 Attempt to guess the filename to compare our input against. 

162 """ 

163 guess = os.path.join(os.path.split(os.path.dirname(__file__))[0], "expected", 

164 determineFlavor(), os.path.basename(checkFilename)) 

165 if os.path.isfile(guess): 165 ↛ 168line 165 didn't jump to line 168, because the condition on line 165 was never false

166 return guess 

167 else: 

168 raise ValueError("Cannot find reference data (looked for %s)." % (guess,)) 

169 

170 

171if __name__ == "__main__": 171 ↛ exitline 171 didn't exit the module, because the condition on line 171 was never false

172 parser = argparse.ArgumentParser() 

173 parser.add_argument('filename', type=extantFile, help="Input data file.") 

174 parser.add_argument('--tolerance', default=1e-10, type=float, help="Tolerance for errors. " 

175 "The test will fail if both the relative and absolute errors exceed the tolerance.") 

176 parser.add_argument('--reference', type=extantFile, help="Reference data for comparison.") 

177 args = parser.parse_args() 

178 

179 if not args.reference: 179 ↛ 182line 179 didn't jump to line 182, because the condition on line 179 was never false

180 args.reference = referenceFilename(args.filename) 

181 

182 if compareWithNumPy(args.filename, args.reference, args.tolerance): 182 ↛ 185line 182 didn't jump to line 185, because the condition on line 182 was never false

183 print("Ok.") 

184 else: 

185 sys.exit(1)