Coverage for tests/test_cmpMap.py: 13%

93 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-19 10:51 +0000

1import unittest 

2 

3import numpy as np 

4from numpy.testing import assert_allclose 

5 

6import astshim as ast 

7from astshim.test import MappingTestCase 

8 

9 

10class TestCmpMap(MappingTestCase): 

11 

12 """Test compound maps: CmpMap, ParallelMap and SeriesMap 

13 """ 

14 

15 def setUp(self): 

16 self.nin = 2 

17 self.zoom = 1.3 

18 self.shift = [-0.5, 1.2] 

19 self.zoommap = ast.ZoomMap(self.nin, self.zoom) 

20 self.shiftmap = ast.ShiftMap(self.shift) 

21 

22 def test_SeriesMap(self): 

23 sermap = ast.SeriesMap(self.shiftmap, self.zoommap) 

24 self.assertEqual(sermap.getRefCount(), 1) 

25 

26 self.checkBasicSimplify(sermap) 

27 self.checkCopy(sermap) 

28 self.checkMemoryForCompoundObject(self.shiftmap, self.zoommap, sermap, isSeries=True) 

29 

30 sermap2 = self.shiftmap.then(self.zoommap) 

31 self.checkBasicSimplify(sermap2) 

32 self.checkCopy(sermap2) 

33 self.checkMemoryForCompoundObject(self.shiftmap, self.zoommap, sermap2, isSeries=True) 

34 

35 sermap3 = ast.CmpMap(self.shiftmap, self.zoommap, True) 

36 self.checkBasicSimplify(sermap3) 

37 self.checkCopy(sermap3) 

38 self.checkMemoryForCompoundObject(self.shiftmap, self.zoommap, sermap3, isSeries=True) 

39 

40 indata = np.array([ 

41 [1.0, 2.0, -6.0, 30.0, 0.2], 

42 [3.0, 99.9, -5.1, 21.0, 0.0], 

43 ], dtype=float) 

44 pred_outdata = ((indata.T + self.shift) * self.zoom).T 

45 topos = sermap.applyForward(indata) 

46 assert_allclose(topos, pred_outdata) 

47 

48 topos2 = sermap2.applyForward(indata) 

49 assert_allclose(topos2, pred_outdata) 

50 

51 topos3 = sermap3.applyForward(indata) 

52 assert_allclose(topos3, pred_outdata) 

53 

54 self.checkRoundTrip(sermap, indata) 

55 self.checkRoundTrip(sermap2, indata) 

56 self.checkRoundTrip(sermap3, indata) 

57 

58 self.checkMappingPersistence(sermap, indata) 

59 self.checkMappingPersistence(sermap2, indata) 

60 self.checkMappingPersistence(sermap3, indata) 

61 

62 def test_ParallelMap(self): 

63 parmap = ast.ParallelMap(self.shiftmap, self.zoommap) 

64 # adding to a ParallelMap increases by 1 

65 self.assertEqual(self.shiftmap.getRefCount(), 2) 

66 # adding to a ParallelMap increases by 1 

67 self.assertEqual(self.zoommap.getRefCount(), 2) 

68 self.assertEqual(parmap.nIn, self.nin * 2) 

69 self.assertEqual(parmap.nOut, self.nin * 2) 

70 self.assertFalse(parmap.series) 

71 

72 self.checkBasicSimplify(parmap) 

73 self.checkCopy(parmap) 

74 self.checkMemoryForCompoundObject(self.shiftmap, self.zoommap, parmap, isSeries=False) 

75 

76 parmap2 = self.shiftmap.under(self.zoommap) 

77 self.checkBasicSimplify(parmap2) 

78 self.checkCopy(parmap2) 

79 self.checkMemoryForCompoundObject(self.shiftmap, self.zoommap, parmap2, isSeries=False) 

80 

81 indata = np.array([ 

82 [3.0, 1.0, -6.0], 

83 [2.2, 3.0, -5.1], 

84 [-5.6, 2.0, 30.0], 

85 [0.32, 99.9, 21.0], 

86 ], dtype=float) 

87 pred_outdata = indata.copy() 

88 pred_outdata.T[:, 0:2] += self.shift 

89 pred_outdata.T[:, 2:4] *= self.zoom 

90 topos = parmap.applyForward(indata) 

91 assert_allclose(topos, pred_outdata) 

92 

93 topos2 = parmap2.applyForward(indata) 

94 assert_allclose(topos2, pred_outdata) 

95 

96 parmap3 = ast.CmpMap(self.shiftmap, self.zoommap, False) 

97 self.checkBasicSimplify(parmap3) 

98 self.checkCopy(parmap3) 

99 self.checkMemoryForCompoundObject(self.shiftmap, self.zoommap, parmap3, isSeries=False) 

100 

101 topos3 = parmap3.applyForward(indata) 

102 assert_allclose(topos3, pred_outdata) 

103 

104 self.checkRoundTrip(parmap, indata) 

105 self.checkRoundTrip(parmap2, indata) 

106 self.checkRoundTrip(parmap3, indata) 

107 

108 self.checkMappingPersistence(parmap, indata) 

109 self.checkMappingPersistence(parmap2, indata) 

110 self.checkMappingPersistence(parmap3, indata) 

111 

112 def test_SeriesMapMatrixShiftSimplify(self): 

113 """Test that a non-square matrix map followed by a shift map can be 

114 simplified. 

115 

116 This is ticket DM-10946 

117 """ 

118 m1 = 1.0 

119 m2 = 2.0 

120 shift = 3.0 

121 matrixMap = ast.MatrixMap(np.array([[m1, m2]])) 

122 self.assertEqual(matrixMap.nIn, 2) 

123 self.assertEqual(matrixMap.nOut, 1) 

124 shiftMap = ast.ShiftMap([shift]) 

125 seriesMap = matrixMap.then(shiftMap) 

126 

127 indata = np.array([ 

128 [1.0, 2.0, 3.0], 

129 [0.0, 1.0, 2.0], 

130 ], dtype=float) 

131 pred_outdata = m1 * indata[0] + m2 * indata[1] + shift 

132 pred_outdata.shape = (1, len(pred_outdata)) 

133 

134 outdata = seriesMap.applyForward(indata) 

135 assert_allclose(outdata, pred_outdata) 

136 

137 simplifiedMap = seriesMap.simplified() 

138 outdata2 = simplifiedMap.applyForward(indata) 

139 assert_allclose(outdata2, pred_outdata) 

140 

141 

142if __name__ == "__main__": 142 ↛ 143line 142 didn't jump to line 143, because the condition on line 142 was never true

143 unittest.main()