Coverage for tests/test_astrometryModel.py: 25%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

219 statements  

1# This file is part of jointcal. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Tests of astrometryModels (simple, constrained). 

23 

24Includes tests of producing a Wcs from a model. 

25""" 

26import itertools 

27import os 

28import numpy as np 

29import shutil 

30 

31import unittest 

32import lsst.utils.tests 

33 

34import lsst.afw.cameraGeom 

35import lsst.afw.geom 

36import lsst.afw.table 

37import lsst.afw.image 

38import lsst.afw.image.utils 

39import lsst.geom 

40import lsst.log 

41 

42import lsst.jointcal 

43from lsst.jointcal.jointcal import (make_schema_table, get_sourceTable_visit_columns, 

44 extract_detector_catalog_from_visit_catalog) 

45from lsst.jointcal import astrometryModels 

46from jointcalTestBase import importRepository 

47 

48 

49def getNParametersPolynomial(order): 

50 """Number of parameters in an astrometry polynomial model is 2 * (d+1)(d+2)/2.""" 

51 return (order + 1)*(order + 2) 

52 

53 

54class AstrometryModelTestBase: 

55 """Test the jointcal AstrometryModel concrete classes, using CFHT data. 

56 """ 

57 @classmethod 

58 def setUpClass(cls): 

59 try: 

60 cls.dataDir = lsst.utils.getPackageDir('testdata_jointcal') 

61 # NOTE: the below is to facilitate testing with hsc test data, 

62 # using chips far apart on the focal plane. See the note in setup() 

63 # below for details. Using it requires having recently-processed 

64 # singleFrame output in a rerun directory in validation_data_hsc. 

65 # cls.dataDir = lsst.utils.getPackageDir('validation_data_hsc') 

66 except LookupError: 

67 raise unittest.SkipTest("testdata_jointcal not setup") 

68 

69 refcatPath = os.path.join(cls.dataDir, "cfht") 

70 refcats = {"gaia_dr2_20200414": os.path.join(refcatPath, "gaia_dr2_20200414.ecsv"), 

71 "ps1_pv3_3pi_20170110": os.path.join(refcatPath, "ps1_pv3_3pi_20170110.ecsv"), 

72 "sdss_dr9_fink_v5b": os.path.join(refcatPath, "sdss-dr9-fink-v5b.ecsv")} 

73 # Share one repo, since none of these tests write anything. 

74 cls.repopath = importRepository("lsst.obs.cfht.MegaPrime", 

75 os.path.join(cls.dataDir, 'cfht/repo/'), 

76 os.path.join(cls.dataDir, 'cfht/exports.yaml'), 

77 refcats=refcats, 

78 refcatPath=refcatPath) 

79 

80 @classmethod 

81 def tearDownClass(cls): 

82 shutil.rmtree(cls.repopath, ignore_errors=True) 

83 

84 def setUp(self): 

85 np.random.seed(200) 

86 

87 # DEBUG messages can help track down failures. 

88 logger = lsst.log.Log.getLogger('lsst.jointcal') 

89 logger.setLevel(lsst.log.DEBUG) 

90 

91 # Append `msg` arguments to assert failures. 

92 self.longMessage = True 

93 # absolute tolerance on positional errors of 10 micro-arcsecond 

94 self.atol = 10.0 / (60 * 60 * 1e6) 

95 

96 # Maximum difference (see assertPairsAlmostEqual) for round-trip 

97 # testing of the inverse for models 1 (simpler) and 2 (more. 

98 # Replace either one for models that don't have as accurate an inverse. 

99 self.inverseMaxDiff1 = 1e-5 

100 self.inverseMaxDiff2 = 1e-5 

101 

102 self.firstIndex = 0 # for assignIndices 

103 matchCut = 2.0 # arcseconds 

104 minMeasurements = 2 # accept all star pairs. 

105 

106 jointcalControl = lsst.jointcal.JointcalControl("flux") 

107 self.associations = lsst.jointcal.Associations() 

108 config = lsst.jointcal.JointcalConfig() 

109 config.load(os.path.join(os.path.dirname(__file__), "config/config-gen3.py")) 

110 sourceSelector = config.sourceSelector.target(config=config.sourceSelector['science']) 

111 

112 # Ensure that the filter list is reset for each test so that we avoid 

113 # confusion or contamination each time we create a cfht camera below. 

114 lsst.afw.image.utils.resetFilters() 

115 

116 # jointcal's cfht test data has 6 ccds and 2 visits. 

117 self.visits = [849375, 850587] 

118 self.detectors = [12, 13, 14, 21, 22, 23] 

119 self.badVisit = -12345 

120 self.badCcd = 888 

121 

122 butler = lsst.daf.butler.Butler(self.repopath, collections='singleFrame', instrument="MegaPrime") 

123 

124 self.catalogs = [] 

125 self.ccdImageList = [] 

126 table = make_schema_table() 

127 inColumns = butler.get("sourceTable_visit", visit=self.visits[0]) 

128 columns, detColumn, ixxColumns = get_sourceTable_visit_columns(inColumns, config, sourceSelector) 

129 catalogs = {v: sourceSelector.run(butler.get('sourceTable_visit', 

130 visit=v, 

131 parameters={'columns': columns})) for v in self.visits} 

132 for (visit, detector) in itertools.product(self.visits, self.detectors): 

133 goodSrc = extract_detector_catalog_from_visit_catalog(table, 

134 catalogs[visit].sourceCat, 

135 detector, 

136 detColumn, 

137 ixxColumns, 

138 config.sourceFluxType, 

139 logger) 

140 dataId = {"detector": detector, "visit": visit} 

141 visitInfo = butler.get('calexp.visitInfo', dataId=dataId) 

142 detector = butler.get('calexp.detector', dataId=dataId) 

143 ccdId = detector.getId() 

144 wcs = butler.get('calexp.wcs', dataId=dataId) 

145 bbox = butler.get('calexp.bbox', dataId=dataId) 

146 filt = butler.get('calexp.filterLabel', dataId=dataId) 

147 filterName = filt.physicalLabel 

148 photoCalib = lsst.afw.image.PhotoCalib(100.0, 1.0) 

149 

150 self.catalogs.append(goodSrc) 

151 self.associations.createCcdImage(goodSrc, 

152 wcs, 

153 visitInfo, 

154 bbox, 

155 filterName, 

156 photoCalib, 

157 detector, 

158 visit, 

159 ccdId, 

160 jointcalControl) 

161 

162 # Have to set the common tangent point so projectionHandler can use skyToCTP. 

163 self.associations.computeCommonTangentPoint() 

164 

165 self.projectionHandler = lsst.jointcal.OneTPPerVisitHandler(self.associations.getCcdImageList()) 

166 

167 self.associations.associateCatalogs(matchCut) 

168 self.associations.prepareFittedStars(minMeasurements) 

169 self.associations.deprojectFittedStars() 

170 

171 def _prepModels(self): 

172 """Call this after model1 and model2 are created, to call assignIndices, 

