Coverage for tests/test_astrometryModel.py: 25%

219 statements  

« prev     ^ index     » next       coverage.py v6.4, created at 2022-05-26 12:38 +0000

1# This file is part of jointcal. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Tests of astrometryModels (simple, constrained). 

23 

24Includes tests of producing a Wcs from a model. 

25""" 

26import itertools 

27import os 

28import numpy as np 

29import shutil 

30 

31import unittest 

32import lsst.utils.tests 

33 

34import lsst.afw.cameraGeom 

35import lsst.afw.geom 

36import lsst.afw.table 

37import lsst.afw.image 

38import lsst.afw.image.utils 

39import lsst.geom 

40import lsst.log 

41 

42import lsst.jointcal 

43from lsst.jointcal.jointcal import (make_schema_table, get_sourceTable_visit_columns, 

44 extract_detector_catalog_from_visit_catalog) 

45from lsst.jointcal import astrometryModels 

46from jointcalTestBase import importRepository 

47 

48 

49def getNParametersPolynomial(order): 

50 """Number of parameters in an astrometry polynomial model is 2 * (d+1)(d+2)/2.""" 

51 return (order + 1)*(order + 2) 

52 

53 

54class AstrometryModelTestBase: 

55 """Test the jointcal AstrometryModel concrete classes, using CFHT data. 

