Coverage for tests/test_astrometryModel.py: 22%

217 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-01-07 11:35 +0000

1# This file is part of jointcal. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Tests of astrometryModels (simple, constrained). 

23 

24Includes tests of producing a Wcs from a model. 

25""" 

26import itertools 

27import os 

28import numpy as np 

29import shutil 

30 

31import unittest 

32import lsst.utils.tests 

33 

34import lsst.afw.cameraGeom 

35import lsst.afw.geom 

36import lsst.afw.table 

37import lsst.afw.image 

38import lsst.geom 

39import lsst.log 

40 

41import lsst.jointcal 

42from lsst.jointcal.jointcal import (make_schema_table, get_sourceTable_visit_columns, 

43 extract_detector_catalog_from_visit_catalog) 

44from lsst.jointcal import astrometryModels 

45from jointcalTestBase import importRepository 

46 

47 

48def getNParametersPolynomial(order): 

49 """Number of parameters in an astrometry polynomial model is 2 * (d+1)(d+2)/2.""" 

50 return (order + 1)*(order + 2) 

51 

52 

53class AstrometryModelTestBase: 

54 """Test the jointcal AstrometryModel concrete classes, using CFHT data. 

55 """ 

56 @classmethod 

57 def setUpClass(cls): 

58 try: 

59 cls.dataDir = lsst.utils.getPackageDir('testdata_jointcal') 

60 except LookupError: 

61 raise unittest.SkipTest("testdata_jointcal not setup") 

62 

63 refcatPath = os.path.join(cls.dataDir, "cfht") 

64 refcats = {"gaia_dr2_20200414": os.path.join(refcatPath, "gaia_dr2_20200414.ecsv"), 

65 "ps1_pv3_3pi_20170110": os.path.join(refcatPath, "ps1_pv3_3pi_20170110.ecsv"), 

66 "sdss_dr9_fink_v5b": os.path.join(refcatPath, "sdss-dr9-fink-v5b.ecsv")} 

67 # Share one repo, since none of these tests write anything. 

68 cls.repopath = importRepository("lsst.obs.cfht.MegaPrime", 

69 os.path.join(cls.dataDir, 'cfht/repo/'), 

70 os.path.join(cls.dataDir, 'cfht/exports.yaml'), 

71 refcats=refcats, 

72 refcatPath=refcatPath) 

73 

74 @classmethod 

75 def tearDownClass(cls): 

76 shutil.rmtree(cls.repopath, ignore_errors=True) 

77 

78 def setUp(self): 

79 np.random.seed(200) 

80 

81 # DEBUG messages can help track down failures. 

82 logger = lsst.log.Log.getLogger('lsst.jointcal') 

83 logger.setLevel(lsst.log.DEBUG) 

84 

85 # Append `msg` arguments to assert failures. 

86 self.longMessage = True 

87 # absolute tolerance on positional errors of 10 micro-arcsecond 

88 self.atol = 10.0 / (60 * 60 * 1e6) 

89 

90 # Maximum difference (see assertPairsAlmostEqual) for round-trip 

91 # testing of the inverse for models 1 (simpler) and 2 (more. 

92 # Replace either one for models that don't have as accurate an inverse. 

93 self.inverseMaxDiff1 = 1e-5 

94 self.inverseMaxDiff2 = 1e-5 

95 

96 self.firstIndex = 0 # for assignIndices 

97 matchCut = 2.0 # arcseconds 

98 minMeasurements = 2 # accept all star pairs. 

99 

100 jointcalControl = lsst.jointcal.JointcalControl("flux") 

101 self.associations = lsst.jointcal.Associations() 

102 config = lsst.jointcal.JointcalConfig() 

103 config.load(os.path.join(os.path.dirname(__file__), "config/config.py")) 

104 sourceSelector = config.sourceSelector.target(config=config.sourceSelector['science']) 

105 

106 # jointcal's cfht test data has 6 ccds and 2 visits. 

107 self.visits = [849375, 850587] 

108 self.detectors = [12, 13, 14, 21, 22, 23] 

109 self.badVisit = -12345 

110 self.badCcd = 888 

111 

112 butler = lsst.daf.butler.Butler(self.repopath, collections='singleFrame', instrument="MegaPrime") 

113 

114 self.catalogs = [] 

115 self.ccdImageList = [] 

116 table = make_schema_table() 

117 inColumns = butler.get("sourceTable_visit", visit=self.visits[0]) 

118 columns, ixxColumns = get_sourceTable_visit_columns(inColumns, config, sourceSelector) 

119 catalogs = {v: sourceSelector.run(butler.get('sourceTable_visit', 

120 visit=v, 

121 parameters={'columns': columns})) for v in self.visits} 

122 for (visit, detector) in itertools.product(self.visits, self.detectors): 

123 goodSrc = extract_detector_catalog_from_visit_catalog(table, 

124 catalogs[visit].sourceCat, 

125 detector, 

126 ixxColumns, 

127 config.sourceFluxType, 

128 logger) 

129 dataId = {"detector": detector, "visit": visit} 

130 visitInfo = butler.get('calexp.visitInfo', dataId=dataId) 

131 detector = butler.get('calexp.detector', dataId=dataId) 

132 ccdId = detector.getId() 

133 wcs = butler.get('calexp.wcs', dataId=dataId) 

134 bbox = butler.get('calexp.bbox', dataId=dataId) 

135 filt = butler.get('calexp.filter', dataId=dataId) 

136 filterName = filt.physicalLabel 

137 photoCalib = lsst.afw.image.PhotoCalib(100.0, 1.0) 

138 

139 self.catalogs.append(goodSrc) 

140 self.associations.createCcdImage(goodSrc, 

141 wcs, 

142 visitInfo, 

143 bbox, 

144 filterName, 

145 photoCalib, 

146 detector, 

147 visit, 

148 ccdId, 

149 jointcalControl) 

150 

151 # Have to set the common tangent point so projectionHandler can use skyToCTP. 

152 self.associations.computeCommonTangentPoint() 

153 

154 self.projectionHandler = lsst.jointcal.OneTPPerVisitHandler(self.associations.getCcdImageList()) 

155 

156 self.associations.associateCatalogs(matchCut) 

157 self.associations.prepareFittedStars(minMeasurements) 

158 self.associations.deprojectFittedStars() 

159 

160 def _prepModels(self): 

161 """Call this after model1 and model2 are created, to call assignIndices, 

