Coverage for tests/test_frameDict.py: 7%
269 statements
« prev ^ index » next coverage.py v7.5.3, created at 2024-06-13 02:50 -0700
« prev ^ index » next coverage.py v7.5.3, created at 2024-06-13 02:50 -0700
1import unittest
3import numpy as np
4from numpy.testing import assert_allclose
6import astshim as ast
7from astshim.test import MappingTestCase
8from astshim.detail.testUtils import makeFrameDict
11class TestFrameDict(MappingTestCase):
13 def setUp(self):
14 self.frame1 = ast.Frame(2, "Domain=frame1, Ident=f1")
15 self.frame2 = ast.Frame(2, "Domain=frame2, Ident=f2")
16 self.zoom = 1.5
17 self.zoomMap = ast.ZoomMap(2, self.zoom, "Ident=zoomMap")
18 self.initialNumFrames = self.frame1.getNObject() # may be >2 when run using pytest
19 self.initialNumZoomMap = self.zoomMap.getNObject() # may be > 1 when run using pytest
21 def checkDict(self, frameDict):
22 for index in range(1, frameDict.nFrame + 1):
23 domain = frameDict.getFrame(index).domain
24 self.assertEqual(frameDict.getIndex(domain), index)
25 self.assertEqual(frameDict.getFrame(domain).domain, domain)
27 def test_FrameDictOneFrameConstructor(self):
28 frameDict = ast.FrameDict(self.frame1)
29 self.assertIsInstance(frameDict, ast.FrameDict)
30 self.assertEqual(frameDict.nFrame, 1)
31 self.assertEqual(frameDict.getAllDomains(), {"FRAME1"})
32 self.assertEqual(frameDict.getIndex("frame1"), 1) # should be case blind
34 with self.assertRaises(IndexError):
35 frameDict.getIndex("missingDomain")
36 with self.assertRaises(IndexError):
37 frameDict.getIndex("")
39 # Make sure the frame is deep copied
40 self.frame1.domain = "NEWDOMAIN"
41 self.assertEqual(frameDict.getFrame("FRAME1").domain, "FRAME1")
42 self.assertEqual(frameDict.getFrame(frameDict.BASE).domain, "FRAME1")
43 self.assertEqual(self.frame1.getRefCount(), 1)
44 self.assertEqual(self.frame1.getNObject(), self.initialNumFrames + 1)
46 # make sure BASE and CURRENT are available on the class and instance
47 self.assertEqual(ast.FrameDict.BASE, frameDict.BASE)
48 self.assertEqual(ast.FrameDict.CURRENT, frameDict.CURRENT)
50 self.checkCopy(frameDict)
52 indata = np.array([
53 [0.0, 0.1, -1.5],
54 [5.1, 0.0, 3.1],
55 ])
56 self.checkMappingPersistence(frameDict, indata)
57 self.checkPersistence(frameDict, typeFromChannel=ast.FrameSet)
58 self.checkDict(frameDict)
60 def test_FrameDictFrameSetConstructor(self):
61 frameSet = ast.FrameSet(self.frame1, self.zoomMap, self.frame2)
62 frameDict = ast.FrameDict(frameSet)
63 indata = np.array([[1.1, 2.1, 3.1], [1.2, 2.2, 3.2]])
64 predictedOut = indata * self.zoom
65 assert_allclose(frameDict.applyForward(indata), predictedOut)
66 assert_allclose(frameDict.applyInverse(predictedOut), indata)
68 frameDict2 = makeFrameDict(frameSet)
69 self.assertEqual(frameDict2.getRefCount(), 1)
71 def test_FrameDictAddFrame(self):
72 frameDict = ast.FrameDict(self.frame1)
73 self.assertEqual(self.frame1.getNObject(), self.initialNumFrames + 1)
74 frameDict.addFrame(1, self.zoomMap, self.frame2)
75 self.assertEqual(frameDict.nFrame, 2)
76 self.assertEqual(frameDict.getFrame("FRAME2").domain, "FRAME2")
77 self.assertEqual(frameDict.getFrame(frameDict.CURRENT).domain, "FRAME2")
78 self.assertEqual(frameDict.getAllDomains(), {"FRAME1", "FRAME2"})
79 self.assertEqual(self.frame2.getRefCount(), 1)
80 self.assertEqual(self.frame1.getNObject(), self.initialNumFrames + 2)
82 # make sure all objects were deep copied
83 self.frame1.domain = "newBase"
84 self.zoomMap.ident = "newMapping"
85 self.frame2.domain = "newCurrent"
86 self.assertEqual(frameDict.getFrame(frameDict.BASE).domain, "FRAME1")
87 self.assertEqual(frameDict.getFrame(frameDict.CURRENT).domain, "FRAME2")
88 self.assertEqual(frameDict.getMapping().ident, "zoomMap")
89 self.checkPersistence(frameDict, typeFromChannel=ast.FrameSet)
90 self.checkDict(frameDict)
92 # make sure we can't add a frame with a duplicate domain name
93 # and that attempting to do so leave the FrameDict unchanged
94 duplicateFrame = ast.Frame(2, "Domain=FRAME1, Ident=duplicate")
95 with self.assertRaises(ValueError):
96 frameDict.addFrame(1, self.zoomMap, duplicateFrame)
97 self.assertEqual(frameDict.getAllDomains(), {"FRAME1", "FRAME2"})
98 self.assertEqual(frameDict.getFrame("FRAME1").ident, "f1")
99 self.checkDict(frameDict)
101 def test_FrameDictFrameMappingFrameConstructor(self):
102 frameDict = ast.FrameDict(self.frame1, self.zoomMap, self.frame2)
103 self.assertEqual(frameDict.nFrame, 2)
104 self.assertEqual(frameDict.base, 1)
105 self.assertEqual(frameDict.getIndex("FRAME1"), 1)
106 self.assertEqual(frameDict.current, 2)
107 self.assertEqual(frameDict.getIndex("frame2"), 2)
108 self.assertEqual(frameDict.getAllDomains(), {"FRAME1", "FRAME2"})
109 self.assertEqual(self.frame1.getNObject(), self.initialNumFrames + 2)
110 self.assertEqual(self.zoomMap.getNObject(), self.initialNumZoomMap + 1)
112 # make sure all objects were deep copied
113 self.frame1.domain = "newBase"
114 self.zoomMap.ident = "newMapping"
115 self.frame2.domain = "newCurrent"
116 self.assertEqual(frameDict.getFrame(frameDict.BASE).domain, "FRAME1")
117 self.assertEqual(frameDict.getFrame(frameDict.CURRENT).domain, "FRAME2")
118 self.assertEqual(frameDict.getMapping().ident, "zoomMap")
119 self.checkPersistence(frameDict, typeFromChannel=ast.FrameSet)
120 self.checkDict(frameDict)
122 def test_FrameDictGetMapping(self):
123 frameDict = ast.FrameDict(self.frame1, self.zoomMap, self.frame2)
125 # make sure the zoomMap in frameDict is a deep copy of self.zoomMap
126 self.zoomMap.ident = "newMappingIdent"
127 zoomMapList = ( # all should be the same
128 frameDict.getMapping(frameDict.BASE, frameDict.CURRENT),
129 frameDict.getMapping("FRAME1", "FRAME2"),
130 frameDict.getMapping(frameDict.BASE, "frame2"),
131 frameDict.getMapping("frame1", frameDict.CURRENT),
132 )
133 for zoomMap in zoomMapList:
134 self.assertEqual(zoomMap.ident, "zoomMap")
135 self.assertEqual(self.zoomMap.getRefCount(), 1)
137 # make sure the zoomMapList are retrieved in the right direction
138 indata = np.array([[1.1, 2.1, 3.1], [1.2, 2.2, 3.2]])
139 predictedOut = indata * self.zoom
140 for zoomMap in zoomMapList:
141 assert_allclose(zoomMap.applyForward(indata), predictedOut)
143 # check that getMapping returns a deep copy
144 for i, zoomMap in enumerate(zoomMapList):
145 zoomMap.ident = "newIdent%s" % (i,)
146 self.assertEqual(zoomMap.getRefCount(), 1)
147 self.assertEqual(frameDict.getMapping().ident, "zoomMap")
148 # 5 = 1 in frameDict plus 4 retrieved copies in zoomMapList
149 self.assertEqual(self.zoomMap.getNObject(), self.initialNumZoomMap + 5)
150 self.checkDict(frameDict)
152 # try to get invalid frames by name and index; test all combinations
153 # of the "from" and "to" index being valid or invalid
154 indexIsValidList = (
155 (1, True),
156 (3, False),
157 ("Frame1", True),
158 ("BadFrame", False),
159 ("", False),
160 )
161 for fromIndex, fromValid in indexIsValidList:
162 for toIndex, toValid in indexIsValidList:
163 if fromValid and toValid:
164 mapping = frameDict.getMapping(fromIndex, toIndex)
165 self.assertIsInstance(mapping, ast.Mapping)
166 else:
167 with self.assertRaises((IndexError, RuntimeError)):
168 frameDict.getMapping(fromIndex, toIndex)
170 # make sure the errors did not mess up the FrameDict
171 self.assertEqual(frameDict.getAllDomains(), {"FRAME1", "FRAME2"})
172 self.checkDict(frameDict)
174 def test_FrameDictRemoveFrame(self):
175 frameDict = ast.FrameDict(self.frame1, self.zoomMap, self.frame2)
176 zoomMap2 = ast.ZoomMap(2, 1.3, "Ident=zoomMap2")
177 frame3 = ast.Frame(2, "Domain=FRAME3, Ident=f3")
178 frameDict.addFrame(2, zoomMap2, frame3)
179 self.assertEqual(frameDict.getAllDomains(), {"FRAME1", "FRAME2", "FRAME3"})
180 self.assertEqual(frameDict.getIndex("FRAME1"), 1)
181 self.assertEqual(frameDict.getIndex("FRAME2"), 2)
182 self.assertEqual(frameDict.getIndex("FRAME3"), 3)
183 self.assertEqual(self.frame1.getNObject(), self.initialNumFrames + 4)
184 self.assertEqual(self.zoomMap.getNObject(), self.initialNumZoomMap + 3)
186 # remove the frame named "FRAME1" by name
187 # this will also remove one of the two zoom maps
188 frameDict.removeFrame("FRAME1")
189 self.checkDict(frameDict)
190 self.assertEqual(frameDict.getAllDomains(), {"FRAME2", "FRAME3"})
191 self.assertEqual(frameDict.nFrame, 2)
192 self.assertEqual(frameDict.getIndex("FRAME2"), 1)
193 self.assertEqual(frameDict.getIndex("FRAME3"), 2)
194 self.assertEqual(frameDict.getFrame("FRAME2").domain, "FRAME2")
195 self.assertEqual(frameDict.getFrame("FRAME3").domain, "FRAME3")
196 self.assertEqual(self.frame1.getNObject(), self.initialNumFrames + 3)
197 self.assertEqual(self.zoomMap.getNObject(), self.initialNumZoomMap + 2)
199 # remove the frame "FRAME3" by index
200 # this will also remove the remaining zoom map
201 frameDict.removeFrame(2)
202 self.checkDict(frameDict)
203 self.assertEqual(frameDict.getAllDomains(), {"FRAME2"})
204 self.assertEqual(frameDict.nFrame, 1)
205 self.assertEqual(frameDict.getIndex("FRAME2"), 1)
206 self.assertEqual(frameDict.getFrame("FRAME2").domain, "FRAME2")
207 self.assertEqual(self.frame1.getNObject(), self.initialNumFrames + 2)
208 self.assertEqual(self.zoomMap.getNObject(), self.initialNumZoomMap + 1)
209 frameDeep = frameDict.getFrame(1)
210 self.assertEqual(frameDeep.domain, "FRAME2")
212 # it is not allowed to remove the last frame
213 with self.assertRaises(RuntimeError):
214 frameDict.removeFrame(1)
216 self.checkDict(frameDict)
218 def test_FrameDictGetFrameAndGetIndex(self):
219 frameDict = ast.FrameDict(self.frame1, self.zoomMap, self.frame2)
220 self.assertEqual(frameDict.getIndex("frame1"), 1)
221 self.assertEqual(frameDict.getFrame(1).domain, "FRAME1")
222 self.assertEqual(frameDict.getFrame(frameDict.BASE).domain, "FRAME1")
223 self.assertEqual(frameDict.getFrame("FRAME1").domain, "FRAME1")
225 self.assertEqual(frameDict.getIndex("frame2"), 2)
226 self.assertEqual(frameDict.getFrame(2).domain, "FRAME2")
227 self.assertEqual(frameDict.getFrame(frameDict.CURRENT).domain, "FRAME2")
228 self.assertEqual(frameDict.getFrame("FRAME2").domain, "FRAME2")
230 # test on invalid indices
231 for badDomain in ("badName", ""):
232 with self.assertRaises(IndexError):
233 frameDict.getFrame(badDomain)
234 with self.assertRaises(IndexError):
235 frameDict.getIndex(badDomain)
236 with self.assertRaises(RuntimeError):
237 frameDict.getFrame(3)
239 # make sure the errors did not mess up the FrameDict
240 self.assertEqual(frameDict.getAllDomains(), {"FRAME1", "FRAME2"})
241 self.checkDict(frameDict)
243 def test_FrameDictRemapFrame(self):
244 for useDomainForRemapFrame in (False, True):
245 frameDict = ast.FrameDict(self.frame1, self.zoomMap, self.frame2)
247 indata = np.array([
248 [0.0, 0.1, -1.5],
249 [5.1, 0.0, 3.1],
250 ])
251 predictedOut1 = indata * self.zoom
252 assert_allclose(frameDict.applyForward(indata), predictedOut1)
253 assert_allclose(frameDict.applyInverse(predictedOut1), indata)
254 self.checkMappingPersistence(frameDict, indata)
256 shift = (0.5, -1.5)
257 shiftMap = ast.ShiftMap(shift, "Ident=shift")
258 initialNumShiftMap = shiftMap.getNObject()
259 self.assertEqual(self.zoomMap.getNObject(), self.initialNumZoomMap + 1)
260 if useDomainForRemapFrame:
261 frameDict.remapFrame("FRAME1", shiftMap)
262 else:
263 frameDict.remapFrame(1, shiftMap)
264 self.assertEqual(self.zoomMap.getNObject(), self.initialNumZoomMap + 1)
265 self.assertEqual(shiftMap.getNObject(), initialNumShiftMap + 1)
266 predictedOut2 = (indata.T - shift).T * self.zoom
267 assert_allclose(frameDict.applyForward(indata), predictedOut2)
268 assert_allclose(frameDict.applyInverse(predictedOut2), indata)
270 def test_FrameDictPermutationSkyFrame(self):
271 """Test permuting FrameDict axes using a SkyFrame
273 Permuting the axes of the current frame of a frame set
274 *in situ* (by calling `permAxes` on the frame set itself)
275 should update the connected mappings.
276 """
277 # test with arbitrary values that will not be wrapped by SkyFrame
278 x = 0.257
279 y = 0.832
280 frame1 = ast.Frame(2)
281 unitMap = ast.UnitMap(2)
282 frame2 = ast.SkyFrame()
283 frameDict = ast.FrameDict(frame1, unitMap, frame2)
284 self.assertAlmostEqual(frameDict.applyForward([x, y]), [x, y])
285 self.assertAlmostEqual(frameDict.applyInverse([x, y]), [x, y])
287 # permuting the axes of the current frame also permutes the mapping
288 frameDict.permAxes([2, 1])
289 self.assertAlmostEqual(frameDict.applyForward([x, y]), [y, x])
290 self.assertAlmostEqual(frameDict.applyInverse([x, y]), [y, x])
292 # permuting again puts things back
293 frameDict.permAxes([2, 1])
294 self.assertAlmostEqual(frameDict.applyForward([x, y]), [x, y])
295 self.assertAlmostEqual(frameDict.applyInverse([x, y]), [x, y])
297 def test_FrameDictPermutationUnequal(self):
298 """Test permuting FrameDict axes with nIn != nOut
300 Permuting the axes of the current frame of a frame set
301 *in situ* (by calling `permAxes` on the frame set itself)
302 should update the connected mappings.
304 Make nIn != nOut in order to test DM-9899
305 FrameDict.permAxes would fail if nIn != nOut
306 """
307 # Initial mapping: 3 inputs, 2 outputs: 1-1, 2-2, 3=z
308 # Test using arbitrary values for x,y,z
309 x = 75.1
310 y = -53.2
311 z = 0.123
312 frame1 = ast.Frame(3)
313 permMap = ast.PermMap([1, 2, -1], [1, 2], [z])
314 frame2 = ast.Frame(2)
315 frameDict = ast.FrameDict(frame1, permMap, frame2)
316 self.assertAlmostEqual(frameDict.applyForward([x, y, z]), [x, y])
317 self.assertAlmostEqual(frameDict.applyInverse([x, y]), [x, y, z])
319 # permuting the axes of the current frame also permutes the mapping
320 frameDict.permAxes([2, 1])
321 self.assertAlmostEqual(frameDict.applyForward([x, y, z]), [y, x])
322 self.assertAlmostEqual(frameDict.applyInverse([x, y]), [y, x, z])
324 # permuting again puts things back
325 frameDict.permAxes([2, 1])
326 self.assertAlmostEqual(frameDict.applyForward([x, y, z]), [x, y])
327 self.assertAlmostEqual(frameDict.applyInverse([x, y]), [x, y, z])
329 def test_FrameDictSetBaseCurrent(self):
330 frameDict = ast.FrameDict(self.frame1, self.zoomMap, self.frame2)
331 self.assertEqual(frameDict.base, 1)
332 self.assertEqual(frameDict.current, 2)
333 self.assertEqual(frameDict.getIndex("frame1"), 1)
334 self.assertEqual(frameDict.getIndex("frame2"), 2)
336 indata = np.array([
337 [0.0, 0.1, -1.5],
338 [5.1, 0.0, 3.1],
339 ])
340 predictedOut1 = indata.copy() * self.zoom
341 assert_allclose(frameDict.applyForward(indata), predictedOut1)
343 frameDict.setCurrent("FRAME1")
344 self.assertEqual(frameDict.base, 1)
345 self.assertEqual(frameDict.current, 1)
346 self.assertEqual(frameDict.getIndex("FRAME1"), 1)
347 self.assertEqual(frameDict.getIndex("FRAME2"), 2)
349 predictedOutput2 = indata.copy()
350 assert_allclose(frameDict.applyForward(indata), predictedOutput2)
352 frameDict.setBase("FRAME2")
353 self.assertEqual(frameDict.base, 2)
354 self.assertEqual(frameDict.current, 1)
355 self.assertEqual(frameDict.getIndex("FRAME1"), 1)
356 self.assertEqual(frameDict.getIndex("FRAME2"), 2)
358 predictedOutput3 = indata.copy() / self.zoom
359 assert_allclose(frameDict.applyForward(indata), predictedOutput3)
361 def test_FrameDictSetDomain(self):
362 frameDict = ast.FrameDict(self.frame1, self.zoomMap, self.frame2)
363 frameDict.setCurrent("FRAME1")
364 frameDict.setDomain("NEWFRAME1")
365 self.assertEqual(frameDict.getAllDomains(), {"NEWFRAME1", "FRAME2"})
366 self.assertEqual(frameDict.getIndex("newFrame1"), 1)
367 self.assertEqual(frameDict.getIndex("FRAME2"), 2)
369 frameDict.setCurrent("FRAME2")
370 frameDict.setDomain("NEWFRAME2")
371 self.assertEqual(frameDict.getAllDomains(), {"NEWFRAME1", "NEWFRAME2"})
372 self.assertEqual(frameDict.getIndex("NEWFRAME1"), 1)
373 self.assertEqual(frameDict.getIndex("NEWFRAME2"), 2)
375 # Renaming a domain to itself should have no effect
376 self.assertEqual(frameDict.getFrame(frameDict.CURRENT).domain, "NEWFRAME2")
377 frameDict.setDomain("NEWFRAME2")
378 self.assertEqual(frameDict.getFrame(frameDict.CURRENT).domain, "NEWFRAME2")
379 self.assertEqual(frameDict.getAllDomains(), {"NEWFRAME1", "NEWFRAME2"})
381 # Make sure setDomain cannot be used to rename a domain to a duplicate
382 # and that this leaves the frameDict unchanged
383 self.assertEqual(frameDict.getFrame(frameDict.CURRENT).domain, "NEWFRAME2")
384 with self.assertRaises(ValueError):
385 frameDict.setDomain("NEWFRAME1")
386 self.assertEqual(frameDict.getFrame(frameDict.CURRENT).domain, "NEWFRAME2")
387 self.assertEqual(frameDict.getAllDomains(), {"NEWFRAME1", "NEWFRAME2"})
390if __name__ == "__main__": 390 ↛ 391line 390 didn't jump to line 391, because the condition on line 390 was never true
391 unittest.main()