56 """ 

57 @classmethod 

58 def setUpClass(cls): 

59 try: 

60 cls.dataDir = lsst.utils.getPackageDir('testdata_jointcal') 

61 except LookupError: 

62 raise unittest.SkipTest("testdata_jointcal not setup") 

63 

64 refcatPath = os.path.join(cls.dataDir, "cfht") 

65 refcats = {"gaia_dr2_20200414": os.path.join(refcatPath, "gaia_dr2_20200414.ecsv"), 

66 "ps1_pv3_3pi_20170110": os.path.join(refcatPath, "ps1_pv3_3pi_20170110.ecsv"), 

67 "sdss_dr9_fink_v5b": os.path.join(refcatPath, "sdss-dr9-fink-v5b.ecsv")} 

68 # Share one repo, since none of these tests write anything. 

69 cls.repopath = importRepository("lsst.obs.cfht.MegaPrime", 

70 os.path.join(cls.dataDir, 'cfht/repo/'), 

71 os.path.join(cls.dataDir, 'cfht/exports.yaml'), 

72 refcats=refcats, 

73 refcatPath=refcatPath) 

74 

75 @classmethod 

76 def tearDownClass(cls): 

77 shutil.rmtree(cls.repopath, ignore_errors=True) 

78 

79 def setUp(self): 

80 np.random.seed(200) 

81 

82 # DEBUG messages can help track down failures. 

83 logger = lsst.log.Log.getLogger('lsst.jointcal') 

84 logger.setLevel(lsst.log.DEBUG) 

85 

86 # Append `msg` arguments to assert failures. 

87 self.longMessage = True 

88 # absolute tolerance on positional errors of 10 micro-arcsecond 

89 self.atol = 10.0 / (60 * 60 * 1e6) 

90 

91 # Maximum difference (see assertPairsAlmostEqual) for round-trip 

92 # testing of the inverse for models 1 (simpler) and 2 (more. 

93 # Replace either one for models that don't have as accurate an inverse. 

94 self.inverseMaxDiff1 = 1e-5 

95 self.inverseMaxDiff2 = 1e-5 

96 

97 self.firstIndex = 0 # for assignIndices 

98 matchCut = 2.0 # arcseconds 

99 minMeasurements = 2 # accept all star pairs. 

100 

101 jointcalControl = lsst.jointcal.JointcalControl("flux") 

102 self.associations = lsst.jointcal.Associations() 

103 config = lsst.jointcal.JointcalConfig() 

104 config.load(os.path.join(os.path.dirname(__file__), "config/config.py")) 

105 sourceSelector = config.sourceSelector.target(config=config.sourceSelector['science']) 

106 

107 # Ensure that the filter list is reset for each test so that we avoid 

108 # confusion or contamination each time we create a cfht camera below. 

109 lsst.afw.image.utils.resetFilters() 

110 

111 # jointcal's cfht test data has 6 ccds and 2 visits. 

112 self.visits = [849375, 850587] 

113 self.detectors = [12, 13, 14, 21, 22, 23] 

114 self.badVisit = -12345 

115 self.badCcd = 888 

116 

117 butler = lsst.daf.butler.Butler(self.repopath, collections='singleFrame', instrument="MegaPrime") 

118 

119 self.catalogs = [] 

120 self.ccdImageList = [] 

121 table = make_schema_table() 

122 inColumns = butler.get("sourceTable_visit", visit=self.visits[0]) 

123 columns, detColumn, ixxColumns = get_sourceTable_visit_columns(inColumns, config, sourceSelector) 

124 catalogs = {v: sourceSelector.run(butler.get('sourceTable_visit', 

125 visit=v, 

126 parameters={'columns': columns})) for v in self.visits} 

127 for (visit, detector) in itertools.product(self.visits, self.detectors): 

128 goodSrc = extract_detector_catalog_from_visit_catalog(table, 

129 catalogs[visit].sourceCat, 

130 detector, 

131 detColumn, 

132 ixxColumns, 

133 config.sourceFluxType, 

134 logger) 

135 dataId = {"detector": detector, "visit": visit} 

136 visitInfo = butler.get('calexp.visitInfo', dataId=dataId) 

137 detector = butler.get('calexp.detector', dataId=dataId) 

138 ccdId = detector.getId() 

139 wcs = butler.get('calexp.wcs', dataId=dataId) 

140 bbox = butler.get('calexp.bbox', dataId=dataId) 

141 filt = butler.get('calexp.filterLabel', dataId=dataId) 

142 filterName = filt.physicalLabel 

143 photoCalib = lsst.afw.image.PhotoCalib(100.0, 1.0) 

144 

145 self.catalogs.append(goodSrc) 

146 self.associations.createCcdImage(goodSrc, 

147 wcs, 

148 visitInfo, 

149 bbox, 

150 filterName, 

151 photoCalib, 

152 detector, 

153 visit, 

154 ccdId, 

155 jointcalControl) 

156 

157 # Have to set the common tangent point so projectionHandler can use skyToCTP. 

158 self.associations.computeCommonTangentPoint() 

159 

160 self.projectionHandler = lsst.jointcal.OneTPPerVisitHandler(self.associations.getCcdImageList()) 

161 

162 self.associations.associateCatalogs(matchCut) 

163 self.associations.prepareFittedStars(minMeasurements) 

164 self.associations.deprojectFittedStars() 

165 

166 def _prepModels(self): 

167 """Call this after model1 and model2 are created, to call assignIndices, 

168 and instantiate the fitters. 

169 """ 

170 posError = 0.02 # in pixels 

171 # have to call this once or offsetParams will fail because the transform indices aren't defined 

172 self.model1.assignIndices("Distortions", self.firstIndex) 

173 self.fitter1 = lsst.jointcal.AstrometryFit(self.associations, self.model1, posError) 

174 

175 # have to call this once or offsetParams will fail because the transform indices aren't defined 

176 self.model2.assignIndices("Distortions", self.firstIndex) 

177 self.fitter2 = lsst.jointcal.AstrometryFit(self.associations, self.model2, posError) 

178 

179 def testMakeSkyWcsModel1(self): 

180 self.checkMakeSkyWcsModel(self.model1, self.fitter1, self.inverseMaxDiff1) 

181 

182 def testMakeSkyWcsModel2(self): 

183 self.checkMakeSkyWcsModel(self.model2, self.fitter2, self.inverseMaxDiff2) 

184 

185 def checkMakeSkyWcsModel(self, model, fitter, inverseMaxDiff): 

186 """Test producing a SkyWcs on a model for every cdImage, 

