Coverage for tests/test_fitTanSipWcsTask.py: 13%

222 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-05-04 02:55 -0700

1# LSST Data Management System 

2# Copyright 2008, 2009, 2010 LSST Corporation. 

3# 

4# This product includes software developed by the 

5# LSST Project (http://www.lsst.org/). 

6# 

7# This program is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# This program is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the LSST License Statement and 

18# the GNU General Public License along with this program. If not, 

19# see <http://www.lsstcorp.org/LegalNotices/>. 

20# 

21# The classes in this test are a little non-standard to reduce code 

22# duplication and support automated unittest discovery. 

23# A base class includes all the code that implements the testing and 

24# itself inherits from unittest.TestCase. unittest automated discovery 

25# will scan all classes that inherit from unittest.TestCase and invoke 

26# any test methods found. To prevent this base class from being executed 

27# the test methods are placed in a different class that does not inherit 

28# from unittest.TestCase. The actual test classes then inherit from 

29# both the testing class and the implementation class allowing test 

30# discovery to only run tests found in the subclasses. 

31 

32import math 

33import unittest 

34 

35import numpy as np 

36 

37import lsst.pipe.base 

38import lsst.utils.tests 

39import lsst.geom 

40import lsst.afw.geom as afwGeom 

41from lsst.afw.geom.wcsUtils import makeTanSipMetadata 

42import lsst.afw.table as afwTable 

43from lsst.meas.algorithms import convertReferenceCatalog 

44from lsst.meas.base import SingleFrameMeasurementTask 

45from lsst.meas.astrom import FitTanSipWcsTask, setMatchDistance 

46from lsst.meas.astrom.sip import makeCreateWcsWithSip 

47 

48 

49class BaseTestCase: 

50 

51 """A test case for CreateWcsWithSip 

52 

53 Use involves setting one class attribute: 

54 * MatchClass: match class, e.g. ReferenceMatch or SourceMatch 

55 

56 This test is a bit messy because it exercises two templatings of makeCreateWcsWithSip, 

57 the underlying TAN-SIP WCS fitter, but only one of those is supported by FitTanSipWcsTask 

58 """ 

59 MatchClass = None 

60 

61 def setUp(self): 

62 crval = lsst.geom.SpherePoint(44, 45, lsst.geom.degrees) 

63 crpix = lsst.geom.Point2D(15000, 4000) 

64 

65 scale = 1 * lsst.geom.arcseconds 

66 cdMatrix = afwGeom.makeCdMatrix(scale=scale, flipX=True) 

67 self.tanWcs = afwGeom.makeSkyWcs(crpix=crpix, crval=crval, cdMatrix=cdMatrix) 

68 self.loadData() 

69 

70 def loadData(self, rangePix=3000, numPoints=25): 

71 """Load catalogs and make the match list 

72 

73 This is a separate function so data can be reloaded if fitting more than once 

74 (each time a WCS is fit it may update the source catalog, reference catalog and match list) 

75 """ 

76 if self.MatchClass == afwTable.ReferenceMatch: 

77 refSchema = convertReferenceCatalog._makeSchema(filterNameList=["r"], addIsPhotometric=True, 

78 addCentroid=True) 

79 self.refCat = afwTable.SimpleCatalog(refSchema) 

80 elif self.MatchClass == afwTable.SourceMatch: 

81 refSchema = afwTable.SourceTable.makeMinimalSchema() 

82 self.refCat = afwTable.SourceCatalog(refSchema) 

83 else: 

84 raise RuntimeError("Unsupported MatchClass=%r" % (self.MatchClass,)) 

85 srcSchema = afwTable.SourceTable.makeMinimalSchema() 

86 SingleFrameMeasurementTask(schema=srcSchema) 

87 afwTable.CoordKey.addErrorFields(srcSchema) 

88 self.srcCoordKey = afwTable.CoordKey(srcSchema["coord"]) 

89 self.srcCentroidKey = afwTable.Point2DKey(srcSchema["slot_Centroid"]) 

90 self.srcCentroidKey_xErr = srcSchema["slot_Centroid_xErr"].asKey() 

91 self.srcCentroidKey_yErr = srcSchema["slot_Centroid_yErr"].asKey() 

