Coverage for tests/test_astrometryModel.py: 23%

216 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-19 10:17 +0000

1# This file is part of jointcal. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Tests of astrometryModels (simple, constrained). 

23 

24Includes tests of producing a Wcs from a model. 

25""" 

26import itertools 

27import os 

28import numpy as np 

29import shutil 

30 

31import unittest 

32import lsst.utils.tests 

33 

34import lsst.afw.cameraGeom 

35import lsst.afw.geom 

36import lsst.afw.table 

37import lsst.afw.image 

38import lsst.geom 

39import lsst.log 

40 

41import lsst.jointcal 

42from lsst.jointcal.jointcal import (make_schema_table, get_sourceTable_visit_columns, 

43 extract_detector_catalog_from_visit_catalog) 

44from jointcalTestBase import importRepository 

45 

46 

47def getNParametersPolynomial(order): 

48 """Number of parameters in an astrometry polynomial model is 2 * (d+1)(d+2)/2.""" 

49 return (order + 1)*(order + 2) 

50 

51 

52class AstrometryModelTestBase: 

53 """Test the jointcal AstrometryModel concrete classes, using CFHT data. 

54 """ 

55 @classmethod 

56 def setUpClass(cls): 

57 try: 

58 cls.dataDir = lsst.utils.getPackageDir('testdata_jointcal') 

59 except LookupError: 

60 raise unittest.SkipTest("testdata_jointcal not setup") 

61 

62 refcatPath = os.path.join(cls.dataDir, "cfht") 

63 refcats = {"gaia_dr2_20200414": os.path.join(refcatPath, "gaia_dr2_20200414.ecsv"), 

64 "ps1_pv3_3pi_20170110": os.path.join(refcatPath, "ps1_pv3_3pi_20170110.ecsv"), 

65 "sdss_dr9_fink_v5b": os.path.join(refcatPath, "sdss-dr9-fink-v5b.ecsv")} 

66 # Share one repo, since none of these tests write anything. 

67 cls.repopath = importRepository("lsst.obs.cfht.MegaPrime", 

68 os.path.join(cls.dataDir, 'cfht/repo/'), 

69 os.path.join(cls.dataDir, 'cfht/exports.yaml'), 

70 refcats=refcats, 

71 refcatPath=refcatPath) 

72 

73 @classmethod 

74 def tearDownClass(cls): 

75 shutil.rmtree(cls.repopath, ignore_errors=True) 

76 

77 def setUp(self): 

78 np.random.seed(200) 

79 

80 # DEBUG messages can help track down failures. 

81 logger = lsst.log.Log.getLogger('lsst.jointcal') 

82 logger.setLevel(lsst.log.DEBUG) 

83 

84 # Append `msg` arguments to assert failures. 

85 self.longMessage = True 

86 # absolute tolerance on positional errors of 10 micro-arcsecond 

87 self.atol = 10.0 / (60 * 60 * 1e6) 

88 

89 # Maximum difference (see assertPairsAlmostEqual) for round-trip 

90 # testing of the inverse for models 1 (simpler) and 2 (more. 

91 # Replace either one for models that don't have as accurate an inverse. 

92 self.inverseMaxDiff1 = 1e-5 

93 self.inverseMaxDiff2 = 1e-5 

94 

95 self.firstIndex = 0 # for assignIndices 

96 matchCut = 2.0 # arcseconds 

97 minMeasurements = 2 # accept all star pairs. 

98 

99 jointcalControl = lsst.jointcal.JointcalControl("flux") 

100 self.associations = lsst.jointcal.Associations() 

101 config = lsst.jointcal.JointcalConfig() 

102 config.load(os.path.join(os.path.dirname(__file__), "config/config.py")) 

103 sourceSelector = config.sourceSelector.target(config=config.sourceSelector['science']) 

104 

105 # jointcal's cfht test data has 6 ccds and 2 visits. 

106 self.visits = [849375, 850587] 

107 self.detectors = [12, 13, 14, 21, 22, 23] 

108 self.badVisit = -12345 

109 self.badCcd = 888 

110 

111 butler = lsst.daf.butler.Butler(self.repopath, collections='singleFrame', instrument="MegaPrime") 

112 

113 self.catalogs = [] 

114 self.ccdImageList = [] 

115 table = make_schema_table() 

116 inColumns = butler.get("sourceTable_visit", visit=self.visits[0]) 

117 columns, ixxColumns = get_sourceTable_visit_columns(inColumns, config, sourceSelector) 

118 catalogs = {v: sourceSelector.run(butler.get('sourceTable_visit', 

119 visit=v, 

120 parameters={'columns': columns})) for v in self.visits} 

121 for (visit, detector) in itertools.product(self.visits, self.detectors): 

122 goodSrc = extract_detector_catalog_from_visit_catalog(table, 

123 catalogs[visit].sourceCat, 

124 detector, 

125 ixxColumns, 

126 config.sourceFluxType, 

127 logger) 

128 dataId = {"detector": detector, "visit": visit} 

129 visitInfo = butler.get('calexp.visitInfo', dataId=dataId) 

130 detector = butler.get('calexp.detector', dataId=dataId) 

131 ccdId = detector.getId() 

132 wcs = butler.get('calexp.wcs', dataId=dataId) 

133 bbox = butler.get('calexp.bbox', dataId=dataId) 

134 filt = butler.get('calexp.filter', dataId=dataId) 

135 filterName = filt.physicalLabel 

136 photoCalib = lsst.afw.image.PhotoCalib(100.0, 1.0) 

137 

138 self.catalogs.append(goodSrc) 

139 self.associations.createCcdImage(goodSrc, 

140 wcs, 

141 visitInfo, 

142 bbox, 

143 filterName, 

144 photoCalib, 

145 detector, 

146 visit, 

147 ccdId, 

148 jointcalControl) 

149 

150 # Have to set the common tangent point so projectionHandler can use skyToCTP. 

151 self.associations.computeCommonTangentPoint() 

152 

153 self.projectionHandler = lsst.jointcal.OneTPPerVisitHandler(self.associations.getCcdImageList()) 

154 

155 self.associations.associateCatalogs(matchCut) 

156 self.associations.prepareFittedStars(minMeasurements) 

157 self.associations.deprojectFittedStars() 

158 

159 def _prepModels(self): 

160 """Call this after model1 and model2 are created, to call assignIndices, 

