Coverage for tests/test_astrometryModel.py: 25%

217 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-08-03 13:26 +0000

1# This file is part of jointcal. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Tests of astrometryModels (simple, constrained). 

23 

24Includes tests of producing a Wcs from a model. 

25""" 

26import itertools 

27import os 

28import numpy as np 

29import shutil 

30 

31import unittest 

32import lsst.utils.tests 

33 

34import lsst.afw.cameraGeom 

35import lsst.afw.geom 

36import lsst.afw.table 

37import lsst.afw.image 

38import lsst.geom 

39import lsst.log 

40 

41import lsst.jointcal 

42from lsst.jointcal.jointcal import (make_schema_table, get_sourceTable_visit_columns, 

43 extract_detector_catalog_from_visit_catalog) 

44from lsst.jointcal import astrometryModels 

45from jointcalTestBase import importRepository 

46 

47 

48def getNParametersPolynomial(order): 

49 """Number of parameters in an astrometry polynomial model is 2 * (d+1)(d+2)/2.""" 

50 return (order + 1)*(order + 2) 

51 

52 

53class AstrometryModelTestBase: 

54 """Test the jointcal AstrometryModel concrete classes, using CFHT data. 

55 """ 

56 @classmethod 

57 def setUpClass(cls): 

58 try: 

59 cls.dataDir = lsst.utils.getPackageDir('testdata_jointcal') 

60 except LookupError: 

61 raise unittest.SkipTest("testdata_jointcal not setup") 

62 

63 refcatPath = os.path.join(cls.dataDir, "cfht") 

64 refcats = {"gaia_dr2_20200414": os.path.join(refcatPath, "gaia_dr2_20200414.ecsv"), 

65 "ps1_pv3_3pi_20170110": os.path.join(refcatPath, "ps1_pv3_3pi_20170110.ecsv"), 

66 "sdss_dr9_fink_v5b": os.path.join(refcatPath, "sdss-dr9-fink-v5b.ecsv")} 

67 # Share one repo, since none of these tests write anything. 

68 cls.repopath = importRepository("lsst.obs.cfht.MegaPrime", 

69 os.path.join(cls.dataDir, 'cfht/repo/'), 

70 os.path.join(cls.dataDir, 'cfht/exports.yaml'), 

71 refcats=refcats, 

72 refcatPath=refcatPath) 

73 

74 @classmethod 

75 def tearDownClass(cls): 

76 shutil.rmtree(cls.repopath, ignore_errors=True) 

77 

78 def setUp(self): 

79 np.random.seed(200) 

80 

81 # DEBUG messages can help track down failures. 

82 logger = lsst.log.Log.getLogger('lsst.jointcal') 

83 logger.setLevel(lsst.log.DEBUG) 

84 

85 # Append `msg` arguments to assert failures. 

86 self.longMessage = True 

87 # absolute tolerance on positional errors of 10 micro-arcsecond 

88 self.atol = 10.0 / (60 * 60 * 1e6) 

89 

90 # Maximum difference (see assertPairsAlmostEqual) for round-trip 

91 # testing of the inverse for models 1 (simpler) and 2 (more. 

92 # Replace either one for models that don't have as accurate an inverse. 

93 self.inverseMaxDiff1 = 1e-5 

94 self.inverseMaxDiff2 = 1e-5 

95 

96 self.firstIndex = 0 # for assignIndices 

97 matchCut = 2.0 # arcseconds 

98 minMeasurements = 2 # accept all star pairs. 

99 

100 jointcalControl = lsst.jointcal.JointcalControl("flux") 

101 self.associations = lsst.jointcal.Associations() 

102 config = lsst.jointcal.JointcalConfig() 

103 config.load(os.path.join(os.path.dirname(__file__), "config/config.py")) 

104 sourceSelector = config.sourceSelector.target(config=config.sourceSelector['science']) 

105 

106 # jointcal's cfht test data has 6 ccds and 2 visits. 

107 self.visits = [849375, 850587] 

108 self.detectors = [12, 13, 14, 21, 22, 23] 

109 self.badVisit = -12345 

110 self.badCcd = 888 

111 

112 butler = lsst.daf.butler.Butler(self.repopath, collections='singleFrame', instrument="MegaPrime") 

113 

114 self.catalogs = [] 

115 self.ccdImageList = [] 

116 table = make_schema_table() 

117 inColumns = butler.get("sourceTable_visit", visit=self.visits[0]) 

118 columns, detColumn, ixxColumns = get_sourceTable_visit_columns(inColumns, config, sourceSelector) 

119 catalogs = {v: sourceSelector.run(butler.get('sourceTable_visit', 

120 visit=v, 

121 parameters={'columns': columns})) for v in self.visits} 

122 for (visit, detector) in itertools.product(self.visits, self.detectors): 

123 goodSrc = extract_detector_catalog_from_visit_catalog(table, 

124 catalogs[visit].sourceCat, 

125 detector, 

126 detColumn, 

127 ixxColumns, 

128 config.sourceFluxType, 

129 logger) 

130 dataId = {"detector": detector, "visit": visit} 

131 visitInfo = butler.get('calexp.visitInfo', dataId=dataId) 

132 detector = butler.get('calexp.detector', dataId=dataId) 

133 ccdId = detector.getId() 

134 wcs = butler.get('calexp.wcs', dataId=dataId) 

135 bbox = butler.get('calexp.bbox', dataId=dataId) 

136 filt = butler.get('calexp.filter', dataId=dataId) 

137 filterName = filt.physicalLabel 

138 photoCalib = lsst.afw.image.PhotoCalib(100.0, 1.0) 

139 

140 self.catalogs.append(goodSrc) 

141 self.associations.createCcdImage(goodSrc, 

142 wcs, 

143 visitInfo, 

144 bbox, 

145 filterName, 

146 photoCalib, 

147 detector, 

148 visit, 

149 ccdId, 

150 jointcalControl) 

151 

152 # Have to set the common tangent point so projectionHandler can use skyToCTP. 

153 self.associations.computeCommonTangentPoint() 

154 

155 self.projectionHandler = lsst.jointcal.OneTPPerVisitHandler(self.associations.getCcdImageList()) 

156 

157 self.associations.associateCatalogs(matchCut) 

158 self.associations.prepareFittedStars(minMeasurements) 

159 self.associations.deprojectFittedStars() 

160 

161 def _prepModels(self): 

162 """Call this after model1 and model2 are created, to call assignIndices, 