92 self.sourceCat = afwTable.SourceCatalog(srcSchema) 

93 

94 self.matches = [] 

95 

96 for i in np.linspace(0., rangePix, numPoints): 

97 for j in np.linspace(0., rangePix, numPoints): 

98 src = self.sourceCat.addNew() 

99 refObj = self.refCat.addNew() 

100 

101 src.set(self.srcCentroidKey, lsst.geom.Point2D(i, j)) 

102 src.set(self.srcCentroidKey_xErr, 0.1) 

103 src.set(self.srcCentroidKey_yErr, 0.1) 

104 

105 c = self.tanWcs.pixelToSky(i, j) 

106 refObj.setCoord(c) 

107 

108 if False: 

109 print("x,y = (%.1f, %.1f) pixels -- RA,Dec = (%.3f, %.3f) deg" % 

110 (i, j, c.toFk5().getRa().asDegrees(), c.toFk5().getDec().asDegrees())) 

111 

112 self.matches.append(self.MatchClass(refObj, src, 0.0)) 

113 

114 def tearDown(self): 

115 del self.refCat 

116 del self.sourceCat 

117 del self.matches 

118 del self.tanWcs 

119 

120 def checkResults(self, fitRes, catsUpdated): 

121 """Check results 

122 

123 @param[in] fitRes a object with two fields: 

124 - wcs fit TAN-SIP WCS, an lsst.afw.geom.SkyWcs 

125 - scatterOnSky median on-sky scatter, an lsst.afw.geom.Angle 

126 @param[in] catsUpdated if True then coord field of self.sourceCat and centroid fields of self.refCat 

127 have been updated 

128 """ 

129 self.assertLess(fitRes.scatterOnSky.asArcseconds(), 0.001) 

130 tanSipWcs = fitRes.wcs 

131 maxAngSep = 0*lsst.geom.radians 

132 maxPixSep = 0 

133 refCoordKey = afwTable.CoordKey(self.refCat.schema["coord"]) 

134 if catsUpdated: 

135 refCentroidKey = afwTable.Point2DKey(self.refCat.schema["centroid"]) 

136 maxDistErr = 0*lsst.geom.radians 

137 for refObj, src, distRad in self.matches: 

138 srcPixPos = src.get(self.srcCentroidKey) 

139 refCoord = refObj.get(refCoordKey) 

140 if catsUpdated: 

141 refPixPos = refObj.get(refCentroidKey) 

142 srcCoord = src.get(self.srcCoordKey) 

143 else: 

144 refPixPos = tanSipWcs.skyToPixel(refCoord) 

145 srcCoord = tanSipWcs.pixelToSky(srcPixPos) 

146 

147 angSep = refCoord.separation(srcCoord) 

148 dist = distRad*lsst.geom.radians 

149 distErr = abs(dist - angSep) 

150 maxDistErr = max(maxDistErr, distErr) 

151 maxAngSep = max(maxAngSep, angSep) 

152 

153 pixSep = math.hypot(*(srcPixPos - refPixPos)) 

154 maxPixSep = max(maxPixSep, pixSep) 

155 

156 print("max angular separation = %0.4f arcsec" % (maxAngSep.asArcseconds(),)) 

157 print("max pixel separation = %0.3f" % (maxPixSep,)) 

158 self.assertLess(maxAngSep.asArcseconds(), 0.001) 

159 self.assertLess(maxPixSep, 0.005) 

160 if catsUpdated: 

161 allowedDistErr = 1e-7 

162 else: 

163 allowedDistErr = 0.001 

164 self.assertLess(maxDistErr.asArcseconds(), allowedDistErr, 

165 "Computed distance in match list is off by %s arcsec" % (maxDistErr.asArcseconds(),)) 

166 

167 def doTest(self, name, func, order=3, numIter=4, specifyBBox=False, doPlot=False, doPrint=False): 

168 """Apply func(x, y) to each source in self.sourceCat, then fit and check the resulting WCS 

169 """ 

170 bbox = lsst.geom.Box2I() 

171 for refObj, src, d in self.matches: 

172 origPos = src.get(self.srcCentroidKey) 

173 x, y = func(*origPos) 

174 distortedPos = lsst.geom.Point2D(*func(*origPos)) 

175 src.set(self.srcCentroidKey, distortedPos) 

