Coverage for tests/test_frameSet.py: 10%

172 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-09-11 00:57 -0700

1import unittest 

2 

3import numpy as np 

4from numpy.testing import assert_allclose 

5 

6import astshim as ast 

7from astshim.test import MappingTestCase 

8 

9 

10class TestFrameSet(MappingTestCase): 

11 

12 def test_FrameSetBasics(self): 

13 frame = ast.Frame(2, "Ident=base") 

14 initialNumFrames = frame.getNObject() # may be >1 when run using pytest 

15 frameSet = ast.FrameSet(frame) 

16 self.assertIsInstance(frameSet, ast.FrameSet) 

17 self.assertEqual(frameSet.nFrame, 1) 

18 

19 # Make sure the frame is deep copied 

20 frame.ident = "newIdent" 

21 self.assertEqual(frameSet.getFrame(frameSet.BASE).ident, "base") 

22 self.assertEqual(frame.getRefCount(), 1) 

23 self.assertEqual(frame.getNObject(), initialNumFrames + 1) 

24 

25 # add a new frame and mapping; make sure they are deep copied 

26 newFrame = ast.Frame(2, "Ident=current") 

27 mapping = ast.UnitMap(2, "Ident=mapping") 

28 initialNumUnitMap = mapping.getNObject() 

29 self.assertEqual(frame.getNObject(), initialNumFrames + 2) 

30 frameSet.addFrame(1, mapping, newFrame) 

31 self.assertEqual(frameSet.nFrame, 2) 

32 newFrame.ident = "newFrameIdent" 

33 mapping.ident = "newMappingIdent" 

34 self.assertEqual(frameSet.getFrame(frameSet.CURRENT).ident, "current") 

35 self.assertEqual(frameSet.getMapping().ident, "mapping") 

36 self.assertEqual(newFrame.getRefCount(), 1) 

37 self.assertEqual(frame.getNObject(), initialNumFrames + 3) 

38 self.assertEqual(mapping.getRefCount(), 1) 

39 self.assertEqual(mapping.getNObject(), initialNumUnitMap + 1) 

40 

41 # make sure BASE is available on the class and instance 

42 self.assertEqual(ast.FrameSet.BASE, frameSet.BASE) 

43 

44 baseframe = frameSet.getFrame(frameSet.BASE) 

45 self.assertEqual(frame.getNObject(), initialNumFrames + 4) 

46 self.assertEqual(baseframe.ident, "base") 

47 self.assertEqual(frameSet.base, 1) 

48 currframe = frameSet.getFrame(frameSet.CURRENT) 

49 self.assertEqual(frame.getNObject(), initialNumFrames + 5) 

50 self.assertEqual(currframe.ident, "current") 

51 self.assertEqual(frameSet.current, 2) 

52 

53 self.checkCopy(frameSet) 

54 

55 input_data = np.array([ 

56 [0.0, 0.1, -1.5], 

57 [5.1, 0.0, 3.1], 

58 ]) 

59 self.checkMappingPersistence(frameSet, input_data) 

60 

61 def testFrameSetFrameMappingFrameConstructor(self): 

62 baseFrame = ast.Frame(2, "Ident=base") 

63 mapping = ast.UnitMap(2, "Ident=mapping") 

64 currFrame = ast.Frame(2, "Ident=current") 

65 frameSet = ast.FrameSet(baseFrame, mapping, currFrame) 

66 self.assertEqual(frameSet.nFrame, 2) 

67 self.assertEqual(frameSet.base, 1) 

68 self.assertEqual(frameSet.current, 2) 

69 

70 # make sure all objects were deep copied 

71 baseFrame.ident = "newBase" 

72 mapping.ident = "newMapping" 

73 currFrame.ident = "newCurrent" 

74 self.assertEqual(frameSet.getFrame(frameSet.BASE).ident, "base") 

75 self.assertEqual(frameSet.getFrame(frameSet.CURRENT).ident, "current") 

76 self.assertEqual(frameSet.getMapping().ident, "mapping") 

77 

78 def test_FrameSetGetFrame(self): 

79 frame = ast.Frame(2, "Ident=base") 

80 frameSet = ast.FrameSet(frame) 

81 self.assertIsInstance(frameSet, ast.FrameSet) 

82 self.assertEqual(frameSet.nFrame, 1) 

83 

84 newFrame = ast.Frame(2, "Ident=current") 

85 frameSet.addFrame(1, ast.UnitMap(2), newFrame) 

