Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

218

219

220

221

222

223

224

225

226

227

228

229

230

231

232

233

234

235

236

237

238

239

240

241

242

243

244

245

246

247

248

249

250

251

252

253

254

255

256

257

258

259

260

261

262

263

264

265

266

267

268

269

270

271

272

273

274

275

276

277

278

279

280

281

282

283

284

285

286

287

288

289

290

291

292

293

294

295

296

297

298

299

300

301

302

303

304

305

306

307

308

309

310

311

312

313

314

315

316

317

318

319

320

321

322

323

import pickle 

import unittest 

 

import numpy as np 

from numpy.testing import assert_allclose, assert_array_equal 

 

from .channel import Channel 

from .fitsChan import FitsChan 

from .polyMap import PolyMap 

from .xmlChan import XmlChan 

from .stream import StringStream 

 

 

class ObjectTestCase(unittest.TestCase): 

"""Base class for unit tests of objects 

""" 

 

def assertObjectsIdentical(self, obj1, obj2, checkType=True): 

"""Assert that two astshim objects are identical. 

 

Identical means the objects are of the same class (if checkType) 

and all properties are identical (including whether set or defaulted). 

""" 

if checkType: 

self.assertIs(type(obj1), type(obj2)) 

self.assertEqual(obj1.show(), obj2.show()) 

self.assertEqual(str(obj1), str(obj2)) 

self.assertEqual(repr(obj1), repr(obj2)) 

 

def checkCopy(self, obj): 

"""Check that an astshim object can be deep-copied 

""" 

nobj = obj.getNObject() 

nref = obj.getRefCount() 

 

def copyIter(obj): 

yield obj.copy() 

yield type(obj)(obj) 

 

for cp in copyIter(obj): 

self.assertObjectsIdentical(obj, cp) 

self.assertEqual(obj.getNObject(), nobj + 1) 

# Object.copy makes a new pointer instead of copying the old one, 

# so the reference count of the old one does not increase 

self.assertEqual(obj.getRefCount(), nref) 

self.assertFalse(obj.same(cp)) 

self.assertEqual(cp.getNObject(), nobj + 1) 

self.assertEqual(cp.getRefCount(), 1) 

# changing an attribute of the copy does not affect the original 

originalIdent = obj.ident 

cp.ident = obj.ident + " modified" 

self.assertEqual(obj.ident, originalIdent) 

 

del cp 

self.assertEqual(obj.getNObject(), nobj) 

self.assertEqual(obj.getRefCount(), nref) 

 

def checkPersistence(self, obj, typeFromChannel=None): 

"""Check that an astshim object can be persisted and unpersisted 

 

@param[in] obj Object to be checked 

@param[in] typeFromChannel Type of object expected to be read from 

a channel (since some thin wrapper types are read 

as the underlying type); None if the original type 

 

Check persistence using Channel, FitsChan (with native encoding, 

as the only encoding compatible with all AST objects), XmlChan 

and pickle. 

""" 

for channelType, options in ( 

(Channel, ""), 

(FitsChan, "Encoding=Native"), 

(XmlChan, ""), 

): 

ss = StringStream() 

chan = channelType(ss, options) 

chan.write(obj) 

ss.sinkToSource() 

if channelType is FitsChan: 

chan.clearCard() 

obj_copy = chan.read() 

if typeFromChannel is not None: 

self.assertIs(type(obj_copy), typeFromChannel) 

self.assertObjectsIdentical(obj, obj_copy, checkType=False) 

else: 

self.assertObjectsIdentical(obj, obj_copy) 

 

obj_copy = pickle.loads(pickle.dumps(obj)) 

self.assertObjectsIdentical(obj, obj_copy) 

 

 

class MappingTestCase(ObjectTestCase): 

 

"""Base class for unit tests of mappings 

""" 

 

def checkRoundTrip(self, amap, poslist, rtol=1e-05, atol=1e-08): 

"""Check that a mapping's reverse transform is the opposite of forward 

 

amap is the mapping to test 

poslist is a list of input position for a forward transform; 

a numpy array with shape [nin, num points] 

or collection that can be cast to same 

rtol is the relative tolerance for numpy.testing.assert_allclose 

atol is the absolute tolerance for numpy.testing.assert_allclose 

""" 

poslist = np.array(poslist, dtype=float) 

if len(poslist.shape) == 1: 

# supplied data was a single list of points 

poslist.shape = (1, len(poslist)) 

# forward with applyForward, inverse with applyInverse 

to_poslist = amap.applyForward(poslist) 

rt_poslist = amap.applyInverse(to_poslist) 

assert_allclose(poslist, rt_poslist, rtol=rtol, atol=atol) 

 

# forward with applyForward, inverse with getInverse().applyForward 

amapinv = amap.getInverse() 