163 and instantiate the fitters. 

164 """ 

165 posError = 0.02 # in pixels 

166 # have to call this once or offsetParams will fail because the transform indices aren't defined 

167 self.model1.assignIndices("Distortions", self.firstIndex) 

168 self.fitter1 = lsst.jointcal.AstrometryFit(self.associations, self.model1, posError) 

169 

170 # have to call this once or offsetParams will fail because the transform indices aren't defined 

171 self.model2.assignIndices("Distortions", self.firstIndex) 

172 self.fitter2 = lsst.jointcal.AstrometryFit(self.associations, self.model2, posError) 

173 

174 def testMakeSkyWcsModel1(self): 

175 self.checkMakeSkyWcsModel(self.model1, self.fitter1, self.inverseMaxDiff1) 

176 

177 def testMakeSkyWcsModel2(self): 

178 self.checkMakeSkyWcsModel(self.model2, self.fitter2, self.inverseMaxDiff2) 

179 

180 def checkMakeSkyWcsModel(self, model, fitter, inverseMaxDiff): 

181 """Test producing a SkyWcs on a model for every cdImage, 

182 both post-initialization and after one fitting step. 

183 

184 Parameters 

185 ---------- 

186 model : `lsst.jointcal.AstrometryModel` 

187 The model to test. 