173 and instantiate the fitters. 

174 """ 

175 posError = 0.02 # in pixels 

176 # have to call this once or offsetParams will fail because the transform indices aren't defined 

177 self.model1.assignIndices("Distortions", self.firstIndex) 

178 self.fitter1 = lsst.jointcal.AstrometryFit(self.associations, self.model1, posError) 

179 

180 # have to call this once or offsetParams will fail because the transform indices aren't defined 

181 self.model2.assignIndices("Distortions", self.firstIndex) 

182 self.fitter2 = lsst.jointcal.AstrometryFit(self.associations, self.model2, posError) 

183 

184 def testMakeSkyWcsModel1(self): 

185 self.checkMakeSkyWcsModel(self.model1, self.fitter1, self.inverseMaxDiff1) 

186 

187 def testMakeSkyWcsModel2(self): 

188 self.checkMakeSkyWcsModel(self.model2, self.fitter2, self.inverseMaxDiff2) 

189 

190 def checkMakeSkyWcsModel(self, model, fitter, inverseMaxDiff): 

191 """Test producing a SkyWcs on a model for every cdImage, 

192 both post-initialization and after one fitting step. 

193 

194 Parameters 

195 ---------- 

196 model : `lsst.jointcal.AstrometryModel` 

197 The model to test. 

