Coverage for tests/test_frameDict.py: 8%

269 statements  

« prev     ^ index     » next       coverage.py v6.4.1, created at 2022-07-09 05:46 -0700

1import unittest 

2 

3import numpy as np 

4from numpy.testing import assert_allclose 

5 

6import astshim as ast 

7from astshim.test import MappingTestCase 

8from astshim.detail.testUtils import makeFrameDict 

9 

10 

11class TestFrameDict(MappingTestCase): 

12 

13 def setUp(self): 

14 self.frame1 = ast.Frame(2, "Domain=frame1, Ident=f1") 

15 self.frame2 = ast.Frame(2, "Domain=frame2, Ident=f2") 

16 self.zoom = 1.5 

17 self.zoomMap = ast.ZoomMap(2, self.zoom, "Ident=zoomMap") 

18 self.initialNumFrames = self.frame1.getNObject() # may be >2 when run using pytest 

19 self.initialNumZoomMap = self.zoomMap.getNObject() # may be > 1 when run using pytest 

20 

21 def checkDict(self, frameDict): 

22 for index in range(1, frameDict.nFrame + 1): 

23 domain = frameDict.getFrame(index).domain 

24 self.assertEqual(frameDict.getIndex(domain), index) 

25 self.assertEqual(frameDict.getFrame(domain).domain, domain) 

26 

27 def test_FrameDictOneFrameConstructor(self): 

28 frameDict = ast.FrameDict(self.frame1) 

29 self.assertIsInstance(frameDict, ast.FrameDict) 

30 self.assertEqual(frameDict.nFrame, 1) 

31 self.assertEqual(frameDict.getAllDomains(), {"FRAME1"}) 

32 self.assertEqual(frameDict.getIndex("frame1"), 1) # should be case blind 

33 

34 with self.assertRaises(IndexError): 

35 frameDict.getIndex("missingDomain") 

36 with self.assertRaises(IndexError): 

37 frameDict.getIndex("") 

38 

39 # Make sure the frame is deep copied 

40 self.frame1.domain = "NEWDOMAIN" 

41 self.assertEqual(frameDict.getFrame("FRAME1").domain, "FRAME1") 

42 self.assertEqual(frameDict.getFrame(frameDict.BASE).domain, "FRAME1") 

43 self.assertEqual(self.frame1.getRefCount(), 1) 

44 self.assertEqual(self.frame1.getNObject(), self.initialNumFrames + 1) 

45 

46 # make sure BASE and CURRENT are available on the class and instance 

47 self.assertEqual(ast.FrameDict.BASE, frameDict.BASE) 

48 self.assertEqual(ast.FrameDict.CURRENT, frameDict.CURRENT) 

49 

50 self.checkCopy(frameDict) 

51 

52 indata = np.array([ 

53 [0.0, 0.1, -1.5], 

54 [5.1, 0.0, 3.1], 

55 ]) 

56 self.checkMappingPersistence(frameDict, indata) 

57 self.checkPersistence(frameDict, typeFromChannel=ast.FrameSet) 

58 self.checkDict(frameDict) 

59 

60 def test_FrameDictFrameSetConstructor(self): 

61 frameSet = ast.FrameSet(self.frame1, self.zoomMap, self.frame2) 

62 frameDict = ast.FrameDict(frameSet) 

63 indata = np.array([[1.1, 2.1, 3.1], [1.2, 2.2, 3.2]]) 

64 predictedOut = indata * self.zoom 

65 assert_allclose(frameDict.applyForward(indata), predictedOut) 

66 assert_allclose(frameDict.applyInverse(predictedOut), indata) 

67 

68 frameDict2 = makeFrameDict(frameSet) 

69 self.assertEqual(frameDict2.getRefCount(), 1) 

70 

71 def test_FrameDictAddFrame(self): 

72 frameDict = ast.FrameDict(self.frame1) 

73 self.assertEqual(self.frame1.getNObject(), self.initialNumFrames + 1) 

74 frameDict.addFrame(1, self.zoomMap, self.frame2) 

75 self.assertEqual(frameDict.nFrame, 2) 

76 self.assertEqual(frameDict.getFrame("FRAME2").domain, "FRAME2") 

77 self.assertEqual(frameDict.getFrame(frameDict.CURRENT).domain, "FRAME2") 

78 self.assertEqual(frameDict.getAllDomains(), {"FRAME1", "FRAME2"}) 

79 self.assertEqual(self.frame2.getRefCount(), 1) 

80 self.assertEqual(self.frame1.getNObject(), self.initialNumFrames + 2) 

81 