188 fitter : `lsst.jointcal.FitterBase` 

189 The fitter to use to step the model to test with new (reasonable) parameters. 

190 inverseMaxDiff : `float` 

191 Required accuracy on inverse transform. 

192 See `lsst.afw.geom.utils.assertPairsAlmostEqual`. 

193 

194 """ 

195 # first test on as-initialized models 

196 for ccdImage in self.associations.getCcdImageList(): 

197 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff) 

198 

199 # now shift the models to non-default, but more reasonable, values by taking one fitting step. 

200 fitter.minimize("DistortionsVisit") 

201 fitter.minimize("Distortions") 

202 for ccdImage in self.associations.getCcdImageList(): 

203 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff) 

204 

205 def checkMakeSkyWcsOneCcdImage(self, model, ccdImage, inverseMaxDiff): 

206 """Test converting the model of one ccdImage to a SkyWcs by comparing 

207 to the original transform at the tangent plane. 

208 

209 Parameters 

210 ---------- 

211 model : `lsst.jointcal.AstrometryModel` 

212 The model to test. 

213 ccdImage : `lsst.jointcal.CcdImage` 

214 The ccdImage to extract from the model and test. 

215 inverseMaxDiff : `float` 

216 Required accuracy on inverse transform. 

217 See `lsst.afw.geom.utils.assertPairsAlmostEqual`. 

218 """ 

219 skyWcs = model.makeSkyWcs(ccdImage) 

220 skyToTangentPlane = model.getSkyToTangentPlane(ccdImage) 

221 mapping = model.getMapping(ccdImage) 

222 

223 bbox = ccdImage.getDetector().getBBox() 

224 num = 200 

225 xx = np.linspace(bbox.getMinX(), bbox.getMaxX(), num) 

226 yy = np.linspace(bbox.getMinY(), bbox.getMaxY(), num) 

227 points = [lsst.geom.Point2D(*xy) for xy in itertools.product(xx, yy)] 

228 

229 expects = [] 

230 forwards = [] 

231 inverses = [] 

232 spherePoints = skyWcs.pixelToSky(points) 

233 inverses = skyWcs.skyToPixel(skyWcs.pixelToSky(points)) 

234 for point, spherePoint in zip(points, spherePoints): 

235 # TODO: Fix these "Point"s once DM-4044 is done. 

236 

237 # jointcal's pixel->tangent-plane mapping 

238 star = lsst.jointcal.star.BaseStar(point.getX(), point.getY(), 0, 0) 

239 tpExpect = mapping.transformPosAndErrors(star) 

240 expects.append(lsst.geom.Point2D(tpExpect.x, tpExpect.y)) 

241 

242 # skywcs takes pixel->sky, and we then have to go sky->tangent-plane 

243 onSky = lsst.jointcal.star.BaseStar(spherePoint.getLongitude().asDegrees(), 

244 spherePoint.getLatitude().asDegrees(), 0, 0) 

245 result = skyToTangentPlane.apply(onSky) 

246 forwards.append(lsst.geom.Point2D(result.x, result.y)) 

247 

248 self.assertPairListsAlmostEqual(forwards, expects) 

249 # NOTE: assertPairListsAlmostEqual() compares absolute, not relative, 

250 # values so the points along the ccd edge may exceed maxDiff while still 

251 # being "close enough": set `inverseMaxDiff` accordingly. 

252 self.assertPairListsAlmostEqual(inverses, points, maxDiff=inverseMaxDiff) 

253 

254 

255class SimpleAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase): 

256 """Test the `SimpleAstrometryModel`, with one mapping per ccd per visit.""" 

257 def setUp(self): 

258 super().setUp() 

259 self.order1 = 3 

260 self.inverseMaxDiff1 = 2e-4 

261 self.model1 = astrometryModels.SimpleAstrometryModel(self.associations.getCcdImageList(), 

262 self.projectionHandler, 

263 True, 

264 order=self.order1) 

265 

266 self.order2 = 5 

267 # NOTE: because assertPairListsAlmostEqual tests an absolute 

268 # difference, we need this to be relatively high to avoid spurious 

269 # incorrect values. 

270 # Alternately, further increasing the order of the inverse polynomial 

271 # in astrometryTransform.toAstMap() can improve the quality of the 

272 # SkyWcs inverse, but that may not be wise for the more general use 

273 # case due to the inverse then having too many wiggles. 

274 self.inverseMaxDiff2 = 2e-2 

275 self.model2 = astrometryModels.SimpleAstrometryModel(self.associations.getCcdImageList(), 

276 self.projectionHandler, 

277 False, 

278 order=self.order2) 

279 self._prepModels() 

280 

281 def _testGetNpar(self, model, order): 

282 for ccdImage in self.associations.getCcdImageList(): 

283 result = model.getNpar(ccdImage) 

284 self.assertEqual(result, getNParametersPolynomial(order)) 

285 

286 def testGetNpar1(self): 

287 self._testGetNpar(self.model1, self.order1) 

288 

289 def testGetNpar2(self): 

290 self._testGetNpar(self.model2, self.order2) 

291 

292 def _testGetTotalParameters(self, model, order): 

293 result = model.getTotalParameters() 

294 expect = getNParametersPolynomial(order)*len(self.associations.getCcdImageList()) 

295 self.assertEqual(result, expect) 

296 

297 def testGetTotalParametersModel1(self): 

298 self._testGetTotalParameters(self.model1, self.order1) 

299 

300 def testGetTotalParametersModel2(self): 

301 self._testGetTotalParameters(self.model2, self.order2) 

302 

303 

304class ConstrainedAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase): 

305 """Test the `ConstrainedAstrometryModel`, with one mapping per ccd and one 

