Coverage for tests/test_cmpMap.py: 13%
93 statements
« prev ^ index » next coverage.py v7.5.3, created at 2024-06-13 02:50 -0700
« prev ^ index » next coverage.py v7.5.3, created at 2024-06-13 02:50 -0700
1import unittest
3import numpy as np
4from numpy.testing import assert_allclose
6import astshim as ast
7from astshim.test import MappingTestCase
10class TestCmpMap(MappingTestCase):
12 """Test compound maps: CmpMap, ParallelMap and SeriesMap
13 """
15 def setUp(self):
16 self.nin = 2
17 self.zoom = 1.3
18 self.shift = [-0.5, 1.2]
19 self.zoommap = ast.ZoomMap(self.nin, self.zoom)
20 self.shiftmap = ast.ShiftMap(self.shift)
22 def test_SeriesMap(self):
23 sermap = ast.SeriesMap(self.shiftmap, self.zoommap)
24 self.assertEqual(sermap.getRefCount(), 1)
26 self.checkBasicSimplify(sermap)
27 self.checkCopy(sermap)
28 self.checkMemoryForCompoundObject(self.shiftmap, self.zoommap, sermap, isSeries=True)
30 sermap2 = self.shiftmap.then(self.zoommap)
31 self.checkBasicSimplify(sermap2)
32 self.checkCopy(sermap2)
33 self.checkMemoryForCompoundObject(self.shiftmap, self.zoommap, sermap2, isSeries=True)
35 sermap3 = ast.CmpMap(self.shiftmap, self.zoommap, True)
36 self.checkBasicSimplify(sermap3)
37 self.checkCopy(sermap3)
38 self.checkMemoryForCompoundObject(self.shiftmap, self.zoommap, sermap3, isSeries=True)
40 indata = np.array([
41 [1.0, 2.0, -6.0, 30.0, 0.2],
42 [3.0, 99.9, -5.1, 21.0, 0.0],
43 ], dtype=float)
44 pred_outdata = ((indata.T + self.shift) * self.zoom).T
45 topos = sermap.applyForward(indata)
46 assert_allclose(topos, pred_outdata)
48 topos2 = sermap2.applyForward(indata)
49 assert_allclose(topos2, pred_outdata)
51 topos3 = sermap3.applyForward(indata)
52 assert_allclose(topos3, pred_outdata)
54 self.checkRoundTrip(sermap, indata)
55 self.checkRoundTrip(sermap2, indata)
56 self.checkRoundTrip(sermap3, indata)
58 self.checkMappingPersistence(sermap, indata)
59 self.checkMappingPersistence(sermap2, indata)
60 self.checkMappingPersistence(sermap3, indata)
62 def test_ParallelMap(self):
63 parmap = ast.ParallelMap(self.shiftmap, self.zoommap)
64 # adding to a ParallelMap increases by 1
65 self.assertEqual(self.shiftmap.getRefCount(), 2)
66 # adding to a ParallelMap increases by 1
67 self.assertEqual(self.zoommap.getRefCount(), 2)
68 self.assertEqual(parmap.nIn, self.nin * 2)
69 self.assertEqual(parmap.nOut, self.nin * 2)
70 self.assertFalse(parmap.series)
72 self.checkBasicSimplify(parmap)
73 self.checkCopy(parmap)
74 self.checkMemoryForCompoundObject(self.shiftmap, self.zoommap, parmap, isSeries=False)
76 parmap2 = self.shiftmap.under(self.zoommap)
77 self.checkBasicSimplify(parmap2)
78 self.checkCopy(parmap2)
79 self.checkMemoryForCompoundObject(self.shiftmap, self.zoommap, parmap2, isSeries=False)
81 indata = np.array([
82 [3.0, 1.0, -6.0],
83 [2.2, 3.0, -5.1],
84 [-5.6, 2.0, 30.0],
85 [0.32, 99.9, 21.0],
86 ], dtype=float)
87 pred_outdata = indata.copy()
88 pred_outdata.T[:, 0:2] += self.shift
89 pred_outdata.T[:, 2:4] *= self.zoom
90 topos = parmap.applyForward(indata)
91 assert_allclose(topos, pred_outdata)
93 topos2 = parmap2.applyForward(indata)
94 assert_allclose(topos2, pred_outdata)
96 parmap3 = ast.CmpMap(self.shiftmap, self.zoommap, False)
97 self.checkBasicSimplify(parmap3)
98 self.checkCopy(parmap3)
99 self.checkMemoryForCompoundObject(self.shiftmap, self.zoommap, parmap3, isSeries=False)
101 topos3 = parmap3.applyForward(indata)
102 assert_allclose(topos3, pred_outdata)
104 self.checkRoundTrip(parmap, indata)
105 self.checkRoundTrip(parmap2, indata)
106 self.checkRoundTrip(parmap3, indata)
108 self.checkMappingPersistence(parmap, indata)
109 self.checkMappingPersistence(parmap2, indata)
110 self.checkMappingPersistence(parmap3, indata)
112 def test_SeriesMapMatrixShiftSimplify(self):
113 """Test that a non-square matrix map followed by a shift map can be
114 simplified.
116 This is ticket DM-10946
117 """
118 m1 = 1.0
119 m2 = 2.0
120 shift = 3.0
121 matrixMap = ast.MatrixMap(np.array([[m1, m2]]))
122 self.assertEqual(matrixMap.nIn, 2)
123 self.assertEqual(matrixMap.nOut, 1)
124 shiftMap = ast.ShiftMap([shift])
125 seriesMap = matrixMap.then(shiftMap)
127 indata = np.array([
128 [1.0, 2.0, 3.0],
129 [0.0, 1.0, 2.0],
130 ], dtype=float)
131 pred_outdata = m1 * indata[0] + m2 * indata[1] + shift
132 pred_outdata.shape = (1, len(pred_outdata))
134 outdata = seriesMap.applyForward(indata)
135 assert_allclose(outdata, pred_outdata)
137 simplifiedMap = seriesMap.simplified()
138 outdata2 = simplifiedMap.applyForward(indata)
139 assert_allclose(outdata2, pred_outdata)
142if __name__ == "__main__": 142 ↛ 143line 142 didn't jump to line 143, because the condition on line 142 was never true
143 unittest.main()