Coverage for tests/test_fitTanSipWcsTask.py: 13%

221 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-02 02:53 +0000

1# LSST Data Management System 

2# Copyright 2008, 2009, 2010 LSST Corporation. 

3# 

4# This product includes software developed by the 

5# LSST Project (http://www.lsst.org/). 

6# 

7# This program is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# This program is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the LSST License Statement and 

18# the GNU General Public License along with this program. If not, 

19# see <http://www.lsstcorp.org/LegalNotices/>. 

20# 

21# The classes in this test are a little non-standard to reduce code 

22# duplication and support automated unittest discovery. 

23# A base class includes all the code that implements the testing and 

24# itself inherits from unittest.TestCase. unittest automated discovery 

25# will scan all classes that inherit from unittest.TestCase and invoke 

26# any test methods found. To prevent this base class from being executed 

27# the test methods are placed in a different class that does not inherit 

28# from unittest.TestCase. The actual test classes then inherit from 

29# both the testing class and the implementation class allowing test 

30# discovery to only run tests found in the subclasses. 

31 

32import math 

33import unittest 

34 

35import numpy as np 

36 

37import lsst.pipe.base 

38import lsst.utils.tests 

39import lsst.geom 

40import lsst.afw.geom as afwGeom 

41from lsst.afw.geom.wcsUtils import makeTanSipMetadata 

42import lsst.afw.table as afwTable 

43from lsst.meas.algorithms import LoadReferenceObjectsTask 

44from lsst.meas.base import SingleFrameMeasurementTask 

45from lsst.meas.astrom import FitTanSipWcsTask, setMatchDistance 

46from lsst.meas.astrom.sip import makeCreateWcsWithSip 

47 

48 

49class BaseTestCase: 

50 

51 """A test case for CreateWcsWithSip 

52 

53 Use involves setting one class attribute: 

54 * MatchClass: match class, e.g. ReferenceMatch or SourceMatch 

55 

56 This test is a bit messy because it exercises two templatings of makeCreateWcsWithSip, 

57 the underlying TAN-SIP WCS fitter, but only one of those is supported by FitTanSipWcsTask 

58 """ 

59 MatchClass = None 

60 

61 def setUp(self): 

62 crval = lsst.geom.SpherePoint(44, 45, lsst.geom.degrees) 

63 crpix = lsst.geom.Point2D(15000, 4000) 

64 

65 scale = 1 * lsst.geom.arcseconds 

66 cdMatrix = afwGeom.makeCdMatrix(scale=scale, flipX=True) 

67 self.tanWcs = afwGeom.makeSkyWcs(crpix=crpix, crval=crval, cdMatrix=cdMatrix) 

68 self.loadData() 

69 

70 def loadData(self, rangePix=3000, numPoints=25): 

71 """Load catalogs and make the match list 

72 

73 This is a separate function so data can be reloaded if fitting more than once 

74 (each time a WCS is fit it may update the source catalog, reference catalog and match list) 

75 """ 

76 if self.MatchClass == afwTable.ReferenceMatch: 

77 refSchema = LoadReferenceObjectsTask.makeMinimalSchema( 

78 filterNameList=["r"], addIsPhotometric=True, addCentroid=True) 

79 self.refCat = afwTable.SimpleCatalog(refSchema) 

80 elif self.MatchClass == afwTable.SourceMatch: 

81 refSchema = afwTable.SourceTable.makeMinimalSchema() 

82 self.refCat = afwTable.SourceCatalog(refSchema) 

83 else: 

84 raise RuntimeError("Unsupported MatchClass=%r" % (self.MatchClass,)) 

85 srcSchema = afwTable.SourceTable.makeMinimalSchema() 

86 SingleFrameMeasurementTask(schema=srcSchema) 

87 self.srcCoordKey = afwTable.CoordKey(srcSchema["coord"]) 

88 self.srcCentroidKey = afwTable.Point2DKey(srcSchema["slot_Centroid"]) 

89 self.srcCentroidKey_xErr = srcSchema["slot_Centroid_xErr"].asKey() 

90 self.srcCentroidKey_yErr = srcSchema["slot_Centroid_yErr"].asKey() 

91 self.sourceCat = afwTable.SourceCatalog(srcSchema) 