187 both post-initialization and after one fitting step. 

188 

189 Parameters 

190 ---------- 

191 model : `lsst.jointcal.AstrometryModel` 

192 The model to test. 

193 fitter : `lsst.jointcal.FitterBase` 

194 The fitter to use to step the model to test with new (reasonable) parameters. 

195 inverseMaxDiff : `float` 

196 Required accuracy on inverse transform. 

197 See `lsst.afw.geom.utils.assertPairsAlmostEqual`. 

198 

199 """ 

200 # first test on as-initialized models 

201 for ccdImage in self.associations.getCcdImageList(): 

202 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff) 

203 

204 # now shift the models to non-default, but more reasonable, values by taking one fitting step. 

205 fitter.minimize("DistortionsVisit") 

206 fitter.minimize("Distortions") 

207 for ccdImage in self.associations.getCcdImageList(): 

208 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff) 

209 

210 def checkMakeSkyWcsOneCcdImage(self, model, ccdImage, inverseMaxDiff): 

211 """Test converting the model of one ccdImage to a SkyWcs by comparing 

212 to the original transform at the tangent plane. 

213 

214 Parameters 

215 ---------- 

216 model : `lsst.jointcal.AstrometryModel` 

217 The model to test. 

218 ccdImage : `lsst.jointcal.CcdImage` 

219 The ccdImage to extract from the model and test. 

220 inverseMaxDiff : `float` 

221 Required accuracy on inverse transform. 

222 See `lsst.afw.geom.utils.assertPairsAlmostEqual`. 

223 """ 

224 skyWcs = model.makeSkyWcs(ccdImage) 

225 skyToTangentPlane = model.getSkyToTangentPlane(ccdImage) 

226 mapping = model.getMapping(ccdImage) 

227 

228 bbox = ccdImage.getDetector().getBBox() 

229 num = 200 

230 xx = np.linspace(bbox.getMinX(), bbox.getMaxX(), num) 

231 yy = np.linspace(bbox.getMinY(), bbox.getMaxY(), num) 

232 points = [lsst.geom.Point2D(*xy) for xy in itertools.product(xx, yy)] 

233 

234 expects = [] 

235 forwards = [] 

236 inverses = [] 

237 spherePoints = skyWcs.pixelToSky(points) 

238 inverses = skyWcs.skyToPixel(skyWcs.pixelToSky(points)) 

239 for point, spherePoint in zip(points, spherePoints): 

240 # TODO: Fix these "Point"s once DM-4044 is done. 

241 

242 # jointcal's pixel->tangent-plane mapping 

243 star = lsst.jointcal.star.BaseStar(point.getX(), point.getY(), 0, 0) 

244 tpExpect = mapping.transformPosAndErrors(star) 

245 expects.append(lsst.geom.Point2D(tpExpect.x, tpExpect.y)) 

246 

247 # skywcs takes pixel->sky, and we then have to go sky->tangent-plane 

248 onSky = lsst.jointcal.star.BaseStar(spherePoint.getLongitude().asDegrees(), 

249 spherePoint.getLatitude().asDegrees(), 0, 0) 

250 result = skyToTangentPlane.apply(onSky) 

251 forwards.append(lsst.geom.Point2D(result.x, result.y)) 

252 

253 self.assertPairListsAlmostEqual(forwards, expects) 

254 # NOTE: assertPairListsAlmostEqual() compares absolute, not relative, 

255 # values so the points along the ccd edge may exceed maxDiff while still 

256 # being "close enough": set `inverseMaxDiff` accordingly. 

257 self.assertPairListsAlmostEqual(inverses, points, maxDiff=inverseMaxDiff) 

258 

259 

260class SimpleAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase): 

261 """Test the `SimpleAstrometryModel`, with one mapping per ccd per visit.""" 

262 def setUp(self): 

263 super().setUp() 

264 self.order1 = 3 

265 self.inverseMaxDiff1 = 2e-4 

266 self.model1 = astrometryModels.SimpleAstrometryModel(self.associations.getCcdImageList(), 

267 self.projectionHandler, 

268 True, 

269 order=self.order1) 

270 

271 self.order2 = 5 

272 # NOTE: because assertPairListsAlmostEqual tests an absolute 

273 # difference, we need this to be relatively high to avoid spurious 

274 # incorrect values. 

275 # Alternately, further increasing the order of the inverse polynomial 

276 # in astrometryTransform.toAstMap() can improve the quality of the 

277 # SkyWcs inverse, but that may not be wise for the more general use 

278 # case due to the inverse then having too many wiggles. 

279 self.inverseMaxDiff2 = 2e-2 

280 self.model2 = astrometryModels.SimpleAstrometryModel(self.associations.getCcdImageList(), 

281 self.projectionHandler, 

282 False, 

283 order=self.order2) 

284 self._prepModels() 

285 

286 def _testGetNpar(self, model, order): 

287 for ccdImage in self.associations.getCcdImageList(): 

288 result = model.getNpar(ccdImage) 

289 self.assertEqual(result, getNParametersPolynomial(order)) 

290 

291 def testGetNpar1(self): 

292 self._testGetNpar(self.model1, self.order1) 

293 

294 def testGetNpar2(self): 

295 self._testGetNpar(self.model2, self.order2) 

296 

297 def _testGetTotalParameters(self, model, order): 

298 result = model.getTotalParameters() 

299 expect = getNParametersPolynomial(order)*len(self.associations.getCcdImageList()) 

300 self.assertEqual(result, expect) 

301 

302 def testGetTotalParametersModel1(self): 

303 self._testGetTotalParameters(self.model1, self.order1) 

304 

305 def testGetTotalParametersModel2(self): 

306 self._testGetTotalParameters(self.model2, self.order2) 

307 

308 

309class ConstrainedAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase): 

310 """Test the `ConstrainedAstrometryModel`, with one mapping per ccd and one 