161 and instantiate the fitters. 

162 """ 

163 posError = 0.02 # in pixels 

164 # have to call this once or offsetParams will fail because the transform indices aren't defined 

165 self.model1.assignIndices("Distortions", self.firstIndex) 

166 self.fitter1 = lsst.jointcal.AstrometryFit(self.associations, self.model1, posError) 

167 

168 # have to call this once or offsetParams will fail because the transform indices aren't defined 

169 self.model2.assignIndices("Distortions", self.firstIndex) 

170 self.fitter2 = lsst.jointcal.AstrometryFit(self.associations, self.model2, posError) 

171 

172 def testMakeSkyWcsModel1(self): 

173 self.checkMakeSkyWcsModel(self.model1, self.fitter1, self.inverseMaxDiff1) 

174 

175 def testMakeSkyWcsModel2(self): 

176 self.checkMakeSkyWcsModel(self.model2, self.fitter2, self.inverseMaxDiff2) 

177 

178 def checkMakeSkyWcsModel(self, model, fitter, inverseMaxDiff): 

179 """Test producing a SkyWcs on a model for every cdImage, 

180 both post-initialization and after one fitting step. 

181 

182 Parameters 

183 ---------- 

184 model : `lsst.jointcal.AstrometryModel` 

185 The model to test. 

186 fitter : `lsst.jointcal.FitterBase` 