92 

93 self.matches = [] 

94 

95 for i in np.linspace(0., rangePix, numPoints): 

96 for j in np.linspace(0., rangePix, numPoints): 

97 src = self.sourceCat.addNew() 

98 refObj = self.refCat.addNew() 

99 

100 src.set(self.srcCentroidKey, lsst.geom.Point2D(i, j)) 

101 src.set(self.srcCentroidKey_xErr, 0.1) 

102 src.set(self.srcCentroidKey_yErr, 0.1) 

103 

104 c = self.tanWcs.pixelToSky(i, j) 

105 refObj.setCoord(c) 

106 

107 if False: 

108 print("x,y = (%.1f, %.1f) pixels -- RA,Dec = (%.3f, %.3f) deg" % 

109 (i, j, c.toFk5().getRa().asDegrees(), c.toFk5().getDec().asDegrees())) 

110 

111 self.matches.append(self.MatchClass(refObj, src, 0.0)) 

112 

113 def tearDown(self): 

114 del self.refCat 

115 del self.sourceCat 

116 del self.matches 

117 del self.tanWcs 

118 

119 def checkResults(self, fitRes, catsUpdated): 

120 """Check results 

121 

122 @param[in] fitRes a object with two fields: 

123 - wcs fit TAN-SIP WCS, an lsst.afw.geom.SkyWcs 

124 - scatterOnSky median on-sky scatter, an lsst.afw.geom.Angle 

125 @param[in] catsUpdated if True then coord field of self.sourceCat and centroid fields of self.refCat 

126 have been updated 

127 """ 

128 self.assertLess(fitRes.scatterOnSky.asArcseconds(), 0.001) 

129 tanSipWcs = fitRes.wcs 

130 maxAngSep = 0*lsst.geom.radians 

131 maxPixSep = 0 

132 refCoordKey = afwTable.CoordKey(self.refCat.schema["coord"]) 

133 if catsUpdated: 

134 refCentroidKey = afwTable.Point2DKey(self.refCat.schema["centroid"]) 

135 maxDistErr = 0*lsst.geom.radians 

136 for refObj, src, distRad in self.matches: 

137 srcPixPos = src.get(self.srcCentroidKey) 

138 refCoord = refObj.get(refCoordKey) 

139 if catsUpdated: 

140 refPixPos = refObj.get(refCentroidKey) 

141 srcCoord = src.get(self.srcCoordKey) 

142 else: 

143 refPixPos = tanSipWcs.skyToPixel(refCoord) 

144 srcCoord = tanSipWcs.pixelToSky(srcPixPos) 

145 

146 angSep = refCoord.separation(srcCoord) 

147 dist = distRad*lsst.geom.radians 

148 distErr = abs(dist - angSep) 

149 maxDistErr = max(maxDistErr, distErr) 

150 maxAngSep = max(maxAngSep, angSep) 

151 

152 pixSep = math.hypot(*(srcPixPos - refPixPos)) 

153 maxPixSep = max(maxPixSep, pixSep) 

154 

155 print("max angular separation = %0.4f arcsec" % (maxAngSep.asArcseconds(),)) 

156 print("max pixel separation = %0.3f" % (maxPixSep,)) 

157 self.assertLess(maxAngSep.asArcseconds(), 0.001) 

158 self.assertLess(maxPixSep, 0.005) 

159 if catsUpdated: 

160 allowedDistErr = 1e-7 

161 else: 

162 allowedDistErr = 0.001 

163 self.assertLess(maxDistErr.asArcseconds(), allowedDistErr, 

164 "Computed distance in match list is off by %s arcsec" % (maxDistErr.asArcseconds(),)) 

165 

166 def doTest(self, name, func, order=3, numIter=4, specifyBBox=False, doPlot=False, doPrint=False): 

167 """Apply func(x, y) to each source in self.sourceCat, then fit and check the resulting WCS 

168 """ 

169 bbox = lsst.geom.Box2I() 

170 for refObj, src, d in self.matches: 

171 origPos = src.get(self.srcCentroidKey) 

172 x, y = func(*origPos) 

173 distortedPos = lsst.geom.Point2D(*func(*origPos)) 

174 src.set(self.srcCentroidKey, distortedPos) 