86 self.assertEqual(frameSet.nFrame, 2) 

87 

88 # check that getFrame returns a deep copy 

89 baseFrameDeep = frameSet.getFrame(ast.FrameSet.BASE) 

90 self.assertEqual(baseFrameDeep.ident, "base") 

91 self.assertEqual(baseFrameDeep.getRefCount(), 1) 

92 baseFrameDeep.ident = "modifiedBase" 

93 self.assertEqual(frameSet.getFrame(ast.FrameSet.BASE).ident, "base") 

94 self.assertEqual(frame.ident, "base") 

95 

96 def test_FrameSetGetMapping(self): 

97 frame = ast.Frame(2, "Ident=base") 

98 frameSet = ast.FrameSet(frame) 

99 self.assertIsInstance(frameSet, ast.FrameSet) 

100 self.assertEqual(frameSet.nFrame, 1) 

101 

102 newFrame = ast.Frame(2) 

103 mapping = ast.UnitMap(2, "Ident=mapping") 

104 initialNumUnitMap = mapping.getNObject() # may be >1 when run using pytest 

105 frameSet.addFrame(1, mapping, newFrame) 

106 self.assertEqual(frameSet.nFrame, 2) 

107 self.assertEqual(mapping.getNObject(), initialNumUnitMap + 1) 

108 

109 # check that getMapping returns a deep copy 

110 mappingDeep = frameSet.getMapping(1, 2) 

111 self.assertEqual(mappingDeep.ident, "mapping") 

112 mappingDeep.ident = "modifiedMapping" 

113 self.assertEqual(mapping.ident, "mapping") 

114 self.assertEqual(mappingDeep.getRefCount(), 1) 

115 self.assertEqual(mapping.getNObject(), initialNumUnitMap + 2) 

116 

117 def test_FrameSetRemoveFrame(self): 

118 frame = ast.Frame(2, "Ident=base") 

119 initialNumFrames = frame.getNObject() # may be >1 when run using pytest 

120 frameSet = ast.FrameSet(frame) 

121 self.assertIsInstance(frameSet, ast.FrameSet) 

122 self.assertEqual(frameSet.nFrame, 1) 

123 self.assertEqual(frame.getNObject(), initialNumFrames + 1) 

124 

125 newFrame = ast.Frame(2, "Ident=current") 

126 self.assertEqual(frame.getNObject(), initialNumFrames + 2) 

127 zoomMap = ast.ZoomMap(2, 0.5, "Ident=zoom") 

128 initialNumZoomMap = zoomMap.getNObject() 

129 frameSet.addFrame(1, zoomMap, newFrame) 

130 self.assertEqual(frameSet.nFrame, 2) 

131 self.assertEqual(frame.getNObject(), initialNumFrames + 3) 

132 self.assertEqual(zoomMap.getNObject(), initialNumZoomMap + 1) 

133 

134 # remove the frame named "base", leaving the frame named "current" 

135 frameSet.removeFrame(1) 

136 self.assertEqual(frameSet.nFrame, 1) 

137 # Removing one frame leaves frame, newFrame and a copy of newFrame in 

138 # FrameSet. 

139 self.assertEqual(frame.getNObject(), initialNumFrames + 2) 

140 self.assertEqual(zoomMap.getNObject(), initialNumZoomMap) 

141 frameDeep = frameSet.getFrame(1) 

142 self.assertEqual(frameDeep.ident, "current") 

143 

144 # it is not allowed to remove the last frame 

145 with self.assertRaises(RuntimeError): 

146 frameSet.removeFrame(1) 

147 

148 def test_FrameSetRemapFrame(self): 

149 frame = ast.Frame(2, "Ident=base") 

150 initialNumFrames = frame.getNObject() # may be >1 when run using pytest 

151 frameSet = ast.FrameSet(frame) 

152 self.assertIsInstance(frameSet, ast.FrameSet) 

153 self.assertEqual(frameSet.nFrame, 1) 

154 self.assertEqual(frame.getNObject(), initialNumFrames + 1) 

155 

156 newFrame = ast.Frame(2, "Ident=current") 

157 self.assertEqual(frame.getNObject(), initialNumFrames + 2) 

158 zoom = 0.5 

159 zoomMap = ast.ZoomMap(2, zoom, "Ident=zoom") 

160 initialNumZoomMap = zoomMap.getNObject() 

161 frameSet.addFrame(1, zoomMap, newFrame) 

162 self.assertEqual(frameSet.nFrame, 2) 