82 # make sure all objects were deep copied 

83 self.frame1.domain = "newBase" 

84 self.zoomMap.ident = "newMapping" 

85 self.frame2.domain = "newCurrent" 

86 self.assertEqual(frameDict.getFrame(frameDict.BASE).domain, "FRAME1") 

87 self.assertEqual(frameDict.getFrame(frameDict.CURRENT).domain, "FRAME2") 

88 self.assertEqual(frameDict.getMapping().ident, "zoomMap") 

89 self.checkPersistence(frameDict, typeFromChannel=ast.FrameSet) 

90 self.checkDict(frameDict) 

91 

92 # make sure we can't add a frame with a duplicate domain name 

93 # and that attempting to do so leave the FrameDict unchanged 

94 duplicateFrame = ast.Frame(2, "Domain=FRAME1, Ident=duplicate") 

95 with self.assertRaises(ValueError): 

96 frameDict.addFrame(1, self.zoomMap, duplicateFrame) 

97 self.assertEqual(frameDict.getAllDomains(), {"FRAME1", "FRAME2"}) 

98 self.assertEqual(frameDict.getFrame("FRAME1").ident, "f1") 

99 self.checkDict(frameDict) 

100 

101 def test_FrameDictFrameMappingFrameConstructor(self): 

102 frameDict = ast.FrameDict(self.frame1, self.zoomMap, self.frame2) 

103 self.assertEqual(frameDict.nFrame, 2) 

104 self.assertEqual(frameDict.base, 1) 

105 self.assertEqual(frameDict.getIndex("FRAME1"), 1) 

106 self.assertEqual(frameDict.current, 2) 

107 self.assertEqual(frameDict.getIndex("frame2"), 2) 

108 self.assertEqual(frameDict.getAllDomains(), {"FRAME1", "FRAME2"}) 

109 self.assertEqual(self.frame1.getNObject(), self.initialNumFrames + 2) 

110 self.assertEqual(self.zoomMap.getNObject(), self.initialNumZoomMap + 1) 

111 

112 # make sure all objects were deep copied 

113 self.frame1.domain = "newBase" 

114 self.zoomMap.ident = "newMapping" 

115 self.frame2.domain = "newCurrent" 

116 self.assertEqual(frameDict.getFrame(frameDict.BASE).domain, "FRAME1") 

117 self.assertEqual(frameDict.getFrame(frameDict.CURRENT).domain, "FRAME2") 

118 self.assertEqual(frameDict.getMapping().ident, "zoomMap") 

119 self.checkPersistence(frameDict, typeFromChannel=ast.FrameSet) 

120 self.checkDict(frameDict) 

121 

122 def test_FrameDictGetMapping(self): 

123 frameDict = ast.FrameDict(self.frame1, self.zoomMap, self.frame2) 

124 

125 # make sure the zoomMap in frameDict is a deep copy of self.zoomMap 

126 self.zoomMap.ident = "newMappingIdent" 

127 zoomMapList = ( # all should be the same 

128 frameDict.getMapping(frameDict.BASE, frameDict.CURRENT), 

129 frameDict.getMapping("FRAME1", "FRAME2"), 

130 frameDict.getMapping(frameDict.BASE, "frame2"), 

131 frameDict.getMapping("frame1", frameDict.CURRENT), 

132 ) 

133 for zoomMap in zoomMapList: 

134 self.assertEqual(zoomMap.ident, "zoomMap") 

135 self.assertEqual(self.zoomMap.getRefCount(), 1) 

136 

137 # make sure the zoomMapList are retrieved in the right direction 

138 indata = np.array([[1.1, 2.1, 3.1], [1.2, 2.2, 3.2]]) 

139 predictedOut = indata * self.zoom 

140 for zoomMap in zoomMapList: 

141 assert_allclose(zoomMap.applyForward(indata), predictedOut) 

142 

143 # check that getMapping returns a deep copy 

144 for i, zoomMap in enumerate(zoomMapList): 

145 zoomMap.ident = "newIdent%s" % (i,) 

146 self.assertEqual(zoomMap.getRefCount(), 1) 

147 self.assertEqual(frameDict.getMapping().ident, "zoomMap") 

148 # 5 = 1 in frameDict plus 4 retrieved copies in zoomMapList 

149 self.assertEqual(self.zoomMap.getNObject(), self.initialNumZoomMap + 5) 

150 self.checkDict(frameDict) 

151 

152 # try to get invalid frames by name and index; test all combinations 

153 # of the "from" and "to" index being valid or invalid 

