Coverage for tests/test_unitNormMap.py: 17%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import unittest
3import numpy as np
4from numpy.testing import assert_allclose, assert_equal
6import astshim as ast
7from astshim.test import MappingTestCase
10class TestUnitNormMap(MappingTestCase):
12 def test_UnitNormMapBasics(self):
13 """Test basics of UnitNormMap including applyForward
14 """
15 # `full_` variables contain data for 3 axes; the variables without
16 # the `full_` prefix are a subset containing the number of axes
17 # being tested.
18 full_center = np.array([-1, 1, 2], dtype=float)
19 full_indata = np.array([
20 [full_center[0], 1.0, 2.0, -6.0, 30.0, 1.0],
21 [full_center[1], 3.0, 99.0, -5.0, 21.0, 0.0],
22 [full_center[2], -5.0, 3.0, -7.0, 37.0, 0.0],
23 ], dtype=float)
24 for nin in (1, 2, 3):
25 center = full_center[0:nin]
26 indata = full_indata[0:nin]
27 unitnormmap = ast.UnitNormMap(center)
28 self.assertEqual(unitnormmap.className, "UnitNormMap")
29 self.assertEqual(unitnormmap.nIn, nin)
30 self.assertEqual(unitnormmap.nOut, nin + 1)
31 self.assertFalse(unitnormmap.isLinear)
33 self.checkBasicSimplify(unitnormmap)
34 self.checkCopy(unitnormmap)
36 self.checkRoundTrip(unitnormmap, indata)
37 self.checkMappingPersistence(unitnormmap, indata)
39 outdata = unitnormmap.applyForward(indata)
40 norm = outdata[-1]
42 # the first input point is at the center, so the expected output is
43 # [Nan, Nan, ..., Nan, 0]
44 pred_out_at_center = [np.nan]*nin + [0]
45 assert_equal(outdata[:, 0], pred_out_at_center)
47 relative_indata = (indata.T - center).T
48 pred_norm = np.linalg.norm(relative_indata, axis=0)
49 assert_allclose(norm, pred_norm)
51 pred_relative_indata = outdata[0:nin] * norm
52 # the first input point is at the center, so the output is
53 # [NaN, NaN, ..., NaN, 0], (as checked above),
54 # but the expected value after scaling by the norm is 0s, so...
55 pred_relative_indata[:, 0] = [0]*nin
56 assert_allclose(relative_indata, pred_relative_indata)
58 # UnitNormMap must have at least one input
59 with self.assertRaises(Exception):
60 ast.UnitNormMap([])
62 def test_UnitNormMapSimplify(self):
63 """Test advanced simplification of UnitNormMap
65 Basic simplification is tested elsewhere.
67 ShiftMap + UnitNormMap(forward) = UnitNormMap(forward)
68 UnitNormMap(inverted) + ShiftMap = UnitNormMap(inverted)
69 UnitNormMap(forward) + non-equal UnitNormMap(inverted) = ShiftMap
70 """
71 center1 = [2, -1, 0]
72 center2 = [-1, 6, 4]
73 shift = [3, 7, -9]
74 # an array of points, each of 4 axes, the max we'll need
75 testpoints = np.array([
76 [1.0, 2.0, -6.0, 30.0, 1.0],
77 [3.0, 99.0, -5.0, 21.0, 0.0],
78 [-5.0, 3.0, -7.0, 37.0, 0.0],
79 [7.0, -23.0, -3.0, 45.0, 0.0],
80 ], dtype=float)
81 unm1 = ast.UnitNormMap(center1)
82 unm1inv = unm1.inverted()
83 unm2 = ast.UnitNormMap(center2)
84 unm2inv = unm2.inverted()
85 shiftmap = ast.ShiftMap(shift)
86 winmap_unitscale = ast.WinMap(
87 np.zeros(3), shift, np.ones(3), np.ones(3) + shift)
88 winmap_notunitscale = ast.WinMap(
89 np.zeros(3), shift, np.ones(3), np.ones(3) * 2 + shift)
91 if ast.astVersion() >= 9001003:
92 expected_map = "ShiftMap" # ShiftMap is ShiftMap in 9.1.3
93 else:
94 expected_map = "WinMap" # ShiftMap gets simplified to WinMap
96 for map1, map2, pred_simplified_class_name in (
97 (unm1, unm2inv, expected_map),
98 (shiftmap, unm1, "UnitNormMap"),
99 (winmap_unitscale, unm1, "UnitNormMap"),
100 (winmap_notunitscale, unm1, "SeriesMap"),
101 (unm1inv, shiftmap, "UnitNormMap"),
102 (unm1inv, winmap_unitscale, "UnitNormMap"),
103 (unm1inv, winmap_notunitscale, "SeriesMap"),
104 ):
105 cmpmap = map1.then(map2)
106 self.assertEqual(map1.nIn, cmpmap.nIn)
107 self.assertEqual(map2.nOut, cmpmap.nOut)
108 cmpmap_simp = cmpmap.simplified()
109 self.assertEqual(cmpmap_simp.className, pred_simplified_class_name)
110 self.assertEqual(cmpmap.nIn, cmpmap_simp.nIn)
111 self.assertEqual(cmpmap.nOut, cmpmap_simp.nOut)
112 testptview = np.array(testpoints[0:cmpmap.nIn])
113 assert_allclose(cmpmap.applyForward(
114 testptview), cmpmap_simp.applyForward(testptview))
117if __name__ == "__main__": 117 ↛ 118line 117 didn't jump to line 118, because the condition on line 117 was never true
118 unittest.main()