Coverage for tests/test_astrometryModel.py: 25%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

211 statements  

1# This file is part of jointcal. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Tests of astrometryModels (simple, constrained). 

23 

24Includes tests of producing a Wcs from a model. 

25""" 

26import itertools 

27import os 

28import numpy as np 

29 

30import unittest 

31import lsst.utils.tests 

32 

33import lsst.afw.cameraGeom 

34import lsst.afw.geom 

35import lsst.afw.table 

36import lsst.afw.image 

37import lsst.afw.image.utils 

38import lsst.daf.persistence 

39import lsst.geom 

40import lsst.jointcal 

41from lsst.jointcal import astrometryModels 

42import lsst.log 

43from lsst.meas.algorithms import astrometrySourceSelector 

44 

45 

46def getNParametersPolynomial(order): 

47 """Number of parameters in an astrometry polynomial model is 2 * (d+1)(d+2)/2.""" 

48 return (order + 1)*(order + 2) 

49 

50 

51class AstrometryModelTestBase: 

52 @classmethod 

53 def setUpClass(cls): 

54 try: 

55 cls.dataDir = lsst.utils.getPackageDir('testdata_jointcal') 

56 # NOTE: the below is to facilitate testing with hsc test data, 

57 # using chips far apart on the focal plane. See the note in setup() 

58 # below for details. Using it requires having recently-processed 

59 # singleFrame output in a rerun directory in validation_data_hsc. 

60 # cls.dataDir = lsst.utils.getPackageDir('validation_data_hsc') 

61 except LookupError: 

62 raise unittest.SkipTest("testdata_jointcal not setup") 

63 

64 def setUp(self): 

65 np.random.seed(200) 

66 

67 # DEBUG messages can help track down failures. 

68 logger = lsst.log.Log.getLogger('lsst.jointcal') 

69 logger.setLevel(lsst.log.DEBUG) 

70 

71 # Append `msg` arguments to assert failures. 

72 self.longMessage = True 

73 # absolute tolerance on positional errors of 10 micro-arcsecond 

74 self.atol = 10.0 / (60 * 60 * 1e6) 

75 

76 # Maximum difference (see assertPairsAlmostEqual) for round-trip 

77 # testing of the inverse for models 1 (simpler) and 2 (more. 

78 # Replace either one for models that don't have as accurate an inverse. 

79 self.inverseMaxDiff1 = 1e-5 

80 self.inverseMaxDiff2 = 1e-5 

81 

82 self.firstIndex = 0 # for assignIndices 

83 matchCut = 2.0 # arcseconds 

84 minMeasurements = 2 # accept all star pairs. 

85 

86 jointcalControl = lsst.jointcal.JointcalControl("slot_CalibFlux") 

87 self.associations = lsst.jointcal.Associations() 

88 # Work around the fact that the testdata_jointcal catalogs were produced 

89 # before DM-13493, and so have a different definition of the interpolated flag. 

90 sourceSelectorConfig = astrometrySourceSelector.AstrometrySourceSelectorConfig() 

91 sourceSelectorConfig.badFlags.append("base_PixelFlags_flag_interpolated") 

92 sourceSelector = astrometrySourceSelector.AstrometrySourceSelectorTask(config=sourceSelectorConfig) 

93 

94 # Ensure that the filter list is reset for each test so that we avoid 

95 # confusion or contamination each time we create a cfht camera below. 

96 lsst.afw.image.utils.resetFilters() 

97 

98 # jointcal's cfht test data has 6 ccds and 2 visits. 

99 inputDir = os.path.join(self.dataDir, 'cfht') 

100 self.visits = [849375, 850587] 

101 self.ccds = [12, 13, 14, 21, 22, 23] 

102 self.badVisit = -12345 

103 self.badCcd = 888 

104 

105 self.butler = lsst.daf.persistence.Butler(inputDir) 

106 

107 self.catalogs = [] 

108 self.ccdImageList = [] 

109 for (visit, ccd) in itertools.product(self.visits, self.ccds): 

110 dataRef = self.butler.dataRef('calexp', visit=visit, ccd=ccd) 

111 

112 src = dataRef.get("src", flags=lsst.afw.table.SOURCE_IO_NO_FOOTPRINTS, immediate=True) 

113 goodSrc = sourceSelector.run(src) 

114 # Need memory contiguity to do vector-like things on the sourceCat. 

115 goodSrc = goodSrc.sourceCat.copy(deep=True) 

116 

117 visitInfo = dataRef.get('calexp_visitInfo') 

118 detector = dataRef.get('calexp_detector') 

119 ccdId = detector.getId() 

120 wcs = dataRef.get('calexp_wcs') 

121 bbox = dataRef.get('calexp_bbox') 

122 filt = dataRef.get('calexp_filterLabel') 

123 filterName = filt.physicalLabel 

124 photoCalib = lsst.afw.image.PhotoCalib(100.0, 1.0) 

125 

126 self.catalogs.append(goodSrc) 

127 self.associations.createCcdImage(goodSrc, 

128 wcs, 

129 visitInfo, 

130 bbox, 

131 filterName, 

132 photoCalib, 

133 detector, 

134 visit, 

135 ccdId, 

136 jointcalControl) 

137 

138 # Have to set the common tangent point so projectionHandler can use skyToCTP. 

139 self.associations.computeCommonTangentPoint() 

140 

141 self.projectionHandler = lsst.jointcal.OneTPPerVisitHandler(self.associations.getCcdImageList()) 

142 

143 self.associations.associateCatalogs(matchCut) 

144 self.associations.prepareFittedStars(minMeasurements) 

145 self.associations.deprojectFittedStars() 

146 

147 def _prepModels(self): 

148 """Call this after model1 and model2 are created, to call assignIndices, 