175 bbox.include(lsst.geom.Point2I(lsst.geom.Point2I(distortedPos))) 

176 

177 tanSipWcs = self.tanWcs 

178 for i in range(numIter): 

179 if specifyBBox: 

180 sipObject = makeCreateWcsWithSip(self.matches, tanSipWcs, order, bbox) 

181 else: 

182 sipObject = makeCreateWcsWithSip(self.matches, tanSipWcs, order) 

183 tanSipWcs = sipObject.getNewWcs() 

184 setMatchDistance(self.matches) 

185 fitRes = lsst.pipe.base.Struct( 

186 wcs=tanSipWcs, 

187 scatterOnSky=sipObject.getScatterOnSky(), 

188 ) 

189 

190 if doPrint: 

191 print("TAN-SIP metadata fit over bbox=", bbox) 

192 metadata = makeTanSipMetadata( 

193 crpix=tanSipWcs.getPixelOrigin(), 

194 crval=tanSipWcs.getSkyOrigin(), 

195 cdMatrix=tanSipWcs.getCdMatrix(), 

196 sipA=sipObject.getSipA(), 

197 sipB=sipObject.getSipB(), 

198 sipAp=sipObject.getSipAp(), 

199 sipBp=sipObject.getSipBp(), 

200 ) 

201 print(metadata.toString()) 

202 

203 if doPlot: 

204 self.plotWcs(tanSipWcs, name=name) 

205 

206 self.checkResults(fitRes, catsUpdated=False) 

207 

208 if self.MatchClass == afwTable.ReferenceMatch: 

209 # reset source coord and reference centroid based on initial WCS 

210 afwTable.updateRefCentroids(wcs=self.tanWcs, refList=self.refCat) 

211 afwTable.updateSourceCoords(wcs=self.tanWcs, sourceList=self.sourceCat) 

212 

213 fitterConfig = FitTanSipWcsTask.ConfigClass() 

214 fitterConfig.order = order 

215 fitterConfig.numIter = numIter 

216 fitter = FitTanSipWcsTask(config=fitterConfig) 

217 self.loadData() 

218 if specifyBBox: 

219 fitRes = fitter.fitWcs( 

220 matches=self.matches, 

221 initWcs=self.tanWcs, 

222 bbox=bbox, 

223 refCat=self.refCat, 

224 sourceCat=self.sourceCat, 

225 ) 

226 else: 

227 fitRes = fitter.fitWcs( 

228 matches=self.matches, 

229 initWcs=self.tanWcs, 

230 bbox=bbox, 

231 refCat=self.refCat, 

232 sourceCat=self.sourceCat, 

233 ) 

234 

235 self.checkResults(fitRes, catsUpdated=True) 

236 

237 def plotWcs(self, tanSipWcs, name=""): 

238 import matplotlib 

239 matplotlib.use("Agg") 

240 import matplotlib.pyplot as plt 

241 fileNamePrefix = "testCreateWcsWithSip_%s_%s" % (self.MatchClass.__name__, name) 

242 pnum = 1 

243 

244 xs, ys, xc, yc = [], [], [], [] 

245 rs, ds, rc, dc = [], [], [], [] 

246 for ref, src, d in self.matches: 

247 xs.append(src.getX()) 

248 ys.append(src.getY()) 

249 refPixPos = tanSipWcs.skyToPixel(ref.getCoord()) 

250 xc.append(refPixPos[0]) 

251 yc.append(refPixPos[1]) 

252 rc.append(ref.getRa()) 

253 dc.append(ref.getDec()) 

254 srd = tanSipWcs.pixelToSky(src.get(self.srcCentroidKey)) 

255 rs.append(srd.getRa()) 

256 ds.append(srd.getDec()) 

257 xs = np.array(xs) 

258 ys = np.array(ys) 

259 xc = np.array(xc) 

260 yc = np.array(yc) 

261 

262 plt.clf() 

263 plt.plot(xs, ys, "r.") 

264 plt.plot(xc, yc, "bx") 

265 fileName = "%s_%i.png" % (fileNamePrefix, pnum) 

266 plt.savefig(fileName) 

267 print("Wrote", fileName) 

268 pnum += 1 

269 

270 plt.clf() 

271 plt.plot(xs, xc-xs, "b.") 