311 mapping per visit. 

312 """ 

313 def setUp(self): 

314 super().setUp() 

315 self.visitOrder1 = 3 

316 self.chipOrder1 = 1 

317 self.inverseMaxDiff1 = 1e-5 

318 self.model1 = astrometryModels.ConstrainedAstrometryModel(self.associations.getCcdImageList(), 

319 self.projectionHandler, 

320 chipOrder=self.chipOrder1, 

321 visitOrder=self.visitOrder1) 

322 

323 self.visitOrder2 = 5 

324 self.chipOrder2 = 2 

325 self.inverseMaxDiff2 = 8e-5 

326 self.model2 = astrometryModels.ConstrainedAstrometryModel(self.associations.getCcdImageList(), 

327 self.projectionHandler, 

328 chipOrder=self.chipOrder2, 

329 visitOrder=self.visitOrder2) 

330 self._prepModels() 

331 

332 # 22 is closest to the center of the focal plane in this data, so it is not fit. 

333 self.fixedCcd = 22 

334 

335 def _polyParams(self, chipOrder, visitOrder): 

336 """Number of parameters per polynomial is (d+1)(d+2)/2, summed over 

337 polynomials, times 2 polynomials per dimension. 

338 The chip transform is fixed for one chip, so only visitOrder matters 

339 if chipOrder is None. 