187 The fitter to use to step the model to test with new (reasonable) parameters. 

188 inverseMaxDiff : `float` 

189 Required accuracy on inverse transform. 

190 See `lsst.afw.geom.utils.assertPairsAlmostEqual`. 

191 

192 """ 

193 # first test on as-initialized models 

194 for ccdImage in self.associations.getCcdImageList(): 

195 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff) 

196 

197 # now shift the models to non-default, but more reasonable, values by taking one fitting step. 

198 fitter.minimize("DistortionsVisit") 

199 fitter.minimize("Distortions") 

200 for ccdImage in self.associations.getCcdImageList(): 

201 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff) 

202 

203 def checkMakeSkyWcsOneCcdImage(self, model, ccdImage, inverseMaxDiff): 

204 """Test converting the model of one ccdImage to a SkyWcs by comparing 

205 to the original transform at the tangent plane. 

206 

207 Parameters 

208 ---------- 

209 model : `lsst.jointcal.AstrometryModel` 

210 The model to test. 

211 ccdImage : `lsst.jointcal.CcdImage` 

212 The ccdImage to extract from the model and test. 

213 inverseMaxDiff : `float` 

214 Required accuracy on inverse transform. 

215 See `lsst.afw.geom.utils.assertPairsAlmostEqual`. 

216 """ 

217 skyWcs = model.makeSkyWcs(ccdImage) 

218 skyToTangentPlane = model.getSkyToTangentPlane(ccdImage) 

219 mapping = model.getMapping(ccdImage) 

220 

221 bbox = ccdImage.getDetector().getBBox() 

222 num = 200 

223 xx = np.linspace(bbox.getMinX(), bbox.getMaxX(), num) 

224 yy = np.linspace(bbox.getMinY(), bbox.getMaxY(), num) 

225 points = [lsst.geom.Point2D(*xy) for xy in itertools.product(xx, yy)] 

226 

227 expects = [] 

228 forwards = [] 

229 inverses = [] 

230 spherePoints = skyWcs.pixelToSky(points) 

231 inverses = skyWcs.skyToPixel(skyWcs.pixelToSky(points)) 

232 for point, spherePoint in zip(points, spherePoints): 

233 # TODO: Fix these "Point"s once DM-4044 is done. 

234 

235 # jointcal's pixel->tangent-plane mapping 

236 star = lsst.jointcal.BaseStar(point.getX(), point.getY(), 0, 0) 

237 tpExpect = mapping.transformPosAndErrors(star) 

238 expects.append(lsst.geom.Point2D(tpExpect.x, tpExpect.y)) 

239 

240 # skywcs takes pixel->sky, and we then have to go sky->tangent-plane 

241 onSky = lsst.jointcal.BaseStar(spherePoint.getLongitude().asDegrees(), 

242 spherePoint.getLatitude().asDegrees(), 0, 0) 

243 result = skyToTangentPlane.apply(onSky) 

244 forwards.append(lsst.geom.Point2D(result.x, result.y)) 

245 

246 self.assertPairListsAlmostEqual(forwards, expects) 

247 # NOTE: assertPairListsAlmostEqual() compares absolute, not relative, 

248 # values so the points along the ccd edge may exceed maxDiff while still 

249 # being "close enough": set `inverseMaxDiff` accordingly. 

250 self.assertPairListsAlmostEqual(inverses, points, maxDiff=inverseMaxDiff) 

251 

252 

253class SimpleAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase): 

254 """Test the `SimpleAstrometryModel`, with one mapping per ccd per visit.""" 

255 def setUp(self): 

256 super().setUp() 

257 self.order1 = 3 

258 self.inverseMaxDiff1 = 2e-4 

259 self.model1 = lsst.jointcal.SimpleAstrometryModel(self.associations.getCcdImageList(), 

260 self.projectionHandler, 

261 True, 

262 order=self.order1) 

263 

264 self.order2 = 5 

265 # NOTE: because assertPairListsAlmostEqual tests an absolute 

266 # difference, we need this to be relatively high to avoid spurious 

267 # incorrect values. 

268 # Alternately, further increasing the order of the inverse polynomial 

269 # in astrometryTransform.toAstMap() can improve the quality of the 

270 # SkyWcs inverse, but that may not be wise for the more general use 

271 # case due to the inverse then having too many wiggles. 

272 self.inverseMaxDiff2 = 2e-2 

273 self.model2 = lsst.jointcal.SimpleAstrometryModel(self.associations.getCcdImageList(), 

274 self.projectionHandler, 

275 False, 

276 order=self.order2) 

277 self._prepModels() 

278 

279 def _testGetNpar(self, model, order): 

280 for ccdImage in self.associations.getCcdImageList(): 

281 result = model.getNpar(ccdImage) 

282 self.assertEqual(result, getNParametersPolynomial(order)) 

283 

284 def testGetNpar1(self): 

285 self._testGetNpar(self.model1, self.order1) 

286 

287 def testGetNpar2(self): 

288 self._testGetNpar(self.model2, self.order2) 

289 

290 def _testGetTotalParameters(self, model, order): 

291 result = model.getTotalParameters() 

292 expect = getNParametersPolynomial(order)*len(self.associations.getCcdImageList()) 

293 self.assertEqual(result, expect) 

294 

295 def testGetTotalParametersModel1(self): 

296 self._testGetTotalParameters(self.model1, self.order1) 

297 

298 def testGetTotalParametersModel2(self): 

299 self._testGetTotalParameters(self.model2, self.order2) 

300 

301 

302class ConstrainedAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase): 

303 """Test the `ConstrainedAstrometryModel`, with one mapping per ccd and one 