306 mapping per visit. 

307 """ 

308 def setUp(self): 

309 super().setUp() 

310 self.visitOrder1 = 3 

311 self.chipOrder1 = 1 

312 self.inverseMaxDiff1 = 1e-5 

313 self.model1 = astrometryModels.ConstrainedAstrometryModel(self.associations.getCcdImageList(), 

314 self.projectionHandler, 

315 chipOrder=self.chipOrder1, 

316 visitOrder=self.visitOrder1) 

317 

318 self.visitOrder2 = 5 

319 self.chipOrder2 = 2 

320 self.inverseMaxDiff2 = 8e-5 

321 self.model2 = astrometryModels.ConstrainedAstrometryModel(self.associations.getCcdImageList(), 

322 self.projectionHandler, 

323 chipOrder=self.chipOrder2, 

324 visitOrder=self.visitOrder2) 

325 self._prepModels() 

326 

327 # 22 is closest to the center of the focal plane in this data, so it is not fit. 

328 self.fixedCcd = 22 

329 

330 def _polyParams(self, chipOrder, visitOrder): 

331 """Number of parameters per polynomial is (d+1)(d+2)/2, summed over 

332 polynomials, times 2 polynomials per dimension. 

333 The chip transform is fixed for one chip, so only visitOrder matters 

334 if chipOrder is None. 