198 fitter : `lsst.jointcal.FitterBase` 

199 The fitter to use to step the model to test with new (reasonable) parameters. 

200 inverseMaxDiff : `float` 

201 Required accuracy on inverse transform. 

202 See `lsst.afw.geom.utils.assertPairsAlmostEqual`. 

203 

204 """ 

205 # first test on as-initialized models 

206 for ccdImage in self.associations.getCcdImageList(): 

207 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff) 

208 

209 # now shift the models to non-default, but more reasonable, values by taking one fitting step. 

210 fitter.minimize("DistortionsVisit") 

211 fitter.minimize("Distortions") 

212 for ccdImage in self.associations.getCcdImageList(): 

213 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff) 

214 

215 def checkMakeSkyWcsOneCcdImage(self, model, ccdImage, inverseMaxDiff): 

216 """Test converting the model of one ccdImage to a SkyWcs by comparing 

217 to the original transform at the tangent plane. 

218 

219 Parameters 

220 ---------- 

221 model : `lsst.jointcal.AstrometryModel` 

222 The model to test. 

223 ccdImage : `lsst.jointcal.CcdImage` 

224 The ccdImage to extract from the model and test. 

225 inverseMaxDiff : `float` 

226 Required accuracy on inverse transform. 

227 See `lsst.afw.geom.utils.assertPairsAlmostEqual`. 

228 """ 

229 skyWcs = model.makeSkyWcs(ccdImage) 

230 skyToTangentPlane = model.getSkyToTangentPlane(ccdImage) 

231 mapping = model.getMapping(ccdImage) 

232 

233 bbox = ccdImage.getDetector().getBBox() 

234 num = 200 

235 xx = np.linspace(bbox.getMinX(), bbox.getMaxX(), num) 

236 yy = np.linspace(bbox.getMinY(), bbox.getMaxY(), num) 

237 points = [lsst.geom.Point2D(*xy) for xy in itertools.product(xx, yy)] 

238 

239 expects = [] 

240 forwards = [] 

241 inverses = [] 

242 spherePoints = skyWcs.pixelToSky(points) 

243 inverses = skyWcs.skyToPixel(skyWcs.pixelToSky(points)) 

244 for point, spherePoint in zip(points, spherePoints): 

245 # TODO: Fix these "Point"s once DM-4044 is done. 

246 

247 # jointcal's pixel->tangent-plane mapping 

248 star = lsst.jointcal.star.BaseStar(point.getX(), point.getY(), 0, 0) 

249 tpExpect = mapping.transformPosAndErrors(star) 

250 expects.append(lsst.geom.Point2D(tpExpect.x, tpExpect.y)) 

251 

252 # skywcs takes pixel->sky, and we then have to go sky->tangent-plane 

253 onSky = lsst.jointcal.star.BaseStar(spherePoint.getLongitude().asDegrees(), 

254 spherePoint.getLatitude().asDegrees(), 0, 0) 

255 result = skyToTangentPlane.apply(onSky) 

256 forwards.append(lsst.geom.Point2D(result.x, result.y)) 

257 

258 self.assertPairListsAlmostEqual(forwards, expects) 

259 # NOTE: assertPairListsAlmostEqual() compares absolute, not relative, 

260 # values so the points along the ccd edge may exceed maxDiff while still 

261 # being "close enough": set `inverseMaxDiff` accordingly. 

262 self.assertPairListsAlmostEqual(inverses, points, maxDiff=inverseMaxDiff) 

263 

264 

265class SimpleAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase): 

266 """Test the `SimpleAstrometryModel`, with one mapping per ccd per visit.""" 

267 def setUp(self): 

268 super().setUp() 

269 self.order1 = 3 

270 self.inverseMaxDiff1 = 2e-4 

271 self.model1 = astrometryModels.SimpleAstrometryModel(self.associations.getCcdImageList(), 

272 self.projectionHandler, 

273 True, 

274 order=self.order1) 

275 

276 self.order2 = 5 

277 # NOTE: because assertPairListsAlmostEqual tests an absolute 

278 # difference, we need this to be relatively high to avoid spurious 

279 # incorrect values. 

280 # Alternately, further increasing the order of the inverse polynomial 

281 # in astrometryTransform.toAstMap() can improve the quality of the 

282 # SkyWcs inverse, but that may not be wise for the more general use 

283 # case due to the inverse then having too many wiggles. 

284 self.inverseMaxDiff2 = 2e-2 

285 self.model2 = astrometryModels.SimpleAstrometryModel(self.associations.getCcdImageList(), 

286 self.projectionHandler, 

287 False, 

288 order=self.order2) 

289 self._prepModels() 

290 

291 def _testGetNpar(self, model, order): 

292 for ccdImage in self.associations.getCcdImageList(): 

293 result = model.getNpar(ccdImage) 

294 self.assertEqual(result, getNParametersPolynomial(order)) 

295 

296 def testGetNpar1(self): 

297 self._testGetNpar(self.model1, self.order1) 

298 

299 def testGetNpar2(self): 

300 self._testGetNpar(self.model2, self.order2) 

301 

302 def _testGetTotalParameters(self, model, order): 

303 result = model.getTotalParameters() 

304 expect = getNParametersPolynomial(order)*len(self.associations.getCcdImageList()) 

305 self.assertEqual(result, expect) 

306 

307 def testGetTotalParametersModel1(self): 

308 self._testGetTotalParameters(self.model1, self.order1) 

309 

310 def testGetTotalParametersModel2(self): 

311 self._testGetTotalParameters(self.model2, self.order2) 

312 

313 

314class ConstrainedAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase): 

315 """Test the `ConstrainedAstrometryModel`, with one mapping per ccd and one 