149 and instantiate the fitters. 

150 """ 

151 posError = 0.02 # in pixels 

152 # have to call this once or offsetParams will fail because the transform indices aren't defined 

153 self.model1.assignIndices("Distortions", self.firstIndex) 

154 self.fitter1 = lsst.jointcal.AstrometryFit(self.associations, self.model1, posError) 

155 

156 # have to call this once or offsetParams will fail because the transform indices aren't defined 

157 self.model2.assignIndices("Distortions", self.firstIndex) 

158 self.fitter2 = lsst.jointcal.AstrometryFit(self.associations, self.model2, posError) 

159 

160 def testMakeSkyWcsModel1(self): 

161 self.CheckMakeSkyWcsModel(self.model1, self.fitter1, self.inverseMaxDiff1) 

162 

163 def testMakeSkyWcsModel2(self): 

164 self.CheckMakeSkyWcsModel(self.model2, self.fitter2, self.inverseMaxDiff2) 

165 

166 def CheckMakeSkyWcsModel(self, model, fitter, inverseMaxDiff): 

167 """Test producing a SkyWcs on a model for every cdImage, 

168 both post-initialization and after one fitting step. 

169 

170 Parameters 

171 ---------- 

172 model : `lsst.jointcal.AstrometryModel` 

173 The model to test. 

174 fitter : `lsst.jointcal.FitterBase` 

175 The fitter to use to step the model to test with new (reasonable) parameters. 

176 inverseMaxDiff : `float` 

177 Required accuracy on inverse transform. 

178 See `lsst.afw.geom.utils.assertPairsAlmostEqual`. 

179 

180 """ 

181 # first test on as-initialized models 

182 for ccdImage in self.associations.getCcdImageList(): 

183 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff) 

184 

185 # now shift the models to non-default, but more reasonable, values by taking one fitting step. 

186 fitter.minimize("DistortionsVisit") 

187 fitter.minimize("Distortions") 

188 for ccdImage in self.associations.getCcdImageList(): 

189 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff) 

190 

191 def checkMakeSkyWcsOneCcdImage(self, model, ccdImage, inverseMaxDiff): 

192 """Test converting the model of one ccdImage to a SkyWcs by comparing 

193 to the original transform at the tangent plane. 

194 

195 Parameters 

196 ---------- 

197 model : `lsst.jointcal.AstrometryModel` 

198 The model to test. 

199 ccdImage : `lsst.jointcal.CcdImage` 

200 The ccdImage to extract from the model and test. 

201 inverseMaxDiff : `float` 

202 Required accuracy on inverse transform. 

203 See `lsst.afw.geom.utils.assertPairsAlmostEqual`. 

