Coverage for python/astshim/test.py: 10%

167 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-02-14 01:55 -0800

1import pickle 

2import unittest 

3 

4import numpy as np 

5from numpy.testing import assert_allclose, assert_array_equal 

6 

7from .channel import Channel 

8from .fitsChan import FitsChan 

9from .polyMap import PolyMap 

10from .xmlChan import XmlChan 

11from .stream import StringStream 

12 

13 

14class ObjectTestCase(unittest.TestCase): 

15 """Base class for unit tests of objects 

16 """ 

17 

18 def assertObjectsIdentical(self, obj1, obj2, checkType=True): 

19 """Assert that two astshim objects are identical. 

20 

21 Identical means the objects are of the same class (if checkType) 

22 and all properties are identical (including whether set or defaulted). 

23 """ 

24 if checkType: 

25 self.assertIs(type(obj1), type(obj2)) 

26 self.assertEqual(obj1.show(), obj2.show()) 

27 self.assertEqual(str(obj1), str(obj2)) 

28 self.assertEqual(repr(obj1), repr(obj2)) 

29 

30 def checkCopy(self, obj): 

31 """Check that an astshim object can be deep-copied 

32 """ 

33 nobj = obj.getNObject() 

34 nref = obj.getRefCount() 

35 

36 def copyIter(obj): 

37 yield obj.copy() 

38 yield type(obj)(obj) 

39 

40 for cp in copyIter(obj): 

41 self.assertObjectsIdentical(obj, cp) 

42 self.assertEqual(obj.getNObject(), nobj + 1) 

43 # Object.copy makes a new pointer instead of copying the old one, 

44 # so the reference count of the old one does not increase 

45 self.assertEqual(obj.getRefCount(), nref) 

46 self.assertFalse(obj.same(cp)) 

47 self.assertEqual(cp.getNObject(), nobj + 1) 

48 self.assertEqual(cp.getRefCount(), 1) 

49 # changing an attribute of the copy does not affect the original 

50 originalIdent = obj.ident 

51 cp.ident = obj.ident + " modified" 

52 self.assertEqual(obj.ident, originalIdent) 

53 

54 del cp 

55 self.assertEqual(obj.getNObject(), nobj) 

56 self.assertEqual(obj.getRefCount(), nref) 

57 

58 def checkPersistence(self, obj, typeFromChannel=None): 

59 """Check that an astshim object can be persisted and unpersisted 

60 

61 @param[in] obj Object to be checked 

62 @param[in] typeFromChannel Type of object expected to be read from 

63 a channel (since some thin wrapper types are read 

64 as the underlying type); None if the original type 

65 

66 Check persistence using Channel, FitsChan (with native encoding, 

67 as the only encoding compatible with all AST objects), XmlChan 

68 and pickle. 

69 """ 

70 for channelType, options in ( 

71 (Channel, ""), 

72 (FitsChan, "Encoding=Native"), 

73 (XmlChan, ""), 

74 ): 

75 ss = StringStream() 

76 chan = channelType(ss, options) 

77 chan.write(obj) 

78 ss.sinkToSource() 

79 if channelType is FitsChan: 

80 chan.clearCard() 

81 obj_copy = chan.read() 

82 if typeFromChannel is not None: 

83 self.assertIs(type(obj_copy), typeFromChannel) 

84 self.assertObjectsIdentical(obj, obj_copy, checkType=False) 

85 else: 

86 self.assertObjectsIdentical(obj, obj_copy) 

87 

88 obj_copy = pickle.loads(pickle.dumps(obj)) 

89 self.assertObjectsIdentical(obj, obj_copy) 

90 

91 

92class MappingTestCase(ObjectTestCase): 

93 

94 """Base class for unit tests of mappings 

95 """ 

96 

97 def checkRoundTrip(self, amap, poslist, rtol=1e-05, atol=1e-08): 

98 """Check that a mapping's reverse transform is the opposite of forward 

99 

100 amap is the mapping to test 

101 poslist is a list of input position for a forward transform; 

102 a numpy array with shape [nin, num points] 

103 or collection that can be cast to same 

104 rtol is the relative tolerance for numpy.testing.assert_allclose 

105 atol is the absolute tolerance for numpy.testing.assert_allclose 

106 """ 

107 poslist = np.array(poslist, dtype=float) 

108 if len(poslist.shape) == 1: 

109 # supplied data was a single list of points 

110 poslist.shape = (1, len(poslist)) 

111 # forward with applyForward, inverse with applyInverse 

112 to_poslist = amap.applyForward(poslist) 

113 rt_poslist = amap.applyInverse(to_poslist) 

114 assert_allclose(poslist, rt_poslist, rtol=rtol, atol=atol) 

115 

116 # forward with applyForward, inverse with inverted().applyForward 

117 amapinv = amap.inverted() 

118 rt2_poslist = amapinv.applyForward(to_poslist) 

119 assert_allclose(poslist, rt2_poslist, rtol=rtol, atol=atol) 

120 

121 # forward and inverse with a compound map of amap.then(amap.inverted()) 