272 fileName = "%s_%i.png" % (fileNamePrefix, pnum) 

273 plt.xlabel("x(source)") 

274 plt.ylabel("x(ref - src)") 

275 plt.savefig(fileName) 

276 print("Wrote", fileName) 

277 pnum += 1 

278 

279 plt.clf() 

280 plt.plot(rs, ds, "r.") 

281 plt.plot(rc, dc, "bx") 

282 fileName = "%s_%i.png" % (fileNamePrefix, pnum) 

283 plt.savefig(fileName) 

284 print("Wrote", fileName) 

285 pnum += 1 

286 

287 plt.clf() 

288 for y in np.linspace(0, 4000, 5): 

289 x0, y0 = [], [] 

290 x1, y1 = [], [] 

291 for x in np.linspace(0., 4000., 401): 

292 x0.append(x) 

293 y0.append(y) 

294 rd = tanSipWcs.pixelToSky(x, y) 

295 xy = tanSipWcs.skyToPixel(rd) 

296 x1.append(xy[0]) 

297 y1.append(xy[1]) 

298 x0 = np.array(x0) 

299 x1 = np.array(x1) 

300 plt.plot(x0, x1-x0, "b-") 

301 fileName = "%s_%i.png" % (fileNamePrefix, pnum) 

302 plt.savefig(fileName) 

303 print("Wrote", fileName) 

304 pnum += 1 

305 

306 

307class SideLoadTestCases: 

308 

309 """Base class implementations of testing methods. 

310 

311 Explicitly does not inherit from unittest.TestCase""" 

312 

313 def testTrivial(self): 

314 """Add no distortion""" 

315 for order in (2, 4, 6): 

316 self.doTest("testTrivial", lambda x, y: (x, y), order=order) 

317 

318 def testOffset(self): 

319 """Add an offset""" 

320 for order in (2, 4, 6): 

321 self.doTest("testOffset", lambda x, y: (x + 5, y + 7), order=order) 

322 

323 def testLinearX(self): 

324 """Scale x, offset y""" 

325 for order in (2, 6): 

326 self.doTest("testLinearX", lambda x, y: (2*x, y + 7), order=order) 

327 

328 def testLinearXY(self): 

329 """Scale x and y""" 

330 self.doTest("testLinearXY", lambda x, y: (2*x, 3*y)) 

331 

332 def testLinearYX(self): 

333 """Add an offset to each point; scale in y and x""" 

334 for order in (2, 6): 

335 self.doTest("testLinearYX", lambda x, y: (x + 0.2*y, y + 0.3*x), order=order) 

336 

337 def testQuadraticX(self): 

338 """Add quadratic distortion in x""" 

339 for order in (4, 5): 

340 self.doTest("testQuadraticX", lambda x, y: (x + 1e-5*x**2, y), order=order) 

341 

342 def testRadial(self): 

343 """Add radial distortion""" 

344 radialTransform = afwGeom.makeRadialTransform([0, 1.01, 1e-8]) 

345 

346 def radialDistortion(x, y): 

347 x, y = radialTransform.applyForward(lsst.geom.Point2D(x, y)) 

348 return (x, y) 

349 for order in (4, 5, 6): 

350 doPrint = order == 5 

351 self.doTest("testRadial", radialDistortion, order=order, doPrint=doPrint) 

352 

353# The test classes inherit from two base classes and differ in the match 

354# class being used. 

355 

356 

357class CreateWcsWithSipTestCaseReferenceMatch(BaseTestCase, SideLoadTestCases, lsst.utils.tests.TestCase): 

358 MatchClass = afwTable.ReferenceMatch 

359 

360 

361class CreateWcsWithSipTestCaseSourceMatch(BaseTestCase, SideLoadTestCases, lsst.utils.tests.TestCase): 

362 MatchClass = afwTable.SourceMatch 

363 

364 

365class MemoryTester(lsst.utils.tests.MemoryTestCase): 

366 pass 

367 

368 

369def setup_module(module): 

370 lsst.utils.tests.init() 

371 

372 

373if __name__ == "__main__": 373 ↛ 374line 373 didn't jump to line 374, because the condition on line 373 was never true

374 lsst.utils.tests.init() 

375 unittest.main()