176 bbox.include(lsst.geom.Point2I(lsst.geom.Point2I(distortedPos))) 

177 

178 tanSipWcs = self.tanWcs 

179 for i in range(numIter): 

180 if specifyBBox: 

181 sipObject = makeCreateWcsWithSip(self.matches, tanSipWcs, order, bbox) 

182 else: 

183 sipObject = makeCreateWcsWithSip(self.matches, tanSipWcs, order) 

184 tanSipWcs = sipObject.getNewWcs() 

185 setMatchDistance(self.matches) 

186 fitRes = lsst.pipe.base.Struct( 

187 wcs=tanSipWcs, 

188 scatterOnSky=sipObject.getScatterOnSky(), 

189 ) 

190 

191 if doPrint: 

192 print("TAN-SIP metadata fit over bbox=", bbox) 

193 metadata = makeTanSipMetadata( 

194 crpix=tanSipWcs.getPixelOrigin(), 

195 crval=tanSipWcs.getSkyOrigin(), 

196 cdMatrix=tanSipWcs.getCdMatrix(), 

197 sipA=sipObject.getSipA(), 

198 sipB=sipObject.getSipB(), 

199 sipAp=sipObject.getSipAp(), 

200 sipBp=sipObject.getSipBp(), 

201 ) 

202 print(metadata.toString()) 

203 

204 if doPlot: 

205 self.plotWcs(tanSipWcs, name=name) 

206 

207 self.checkResults(fitRes, catsUpdated=False) 

208 

209 if self.MatchClass == afwTable.ReferenceMatch: 

210 # reset source coord and reference centroid based on initial WCS 

211 afwTable.updateRefCentroids(wcs=self.tanWcs, refList=self.refCat) 

212 afwTable.updateSourceCoords(wcs=self.tanWcs, sourceList=self.sourceCat) 

213 

214 fitterConfig = FitTanSipWcsTask.ConfigClass() 

215 fitterConfig.order = order 

216 fitterConfig.numIter = numIter 

217 fitter = FitTanSipWcsTask(config=fitterConfig) 

218 self.loadData() 

219 if specifyBBox: 

220 fitRes = fitter.fitWcs( 

221 matches=self.matches, 

222 initWcs=self.tanWcs, 

223 bbox=bbox, 

224 refCat=self.refCat, 

225 sourceCat=self.sourceCat, 

226 ) 

227 else: 

228 fitRes = fitter.fitWcs( 

229 matches=self.matches, 

230 initWcs=self.tanWcs, 

231 bbox=bbox, 

232 refCat=self.refCat, 

233 sourceCat=self.sourceCat, 

234 ) 

235 

236 self.checkResults(fitRes, catsUpdated=True) 

237 

238 def plotWcs(self, tanSipWcs, name=""): 

239 import matplotlib 

240 matplotlib.use("Agg") 

241 import matplotlib.pyplot as plt 

242 fileNamePrefix = "testCreateWcsWithSip_%s_%s" % (self.MatchClass.__name__, name) 

243 pnum = 1 

244 

245 xs, ys, xc, yc = [], [], [], [] 

246 rs, ds, rc, dc = [], [], [], [] 

247 for ref, src, d in self.matches: 

248 xs.append(src.getX()) 

249 ys.append(src.getY()) 

250 refPixPos = tanSipWcs.skyToPixel(ref.getCoord()) 

251 xc.append(refPixPos[0]) 

252 yc.append(refPixPos[1]) 

253 rc.append(ref.getRa()) 

254 dc.append(ref.getDec()) 

255 srd = tanSipWcs.pixelToSky(src.get(self.srcCentroidKey)) 

256 rs.append(srd.getRa()) 

257 ds.append(srd.getDec()) 

258 xs = np.array(xs) 

259 ys = np.array(ys) 

260 xc = np.array(xc) 

261 yc = np.array(yc) 

262 

263 plt.clf() 

264 plt.plot(xs, ys, "r.") 

265 plt.plot(xc, yc, "bx") 

266 fileName = "%s_%i.png" % (fileNamePrefix, pnum) 

267 plt.savefig(fileName) 

268 print("Wrote", fileName) 

269 pnum += 1 

270 

271 plt.clf() 