rt2_poslist = amapinv.applyForward(to_poslist) 

assert_allclose(poslist, rt2_poslist, rtol=rtol, atol=atol) 

 

# forward and inverse with a compound map of amap.then(amap.getInverse()) 

acmp = amap.then(amapinv) 

assert_allclose(poslist, acmp.applyForward(poslist), rtol=rtol, atol=atol) 

 

# test vector versions of forward and inverse 

posvec = list(poslist.flat) 

to_posvec = amap.applyForward(posvec) 

# cast to_poslist to np.array because if poslist has 1 axis then 

# a list is returned, which has no `flat` attribute 

assert_allclose(to_posvec, list(to_poslist.flat), rtol=rtol, atol=atol) 

 

rt_posvec = amap.applyInverse(to_posvec) 

assert_allclose(posvec, rt_posvec, rtol=rtol, atol=atol) 

 

def checkBasicSimplify(self, amap): 

"""Check basic simplfication for a reversible mapping 

 

Check the following: 

- A compound mapping of a amap and its inverse simplifies to a unit amap 

- A compound mapping of a amap and a unit amap simplifies to the original amap 

""" 

amapinv = amap.getInverse() 

cmp1 = amap.then(amapinv) 

unit1 = cmp1.simplify() 

self.assertEqual(unit1.className, "UnitMap") 

self.assertEqual(amap.nIn, cmp1.nIn) 

self.assertEqual(amap.nIn, cmp1.nOut) 

self.assertEqual(cmp1.nIn, unit1.nIn) 

self.assertEqual(cmp1.nOut, unit1.nOut) 

 

cmp2 = amapinv.then(amap) 

unit2 = cmp2.simplify() 

self.assertEqual(unit2.className, "UnitMap") 

self.assertEqual(amapinv.nIn, cmp2.nIn) 

self.assertEqual(amapinv.nIn, cmp2.nOut) 

self.assertEqual(cmp2.nIn, unit2.nIn) 

self.assertEqual(cmp2.nOut, unit2.nOut) 

 

for ma, mb, desmap3 in ( 

(unit1, amap, amap), 

(amap, unit2, amap), 

(unit2, amapinv, amapinv), 

(amapinv, unit1, amapinv), 

): 

cmp3 = ma.then(mb) 

cmp3simp = cmp3.simplify() 

self.assertEqual(cmp3simp.className, amap.simplify().className) 

self.assertEqual(ma.nIn, cmp3.nIn) 

self.assertEqual(mb.nOut, cmp3.nOut) 

self.assertEqual(cmp3.nIn, cmp3simp.nIn) 

self.assertEqual(cmp3.nOut, cmp3simp.nOut) 

 

def checkMappingPersistence(self, amap, poslist): 

"""Check that a mapping gives identical answers to unpersisted copy 

 

poslist is a list of input position for a forward transform 

(if it exists), or the inverse transform (if not). 

A numpy array with shape [nAxes, num points] 

or collection that can be cast to same 

 

Checks each direction, if present. However, for generality, 

does not check that the two directions are inverses of each other; 

call checkRoundTrip for that. 

 

Does everything checkPersistence does, so no need to call both. 

""" 

for channelType, options in ( 

(Channel, ""), 

(FitsChan, "Encoding=Native"), 

(XmlChan, ""), 

): 

ss = StringStream() 

chan = Channel(ss) 

chan.write(amap) 

ss.sinkToSource() 

amap_copy = chan.read() 

self.assertEqual(amap.className, amap_copy.className) 

self.assertEqual(amap.show(), amap_copy.show()) 

self.assertEqual(str(amap), str(amap_copy)) 

self.assertEqual(repr(amap), repr(amap_copy)) 

 

202 ↛ 210line 202 didn't jump to line 210, because the condition on line 202 was never false if amap.hasForward: 

outPoslist = amap.applyForward(poslist) 

assert_array_equal(outPoslist, amap_copy.applyForward(poslist)) 

 

if amap.hasInverse: 

assert_array_equal(amap.applyInverse(outPoslist), 

amap_copy.applyInverse(outPoslist)) 

 

elif amap.hasInverse: 

assert_array_equal(amap.applyInverse(poslist), 

amap_copy.applyInverse(poslist)) 

 

else: 

raise RuntimeError("mapping has neither forward nor inverse transform") 

 

def checkMemoryForCompoundObject(self, obj1, obj2, cmpObj, isSeries): 

"""Check the memory usage for a compoundObject 

 

obj1: first object in compound object 

obj2: second object in compound object 

cmpObj: compound object (SeriesMap, ParallelMap, CmpMap or CmpFrame) 

isSeries: is compound object in series? None to not test (e.g. CmpFrame) 

""" 

# if obj1 and obj2 are the same type then copying the compound object 