316 mapping per visit. 

317 """ 

318 def setUp(self): 

319 super().setUp() 

320 self.visitOrder1 = 3 

321 self.chipOrder1 = 1 

322 self.inverseMaxDiff1 = 1e-5 

323 self.model1 = astrometryModels.ConstrainedAstrometryModel(self.associations.getCcdImageList(), 

324 self.projectionHandler, 

325 chipOrder=self.chipOrder1, 

326 visitOrder=self.visitOrder1) 

327 

328 self.visitOrder2 = 5 

329 self.chipOrder2 = 2 

330 self.inverseMaxDiff2 = 8e-5 

331 self.model2 = astrometryModels.ConstrainedAstrometryModel(self.associations.getCcdImageList(), 

332 self.projectionHandler, 

333 chipOrder=self.chipOrder2, 

334 visitOrder=self.visitOrder2) 

335 self._prepModels() 

336 

337 # 22 is closest to the center of the focal plane in this data, so it is not fit. 

338 self.fixedCcd = 22 

339 

340 def _polyParams(self, chipOrder, visitOrder): 

341 """Number of parameters per polynomial is (d+1)(d+2)/2, summed over 

342 polynomials, times 2 polynomials per dimension. 

343 The chip transform is fixed for one chip, so only visitOrder matters 

344 if chipOrder is None. 