335 """ 

336 params = getNParametersPolynomial(visitOrder) 

337 if chipOrder is not None: 

338 params += getNParametersPolynomial(chipOrder) 

339 return params 

340 

341 def _testGetNpar(self, model, chipOrder, visitOrder): 

342 def checkParams(ccdImage, model, chipOrder, visitOrder): 

343 result = model.getNpar(ccdImage) 

344 failMsg = "ccdImage: %s, with chipOrder %s and visitOrder %s"%(ccdImage.getName(), 

345 chipOrder, 

346 visitOrder) 

347 self.assertEqual(result, self._polyParams(chipOrder, visitOrder), msg=failMsg) 

348 

349 for ccdImage in self.associations.getCcdImageList(): 

350 realChipOrder = None if ccdImage.getCcdId() == self.fixedCcd else chipOrder 

351 checkParams(ccdImage, model, realChipOrder, visitOrder) 

352 

353 def testGetNpar1(self): 

354 self._testGetNpar(self.model1, self.chipOrder1, self.visitOrder1) 

355 

356 def testGetNpar2(self): 

357 self._testGetNpar(self.model2, self.chipOrder2, self.visitOrder2) 

358 

359 def _testGetTotalParameters(self, model, chipOrder, visitOrder): 

360 result = model.getTotalParameters() 

361 # one sensor is held fixed, hence len(ccds)-1 

362 expect = getNParametersPolynomial(chipOrder)*(len(self.detectors) - 1) + \ 

363 getNParametersPolynomial(visitOrder)*len(self.visits) 

364 self.assertEqual(result, expect) 

365 

366 def testGetTotalParametersModel1(self): 

367 self._testGetTotalParameters(self.model1, self.chipOrder1, self.visitOrder1) 

368 

369 def testGetTotalParametersModel2(self): 

370 self._testGetTotalParameters(self.model2, self.chipOrder2, self.visitOrder2) 

371 

372 def checkGetChipTransform(self, model): 

373 # Check valid ccds 

374 for ccd in self.detectors: 

375 try: 

376 model.getChipTransform(ccd) 

377 except lsst.pex.exceptions.wrappers.InvalidParameterError: 

378 self.fail("model: {} raised on ccd: {}, but should not have.".format(model, ccd)) 

379 

380 # Check an invalid ccd 

381 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm: 

382 model.getChipTransform(self.badCcd) 

383 errMsg = f"No such chipId: {self.badCcd} among [{', '.join(str(d) for d in self.detectors)}]" 

384 self.assertIn(errMsg, str(cm.exception)) 

385 

386 def testGetChipTransform(self): 

387 """getChipTransform should get each known transform, and raise with an 

388 appropriate message otherwise. 

389 """ 

390 self.checkGetChipTransform(self.model1) 

391 self.checkGetChipTransform(self.model2) 

392 

393 def checkGetVisitTransform(self, model): 

394 # Check valid visits 

395 for visit in self.visits: 

396 try: 

397 model.getVisitTransform(visit) 

398 except lsst.pex.exceptions.wrappers.InvalidParameterError: 

399 self.fail("model: {} raised on visit: {}, but should not have.".format(model, visit)) 

400 

401 # Check an invalid visit 

402 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm: 

403 model.getVisitTransform(self.badVisit) 

404 errMsg = "No such visitId: {} among [{}]".format(self.badVisit, 

405 ", ".join(str(v) for v in self.visits)) 

406 self.assertIn(errMsg, str(cm.exception)) 

407 

408 def testGetVisitTransform(self): 

409 """getVisitTransform should get each known transform, and raise with an 

410 appropriate message otherwise. 

411 """ 

412 self.checkGetVisitTransform(self.model1) 

413 self.checkGetVisitTransform(self.model2) 

414 

415 def testValidate(self): 

416 """Test that invalid models fail validate(), and that valid ones pass. 

417 """ 

418 # We need at least 0 degrees of freedom (data - parameters) for the model to be valid. 

419 # Note: model1 has 70 total parameters (2 visits*20 params + (6-1) sensors*5 params) 

420 self.assertTrue(self.model1.validate(self.ccdImageList, 0)) 

421 self.assertFalse(self.model1.validate(self.ccdImageList, -1)) 

422 

423 

424class MemoryTester(lsst.utils.tests.MemoryTestCase): 

425 pass 

426 

427 

428def setup_module(module): 

429 lsst.utils.tests.init() 

430 

431 

432if __name__ == "__main__": 432 ↛ 433line 432 didn't jump to line 433, because the condition on line 432 was never true

433 lsst.utils.tests.init() 

434 unittest.main()