Coverage for python/astshim/test.py: 9%

163 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-21 03:01 -0700

1import pickle 

2import unittest 

3 

4import numpy as np 

5from numpy.testing import assert_allclose, assert_array_equal 

6 

7from ._astshimLib import Channel, FitsChan, PolyMap, XmlChan, StringStream 

8 

9 

10class ObjectTestCase(unittest.TestCase): 

11 """Base class for unit tests of objects 

12 """ 

13 

14 def assertObjectsIdentical(self, obj1, obj2, checkType=True): 

15 """Assert that two astshim objects are identical. 

16 

17 Identical means the objects are of the same class (if checkType) 

18 and all properties are identical (including whether set or defaulted). 

19 """ 

20 if checkType: 

21 self.assertIs(type(obj1), type(obj2)) 

22 self.assertEqual(obj1.show(), obj2.show()) 

23 self.assertEqual(str(obj1), str(obj2)) 

24 self.assertEqual(repr(obj1), repr(obj2)) 

25 

26 def checkCopy(self, obj): 

27 """Check that an astshim object can be deep-copied 

28 """ 

29 nobj = obj.getNObject() 

30 nref = obj.getRefCount() 

31 

32 def copyIter(obj): 

33 yield obj.copy() 

34 yield type(obj)(obj) 

35 

36 for cp in copyIter(obj): 

37 self.assertObjectsIdentical(obj, cp) 

38 self.assertEqual(obj.getNObject(), nobj + 1) 

39 # Object.copy makes a new pointer instead of copying the old one, 

40 # so the reference count of the old one does not increase 

41 self.assertEqual(obj.getRefCount(), nref) 

42 self.assertFalse(obj.same(cp)) 

43 self.assertEqual(cp.getNObject(), nobj + 1) 

44 self.assertEqual(cp.getRefCount(), 1) 

45 # changing an attribute of the copy does not affect the original 

46 originalIdent = obj.ident 

47 cp.ident = obj.ident + " modified" 

48 self.assertEqual(obj.ident, originalIdent) 

49 

50 del cp 

51 self.assertEqual(obj.getNObject(), nobj) 

52 self.assertEqual(obj.getRefCount(), nref) 

53 

54 def checkPersistence(self, obj, typeFromChannel=None): 

55 """Check that an astshim object can be persisted and unpersisted 

56 

57 @param[in] obj Object to be checked 

58 @param[in] typeFromChannel Type of object expected to be read from 

59 a channel (since some thin wrapper types are read 

60 as the underlying type); None if the original type 

61 

62 Check persistence using Channel, FitsChan (with native encoding, 

63 as the only encoding compatible with all AST objects), XmlChan 

64 and pickle. 

65 """ 

66 for channelType, options in ( 

67 (Channel, ""), 

68 (FitsChan, "Encoding=Native"), 

69 (XmlChan, ""), 

70 ): 

71 ss = StringStream() 

72 chan = channelType(ss, options) 

73 chan.write(obj) 

74 ss.sinkToSource() 

75 if channelType is FitsChan: 

76 chan.clearCard() 

77 obj_copy = chan.read() 

78 if typeFromChannel is not None: 

79 self.assertIs(type(obj_copy), typeFromChannel) 

80 self.assertObjectsIdentical(obj, obj_copy, checkType=False) 

81 else: 

82 self.assertObjectsIdentical(obj, obj_copy) 

83 

84 obj_copy = pickle.loads(pickle.dumps(obj)) 

85 self.assertObjectsIdentical(obj, obj_copy) 

86 

87 

88class MappingTestCase(ObjectTestCase): 

89 

90 """Base class for unit tests of mappings 

91 """ 

92 

93 def checkRoundTrip(self, amap, poslist, rtol=1e-05, atol=1e-08): 

94 """Check that a mapping's reverse transform is the opposite of forward 

95 

96 amap is the mapping to test 

97 poslist is a list of input position for a forward transform; 

98 a numpy array with shape [nin, num points] 

99 or collection that can be cast to same 

100 rtol is the relative tolerance for numpy.testing.assert_allclose 

101 atol is the absolute tolerance for numpy.testing.assert_allclose 

102 """ 

103 poslist = np.array(poslist, dtype=float) 

104 if len(poslist.shape) == 1: 

105 # supplied data was a single list of points 

106 poslist.shape = (1, len(poslist)) 

107 # forward with applyForward, inverse with applyInverse 

108 to_poslist = amap.applyForward(poslist) 

109 rt_poslist = amap.applyInverse(to_poslist) 

110 assert_allclose(poslist, rt_poslist, rtol=rtol, atol=atol) 

111 

112 # forward with applyForward, inverse with inverted().applyForward 

113 amapinv = amap.inverted() 

114 rt2_poslist = amapinv.applyForward(to_poslist) 

115 assert_allclose(poslist, rt2_poslist, rtol=rtol, atol=atol) 

116 

117 # forward and inverse with a compound map of amap.then(amap.inverted()) 

118 acmp = amap.then(amapinv) 