162 and instantiate the fitters. 

163 """ 

164 posError = 0.02 # in pixels 

165 # have to call this once or offsetParams will fail because the transform indices aren't defined 

166 self.model1.assignIndices("Distortions", self.firstIndex) 

167 self.fitter1 = lsst.jointcal.AstrometryFit(self.associations, self.model1, posError) 

168 

169 # have to call this once or offsetParams will fail because the transform indices aren't defined 

170 self.model2.assignIndices("Distortions", self.firstIndex) 

171 self.fitter2 = lsst.jointcal.AstrometryFit(self.associations, self.model2, posError) 

172 

173 def testMakeSkyWcsModel1(self): 

174 self.checkMakeSkyWcsModel(self.model1, self.fitter1, self.inverseMaxDiff1) 

175 

176 def testMakeSkyWcsModel2(self): 

177 self.checkMakeSkyWcsModel(self.model2, self.fitter2, self.inverseMaxDiff2) 

178 

179 def checkMakeSkyWcsModel(self, model, fitter, inverseMaxDiff): 

180 """Test producing a SkyWcs on a model for every cdImage, 

181 both post-initialization and after one fitting step. 

182 

183 Parameters 

184 ---------- 

185 model : `lsst.jointcal.AstrometryModel` 

186 The model to test. 

