Coverage for tests/test_frameSet.py: 9%
172 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-08 21:54 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-08 21:54 -0800
1import unittest
3import numpy as np
4from numpy.testing import assert_allclose
6import astshim as ast
7from astshim.test import MappingTestCase
10class TestFrameSet(MappingTestCase):
12 def test_FrameSetBasics(self):
13 frame = ast.Frame(2, "Ident=base")
14 initialNumFrames = frame.getNObject() # may be >1 when run using pytest
15 frameSet = ast.FrameSet(frame)
16 self.assertIsInstance(frameSet, ast.FrameSet)
17 self.assertEqual(frameSet.nFrame, 1)
19 # Make sure the frame is deep copied
20 frame.ident = "newIdent"
21 self.assertEqual(frameSet.getFrame(frameSet.BASE).ident, "base")
22 self.assertEqual(frame.getRefCount(), 1)
23 self.assertEqual(frame.getNObject(), initialNumFrames + 1)
25 # add a new frame and mapping; make sure they are deep copied
26 newFrame = ast.Frame(2, "Ident=current")
27 mapping = ast.UnitMap(2, "Ident=mapping")
28 initialNumUnitMap = mapping.getNObject()
29 self.assertEqual(frame.getNObject(), initialNumFrames + 2)
30 frameSet.addFrame(1, mapping, newFrame)
31 self.assertEqual(frameSet.nFrame, 2)
32 newFrame.ident = "newFrameIdent"
33 mapping.ident = "newMappingIdent"
34 self.assertEqual(frameSet.getFrame(frameSet.CURRENT).ident, "current")
35 self.assertEqual(frameSet.getMapping().ident, "mapping")
36 self.assertEqual(newFrame.getRefCount(), 1)
37 self.assertEqual(frame.getNObject(), initialNumFrames + 3)
38 self.assertEqual(mapping.getRefCount(), 1)
39 self.assertEqual(mapping.getNObject(), initialNumUnitMap + 1)
41 # make sure BASE is available on the class and instance
42 self.assertEqual(ast.FrameSet.BASE, frameSet.BASE)
44 baseframe = frameSet.getFrame(frameSet.BASE)
45 self.assertEqual(frame.getNObject(), initialNumFrames + 4)
46 self.assertEqual(baseframe.ident, "base")
47 self.assertEqual(frameSet.base, 1)
48 currframe = frameSet.getFrame(frameSet.CURRENT)
49 self.assertEqual(frame.getNObject(), initialNumFrames + 5)
50 self.assertEqual(currframe.ident, "current")
51 self.assertEqual(frameSet.current, 2)
53 self.checkCopy(frameSet)
55 input_data = np.array([
56 [0.0, 0.1, -1.5],
57 [5.1, 0.0, 3.1],
58 ])
59 self.checkMappingPersistence(frameSet, input_data)
61 def testFrameSetFrameMappingFrameConstructor(self):
62 baseFrame = ast.Frame(2, "Ident=base")
63 mapping = ast.UnitMap(2, "Ident=mapping")
64 currFrame = ast.Frame(2, "Ident=current")
65 frameSet = ast.FrameSet(baseFrame, mapping, currFrame)
66 self.assertEqual(frameSet.nFrame, 2)
67 self.assertEqual(frameSet.base, 1)
68 self.assertEqual(frameSet.current, 2)
70 # make sure all objects were deep copied
71 baseFrame.ident = "newBase"
72 mapping.ident = "newMapping"
73 currFrame.ident = "newCurrent"
74 self.assertEqual(frameSet.getFrame(frameSet.BASE).ident, "base")
75 self.assertEqual(frameSet.getFrame(frameSet.CURRENT).ident, "current")
76 self.assertEqual(frameSet.getMapping().ident, "mapping")
78 def test_FrameSetGetFrame(self):
79 frame = ast.Frame(2, "Ident=base")
80 frameSet = ast.FrameSet(frame)
81 self.assertIsInstance(frameSet, ast.FrameSet)
82 self.assertEqual(frameSet.nFrame, 1)
84 newFrame = ast.Frame(2, "Ident=current")
85 frameSet.addFrame(1, ast.UnitMap(2), newFrame)
86 self.assertEqual(frameSet.nFrame, 2)
88 # check that getFrame returns a deep copy
89 baseFrameDeep = frameSet.getFrame(ast.FrameSet.BASE)
90 self.assertEqual(baseFrameDeep.ident, "base")
91 self.assertEqual(baseFrameDeep.getRefCount(), 1)
92 baseFrameDeep.ident = "modifiedBase"
93 self.assertEqual(frameSet.getFrame(ast.FrameSet.BASE).ident, "base")
94 self.assertEqual(frame.ident, "base")
96 def test_FrameSetGetMapping(self):
97 frame = ast.Frame(2, "Ident=base")
98 frameSet = ast.FrameSet(frame)
99 self.assertIsInstance(frameSet, ast.FrameSet)
100 self.assertEqual(frameSet.nFrame, 1)
102 newFrame = ast.Frame(2)
103 mapping = ast.UnitMap(2, "Ident=mapping")
104 initialNumUnitMap = mapping.getNObject() # may be >1 when run using pytest
105 frameSet.addFrame(1, mapping, newFrame)
106 self.assertEqual(frameSet.nFrame, 2)
107 self.assertEqual(mapping.getNObject(), initialNumUnitMap + 1)
109 # check that getMapping returns a deep copy
110 mappingDeep = frameSet.getMapping(1, 2)
111 self.assertEqual(mappingDeep.ident, "mapping")
112 mappingDeep.ident = "modifiedMapping"
113 self.assertEqual(mapping.ident, "mapping")
114 self.assertEqual(mappingDeep.getRefCount(), 1)
115 self.assertEqual(mapping.getNObject(), initialNumUnitMap + 2)
117 def test_FrameSetRemoveFrame(self):
118 frame = ast.Frame(2, "Ident=base")
119 initialNumFrames = frame.getNObject() # may be >1 when run using pytest
120 frameSet = ast.FrameSet(frame)
121 self.assertIsInstance(frameSet, ast.FrameSet)
122 self.assertEqual(frameSet.nFrame, 1)
123 self.assertEqual(frame.getNObject(), initialNumFrames + 1)
125 newFrame = ast.Frame(2, "Ident=current")
126 self.assertEqual(frame.getNObject(), initialNumFrames + 2)
127 zoomMap = ast.ZoomMap(2, 0.5, "Ident=zoom")
128 initialNumZoomMap = zoomMap.getNObject()
129 frameSet.addFrame(1, zoomMap, newFrame)
130 self.assertEqual(frameSet.nFrame, 2)
131 self.assertEqual(frame.getNObject(), initialNumFrames + 3)
132 self.assertEqual(zoomMap.getNObject(), initialNumZoomMap + 1)
134 # remove the frame named "base", leaving the frame named "current"
135 frameSet.removeFrame(1)
136 self.assertEqual(frameSet.nFrame, 1)
137 # Removing one frame leaves frame, newFrame and a copy of newFrame in
138 # FrameSet.
139 self.assertEqual(frame.getNObject(), initialNumFrames + 2)
140 self.assertEqual(zoomMap.getNObject(), initialNumZoomMap)
141 frameDeep = frameSet.getFrame(1)
142 self.assertEqual(frameDeep.ident, "current")
144 # it is not allowed to remove the last frame
145 with self.assertRaises(RuntimeError):
146 frameSet.removeFrame(1)
148 def test_FrameSetRemapFrame(self):
149 frame = ast.Frame(2, "Ident=base")
150 initialNumFrames = frame.getNObject() # may be >1 when run using pytest
151 frameSet = ast.FrameSet(frame)
152 self.assertIsInstance(frameSet, ast.FrameSet)
153 self.assertEqual(frameSet.nFrame, 1)
154 self.assertEqual(frame.getNObject(), initialNumFrames + 1)
156 newFrame = ast.Frame(2, "Ident=current")
157 self.assertEqual(frame.getNObject(), initialNumFrames + 2)
158 zoom = 0.5
159 zoomMap = ast.ZoomMap(2, zoom, "Ident=zoom")
160 initialNumZoomMap = zoomMap.getNObject()
161 frameSet.addFrame(1, zoomMap, newFrame)
162 self.assertEqual(frameSet.nFrame, 2)
163 self.assertEqual(frame.getNObject(), initialNumFrames + 3)
164 self.assertEqual(zoomMap.getNObject(), initialNumZoomMap + 1)
166 input_data = np.array([
167 [0.0, 0.1, -1.5],
168 [5.1, 0.0, 3.1],
169 ])
170 predicted_output1 = input_data * zoom
171 assert_allclose(frameSet.applyForward(input_data), predicted_output1)
172 self.checkMappingPersistence(frameSet, input_data)
174 shift = (0.5, -1.5)
175 shiftMap = ast.ShiftMap(shift, "Ident=shift")
176 initialNumShiftMap = shiftMap.getNObject()
177 self.assertEqual(zoomMap.getNObject(), initialNumZoomMap + 1)
178 frameSet.remapFrame(1, shiftMap)
179 self.assertEqual(zoomMap.getNObject(), initialNumZoomMap + 1)
180 self.assertEqual(shiftMap.getNObject(), initialNumShiftMap + 1)
181 predicted_output2 = (input_data.T - shift).T * zoom
182 assert_allclose(frameSet.applyForward(input_data), predicted_output2)
184 def test_FrameSetPermutationSkyFrame(self):
185 """Test permuting FrameSet axes using a SkyFrame
187 Permuting the axes of the current frame of a frame set
188 *in situ* (by calling `permAxes` on the frame set itself)
189 should update the connected mappings.
190 """
191 # test with arbitrary values that will not be wrapped by SkyFrame
192 x = 0.257
193 y = 0.832
194 frame1 = ast.Frame(2)
195 unitMap = ast.UnitMap(2)
196 frame2 = ast.SkyFrame()
197 frameSet = ast.FrameSet(frame1, unitMap, frame2)
198 self.assertAlmostEqual(frameSet.applyForward([x, y]), [x, y])
199 self.assertAlmostEqual(frameSet.applyInverse([x, y]), [x, y])
201 # permuting the axes of the current frame also permutes the mapping
202 frameSet.permAxes([2, 1])
203 self.assertAlmostEqual(frameSet.applyForward([x, y]), [y, x])
204 self.assertAlmostEqual(frameSet.applyInverse([x, y]), [y, x])
206 # permuting again puts things back
207 frameSet.permAxes([2, 1])
208 self.assertAlmostEqual(frameSet.applyForward([x, y]), [x, y])
209 self.assertAlmostEqual(frameSet.applyInverse([x, y]), [x, y])
211 def test_FrameSetPermutationUnequal(self):
212 """Test that permuting FrameSet axes with nIn != nOut
214 Permuting the axes of the current frame of a frame set
215 *in situ* (by calling `permAxes` on the frame set itself)
216 should update the connected mappings.
218 Make nIn != nOut in order to test DM-9899
219 FrameSet.permAxes would fail if nIn != nOut
220 """
221 # Initial mapping: 3 inputs, 2 outputs: 1-1, 2-2, 3=z
222 # Test using arbitrary values for x,y,z
223 x = 75.1
224 y = -53.2
225 z = 0.123
226 frame1 = ast.Frame(3)
227 permMap = ast.PermMap([1, 2, -1], [1, 2], [z])
228 frame2 = ast.Frame(2)
229 frameSet = ast.FrameSet(frame1, permMap, frame2)
230 self.assertAlmostEqual(frameSet.applyForward([x, y, z]), [x, y])
231 self.assertAlmostEqual(frameSet.applyInverse([x, y]), [x, y, z])
233 # permuting the axes of the current frame also permutes the mapping
234 frameSet.permAxes([2, 1])
235 self.assertAlmostEqual(frameSet.applyForward([x, y, z]), [y, x])
236 self.assertAlmostEqual(frameSet.applyInverse([x, y]), [y, x, z])
238 # permuting again puts things back
239 frameSet.permAxes([2, 1])
240 self.assertAlmostEqual(frameSet.applyForward([x, y, z]), [x, y])
241 self.assertAlmostEqual(frameSet.applyInverse([x, y]), [x, y, z])
244if __name__ == "__main__": 244 ↛ 245line 244 didn't jump to line 245, because the condition on line 244 was never true
245 unittest.main()