154 indexIsValidList = ( 

155 (1, True), 

156 (3, False), 

157 ("Frame1", True), 

158 ("BadFrame", False), 

159 ("", False), 

160 ) 

161 for fromIndex, fromValid in indexIsValidList: 

162 for toIndex, toValid in indexIsValidList: 

163 if fromValid and toValid: 

164 mapping = frameDict.getMapping(fromIndex, toIndex) 

165 self.assertIsInstance(mapping, ast.Mapping) 

166 else: 

167 with self.assertRaises((IndexError, RuntimeError)): 

168 frameDict.getMapping(fromIndex, toIndex) 

169 

170 # make sure the errors did not mess up the FrameDict 

171 self.assertEqual(frameDict.getAllDomains(), {"FRAME1", "FRAME2"}) 

172 self.checkDict(frameDict) 

173 

174 def test_FrameDictRemoveFrame(self): 

175 frameDict = ast.FrameDict(self.frame1, self.zoomMap, self.frame2) 

176 zoomMap2 = ast.ZoomMap(2, 1.3, "Ident=zoomMap2") 

177 frame3 = ast.Frame(2, "Domain=FRAME3, Ident=f3") 

178 frameDict.addFrame(2, zoomMap2, frame3) 

179 self.assertEqual(frameDict.getAllDomains(), {"FRAME1", "FRAME2", "FRAME3"}) 

180 self.assertEqual(frameDict.getIndex("FRAME1"), 1) 

181 self.assertEqual(frameDict.getIndex("FRAME2"), 2) 

182 self.assertEqual(frameDict.getIndex("FRAME3"), 3) 

183 self.assertEqual(self.frame1.getNObject(), self.initialNumFrames + 4) 

184 self.assertEqual(self.zoomMap.getNObject(), self.initialNumZoomMap + 3) 

185 

186 # remove the frame named "FRAME1" by name 

187 # this will also remove one of the two zoom maps 

188 frameDict.removeFrame("FRAME1") 

189 self.checkDict(frameDict) 

190 self.assertEqual(frameDict.getAllDomains(), {"FRAME2", "FRAME3"}) 

191 self.assertEqual(frameDict.nFrame, 2) 

192 self.assertEqual(frameDict.getIndex("FRAME2"), 1) 

193 self.assertEqual(frameDict.getIndex("FRAME3"), 2) 

194 self.assertEqual(frameDict.getFrame("FRAME2").domain, "FRAME2") 

195 self.assertEqual(frameDict.getFrame("FRAME3").domain, "FRAME3") 

196 self.assertEqual(self.frame1.getNObject(), self.initialNumFrames + 3) 

197 self.assertEqual(self.zoomMap.getNObject(), self.initialNumZoomMap + 2) 

198 

199 # remove the frame "FRAME3" by index 

200 # this will also remove the remaining zoom map 

201 frameDict.removeFrame(2) 

202 self.checkDict(frameDict) 

203 self.assertEqual(frameDict.getAllDomains(), {"FRAME2"}) 

204 self.assertEqual(frameDict.nFrame, 1) 

205 self.assertEqual(frameDict.getIndex("FRAME2"), 1) 

206 self.assertEqual(frameDict.getFrame("FRAME2").domain, "FRAME2") 

207 self.assertEqual(self.frame1.getNObject(), self.initialNumFrames + 2) 

208 self.assertEqual(self.zoomMap.getNObject(), self.initialNumZoomMap + 1) 

209 frameDeep = frameDict.getFrame(1) 

210 self.assertEqual(frameDeep.domain, "FRAME2") 

211 

212 # it is not allowed to remove the last frame 

213 with self.assertRaises(RuntimeError): 

214 frameDict.removeFrame(1) 

215 

216 self.checkDict(frameDict) 

217 

218 def test_FrameDictGetFrameAndGetIndex(self): 

219 frameDict = ast.FrameDict(self.frame1, self.zoomMap, self.frame2) 

220 self.assertEqual(frameDict.getIndex("frame1"), 1) 

221 self.assertEqual(frameDict.getFrame(1).domain, "FRAME1") 

222 self.assertEqual(frameDict.getFrame(frameDict.BASE).domain, "FRAME1") 

223 self.assertEqual(frameDict.getFrame("FRAME1").domain, "FRAME1") 

224 

225 self.assertEqual(frameDict.getIndex("frame2"), 2) 

226 self.assertEqual(frameDict.getFrame(2).domain, "FRAME2") 

227 self.assertEqual(frameDict.getFrame(frameDict.CURRENT).domain, "FRAME2") 