304 mapping per visit. 

305 """ 

306 def setUp(self): 

307 super().setUp() 

308 self.visitOrder1 = 3 

309 self.chipOrder1 = 1 

310 self.inverseMaxDiff1 = 1e-5 

311 self.model1 = lsst.jointcal.ConstrainedAstrometryModel(self.associations.getCcdImageList(), 

312 self.projectionHandler, 

313 chipOrder=self.chipOrder1, 

314 visitOrder=self.visitOrder1) 

315 

316 self.visitOrder2 = 5 

317 self.chipOrder2 = 2 

318 self.inverseMaxDiff2 = 8e-5 

319 self.model2 = lsst.jointcal.ConstrainedAstrometryModel(self.associations.getCcdImageList(), 

320 self.projectionHandler, 

321 chipOrder=self.chipOrder2, 

322 visitOrder=self.visitOrder2) 

323 self._prepModels() 

324 

325 # 22 is closest to the center of the focal plane in this data, so it is not fit. 

326 self.fixedCcd = 22 

327 

328 def _polyParams(self, chipOrder, visitOrder): 

329 """Number of parameters per polynomial is (d+1)(d+2)/2, summed over 

330 polynomials, times 2 polynomials per dimension. 

331 The chip transform is fixed for one chip, so only visitOrder matters 

332 if chipOrder is None. 

