Coverage for bin/compare.py : 71%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/Users/square/j/ws/release/tarball/59a01628fc/build/conda/miniconda3-4.7.12/envs/lsst-scipipe-1a1d771/bin/python # noqa
2from __future__ import division
3from __future__ import print_function
5import argparse
6import os
7import sys
8import numpy as np
10# Data columns in the output
11# This needs to be kept in sync with the columns in export-results, until
12# this script and the expected results files are modified to run directly
13# on the binary table outputs.
14# Note that we load the flags as 'str', since that's what export-results
15# does, but that doesn't matter since we're just comparing for equality
16# for all non-float fields.
17DTYPE = np.dtype(
18 [("id", np.int64),
19 ("coord_ra", float),
20 ("coord_dec", float),
21 ("flags_negative", str, 5),
22 ("base_SdssCentroid_flag", str, 5),
23 ("base_PixelFlags_flag_edge", str, 5),
24 ("base_PixelFlags_flag_interpolated", str, 5),
25 ("base_PixelFlags_flag_interpolatedCenter", str, 5),
26 ("base_PixelFlags_flag_saturated", str, 5),
27 ("base_PixelFlags_flag_saturatedCenter", str, 5),
28 ("base_SdssCentroid_x", float),
29 ("base_SdssCentroid_y", float),
30 ("base_SdssCentroid_xErr", float),
31 ("base_SdssCentroid_yErr", float),
32 ("base_SdssShape_xx", float),
33 ("base_SdssShape_xy", float),
34 ("base_SdssShape_yy", float),
35 ("base_SdssShape_xxErr", float),
36 ("base_SdssShape_xyErr", float),
37 ("base_SdssShape_yyErr", float),
38 ("base_SdssShape_flag", str, 5),
39 ("base_GaussianFlux_instFlux", float),
40 ("base_GaussianFlux_instFluxErr", float),
41 ("base_PsfFlux_instFlux", float),
42 ("base_PsfFlux_instFluxErr", float),
43 ("base_CircularApertureFlux_2_instFlux", float),
44 ("base_CircularApertureFlux_2_instFluxErr", float),
45 ("base_ClassificationExtendedness_value", float),
46 ])
49def get_array(filename):
50 with open(filename, 'r') as f:
51 array = np.loadtxt(f, dtype=DTYPE)
52 return array
55def difference(arr1, arr2):
56 """
57 Compute the relative and absolute differences of numpy arrays arr1 & arr2.
59 The relative difference R between numbers n1 and n2 is defined as per
60 numdiff (http://www.nongnu.org/numdiff/numdiff.html):
61 * R = 0 if n1 and n2 are equal,
62 * R = Inf if n2 differs from n1 and at least one of them is zero,
63 * R = A/ min(|n1|, |n2|) if n1 and n2 are both non zero and n2 differs from n1.
64 """
65 absDiff = np.abs(arr1 - arr2)
67 # If there is a difference between 0 and something else, the result is
68 # infinite.
69 absDiff = np.where((absDiff != 0) & ((arr1 == 0) | (arr2 == 0)), np.inf, absDiff)
71 # If both inputs are nan, the result is 0.
72 absDiff = np.where(np.isnan(arr1) & np.isnan(arr2), 0, absDiff)
74 # If one input is nan, the result is infinite.
75 absDiff = np.where(np.logical_xor(np.isnan(arr1), np.isnan(arr2)), np.inf, absDiff)
77 # Divide by the minimum of the inputs, unless 0 or nan.
78 # If the minimum is 0 or nan, then either both inputs are 0/nan (so there's no
79 # difference) or one is 0/nan (in which case the absolute difference is
80 # already inf).
81 divisor = np.where(np.minimum(arr1, arr2) == 0, 1, np.minimum(arr1, arr2))
82 divisor = np.where(np.isnan(divisor), 1, divisor)
84 return absDiff, absDiff/np.abs(divisor)
87def compareWithNumPy(filename, reference, tolerance):
88 """
89 Compare a generated data file to a reference using NumPy.
91 The comparison is successful if:
92 * The files contain the same data columns (both by name and by data type);
93 * Each numeric value is in the input is either:
94 a) Within ``tolerance`` of the corresponding value in the reference, or
95 b) The relative difference with the reference (defined as above) is
96 within ``tolerance``;
97 * Flags recorded in the input and the reference are identical.
99 @param filename Path to input data file.
100 @param reference Path to reference file.
101 @param tolerance Tolerance.
102 """
103 with open(filename, 'r') as data, open(reference, 'r') as ref:
104 data_columns = data.readline().strip('#').split()
105 ref_columns = ref.readline().strip('#').split()
106 table1, table2 = get_array(filename), get_array(reference)
107 if (table1.dtype != table2.dtype) or (data_columns != ref_columns): 107 ↛ 108line 107 didn't jump to line 108, because the condition on line 107 was never true
108 print("Files do not contain the same columns.")
109 return False
110 valid = True
111 for name in table1.dtype.names:
112 dtype, count = table1.dtype.fields[name]
113 if dtype == np.dtype(float):
114 absDiff, relDiff = difference(table1[name], table2[name])
115 for pos in np.where((relDiff > tolerance) & (absDiff > tolerance))[0]: 115 ↛ 116line 115 didn't jump to line 116, because the loop on line 115 never started
116 valid = False
117 print("Failed (absolute difference %g, relative difference %g over tolerance %g) "
118 "in column %s." % (absDiff[pos], relDiff[pos], tolerance, name))
119 else:
120 if not np.all(table1[name] == table2[name]): 120 ↛ 121line 120 didn't jump to line 121, because the condition on line 120 was never true
121 nTotal = len(table1[name])
122 nDiff = len(np.where(table1[name] != table2[name])[0])
123 print("Failed (%s of %s flags do not match) in column %s." % (str(nDiff), str(nTotal), name))
124 valid = False
125 return valid
128def determineFlavor():
129 """
130 Return a string representing the 'flavor' of the local system.
132 Based on the equivalent logic in EUPS, but without introducing an EUPS
133 dependency.
134 """
135 uname, machine = os.uname()[0:5:4]
136 if uname == "Linux": 136 ↛ 137line 136 didn't jump to line 137, because the condition on line 136 was never true
137 if machine[-2:] == "64":
138 return "Linux64"
139 else:
140 return "Linux"
141 elif uname == "Darwin": 141 ↛ 147line 141 didn't jump to line 147, because the condition on line 141 was never false
142 if machine in ("x86_64", "i686"): 142 ↛ 145line 142 didn't jump to line 145, because the condition on line 142 was never false
143 return "DarwinX86"
144 else:
145 return "Darwin"
146 else:
147 raise RuntimeError("Unknown flavor: (%s, %s)" % (uname, machine))
150def extantFile(filename):
151 """
152 Raise ArgumentTypeError if ``filename`` does not exist.
153 """
154 if not os.path.isfile(filename): 154 ↛ 155line 154 didn't jump to line 155, because the condition on line 154 was never true
155 raise argparse.ArgumentTypeError(filename + " is not a file.")
156 return filename
159def referenceFilename(checkFilename):
160 """
161 Attempt to guess the filename to compare our input against.
162 """
163 guess = os.path.join(os.path.split(os.path.dirname(__file__))[0], "expected",
164 determineFlavor(), os.path.basename(checkFilename))
165 if os.path.isfile(guess): 165 ↛ 168line 165 didn't jump to line 168, because the condition on line 165 was never false
166 return guess
167 else:
168 raise ValueError("Cannot find reference data (looked for %s)." % (guess,))
171if __name__ == "__main__": 171 ↛ exitline 171 didn't exit the module, because the condition on line 171 was never false
172 parser = argparse.ArgumentParser()
173 parser.add_argument('filename', type=extantFile, help="Input data file.")
174 parser.add_argument('--tolerance', default=1e-10, type=float, help="Tolerance for errors. "
175 "The test will fail if both the relative and absolute errors exceed the tolerance.")
176 parser.add_argument('--reference', type=extantFile, help="Reference data for comparison.")
177 args = parser.parse_args()
179 if not args.reference: 179 ↛ 182line 179 didn't jump to line 182, because the condition on line 179 was never false
180 args.reference = referenceFilename(args.filename)
182 if compareWithNumPy(args.filename, args.reference, args.tolerance): 182 ↛ 185line 182 didn't jump to line 185, because the condition on line 182 was never false
183 print("Ok.")
184 else:
185 sys.exit(1)