187 fitter : `lsst.jointcal.FitterBase` 

188 The fitter to use to step the model to test with new (reasonable) parameters. 

189 inverseMaxDiff : `float` 

190 Required accuracy on inverse transform. 

191 See `lsst.afw.geom.utils.assertPairsAlmostEqual`. 

192 

193 """ 

194 # first test on as-initialized models 

195 for ccdImage in self.associations.getCcdImageList(): 

196 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff) 

197 

198 # now shift the models to non-default, but more reasonable, values by taking one fitting step. 

199 fitter.minimize("DistortionsVisit") 

200 fitter.minimize("Distortions") 

201 for ccdImage in self.associations.getCcdImageList(): 

202 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff) 

203 

204 def checkMakeSkyWcsOneCcdImage(self, model, ccdImage, inverseMaxDiff): 

205 """Test converting the model of one ccdImage to a SkyWcs by comparing 

206 to the original transform at the tangent plane. 

207 

208 Parameters 

209 ---------- 

210 model : `lsst.jointcal.AstrometryModel` 

211 The model to test. 

212 ccdImage : `lsst.jointcal.CcdImage` 

213 The ccdImage to extract from the model and test. 

214 inverseMaxDiff : `float` 

215 Required accuracy on inverse transform. 

216 See `lsst.afw.geom.utils.assertPairsAlmostEqual`. 

217 """ 

218 skyWcs = model.makeSkyWcs(ccdImage) 

219 skyToTangentPlane = model.getSkyToTangentPlane(ccdImage) 

220 mapping = model.getMapping(ccdImage) 

221 

222 bbox = ccdImage.getDetector().getBBox() 

223 num = 200 

224 xx = np.linspace(bbox.getMinX(), bbox.getMaxX(), num) 

225 yy = np.linspace(bbox.getMinY(), bbox.getMaxY(), num) 

226 points = [lsst.geom.Point2D(*xy) for xy in itertools.product(xx, yy)] 

227 

228 expects = [] 

229 forwards = [] 

230 inverses = [] 

231 spherePoints = skyWcs.pixelToSky(points) 

232 inverses = skyWcs.skyToPixel(skyWcs.pixelToSky(points)) 

233 for point, spherePoint in zip(points, spherePoints): 

234 # TODO: Fix these "Point"s once DM-4044 is done. 

235 

236 # jointcal's pixel->tangent-plane mapping 

237 star = lsst.jointcal.star.BaseStar(point.getX(), point.getY(), 0, 0) 

238 tpExpect = mapping.transformPosAndErrors(star) 

239 expects.append(lsst.geom.Point2D(tpExpect.x, tpExpect.y)) 

240 

241 # skywcs takes pixel->sky, and we then have to go sky->tangent-plane 

242 onSky = lsst.jointcal.star.BaseStar(spherePoint.getLongitude().asDegrees(), 

243 spherePoint.getLatitude().asDegrees(), 0, 0) 

244 result = skyToTangentPlane.apply(onSky) 

245 forwards.append(lsst.geom.Point2D(result.x, result.y)) 

246 

247 self.assertPairListsAlmostEqual(forwards, expects) 

248 # NOTE: assertPairListsAlmostEqual() compares absolute, not relative, 

249 # values so the points along the ccd edge may exceed maxDiff while still 

250 # being "close enough": set `inverseMaxDiff` accordingly. 

251 self.assertPairListsAlmostEqual(inverses, points, maxDiff=inverseMaxDiff) 

252 

253 

254class SimpleAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase): 

255 """Test the `SimpleAstrometryModel`, with one mapping per ccd per visit.""" 

256 def setUp(self): 

257 super().setUp() 

258 self.order1 = 3 

259 self.inverseMaxDiff1 = 2e-4 

260 self.model1 = astrometryModels.SimpleAstrometryModel(self.associations.getCcdImageList(), 

261 self.projectionHandler, 

262 True, 

263 order=self.order1) 

264 

265 self.order2 = 5 

266 # NOTE: because assertPairListsAlmostEqual tests an absolute 

267 # difference, we need this to be relatively high to avoid spurious 

268 # incorrect values. 

269 # Alternately, further increasing the order of the inverse polynomial 

270 # in astrometryTransform.toAstMap() can improve the quality of the 

271 # SkyWcs inverse, but that may not be wise for the more general use 

272 # case due to the inverse then having too many wiggles. 

273 self.inverseMaxDiff2 = 2e-2 

274 self.model2 = astrometryModels.SimpleAstrometryModel(self.associations.getCcdImageList(), 

275 self.projectionHandler, 

276 False, 

277 order=self.order2) 

278 self._prepModels() 

279 

280 def _testGetNpar(self, model, order): 

281 for ccdImage in self.associations.getCcdImageList(): 

282 result = model.getNpar(ccdImage) 

283 self.assertEqual(result, getNParametersPolynomial(order)) 

284 

285 def testGetNpar1(self): 

286 self._testGetNpar(self.model1, self.order1) 

287 

288 def testGetNpar2(self): 

289 self._testGetNpar(self.model2, self.order2) 

290 

291 def _testGetTotalParameters(self, model, order): 

292 result = model.getTotalParameters() 

293 expect = getNParametersPolynomial(order)*len(self.associations.getCcdImageList()) 

294 self.assertEqual(result, expect) 

295 

296 def testGetTotalParametersModel1(self): 

297 self._testGetTotalParameters(self.model1, self.order1) 

298 

299 def testGetTotalParametersModel2(self): 

300 self._testGetTotalParameters(self.model2, self.order2) 

301 

302 

303class ConstrainedAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase): 

304 """Test the `ConstrainedAstrometryModel`, with one mapping per ccd and one 