163 self.assertEqual(frame.getNObject(), initialNumFrames + 3) 

164 self.assertEqual(zoomMap.getNObject(), initialNumZoomMap + 1) 

165 

166 input_data = np.array([ 

167 [0.0, 0.1, -1.5], 

168 [5.1, 0.0, 3.1], 

169 ]) 

170 predicted_output1 = input_data * zoom 

171 assert_allclose(frameSet.applyForward(input_data), predicted_output1) 

172 self.checkMappingPersistence(frameSet, input_data) 

173 

174 shift = (0.5, -1.5) 

175 shiftMap = ast.ShiftMap(shift, "Ident=shift") 

176 initialNumShiftMap = shiftMap.getNObject() 

177 self.assertEqual(zoomMap.getNObject(), initialNumZoomMap + 1) 

178 frameSet.remapFrame(1, shiftMap) 

179 self.assertEqual(zoomMap.getNObject(), initialNumZoomMap + 1) 

180 self.assertEqual(shiftMap.getNObject(), initialNumShiftMap + 1) 

181 predicted_output2 = (input_data.T - shift).T * zoom 

182 assert_allclose(frameSet.applyForward(input_data), predicted_output2) 

183 

184 def test_FrameSetPermutationSkyFrame(self): 

185 """Test permuting FrameSet axes using a SkyFrame 

186 

187 Permuting the axes of the current frame of a frame set 

188 *in situ* (by calling `permAxes` on the frame set itself) 

189 should update the connected mappings. 

190 """ 

191 # test with arbitrary values that will not be wrapped by SkyFrame 

192 x = 0.257 

193 y = 0.832 

194 frame1 = ast.Frame(2) 

195 unitMap = ast.UnitMap(2) 

196 frame2 = ast.SkyFrame() 

197 frameSet = ast.FrameSet(frame1, unitMap, frame2) 

198 self.assertAlmostEqual(frameSet.applyForward([x, y]), [x, y]) 

199 self.assertAlmostEqual(frameSet.applyInverse([x, y]), [x, y]) 

200 

201 # permuting the axes of the current frame also permutes the mapping 

202 frameSet.permAxes([2, 1]) 

203 self.assertAlmostEqual(frameSet.applyForward([x, y]), [y, x]) 

204 self.assertAlmostEqual(frameSet.applyInverse([x, y]), [y, x]) 

205 

206 # permuting again puts things back 

207 frameSet.permAxes([2, 1]) 

208 self.assertAlmostEqual(frameSet.applyForward([x, y]), [x, y]) 

209 self.assertAlmostEqual(frameSet.applyInverse([x, y]), [x, y]) 

210 

211 def test_FrameSetPermutationUnequal(self): 

212 """Test that permuting FrameSet axes with nIn != nOut 

213 

214 Permuting the axes of the current frame of a frame set 

215 *in situ* (by calling `permAxes` on the frame set itself) 

216 should update the connected mappings. 

217 

218 Make nIn != nOut in order to test DM-9899 

219 FrameSet.permAxes would fail if nIn != nOut 

220 """ 

221 # Initial mapping: 3 inputs, 2 outputs: 1-1, 2-2, 3=z 

222 # Test using arbitrary values for x,y,z 

223 x = 75.1 

224 y = -53.2 

225 z = 0.123 

226 frame1 = ast.Frame(3) 

227 permMap = ast.PermMap([1, 2, -1], [1, 2], [z]) 

228 frame2 = ast.Frame(2) 

229 frameSet = ast.FrameSet(frame1, permMap, frame2) 

230 self.assertAlmostEqual(frameSet.applyForward([x, y, z]), [x, y]) 

231 self.assertAlmostEqual(frameSet.applyInverse([x, y]), [x, y, z]) 

232 

233 # permuting the axes of the current frame also permutes the mapping 

234 frameSet.permAxes([2, 1]) 

235 self.assertAlmostEqual(frameSet.applyForward([x, y, z]), [y, x]) 

236 self.assertAlmostEqual(frameSet.applyInverse([x, y]), [y, x, z]) 

237 

238 # permuting again puts things back 

239 frameSet.permAxes([2, 1]) 

240 self.assertAlmostEqual(frameSet.applyForward([x, y, z]), [x, y]) 

241 self.assertAlmostEqual(frameSet.applyInverse([x, y]), [x, y, z]) 

242 

243 

244if __name__ == "__main__": 244 ↛ 245line 244 didn't jump to line 245, because the condition on line 244 was never true

245 unittest.main()