333 """ 

334 params = getNParametersPolynomial(visitOrder) 

335 if chipOrder is not None: 

336 params += getNParametersPolynomial(chipOrder) 

337 return params 

338 

339 def _testGetNpar(self, model, chipOrder, visitOrder): 

340 def checkParams(ccdImage, model, chipOrder, visitOrder): 

341 result = model.getNpar(ccdImage) 

342 failMsg = "ccdImage: %s, with chipOrder %s and visitOrder %s"%(ccdImage.getName(), 

343 chipOrder, 

344 visitOrder) 

345 self.assertEqual(result, self._polyParams(chipOrder, visitOrder), msg=failMsg) 

346 

347 for ccdImage in self.associations.getCcdImageList(): 

348 realChipOrder = None if ccdImage.getCcdId() == self.fixedCcd else chipOrder 

349 checkParams(ccdImage, model, realChipOrder, visitOrder) 

350 

351 def testGetNpar1(self): 

352 self._testGetNpar(self.model1, self.chipOrder1, self.visitOrder1) 

353 

354 def testGetNpar2(self): 

355 self._testGetNpar(self.model2, self.chipOrder2, self.visitOrder2) 

356 

357 def _testGetTotalParameters(self, model, chipOrder, visitOrder): 

358 result = model.getTotalParameters() 

359 # one sensor is held fixed, hence len(ccds)-1 

360 expect = getNParametersPolynomial(chipOrder)*(len(self.detectors) - 1) + \ 

361 getNParametersPolynomial(visitOrder)*len(self.visits) 

362 self.assertEqual(result, expect) 

363 

364 def testGetTotalParametersModel1(self): 

365 self._testGetTotalParameters(self.model1, self.chipOrder1, self.visitOrder1) 

366 

367 def testGetTotalParametersModel2(self): 

368 self._testGetTotalParameters(self.model2, self.chipOrder2, self.visitOrder2) 

369 

370 def checkGetChipTransform(self, model): 

371 # Check valid ccds 

372 for ccd in self.detectors: 

373 try: 

374 model.getChipTransform(ccd) 

375 except lsst.pex.exceptions.wrappers.InvalidParameterError: 

376 self.fail("model: {} raised on ccd: {}, but should not have.".format(model, ccd)) 

377 

378 # Check an invalid ccd 

379 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm: 

380 model.getChipTransform(self.badCcd) 

381 errMsg = f"No such chipId: {self.badCcd} among [{', '.join(str(d) for d in self.detectors)}]" 

382 self.assertIn(errMsg, str(cm.exception)) 

383 

384 def testGetChipTransform(self): 

385 """getChipTransform should get each known transform, and raise with an 

386 appropriate message otherwise. 

387 """ 

388 self.checkGetChipTransform(self.model1) 

389 self.checkGetChipTransform(self.model2) 

390 

391 def checkGetVisitTransform(self, model): 

392 # Check valid visits 

393 for visit in self.visits: 

394 try: 

395 model.getVisitTransform(visit) 

396 except lsst.pex.exceptions.wrappers.InvalidParameterError: 

397 self.fail("model: {} raised on visit: {}, but should not have.".format(model, visit)) 

398 

399 # Check an invalid visit 

400 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm: 

401 model.getVisitTransform(self.badVisit) 

402 errMsg = "No such visitId: {} among [{}]".format(self.badVisit, 

403 ", ".join(str(v) for v in self.visits)) 

404 self.assertIn(errMsg, str(cm.exception)) 

405 

406 def testGetVisitTransform(self): 

407 """getVisitTransform should get each known transform, and raise with an 

408 appropriate message otherwise. 

409 """ 

410 self.checkGetVisitTransform(self.model1) 

411 self.checkGetVisitTransform(self.model2) 

412 

413 def testValidate(self): 

414 """Test that invalid models fail validate(), and that valid ones pass. 

415 """ 

416 # We need at least 0 degrees of freedom (data - parameters) for the model to be valid. 

417 # Note: model1 has 70 total parameters (2 visits*20 params + (6-1) sensors*5 params) 

418 self.assertTrue(self.model1.validate(self.ccdImageList, 0)) 

419 self.assertFalse(self.model1.validate(self.ccdImageList, -1)) 

420 

421 

422class MemoryTester(lsst.utils.tests.MemoryTestCase): 

423 pass 

424 

425 

426def setup_module(module): 

427 lsst.utils.tests.init() 

428 

429 

430if __name__ == "__main__": 430 ↛ 431line 430 didn't jump to line 431, because the condition on line 430 was never true

431 lsst.utils.tests.init() 

432 unittest.main()