345 """ 

346 params = getNParametersPolynomial(visitOrder) 

347 if chipOrder is not None: 

348 params += getNParametersPolynomial(chipOrder) 

349 return params 

350 

351 def _testGetNpar(self, model, chipOrder, visitOrder): 

352 def checkParams(ccdImage, model, chipOrder, visitOrder): 

353 result = model.getNpar(ccdImage) 

354 failMsg = "ccdImage: %s, with chipOrder %s and visitOrder %s"%(ccdImage.getName(), 

355 chipOrder, 

356 visitOrder) 

357 self.assertEqual(result, self._polyParams(chipOrder, visitOrder), msg=failMsg) 

358 

359 for ccdImage in self.associations.getCcdImageList(): 

360 realChipOrder = None if ccdImage.getCcdId() == self.fixedCcd else chipOrder 

361 checkParams(ccdImage, model, realChipOrder, visitOrder) 

362 

363 def testGetNpar1(self): 

364 self._testGetNpar(self.model1, self.chipOrder1, self.visitOrder1) 

365 

366 def testGetNpar2(self): 

367 self._testGetNpar(self.model2, self.chipOrder2, self.visitOrder2) 

368 

369 def _testGetTotalParameters(self, model, chipOrder, visitOrder): 

370 result = model.getTotalParameters() 

371 # one sensor is held fixed, hence len(ccds)-1 

372 expect = getNParametersPolynomial(chipOrder)*(len(self.detectors) - 1) + \ 

373 getNParametersPolynomial(visitOrder)*len(self.visits) 

374 self.assertEqual(result, expect) 

375 

376 def testGetTotalParametersModel1(self): 

377 self._testGetTotalParameters(self.model1, self.chipOrder1, self.visitOrder1) 

378 

379 def testGetTotalParametersModel2(self): 

380 self._testGetTotalParameters(self.model2, self.chipOrder2, self.visitOrder2) 

381 

382 def checkGetChipTransform(self, model): 

383 # Check valid ccds 

384 for ccd in self.detectors: 

385 try: 

386 model.getChipTransform(ccd) 

387 except lsst.pex.exceptions.wrappers.InvalidParameterError: 

388 self.fail("model: {} raised on ccd: {}, but should not have.".format(model, ccd)) 

389 

390 # Check an invalid ccd 

391 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm: 

392 model.getChipTransform(self.badCcd) 

393 errMsg = f"No such chipId: {self.badCcd} among [{', '.join(str(d) for d in self.detectors)}]" 

394 self.assertIn(errMsg, str(cm.exception)) 

395 

396 def testGetChipTransform(self): 

397 """getChipTransform should get each known transform, and raise with an 

398 appropriate message otherwise. 

399 """ 

400 self.checkGetChipTransform(self.model1) 

401 self.checkGetChipTransform(self.model2) 

402 

403 def checkGetVisitTransform(self, model): 

404 # Check valid visits 

405 for visit in self.visits: 

406 try: 

407 model.getVisitTransform(visit) 

408 except lsst.pex.exceptions.wrappers.InvalidParameterError: 

409 self.fail("model: {} raised on visit: {}, but should not have.".format(model, visit)) 

410 

411 # Check an invalid visit 

412 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm: 

413 model.getVisitTransform(self.badVisit) 

414 errMsg = "No such visitId: {} among [{}]".format(self.badVisit, 

415 ", ".join(str(v) for v in self.visits)) 

416 self.assertIn(errMsg, str(cm.exception)) 

417 

418 def testGetVisitTransform(self): 

419 """getVisitTransform should get each known transform, and raise with an 

420 appropriate message otherwise. 

421 """ 

422 self.checkGetVisitTransform(self.model1) 

423 self.checkGetVisitTransform(self.model2) 

424 

425 def testValidate(self): 

426 """Test that invalid models fail validate(), and that valid ones pass. 

427 """ 

428 # We need at least 0 degrees of freedom (data - parameters) for the model to be valid. 

429 # Note: model1 has 70 total parameters (2 visits*20 params + (6-1) sensors*5 params) 

430 self.assertTrue(self.model1.validate(self.ccdImageList, 0)) 

431 self.assertFalse(self.model1.validate(self.ccdImageList, -1)) 

432 

433 

434class MemoryTester(lsst.utils.tests.MemoryTestCase): 

435 pass 

436 

437 

438def setup_module(module): 

439 lsst.utils.tests.init() 

440 

441 

442if __name__ == "__main__": 442 ↛ 443line 442 didn't jump to line 443, because the condition on line 442 was never true

443 lsst.utils.tests.init() 

444 unittest.main()