272 plt.plot(xs, xc-xs, "b.") 

273 fileName = "%s_%i.png" % (fileNamePrefix, pnum) 

274 plt.xlabel("x(source)") 

275 plt.ylabel("x(ref - src)") 

276 plt.savefig(fileName) 

277 print("Wrote", fileName) 

278 pnum += 1 

279 

280 plt.clf() 

281 plt.plot(rs, ds, "r.") 

282 plt.plot(rc, dc, "bx") 

283 fileName = "%s_%i.png" % (fileNamePrefix, pnum) 

284 plt.savefig(fileName) 

285 print("Wrote", fileName) 

286 pnum += 1 

287 

288 plt.clf() 

289 for y in np.linspace(0, 4000, 5): 

290 x0, y0 = [], [] 

291 x1, y1 = [], [] 

292 for x in np.linspace(0., 4000., 401): 

293 x0.append(x) 

294 y0.append(y) 

295 rd = tanSipWcs.pixelToSky(x, y) 

296 xy = tanSipWcs.skyToPixel(rd) 

297 x1.append(xy[0]) 

298 y1.append(xy[1]) 

299 x0 = np.array(x0) 

300 x1 = np.array(x1) 

301 plt.plot(x0, x1-x0, "b-") 

302 fileName = "%s_%i.png" % (fileNamePrefix, pnum) 

303 plt.savefig(fileName) 

304 print("Wrote", fileName) 

305 pnum += 1 

306 

307 

308class SideLoadTestCases: 

309 

310 """Base class implementations of testing methods. 

311 

312 Explicitly does not inherit from unittest.TestCase""" 

313 

314 def testTrivial(self): 

315 """Add no distortion""" 

316 for order in (2, 4, 6): 

317 self.doTest("testTrivial", lambda x, y: (x, y), order=order) 

318 

319 def testOffset(self): 

320 """Add an offset""" 

321 for order in (2, 4, 6): 

322 self.doTest("testOffset", lambda x, y: (x + 5, y + 7), order=order) 

323 

324 def testLinearX(self): 

325 """Scale x, offset y""" 

326 for order in (2, 6): 

327 self.doTest("testLinearX", lambda x, y: (2*x, y + 7), order=order) 

328 

329 def testLinearXY(self): 

330 """Scale x and y""" 

331 self.doTest("testLinearXY", lambda x, y: (2*x, 3*y)) 

332 

333 def testLinearYX(self): 

334 """Add an offset to each point; scale in y and x""" 

335 for order in (2, 6): 

336 self.doTest("testLinearYX", lambda x, y: (x + 0.2*y, y + 0.3*x), order=order) 

337 

338 def testQuadraticX(self): 

339 """Add quadratic distortion in x""" 

340 for order in (4, 5): 

341 self.doTest("testQuadraticX", lambda x, y: (x + 1e-5*x**2, y), order=order) 

342 

343 def testRadial(self): 

344 """Add radial distortion""" 

345 radialTransform = afwGeom.makeRadialTransform([0, 1.01, 1e-8]) 

346 

347 def radialDistortion(x, y): 

348 x, y = radialTransform.applyForward(lsst.geom.Point2D(x, y)) 

349 return (x, y) 

350 for order in (4, 5, 6): 

351 doPrint = order == 5 

352 self.doTest("testRadial", radialDistortion, order=order, doPrint=doPrint) 

353 

354# The test classes inherit from two base classes and differ in the match 

355# class being used. 

356 

357 

358class CreateWcsWithSipTestCaseReferenceMatch(BaseTestCase, SideLoadTestCases, lsst.utils.tests.TestCase): 

359 MatchClass = afwTable.ReferenceMatch 

360 

361 

362class CreateWcsWithSipTestCaseSourceMatch(BaseTestCase, SideLoadTestCases, lsst.utils.tests.TestCase): 

363 MatchClass = afwTable.SourceMatch 

364 

365 

366class MemoryTester(lsst.utils.tests.MemoryTestCase): 

367 pass 

368 

369 

370def setup_module(module): 

371 lsst.utils.tests.init() 

372 

373 

374if __name__ == "__main__": 374 ↛ 375line 374 didn't jump to line 375, because the condition on line 374 was never true

375 lsst.utils.tests.init() 

376 unittest.main()