204 """ 

205 skyWcs = model.makeSkyWcs(ccdImage) 

206 skyToTangentPlane = model.getSkyToTangentPlane(ccdImage) 

207 mapping = model.getMapping(ccdImage) 

208 

209 bbox = ccdImage.getDetector().getBBox() 

210 num = 200 

211 xx = np.linspace(bbox.getMinX(), bbox.getMaxX(), num) 

212 yy = np.linspace(bbox.getMinY(), bbox.getMaxY(), num) 

213 points = [lsst.geom.Point2D(*xy) for xy in itertools.product(xx, yy)] 

214 

215 expects = [] 

216 forwards = [] 

217 inverses = [] 

218 spherePoints = skyWcs.pixelToSky(points) 

219 inverses = skyWcs.skyToPixel(skyWcs.pixelToSky(points)) 

220 for point, spherePoint in zip(points, spherePoints): 

221 # TODO: Fix these "Point"s once DM-4044 is done. 

222 

223 # jointcal's pixel->tangent-plane mapping 

224 star = lsst.jointcal.star.BaseStar(point.getX(), point.getY(), 0, 0) 

225 tpExpect = mapping.transformPosAndErrors(star) 

226 expects.append(lsst.geom.Point2D(tpExpect.x, tpExpect.y)) 

227 

228 # skywcs takes pixel->sky, and we then have to go sky->tangent-plane 

229 onSky = lsst.jointcal.star.BaseStar(spherePoint.getLongitude().asDegrees(), 

230 spherePoint.getLatitude().asDegrees(), 0, 0) 

231 result = skyToTangentPlane.apply(onSky) 

232 forwards.append(lsst.geom.Point2D(result.x, result.y)) 

233 

234 self.assertPairListsAlmostEqual(forwards, expects) 

235 # NOTE: assertPairListsAlmostEqual() compares absolute, not relative, 

236 # values so the points along the ccd edge may exceed maxDiff while still 

237 # being "close enough": set `inverseMaxDiff` accordingly. 

238 self.assertPairListsAlmostEqual(inverses, points, maxDiff=inverseMaxDiff) 

239 

240 

241class SimpleAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase): 

242 """Test the `SimpleAstrometryModel`, with one mapping per ccd per visit.""" 

243 def setUp(self): 

244 super().setUp() 

245 self.order1 = 3 

246 self.inverseMaxDiff1 = 2e-5 

247 self.model1 = astrometryModels.SimpleAstrometryModel(self.associations.getCcdImageList(), 

248 self.projectionHandler, 

249 True, 

250 order=self.order1) 

251 

252 self.order2 = 5 

253 # NOTE: because assertPairListsAlmostEqual tests an absolute 

254 # difference, we need this to be relatively high to avoid spurious 

255 # incorrect values. 

256 # Alternately, further increasing the order of the inverse polynomial 

257 # in astrometryTransform.toAstMap() can improve the quality of the 

258 # SkyWcs inverse, but that may not be wise for the more general use 

259 # case due to the inverse then having too many wiggles. 

260 self.inverseMaxDiff2 = 2e-3 

261 self.model2 = astrometryModels.SimpleAstrometryModel(self.associations.getCcdImageList(), 

262 self.projectionHandler, 

263 False, 

264 order=self.order2) 

265 self._prepModels() 

266 

267 def _testGetNpar(self, model, order): 

268 for ccdImage in self.associations.getCcdImageList(): 

269 result = model.getNpar(ccdImage) 

270 self.assertEqual(result, getNParametersPolynomial(order)) 

271 

272 def testGetNpar1(self): 

273 self._testGetNpar(self.model1, self.order1) 

274 

275 def testGetNpar2(self): 

276 self._testGetNpar(self.model2, self.order2) 

277 

278 def _testGetTotalParameters(self, model, order): 

279 result = model.getTotalParameters() 

280 expect = getNParametersPolynomial(order)*len(self.associations.getCcdImageList()) 

281 self.assertEqual(result, expect) 

282 

283 def testGetTotalParametersModel1(self): 

284 self._testGetTotalParameters(self.model1, self.order1) 

285 

286 def testGetTotalParametersModel2(self): 

287 self._testGetTotalParameters(self.model2, self.order2) 

288 

289 

290class ConstrainedAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase): 

291 """Test the `ConstrainedAstrometryModel`, with one mapping per ccd and one 

292 mapping per visit. 

293 """ 

294 def setUp(self): 

295 super().setUp() 

296 self.visitOrder1 = 3 

297 self.chipOrder1 = 1 

298 self.inverseMaxDiff1 = 1e-5 

299 self.model1 = astrometryModels.ConstrainedAstrometryModel(self.associations.getCcdImageList(), 

300 self.projectionHandler, 

301 chipOrder=self.chipOrder1, 

302 visitOrder=self.visitOrder1) 

303 

304 self.visitOrder2 = 5 

305 self.chipOrder2 = 2 

306 self.inverseMaxDiff2 = 5e-5 

307 self.model2 = astrometryModels.ConstrainedAstrometryModel(self.associations.getCcdImageList(), 

308 self.projectionHandler, 

309 chipOrder=self.chipOrder2, 

310 visitOrder=self.visitOrder2) 

311 self._prepModels() 

312 

313 # 22 is closest to the center of the focal plane in this data, so it is not fit. 

314 self.fixedCcd = 22 

315 

316 def _polyParams(self, chipOrder, visitOrder): 

317 """Number of parameters per polynomial is (d+1)(d+2)/2, summed over 

318 polynomials, times 2 polynomials per dimension. 

319 The chip transform is fixed for one chip, so only visitOrder matters 