340 """ 

341 params = getNParametersPolynomial(visitOrder) 

342 if chipOrder is not None: 

343 params += getNParametersPolynomial(chipOrder) 

344 return params 

345 

346 def _testGetNpar(self, model, chipOrder, visitOrder): 

347 def checkParams(ccdImage, model, chipOrder, visitOrder): 

348 result = model.getNpar(ccdImage) 

349 failMsg = "ccdImage: %s, with chipOrder %s and visitOrder %s"%(ccdImage.getName(), 

350 chipOrder, 

351 visitOrder) 

352 self.assertEqual(result, self._polyParams(chipOrder, visitOrder), msg=failMsg) 

353 

354 for ccdImage in self.associations.getCcdImageList(): 

355 realChipOrder = None if ccdImage.getCcdId() == self.fixedCcd else chipOrder 

356 checkParams(ccdImage, model, realChipOrder, visitOrder) 

357 

358 def testGetNpar1(self): 

359 self._testGetNpar(self.model1, self.chipOrder1, self.visitOrder1) 

360 

361 def testGetNpar2(self): 

362 self._testGetNpar(self.model2, self.chipOrder2, self.visitOrder2) 

363 

364 def _testGetTotalParameters(self, model, chipOrder, visitOrder): 

365 result = model.getTotalParameters() 

366 # one sensor is held fixed, hence len(ccds)-1 

367 expect = getNParametersPolynomial(chipOrder)*(len(self.detectors) - 1) + \ 

368 getNParametersPolynomial(visitOrder)*len(self.visits) 

369 self.assertEqual(result, expect) 

370 

371 def testGetTotalParametersModel1(self): 

372 self._testGetTotalParameters(self.model1, self.chipOrder1, self.visitOrder1) 

373 

374 def testGetTotalParametersModel2(self): 

375 self._testGetTotalParameters(self.model2, self.chipOrder2, self.visitOrder2) 

376 

377 def checkGetChipTransform(self, model): 

378 # Check valid ccds 

379 for ccd in self.detectors: 

380 try: 

381 model.getChipTransform(ccd) 

382 except lsst.pex.exceptions.wrappers.InvalidParameterError: 

383 self.fail("model: {} raised on ccd: {}, but should not have.".format(model, ccd)) 

384 

385 # Check an invalid ccd 

386 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm: 

387 model.getChipTransform(self.badCcd) 

388 errMsg = f"No such chipId: {self.badCcd} among [{', '.join(str(d) for d in self.detectors)}]" 

389 self.assertIn(errMsg, str(cm.exception)) 

390 

391 def testGetChipTransform(self): 

392 """getChipTransform should get each known transform, and raise with an 

393 appropriate message otherwise. 

394 """ 

395 self.checkGetChipTransform(self.model1) 

396 self.checkGetChipTransform(self.model2) 

397 

398 def checkGetVisitTransform(self, model): 

399 # Check valid visits 

400 for visit in self.visits: 

401 try: 

402 model.getVisitTransform(visit) 

403 except lsst.pex.exceptions.wrappers.InvalidParameterError: 

404 self.fail("model: {} raised on visit: {}, but should not have.".format(model, visit)) 

405 

406 # Check an invalid visit 

407 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm: 

408 model.getVisitTransform(self.badVisit) 

409 errMsg = "No such visitId: {} among [{}]".format(self.badVisit, 

410 ", ".join(str(v) for v in self.visits)) 

411 self.assertIn(errMsg, str(cm.exception)) 

412 

413 def testGetVisitTransform(self): 

414 """getVisitTransform should get each known transform, and raise with an 

415 appropriate message otherwise. 

416 """ 

417 self.checkGetVisitTransform(self.model1) 

418 self.checkGetVisitTransform(self.model2) 

419 

420 def testValidate(self): 

421 """Test that invalid models fail validate(), and that valid ones pass. 

422 """ 

423 # We need at least 0 degrees of freedom (data - parameters) for the model to be valid. 

424 # Note: model1 has 70 total parameters (2 visits*20 params + (6-1) sensors*5 params) 

425 self.assertTrue(self.model1.validate(self.ccdImageList, 0)) 

426 self.assertFalse(self.model1.validate(self.ccdImageList, -1)) 

427 

428 

429class MemoryTester(lsst.utils.tests.MemoryTestCase): 

430 pass 

431 

432 

433def setup_module(module): 

434 lsst.utils.tests.init() 

435 

436 

437if __name__ == "__main__": 437 ↛ 438line 437 didn't jump to line 438, because the condition on line 437 was never true

438 lsst.utils.tests.init() 

439 unittest.main()