122 acmp = amap.then(amapinv) 

123 assert_allclose(poslist, acmp.applyForward(poslist), rtol=rtol, atol=atol) 

124 

125 # test vector versions of forward and inverse 

126 posvec = list(poslist.flat) 

127 to_posvec = amap.applyForward(posvec) 

128 # cast to_poslist to np.array because if poslist has 1 axis then 

129 # a list is returned, which has no `flat` attribute 

130 assert_allclose(to_posvec, list(to_poslist.flat), rtol=rtol, atol=atol) 

131 

132 rt_posvec = amap.applyInverse(to_posvec) 

133 assert_allclose(posvec, rt_posvec, rtol=rtol, atol=atol) 

134 

135 def checkBasicSimplify(self, amap): 

136 """Check basic simplfication for a reversible mapping 

137 

138 Check the following: 

139 - A compound mapping of a amap and its inverse simplifies to 

140 a unit amap. 

141 - A compound mapping of a amap and a unit amap simplifies to 

142 the original amap. 

143 """ 

144 amapinv = amap.inverted() 

145 cmp1 = amap.then(amapinv) 

146 unit1 = cmp1.simplified() 

147 self.assertEqual(unit1.className, "UnitMap") 

148 self.assertEqual(amap.nIn, cmp1.nIn) 

149 self.assertEqual(amap.nIn, cmp1.nOut) 

150 self.assertEqual(cmp1.nIn, unit1.nIn) 

151 self.assertEqual(cmp1.nOut, unit1.nOut) 

152 

153 cmp2 = amapinv.then(amap) 

154 unit2 = cmp2.simplified() 

155 self.assertEqual(unit2.className, "UnitMap") 

156 self.assertEqual(amapinv.nIn, cmp2.nIn) 

157 self.assertEqual(amapinv.nIn, cmp2.nOut) 

158 self.assertEqual(cmp2.nIn, unit2.nIn) 

159 self.assertEqual(cmp2.nOut, unit2.nOut) 

160 

161 for ma, mb, desmap3 in ( 

162 (unit1, amap, amap), 

163 (amap, unit2, amap), 

164 (unit2, amapinv, amapinv), 

165 (amapinv, unit1, amapinv), 

166 ): 

167 cmp3 = ma.then(mb) 

168 cmp3simp = cmp3.simplified() 

169 self.assertEqual(cmp3simp.className, amap.simplified().className) 

170 self.assertEqual(ma.nIn, cmp3.nIn) 

171 self.assertEqual(mb.nOut, cmp3.nOut) 

172 self.assertEqual(cmp3.nIn, cmp3simp.nIn) 

173 self.assertEqual(cmp3.nOut, cmp3simp.nOut) 

174 

175 def checkMappingPersistence(self, amap, poslist): 

176 """Check that a mapping gives identical answers to unpersisted copy 

177 

178 poslist is a list of input position for a forward transform 

179 (if it exists), or the inverse transform (if not). 

180 A numpy array with shape [nAxes, num points] 

181 or collection that can be cast to same 

182 

183 Checks each direction, if present. However, for generality, 

184 does not check that the two directions are inverses of each other; 

185 call checkRoundTrip for that. 

186 

187 Does everything checkPersistence does, so no need to call both. 

188 """ 

189 for channelType, options in ( 

190 (Channel, ""), 

191 (FitsChan, "Encoding=Native"), 

192 (XmlChan, ""), 

193 ): 

194 ss = StringStream() 

195 chan = Channel(ss) 

196 chan.write(amap) 

197 ss.sinkToSource() 

198 amap_copy = chan.read() 

199 self.assertEqual(amap.className, amap_copy.className) 

200 self.assertEqual(amap.show(), amap_copy.show()) 

201 self.assertEqual(str(amap), str(amap_copy)) 

202 self.assertEqual(repr(amap), repr(amap_copy)) 

203 

204 if amap.hasForward: 

205 outPoslist = amap.applyForward(poslist) 

206 assert_array_equal(outPoslist, amap_copy.applyForward(poslist)) 

207 

208 if amap.hasInverse: 

209 assert_array_equal(amap.applyInverse(outPoslist), 

210 amap_copy.applyInverse(outPoslist)) 

211 

212 elif amap.hasInverse: 

213 assert_array_equal(amap.applyInverse(poslist), 

214 amap_copy.applyInverse(poslist)) 

215 

216 else: 

217 raise RuntimeError("mapping has neither forward nor inverse transform") 

218 

219 def checkMemoryForCompoundObject(self, obj1, obj2, cmpObj, isSeries): 

220 """Check the memory usage for a compoundObject 

221 

222 obj1: first object in compound object 

223 obj2: second object in compound object 

224 cmpObj: compound object (SeriesMap, ParallelMap, CmpMap or CmpFrame) 

225 isSeries: is compound object in series? None to not test 

226 (e.g. CmpFrame) 

227 """ 

228 # if obj1 and obj2 are the same type then copying the compound object 

229 # will increase the NObject of each by 2, otherwise 1 