119 assert_allclose(poslist, acmp.applyForward(poslist), rtol=rtol, atol=atol) 

120 

121 # test vector versions of forward and inverse 

122 posvec = list(poslist.flat) 

123 to_posvec = amap.applyForward(posvec) 

124 # cast to_poslist to np.array because if poslist has 1 axis then 

125 # a list is returned, which has no `flat` attribute 

126 assert_allclose(to_posvec, list(to_poslist.flat), rtol=rtol, atol=atol) 

127 

128 rt_posvec = amap.applyInverse(to_posvec) 

129 assert_allclose(posvec, rt_posvec, rtol=rtol, atol=atol) 

130 

131 def checkBasicSimplify(self, amap): 

132 """Check basic simplfication for a reversible mapping 

133 

134 Check the following: 

135 - A compound mapping of a amap and its inverse simplifies to 

136 a unit amap. 

137 - A compound mapping of a amap and a unit amap simplifies to 

138 the original amap. 

139 """ 

140 amapinv = amap.inverted() 

141 cmp1 = amap.then(amapinv) 

142 unit1 = cmp1.simplified() 

143 self.assertEqual(unit1.className, "UnitMap") 

144 self.assertEqual(amap.nIn, cmp1.nIn) 

145 self.assertEqual(amap.nIn, cmp1.nOut) 

146 self.assertEqual(cmp1.nIn, unit1.nIn) 

147 self.assertEqual(cmp1.nOut, unit1.nOut) 

148 

149 cmp2 = amapinv.then(amap) 

150 unit2 = cmp2.simplified() 

151 self.assertEqual(unit2.className, "UnitMap") 

152 self.assertEqual(amapinv.nIn, cmp2.nIn) 

153 self.assertEqual(amapinv.nIn, cmp2.nOut) 

154 self.assertEqual(cmp2.nIn, unit2.nIn) 

155 self.assertEqual(cmp2.nOut, unit2.nOut) 

156 

157 for ma, mb, desmap3 in ( 

158 (unit1, amap, amap), 

159 (amap, unit2, amap), 

160 (unit2, amapinv, amapinv), 

161 (amapinv, unit1, amapinv), 

162 ): 

163 cmp3 = ma.then(mb) 

164 cmp3simp = cmp3.simplified() 

165 self.assertEqual(cmp3simp.className, amap.simplified().className) 

166 self.assertEqual(ma.nIn, cmp3.nIn) 

167 self.assertEqual(mb.nOut, cmp3.nOut) 

168 self.assertEqual(cmp3.nIn, cmp3simp.nIn) 

169 self.assertEqual(cmp3.nOut, cmp3simp.nOut) 

170 

171 def checkMappingPersistence(self, amap, poslist): 

172 """Check that a mapping gives identical answers to unpersisted copy 

173 

174 poslist is a list of input position for a forward transform 

175 (if it exists), or the inverse transform (if not). 

176 A numpy array with shape [nAxes, num points] 

177 or collection that can be cast to same 

178 

179 Checks each direction, if present. However, for generality, 

180 does not check that the two directions are inverses of each other; 

181 call checkRoundTrip for that. 

182 

183 Does everything checkPersistence does, so no need to call both. 

184 """ 

185 for channelType, options in ( 

186 (Channel, ""), 

187 (FitsChan, "Encoding=Native"), 

188 (XmlChan, ""), 

189 ): 

190 ss = StringStream() 

191 chan = Channel(ss) 

192 chan.write(amap) 

193 ss.sinkToSource() 

194 amap_copy = chan.read() 

195 self.assertEqual(amap.className, amap_copy.className) 

196 self.assertEqual(amap.show(), amap_copy.show()) 

197 self.assertEqual(str(amap), str(amap_copy)) 

198 self.assertEqual(repr(amap), repr(amap_copy)) 

199 

200 if amap.hasForward: 

201 outPoslist = amap.applyForward(poslist) 

202 assert_array_equal(outPoslist, amap_copy.applyForward(poslist)) 

203 

204 if amap.hasInverse: 

205 assert_array_equal(amap.applyInverse(outPoslist), 

206 amap_copy.applyInverse(outPoslist)) 

207 

208 elif amap.hasInverse: 

209 assert_array_equal(amap.applyInverse(poslist), 

210 amap_copy.applyInverse(poslist)) 

211 

212 else: 

213 raise RuntimeError("mapping has neither forward nor inverse transform") 

214 

215 def checkMemoryForCompoundObject(self, obj1, obj2, cmpObj, isSeries): 

216 """Check the memory usage for a compoundObject 

217 

218 obj1: first object in compound object 

219 obj2: second object in compound object 

220 cmpObj: compound object (SeriesMap, ParallelMap, CmpMap or CmpFrame) 

221 isSeries: is compound object in series? None to not test 

222 (e.g. CmpFrame) 

223 """ 

224 # if obj1 and obj2 are the same type then copying the compound object 

225 # will increase the NObject of each by 2, otherwise 1 

226 deltaObj = 2 if type(obj1) == type(obj2) else 1 