# will increase the NObject of each by 2, otherwise 1 

deltaObj = 2 if type(obj1) == type(obj2) else 1 

 

initialNumObj1 = obj1.getNObject() 

initialNumObj2 = obj2.getNObject() 

initialNumCmpObj = cmpObj.getNObject() 

initialRefCountObj1 = obj1.getRefCount() 

initialRefCountObj2 = obj2.getRefCount() 

initialRefCountCmpObj = cmpObj.getRefCount() 

self.assertEqual(obj1.getNObject(), initialNumObj1) 

self.assertEqual(obj2.getNObject(), initialNumObj2) 

if isSeries is not None: 

if isSeries is True: 

self.assertTrue(cmpObj.series) 

240 ↛ 245line 240 didn't jump to line 245, because the condition on line 240 was never false elif isSeries is False: 

self.assertFalse(cmpObj.series) 

 

# making a deep copy should increase the object count of the contained objects 

# but should not affect the reference count 

cp = cmpObj.copy() 

self.assertEqual(cmpObj.getRefCount(), initialRefCountCmpObj) 

self.assertEqual(cmpObj.getNObject(), initialNumCmpObj + 1) 

self.assertEqual(obj1.getRefCount(), initialRefCountObj1) 

self.assertEqual(obj2.getRefCount(), initialRefCountObj2) 

self.assertEqual(obj1.getNObject(), initialNumObj1 + deltaObj) 

self.assertEqual(obj2.getNObject(), initialNumObj2 + deltaObj) 

 

# deleting the deep copy should restore ref count and nobject 

del cp 

self.assertEqual(cmpObj.getRefCount(), initialRefCountCmpObj) 

self.assertEqual(cmpObj.getNObject(), initialNumCmpObj) 

self.assertEqual(obj1.getRefCount(), initialRefCountObj1) 

self.assertEqual(obj1.getNObject(), initialNumObj1) 

self.assertEqual(obj2.getRefCount(), initialRefCountObj2) 

self.assertEqual(obj2.getNObject(), initialNumObj2) 

 

 

def makePolyMapCoeffs(nIn, nOut): 

"""Make an array of coefficients for astshim.PolyMap for the following equation: 

 

fj(x) = C0j x0^2 + C1j x1^2 + C2j x2^2 + ... + CNj xN^2 

where: 

* i ranges from 0 to N=nIn-1 

* j ranges from 0 to nOut-1, 

* Cij = 0.001 (i+j+1) 

""" 

baseCoeff = 0.001 

forwardCoeffs = [] 

for out_ind in range(nOut): 

coeffOffset = baseCoeff * out_ind 

for in_ind in range(nIn): 

coeff = baseCoeff * (in_ind + 1) + coeffOffset 

coeffArr = [coeff, out_ind + 1] + [2 if i == in_ind else 0 for i in range(nIn)] 

forwardCoeffs.append(coeffArr) 

return np.array(forwardCoeffs, dtype=float) 

 

 

def makeTwoWayPolyMap(nIn, nOut): 

"""Make an astshim.PolyMap suitable for testing 

 

The forward transform is as follows: 

fj(x) = C0j x0^2 + C1j x1^2 + C2j x2^2 + ... + CNj xN^2 where Cij = 0.001 (i+j+1) 

 

The reverse transform is the same equation with i and j reversed 

thus it is NOT the inverse of the forward direction, 

but is something that can be easily evaluated. 

 

The equation is chosen for the following reasons: 

- It is well defined for any positive value of nIn, nOut 

- It stays small for small x, to avoid wraparound of angles for SpherePoint endpoints 

""" 

forwardCoeffs = makePolyMapCoeffs(nIn, nOut) 

reverseCoeffs = makePolyMapCoeffs(nOut, nIn) 

polyMap = PolyMap(forwardCoeffs, reverseCoeffs) 

assert polyMap.nIn == nIn 

assert polyMap.nOut == nOut 

assert polyMap.hasForward 

assert polyMap.hasInverse 

return polyMap 

 

 

def makeForwardPolyMap(nIn, nOut): 

"""Make an astshim.PolyMap suitable for testing 

 

The forward transform is the same as for `makeTwoWayPolyMap`. 

This map does not have a reverse transform. 

 

The equation is chosen for the following reasons: 

- It is well defined for any positive value of nIn, nOut 

- It stays small for small x, to avoid wraparound of angles for SpherePoint endpoints 

""" 

forwardCoeffs = makePolyMapCoeffs(nIn, nOut) 

polyMap = PolyMap(forwardCoeffs, nOut, "IterInverse=0") 

assert polyMap.nIn == nIn 

assert polyMap.nOut == nOut 

assert polyMap.hasForward 

assert not polyMap.hasInverse 

return polyMap