305 mapping per visit. 

306 """ 

307 def setUp(self): 

308 super().setUp() 

309 self.visitOrder1 = 3 

310 self.chipOrder1 = 1 

311 self.inverseMaxDiff1 = 1e-5 

312 self.model1 = astrometryModels.ConstrainedAstrometryModel(self.associations.getCcdImageList(), 

313 self.projectionHandler, 

314 chipOrder=self.chipOrder1, 

315 visitOrder=self.visitOrder1) 

316 

317 self.visitOrder2 = 5 

318 self.chipOrder2 = 2 

319 self.inverseMaxDiff2 = 8e-5 

320 self.model2 = astrometryModels.ConstrainedAstrometryModel(self.associations.getCcdImageList(), 

321 self.projectionHandler, 

322 chipOrder=self.chipOrder2, 

323 visitOrder=self.visitOrder2) 

324 self._prepModels() 

325 

326 # 22 is closest to the center of the focal plane in this data, so it is not fit. 

327 self.fixedCcd = 22 

328 

329 def _polyParams(self, chipOrder, visitOrder): 

330 """Number of parameters per polynomial is (d+1)(d+2)/2, summed over 

331 polynomials, times 2 polynomials per dimension. 

332 The chip transform is fixed for one chip, so only visitOrder matters 

333 if chipOrder is None. 