227 

228 initialNumObj1 = obj1.getNObject() 

229 initialNumObj2 = obj2.getNObject() 

230 initialNumCmpObj = cmpObj.getNObject() 

231 initialRefCountObj1 = obj1.getRefCount() 

232 initialRefCountObj2 = obj2.getRefCount() 

233 initialRefCountCmpObj = cmpObj.getRefCount() 

234 self.assertEqual(obj1.getNObject(), initialNumObj1) 

235 self.assertEqual(obj2.getNObject(), initialNumObj2) 

236 if isSeries is not None: 

237 if isSeries is True: 

238 self.assertTrue(cmpObj.series) 

239 elif isSeries is False: 

240 self.assertFalse(cmpObj.series) 

241 

242 # Making a deep copy should increase the object count of the contained 

243 # objects but should not affect the reference count. 

244 cp = cmpObj.copy() 

245 self.assertEqual(cmpObj.getRefCount(), initialRefCountCmpObj) 

246 self.assertEqual(cmpObj.getNObject(), initialNumCmpObj + 1) 

247 self.assertEqual(obj1.getRefCount(), initialRefCountObj1) 

248 self.assertEqual(obj2.getRefCount(), initialRefCountObj2) 

249 self.assertEqual(obj1.getNObject(), initialNumObj1 + deltaObj) 

250 self.assertEqual(obj2.getNObject(), initialNumObj2 + deltaObj) 

251 

252 # deleting the deep copy should restore ref count and nobject 

253 del cp 

254 self.assertEqual(cmpObj.getRefCount(), initialRefCountCmpObj) 

255 self.assertEqual(cmpObj.getNObject(), initialNumCmpObj) 

256 self.assertEqual(obj1.getRefCount(), initialRefCountObj1) 

257 self.assertEqual(obj1.getNObject(), initialNumObj1) 

258 self.assertEqual(obj2.getRefCount(), initialRefCountObj2) 

259 self.assertEqual(obj2.getNObject(), initialNumObj2) 

260 

261 

262def makePolyMapCoeffs(nIn, nOut): 

263 """Make an array of coefficients for astshim.PolyMap for the following 

264 equation: 

265 

266 fj(x) = C0j x0^2 + C1j x1^2 + C2j x2^2 + ... + CNj xN^2 

267 where: 

268 * i ranges from 0 to N=nIn-1 

269 * j ranges from 0 to nOut-1, 

270 * Cij = 0.001 (i+j+1) 

271 """ 

272 baseCoeff = 0.001 

273 forwardCoeffs = [] 

274 for out_ind in range(nOut): 

275 coeffOffset = baseCoeff * out_ind 

276 for in_ind in range(nIn): 

277 coeff = baseCoeff * (in_ind + 1) + coeffOffset 

278 coeffArr = [coeff, out_ind + 1] + [2 if i == in_ind else 0 for i in range(nIn)] 

279 forwardCoeffs.append(coeffArr) 

280 return np.array(forwardCoeffs, dtype=float) 

281 

282 

283def makeTwoWayPolyMap(nIn, nOut): 

284 """Make an astshim.PolyMap suitable for testing 

285 

286 The forward transform is as follows: 

287 fj(x) = C0j x0^2 + C1j x1^2 + C2j x2^2 + ... 

288 + CNj xN^2 where Cij = 0.001 (i+j+1) 

289 

290 The reverse transform is the same equation with i and j reversed 

291 thus it is NOT the inverse of the forward direction, 

292 but is something that can be easily evaluated. 

293 

294 The equation is chosen for the following reasons: 

295 - It is well defined for any positive value of nIn, nOut. 

296 - It stays small for small x, to avoid wraparound of angles for 

297 SpherePoint endpoints. 

298 """ 

299 forwardCoeffs = makePolyMapCoeffs(nIn, nOut) 

300 reverseCoeffs = makePolyMapCoeffs(nOut, nIn) 

301 polyMap = PolyMap(forwardCoeffs, reverseCoeffs) 

302 assert polyMap.nIn == nIn 

303 assert polyMap.nOut == nOut 

304 assert polyMap.hasForward 

305 assert polyMap.hasInverse 

306 return polyMap 

307 

308 

309def makeForwardPolyMap(nIn, nOut): 

310 """Make an astshim.PolyMap suitable for testing 

311 

312 The forward transform is the same as for `makeTwoWayPolyMap`. 

313 This map does not have a reverse transform. 

314 

315 The equation is chosen for the following reasons: 

316 - It is well defined for any positive value of nIn, nOut. 

317 - It stays small for small x, to avoid wraparound of angles for 

318 SpherePoint endpoints. 

319 """ 

320 forwardCoeffs = makePolyMapCoeffs(nIn, nOut) 

321 polyMap = PolyMap(forwardCoeffs, nOut, "IterInverse=0") 

322 assert polyMap.nIn == nIn 

323 assert polyMap.nOut == nOut 

324 assert polyMap.hasForward 

325 assert not polyMap.hasInverse 

326 return polyMap