228 self.assertEqual(frameDict.getFrame("FRAME2").domain, "FRAME2") 

229 

230 # test on invalid indices 

231 for badDomain in ("badName", ""): 

232 with self.assertRaises(IndexError): 

233 frameDict.getFrame(badDomain) 

234 with self.assertRaises(IndexError): 

235 frameDict.getIndex(badDomain) 

236 with self.assertRaises(RuntimeError): 

237 frameDict.getFrame(3) 

238 

239 # make sure the errors did not mess up the FrameDict 

240 self.assertEqual(frameDict.getAllDomains(), {"FRAME1", "FRAME2"}) 

241 self.checkDict(frameDict) 

242 

243 def test_FrameDictRemapFrame(self): 

244 for useDomainForRemapFrame in (False, True): 

245 frameDict = ast.FrameDict(self.frame1, self.zoomMap, self.frame2) 

246 

247 indata = np.array([ 

248 [0.0, 0.1, -1.5], 

249 [5.1, 0.0, 3.1], 

250 ]) 

251 predictedOut1 = indata * self.zoom 

252 assert_allclose(frameDict.applyForward(indata), predictedOut1) 

253 assert_allclose(frameDict.applyInverse(predictedOut1), indata) 

254 self.checkMappingPersistence(frameDict, indata) 

255 

256 shift = (0.5, -1.5) 

257 shiftMap = ast.ShiftMap(shift, "Ident=shift") 

258 initialNumShiftMap = shiftMap.getNObject() 

259 self.assertEqual(self.zoomMap.getNObject(), self.initialNumZoomMap + 1) 

260 if useDomainForRemapFrame: 

261 frameDict.remapFrame("FRAME1", shiftMap) 

262 else: 

263 frameDict.remapFrame(1, shiftMap) 

264 self.assertEqual(self.zoomMap.getNObject(), self.initialNumZoomMap + 1) 

265 self.assertEqual(shiftMap.getNObject(), initialNumShiftMap + 1) 

266 predictedOut2 = (indata.T - shift).T * self.zoom 

267 assert_allclose(frameDict.applyForward(indata), predictedOut2) 

268 assert_allclose(frameDict.applyInverse(predictedOut2), indata) 

269 

270 def test_FrameDictPermutationSkyFrame(self): 

271 """Test permuting FrameDict axes using a SkyFrame 

272 

273 Permuting the axes of the current frame of a frame set 

274 *in situ* (by calling `permAxes` on the frame set itself) 

275 should update the connected mappings. 

276 """ 

277 # test with arbitrary values that will not be wrapped by SkyFrame 

278 x = 0.257 

279 y = 0.832 

280 frame1 = ast.Frame(2) 

281 unitMap = ast.UnitMap(2) 

282 frame2 = ast.SkyFrame() 

283 frameDict = ast.FrameDict(frame1, unitMap, frame2) 

284 self.assertAlmostEqual(frameDict.applyForward([x, y]), [x, y]) 

285 self.assertAlmostEqual(frameDict.applyInverse([x, y]), [x, y]) 

286 

287 # permuting the axes of the current frame also permutes the mapping 

288 frameDict.permAxes([2, 1]) 

289 self.assertAlmostEqual(frameDict.applyForward([x, y]), [y, x]) 

290 self.assertAlmostEqual(frameDict.applyInverse([x, y]), [y, x]) 

291 

292 # permuting again puts things back 

293 frameDict.permAxes([2, 1]) 

294 self.assertAlmostEqual(frameDict.applyForward([x, y]), [x, y]) 

295 self.assertAlmostEqual(frameDict.applyInverse([x, y]), [x, y]) 

296 

297 def test_FrameDictPermutationUnequal(self): 

298 """Test permuting FrameDict axes with nIn != nOut 

299 

300 Permuting the axes of the current frame of a frame set 

301 *in situ* (by calling `permAxes` on the frame set itself) 

302 should update the connected mappings. 

303 

304 Make nIn != nOut in order to test DM-9899 

305 FrameDict.permAxes would fail if nIn != nOut 

306 """ 

307 # Initial mapping: 3 inputs, 2 outputs: 1-1, 2-2, 3=z 

308 # Test using arbitrary values for x,y,z 

309 x = 75.1 

310 y = -53.2 

311 z = 0.123 

312 frame1 = ast.Frame(3) 

313 permMap = ast.PermMap([1, 2, -1], [1, 2], [z]) 

314 frame2 = ast.Frame(2) 

315 frameDict = ast.FrameDict(frame1, permMap, frame2) 

316 self.assertAlmostEqual(frameDict.applyForward([x, y, z]), [x, y]) 

