Coverage for tests/test_unitNormMap.py: 14%

60 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2024-06-13 02:50 -0700

1import unittest 

2 

3import numpy as np 

4from numpy.testing import assert_allclose, assert_equal 

5 

6import astshim as ast 

7from astshim.test import MappingTestCase 

8 

9 

10class TestUnitNormMap(MappingTestCase): 

11 

12 def test_UnitNormMapBasics(self): 

13 """Test basics of UnitNormMap including applyForward 

14 """ 

15 # `full_` variables contain data for 3 axes; the variables without 

16 # the `full_` prefix are a subset containing the number of axes 

17 # being tested. 

18 full_center = np.array([-1, 1, 2], dtype=float) 

19 full_indata = np.array([ 

20 [full_center[0], 1.0, 2.0, -6.0, 30.0, 1.0], 

21 [full_center[1], 3.0, 99.0, -5.0, 21.0, 0.0], 

22 [full_center[2], -5.0, 3.0, -7.0, 37.0, 0.0], 

23 ], dtype=float) 

24 for nin in (1, 2, 3): 

25 center = full_center[0:nin] 

26 indata = full_indata[0:nin] 

27 unitnormmap = ast.UnitNormMap(center) 

28 self.assertEqual(unitnormmap.className, "UnitNormMap") 

29 self.assertEqual(unitnormmap.nIn, nin) 

30 self.assertEqual(unitnormmap.nOut, nin + 1) 

31 self.assertFalse(unitnormmap.isLinear) 

32 

33 self.checkBasicSimplify(unitnormmap) 

34 self.checkCopy(unitnormmap) 

35 

36 self.checkRoundTrip(unitnormmap, indata) 

37 self.checkMappingPersistence(unitnormmap, indata) 

38 

39 outdata = unitnormmap.applyForward(indata) 

40 norm = outdata[-1] 

41 

42 # the first input point is at the center, so the expected output is 

43 # [Nan, Nan, ..., Nan, 0] 

44 pred_out_at_center = [np.nan]*nin + [0] 

45 assert_equal(outdata[:, 0], pred_out_at_center) 

46 

47 relative_indata = (indata.T - center).T 

48 pred_norm = np.linalg.norm(relative_indata, axis=0) 

49 assert_allclose(norm, pred_norm) 

50 

51 pred_relative_indata = outdata[0:nin] * norm 

52 # the first input point is at the center, so the output is 

53 # [NaN, NaN, ..., NaN, 0], (as checked above), 

54 # but the expected value after scaling by the norm is 0s, so... 

55 pred_relative_indata[:, 0] = [0]*nin 

56 assert_allclose(relative_indata, pred_relative_indata) 

57 

58 # UnitNormMap must have at least one input 

59 with self.assertRaises(Exception): 

60 ast.UnitNormMap([]) 

61 

62 def test_UnitNormMapSimplify(self): 

63 """Test advanced simplification of UnitNormMap 

64 

65 Basic simplification is tested elsewhere. 

66 

67 ShiftMap + UnitNormMap(forward) = UnitNormMap(forward) 

68 UnitNormMap(inverted) + ShiftMap = UnitNormMap(inverted) 

69 UnitNormMap(forward) + non-equal UnitNormMap(inverted) = ShiftMap 

70 """ 

71 center1 = [2, -1, 0] 

72 center2 = [-1, 6, 4] 

73 shift = [3, 7, -9] 

74 # an array of points, each of 4 axes, the max we'll need 

75 testpoints = np.array([ 

76 [1.0, 2.0, -6.0, 30.0, 1.0], 

77 [3.0, 99.0, -5.0, 21.0, 0.0], 

78 [-5.0, 3.0, -7.0, 37.0, 0.0], 

79 [7.0, -23.0, -3.0, 45.0, 0.0], 

80 ], dtype=float) 

81 unm1 = ast.UnitNormMap(center1) 

82 unm1inv = unm1.inverted() 

83 unm2 = ast.UnitNormMap(center2) 

84 unm2inv = unm2.inverted() 

85 shiftmap = ast.ShiftMap(shift) 

86 winmap_unitscale = ast.WinMap( 

87 np.zeros(3), shift, np.ones(3), np.ones(3) + shift) 

88 winmap_notunitscale = ast.WinMap( 

89 np.zeros(3), shift, np.ones(3), np.ones(3) * 2 + shift) 

90 

91 if ast.astVersion() >= 9001003: 

92 expected_map = "ShiftMap" # ShiftMap is ShiftMap in 9.1.3 

93 else: 

94 expected_map = "WinMap" # ShiftMap gets simplified to WinMap 

95 

96 for map1, map2, pred_simplified_class_name in ( 

97 (unm1, unm2inv, expected_map), 

98 (shiftmap, unm1, "UnitNormMap"), 

99 (winmap_unitscale, unm1, "UnitNormMap"), 

100 (winmap_notunitscale, unm1, "SeriesMap"), 

101 (unm1inv, shiftmap, "UnitNormMap"), 

102 (unm1inv, winmap_unitscale, "UnitNormMap"), 

103 (unm1inv, winmap_notunitscale, "SeriesMap"), 

104 ): 

105 cmpmap = map1.then(map2) 

106 self.assertEqual(map1.nIn, cmpmap.nIn) 

107 self.assertEqual(map2.nOut, cmpmap.nOut) 

108 cmpmap_simp = cmpmap.simplified() 

109 self.assertEqual(cmpmap_simp.className, pred_simplified_class_name) 

110 self.assertEqual(cmpmap.nIn, cmpmap_simp.nIn) 

111 self.assertEqual(cmpmap.nOut, cmpmap_simp.nOut) 

112 testptview = np.array(testpoints[0:cmpmap.nIn]) 

113 assert_allclose(cmpmap.applyForward( 

114 testptview), cmpmap_simp.applyForward(testptview)) 

115 

116 

117if __name__ == "__main__": 117 ↛ 118line 117 didn't jump to line 118, because the condition on line 117 was never true

118 unittest.main()