320 if chipOrder is None. 

321 """ 

322 params = getNParametersPolynomial(visitOrder) 

323 if chipOrder is not None: 

324 params += getNParametersPolynomial(chipOrder) 

325 return params 

326 

327 def _testGetNpar(self, model, chipOrder, visitOrder): 

328 def checkParams(ccdImage, model, chipOrder, visitOrder): 

329 result = model.getNpar(ccdImage) 

330 failMsg = "ccdImage: %s, with chipOrder %s and visitOrder %s"%(ccdImage.getName(), 

331 chipOrder, 

332 visitOrder) 

333 self.assertEqual(result, self._polyParams(chipOrder, visitOrder), msg=failMsg) 

334 

335 for ccdImage in self.associations.getCcdImageList(): 

336 realChipOrder = None if ccdImage.getCcdId() == self.fixedCcd else chipOrder 

337 checkParams(ccdImage, model, realChipOrder, visitOrder) 

338 

339 def testGetNpar1(self): 

340 self._testGetNpar(self.model1, self.chipOrder1, self.visitOrder1) 

341 

342 def testGetNpar2(self): 

343 self._testGetNpar(self.model2, self.chipOrder2, self.visitOrder2) 

344 

345 def _testGetTotalParameters(self, model, chipOrder, visitOrder): 

346 result = model.getTotalParameters() 

347 # one sensor is held fixed, hence len(ccds)-1 

348 expect = getNParametersPolynomial(chipOrder)*(len(self.ccds) - 1) + \ 

349 getNParametersPolynomial(visitOrder)*len(self.visits) 

350 self.assertEqual(result, expect) 

351 

352 def testGetTotalParametersModel1(self): 

353 self._testGetTotalParameters(self.model1, self.chipOrder1, self.visitOrder1) 

354 

355 def testGetTotalParametersModel2(self): 

356 self._testGetTotalParameters(self.model2, self.chipOrder2, self.visitOrder2) 

357 

358 def checkGetChipTransform(self, model): 

359 # Check valid ccds 

360 for ccd in self.ccds: 

361 try: 

362 model.getChipTransform(ccd) 

363 except lsst.pex.exceptions.wrappers.InvalidParameterError: 

364 self.fail("model: {} raised on ccd: {}, but should not have.".format(model, ccd)) 

365 

366 # Check an invalid ccd 

367 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm: 

368 model.getChipTransform(self.badCcd) 

369 errMsg = "No such chipId: {} among [{}]".format(self.badCcd, ", ".join(str(ccd) for ccd in self.ccds)) 

370 self.assertIn(errMsg, str(cm.exception)) 

371 

372 def testGetChipTransform(self): 

373 """getChipTransform should get each known transform, and raise with an 

374 appropriate message otherwise. 

375 """ 

376 self.checkGetChipTransform(self.model1) 

377 self.checkGetChipTransform(self.model2) 

378 

379 def checkGetVisitTransform(self, model): 

380 # Check valid visits 

381 for visit in self.visits: 

382 try: 

383 model.getVisitTransform(visit) 

384 except lsst.pex.exceptions.wrappers.InvalidParameterError: 

385 self.fail("model: {} raised on visit: {}, but should not have.".format(model, visit)) 

386 

387 # Check an invalid visit 

388 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm: 

389 model.getVisitTransform(self.badVisit) 

390 errMsg = "No such visitId: {} among [{}]".format(self.badVisit, 

391 ", ".join(str(v) for v in self.visits)) 

392 self.assertIn(errMsg, str(cm.exception)) 

393 

394 def testGetVisitTransform(self): 

395 """getVisitTransform should get each known transform, and raise with an 

396 appropriate message otherwise. 

397 """ 

398 self.checkGetVisitTransform(self.model1) 

399 self.checkGetVisitTransform(self.model2) 

400 

401 def testValidate(self): 

402 """Test that invalid models fail validate(), and that valid ones pass. 

403 """ 

404 # We need at least 0 degrees of freedom (data - parameters) for the model to be valid. 

405 # Note: model1 has 70 total parameters (2 visits*20 params + (6-1) sensors*5 params) 

406 self.assertTrue(self.model1.validate(self.ccdImageList, 0)) 

407 self.assertFalse(self.model1.validate(self.ccdImageList, -1)) 

408 

409 

410class MemoryTester(lsst.utils.tests.MemoryTestCase): 

411 pass 

412 

413 

414def setup_module(module): 

415 lsst.utils.tests.init() 

416 

417 

418if __name__ == "__main__": 418 ↛ 419line 418 didn't jump to line 419, because the condition on line 418 was never true

419 lsst.utils.tests.init() 

420 unittest.main()