317 self.assertAlmostEqual(frameDict.applyInverse([x, y]), [x, y, z]) 

318 

319 # permuting the axes of the current frame also permutes the mapping 

320 frameDict.permAxes([2, 1]) 

321 self.assertAlmostEqual(frameDict.applyForward([x, y, z]), [y, x]) 

322 self.assertAlmostEqual(frameDict.applyInverse([x, y]), [y, x, z]) 

323 

324 # permuting again puts things back 

325 frameDict.permAxes([2, 1]) 

326 self.assertAlmostEqual(frameDict.applyForward([x, y, z]), [x, y]) 

327 self.assertAlmostEqual(frameDict.applyInverse([x, y]), [x, y, z]) 

328 

329 def test_FrameDictSetBaseCurrent(self): 

330 frameDict = ast.FrameDict(self.frame1, self.zoomMap, self.frame2) 

331 self.assertEqual(frameDict.base, 1) 

332 self.assertEqual(frameDict.current, 2) 

333 self.assertEqual(frameDict.getIndex("frame1"), 1) 

334 self.assertEqual(frameDict.getIndex("frame2"), 2) 

335 

336 indata = np.array([ 

337 [0.0, 0.1, -1.5], 

338 [5.1, 0.0, 3.1], 

339 ]) 

340 predictedOut1 = indata.copy() * self.zoom 

341 assert_allclose(frameDict.applyForward(indata), predictedOut1) 

342 

343 frameDict.setCurrent("FRAME1") 

344 self.assertEqual(frameDict.base, 1) 

345 self.assertEqual(frameDict.current, 1) 

346 self.assertEqual(frameDict.getIndex("FRAME1"), 1) 

347 self.assertEqual(frameDict.getIndex("FRAME2"), 2) 

348 

349 predictedOutput2 = indata.copy() 

350 assert_allclose(frameDict.applyForward(indata), predictedOutput2) 

351 

352 frameDict.setBase("FRAME2") 

353 self.assertEqual(frameDict.base, 2) 

354 self.assertEqual(frameDict.current, 1) 

355 self.assertEqual(frameDict.getIndex("FRAME1"), 1) 

356 self.assertEqual(frameDict.getIndex("FRAME2"), 2) 

357 

358 predictedOutput3 = indata.copy() / self.zoom 

359 assert_allclose(frameDict.applyForward(indata), predictedOutput3) 

360 

361 def test_FrameDictSetDomain(self): 

362 frameDict = ast.FrameDict(self.frame1, self.zoomMap, self.frame2) 

363 frameDict.setCurrent("FRAME1") 

364 frameDict.setDomain("NEWFRAME1") 

365 self.assertEqual(frameDict.getAllDomains(), {"NEWFRAME1", "FRAME2"}) 

366 self.assertEqual(frameDict.getIndex("newFrame1"), 1) 

367 self.assertEqual(frameDict.getIndex("FRAME2"), 2) 

368 

369 frameDict.setCurrent("FRAME2") 

370 frameDict.setDomain("NEWFRAME2") 

371 self.assertEqual(frameDict.getAllDomains(), {"NEWFRAME1", "NEWFRAME2"}) 

372 self.assertEqual(frameDict.getIndex("NEWFRAME1"), 1) 

373 self.assertEqual(frameDict.getIndex("NEWFRAME2"), 2) 

374 

375 # Renaming a domain to itself should have no effect 

376 self.assertEqual(frameDict.getFrame(frameDict.CURRENT).domain, "NEWFRAME2") 

377 frameDict.setDomain("NEWFRAME2") 

378 self.assertEqual(frameDict.getFrame(frameDict.CURRENT).domain, "NEWFRAME2") 

379 self.assertEqual(frameDict.getAllDomains(), {"NEWFRAME1", "NEWFRAME2"}) 

380 

381 # Make sure setDomain cannot be used to rename a domain to a duplicate 

382 # and that this leaves the frameDict unchanged 

383 self.assertEqual(frameDict.getFrame(frameDict.CURRENT).domain, "NEWFRAME2") 

384 with self.assertRaises(ValueError): 

385 frameDict.setDomain("NEWFRAME1") 

386 self.assertEqual(frameDict.getFrame(frameDict.CURRENT).domain, "NEWFRAME2") 

387 self.assertEqual(frameDict.getAllDomains(), {"NEWFRAME1", "NEWFRAME2"}) 

388 

389 

390if __name__ == "__main__": 390 ↛ 391line 390 didn't jump to line 391, because the condition on line 390 was never true

391 unittest.main()