230 deltaObj = 2 if type(obj1) == type(obj2) else 1 

231 

232 initialNumObj1 = obj1.getNObject() 

233 initialNumObj2 = obj2.getNObject() 

234 initialNumCmpObj = cmpObj.getNObject() 

235 initialRefCountObj1 = obj1.getRefCount() 

236 initialRefCountObj2 = obj2.getRefCount() 

237 initialRefCountCmpObj = cmpObj.getRefCount() 

238 self.assertEqual(obj1.getNObject(), initialNumObj1) 

239 self.assertEqual(obj2.getNObject(), initialNumObj2) 

240 if isSeries is not None: 

241 if isSeries is True: 

242 self.assertTrue(cmpObj.series) 

243 elif isSeries is False: 

244 self.assertFalse(cmpObj.series) 

245 

246 # Making a deep copy should increase the object count of the contained 

247 # objects but should not affect the reference count. 

248 cp = cmpObj.copy() 

249 self.assertEqual(cmpObj.getRefCount(), initialRefCountCmpObj) 

250 self.assertEqual(cmpObj.getNObject(), initialNumCmpObj + 1) 

251 self.assertEqual(obj1.getRefCount(), initialRefCountObj1) 

252 self.assertEqual(obj2.getRefCount(), initialRefCountObj2) 

253 self.assertEqual(obj1.getNObject(), initialNumObj1 + deltaObj) 

254 self.assertEqual(obj2.getNObject(), initialNumObj2 + deltaObj) 

255 

256 # deleting the deep copy should restore ref count and nobject 

257 del cp 

258 self.assertEqual(cmpObj.getRefCount(), initialRefCountCmpObj) 

259 self.assertEqual(cmpObj.getNObject(), initialNumCmpObj) 

260 self.assertEqual(obj1.getRefCount(), initialRefCountObj1) 

261 self.assertEqual(obj1.getNObject(), initialNumObj1) 

262 self.assertEqual(obj2.getRefCount(), initialRefCountObj2) 

263 self.assertEqual(obj2.getNObject(), initialNumObj2) 

264 

265 

266def makePolyMapCoeffs(nIn, nOut): 

267 """Make an array of coefficients for astshim.PolyMap for the following 

268 equation: 

269 

270 fj(x) = C0j x0^2 + C1j x1^2 + C2j x2^2 + ... + CNj xN^2 

271 where: 

272 * i ranges from 0 to N=nIn-1 

273 * j ranges from 0 to nOut-1, 

274 * Cij = 0.001 (i+j+1) 

275 """ 

276 baseCoeff = 0.001 

277 forwardCoeffs = [] 

278 for out_ind in range(nOut): 

279 coeffOffset = baseCoeff * out_ind 

280 for in_ind in range(nIn): 

281 coeff = baseCoeff * (in_ind + 1) + coeffOffset 

282 coeffArr = [coeff, out_ind + 1] + [2 if i == in_ind else 0 for i in range(nIn)] 

283 forwardCoeffs.append(coeffArr) 

284 return np.array(forwardCoeffs, dtype=float) 

285 

286 

287def makeTwoWayPolyMap(nIn, nOut): 

288 """Make an astshim.PolyMap suitable for testing 

289 

290 The forward transform is as follows: 

291 fj(x) = C0j x0^2 + C1j x1^2 + C2j x2^2 + ... 

292 + CNj xN^2 where Cij = 0.001 (i+j+1) 

293 

294 The reverse transform is the same equation with i and j reversed 

295 thus it is NOT the inverse of the forward direction, 

296 but is something that can be easily evaluated. 

297 

298 The equation is chosen for the following reasons: 

299 - It is well defined for any positive value of nIn, nOut. 

300 - It stays small for small x, to avoid wraparound of angles for 

301 SpherePoint endpoints. 

302 """ 

303 forwardCoeffs = makePolyMapCoeffs(nIn, nOut) 

304 reverseCoeffs = makePolyMapCoeffs(nOut, nIn) 

305 polyMap = PolyMap(forwardCoeffs, reverseCoeffs) 

306 assert polyMap.nIn == nIn 

307 assert polyMap.nOut == nOut 

308 assert polyMap.hasForward 

309 assert polyMap.hasInverse 

310 return polyMap 

311 

312 

313def makeForwardPolyMap(nIn, nOut): 

314 """Make an astshim.PolyMap suitable for testing 

315 

316 The forward transform is the same as for `makeTwoWayPolyMap`. 

317 This map does not have a reverse transform. 

318 

319 The equation is chosen for the following reasons: 

320 - It is well defined for any positive value of nIn, nOut. 

321 - It stays small for small x, to avoid wraparound of angles for 

322 SpherePoint endpoints. 

323 """ 

324 forwardCoeffs = makePolyMapCoeffs(nIn, nOut) 

325 polyMap = PolyMap(forwardCoeffs, nOut, "IterInverse=0") 

326 assert polyMap.nIn == nIn 

327 assert polyMap.nOut == nOut 

328 assert polyMap.hasForward 

329 assert not polyMap.hasInverse 

330 return polyMap