334 """ 

335 params = getNParametersPolynomial(visitOrder) 

336 if chipOrder is not None: 

337 params += getNParametersPolynomial(chipOrder) 

338 return params 

339 

340 def _testGetNpar(self, model, chipOrder, visitOrder): 

341 def checkParams(ccdImage, model, chipOrder, visitOrder): 

342 result = model.getNpar(ccdImage) 

343 failMsg = "ccdImage: %s, with chipOrder %s and visitOrder %s"%(ccdImage.getName(), 

344 chipOrder, 

345 visitOrder) 

346 self.assertEqual(result, self._polyParams(chipOrder, visitOrder), msg=failMsg) 

347 

348 for ccdImage in self.associations.getCcdImageList(): 

349 realChipOrder = None if ccdImage.getCcdId() == self.fixedCcd else chipOrder 

350 checkParams(ccdImage, model, realChipOrder, visitOrder) 

351 

352 def testGetNpar1(self): 

353 self._testGetNpar(self.model1, self.chipOrder1, self.visitOrder1) 

354 

355 def testGetNpar2(self): 

356 self._testGetNpar(self.model2, self.chipOrder2, self.visitOrder2) 

357 

358 def _testGetTotalParameters(self, model, chipOrder, visitOrder): 

359 result = model.getTotalParameters() 

360 # one sensor is held fixed, hence len(ccds)-1 

361 expect = getNParametersPolynomial(chipOrder)*(len(self.detectors) - 1) + \ 

362 getNParametersPolynomial(visitOrder)*len(self.visits) 

363 self.assertEqual(result, expect) 

364 

365 def testGetTotalParametersModel1(self): 

366 self._testGetTotalParameters(self.model1, self.chipOrder1, self.visitOrder1) 

367 

368 def testGetTotalParametersModel2(self): 

369 self._testGetTotalParameters(self.model2, self.chipOrder2, self.visitOrder2) 

370 

371 def checkGetChipTransform(self, model): 

372 # Check valid ccds 

373 for ccd in self.detectors: 

374 try: 

375 model.getChipTransform(ccd) 

376 except lsst.pex.exceptions.wrappers.InvalidParameterError: 

377 self.fail("model: {} raised on ccd: {}, but should not have.".format(model, ccd)) 

378 

379 # Check an invalid ccd 

380 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm: 

381 model.getChipTransform(self.badCcd) 

382 errMsg = f"No such chipId: {self.badCcd} among [{', '.join(str(d) for d in self.detectors)}]" 

383 self.assertIn(errMsg, str(cm.exception)) 

384 

385 def testGetChipTransform(self): 

386 """getChipTransform should get each known transform, and raise with an 

387 appropriate message otherwise. 

388 """ 

389 self.checkGetChipTransform(self.model1) 

390 self.checkGetChipTransform(self.model2) 

391 

392 def checkGetVisitTransform(self, model): 

393 # Check valid visits 

394 for visit in self.visits: 

395 try: 

396 model.getVisitTransform(visit) 

397 except lsst.pex.exceptions.wrappers.InvalidParameterError: 

398 self.fail("model: {} raised on visit: {}, but should not have.".format(model, visit)) 

399 

400 # Check an invalid visit 

401 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm: 

402 model.getVisitTransform(self.badVisit) 

403 errMsg = "No such visitId: {} among [{}]".format(self.badVisit, 

404 ", ".join(str(v) for v in self.visits)) 

405 self.assertIn(errMsg, str(cm.exception)) 

406 

407 def testGetVisitTransform(self): 

408 """getVisitTransform should get each known transform, and raise with an 

409 appropriate message otherwise. 

410 """ 

411 self.checkGetVisitTransform(self.model1) 

412 self.checkGetVisitTransform(self.model2) 

413 

414 def testValidate(self): 

415 """Test that invalid models fail validate(), and that valid ones pass. 

416 """ 

417 # We need at least 0 degrees of freedom (data - parameters) for the model to be valid. 

418 # Note: model1 has 70 total parameters (2 visits*20 params + (6-1) sensors*5 params) 

419 self.assertTrue(self.model1.validate(self.ccdImageList, 0)) 

420 self.assertFalse(self.model1.validate(self.ccdImageList, -1)) 

421 

422 

423class MemoryTester(lsst.utils.tests.MemoryTestCase): 

424 pass 

425 

426 

427def setup_module(module): 

428 lsst.utils.tests.init() 

429 

430 

431if __name__ == "__main__": 431 ↛ 432line 431 didn't jump to line 432, because the condition on line 431 was never true

432 lsst.utils.tests.init() 

433 unittest.main()