Coverage for tests/test_astrometryModel.py: 26%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

220 statements  

1# This file is part of jointcal. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Tests of astrometryModels (simple, constrained). 

23 

24Includes tests of producing a Wcs from a model. 

25""" 

26import itertools 

27import os 

28import numpy as np 

29import shutil 

30 

31import unittest 

32import lsst.utils.tests 

33 

34import lsst.afw.cameraGeom 

35import lsst.afw.geom 

36import lsst.afw.table 

37import lsst.afw.image 

38import lsst.afw.image.utils 

39import lsst.daf.persistence 

40import lsst.geom 

41import lsst.log 

42 

43import lsst.jointcal 

44from lsst.jointcal.jointcal import (make_schema_table, get_sourceTable_visit_columns, 

45 extract_detector_catalog_from_visit_catalog) 

46from lsst.jointcal import astrometryModels 

47from jointcalTestBase import importRepository 

48 

49 

50def getNParametersPolynomial(order): 

51 """Number of parameters in an astrometry polynomial model is 2 * (d+1)(d+2)/2.""" 

52 return (order + 1)*(order + 2) 

53 

54 

55class AstrometryModelTestBase: 

56 """Test the jointcal AstrometryModel concrete classes, using CFHT data. 

57 """ 

58 @classmethod 

59 def setUpClass(cls): 

60 try: 

61 cls.dataDir = lsst.utils.getPackageDir('testdata_jointcal') 

62 # NOTE: the below is to facilitate testing with hsc test data, 

63 # using chips far apart on the focal plane. See the note in setup() 

64 # below for details. Using it requires having recently-processed 

65 # singleFrame output in a rerun directory in validation_data_hsc. 

66 # cls.dataDir = lsst.utils.getPackageDir('validation_data_hsc') 

67 except LookupError: 

68 raise unittest.SkipTest("testdata_jointcal not setup") 

69 

70 refcatPath = os.path.join(cls.dataDir, "cfht") 

71 refcats = {"gaia_dr2_20200414": os.path.join(refcatPath, "gaia_dr2_20200414.ecsv"), 

72 "ps1_pv3_3pi_20170110": os.path.join(refcatPath, "ps1_pv3_3pi_20170110.ecsv"), 

73 "sdss_dr9_fink_v5b": os.path.join(refcatPath, "sdss-dr9-fink-v5b.ecsv")} 

74 # Share one repo, since none of these tests write anything. 

75 cls.repopath = importRepository("lsst.obs.cfht.MegaPrime", 

76 os.path.join(cls.dataDir, 'cfht/repo/'), 

77 os.path.join(cls.dataDir, 'cfht/exports.yaml'), 

78 refcats=refcats, 

79 refcatPath=refcatPath) 

80 

81 @classmethod 

82 def tearDownClass(cls): 

83 shutil.rmtree(cls.repopath, ignore_errors=True) 

84 

85 def setUp(self): 

86 np.random.seed(200) 

87 

88 # DEBUG messages can help track down failures. 

89 logger = lsst.log.Log.getLogger('lsst.jointcal') 

90 logger.setLevel(lsst.log.DEBUG) 

91 

92 # Append `msg` arguments to assert failures. 

93 self.longMessage = True 

94 # absolute tolerance on positional errors of 10 micro-arcsecond 

95 self.atol = 10.0 / (60 * 60 * 1e6) 

96 

97 # Maximum difference (see assertPairsAlmostEqual) for round-trip 

98 # testing of the inverse for models 1 (simpler) and 2 (more. 

99 # Replace either one for models that don't have as accurate an inverse. 

100 self.inverseMaxDiff1 = 1e-5 

101 self.inverseMaxDiff2 = 1e-5 

102 

103 self.firstIndex = 0 # for assignIndices 

104 matchCut = 2.0 # arcseconds 

105 minMeasurements = 2 # accept all star pairs. 

106 

107 jointcalControl = lsst.jointcal.JointcalControl("flux") 

108 self.associations = lsst.jointcal.Associations() 

109 config = lsst.jointcal.JointcalConfig() 

110 config.load(os.path.join(os.path.dirname(__file__), "config/config-gen3.py")) 

111 sourceSelector = config.sourceSelector.target(config=config.sourceSelector['science']) 

112 

113 # Ensure that the filter list is reset for each test so that we avoid 

114 # confusion or contamination each time we create a cfht camera below. 

115 lsst.afw.image.utils.resetFilters() 

116 

117 # jointcal's cfht test data has 6 ccds and 2 visits. 

118 self.visits = [849375, 850587] 

119 self.detectors = [12, 13, 14, 21, 22, 23] 

120 self.badVisit = -12345 

121 self.badCcd = 888 

122 

123 butler = lsst.daf.butler.Butler(self.repopath, collections='singleFrame', instrument="MegaPrime") 

124 

125 self.catalogs = [] 

126 self.ccdImageList = [] 

127 table = make_schema_table() 

128 inColumns = butler.get("sourceTable_visit", visit=self.visits[0]) 

129 columns, detColumn, ixxColumns = get_sourceTable_visit_columns(inColumns, config, sourceSelector) 

130 catalogs = {v: sourceSelector.run(butler.get('sourceTable_visit', 

131 visit=v, 

132 parameters={'columns': columns})) for v in self.visits} 

133 for (visit, detector) in itertools.product(self.visits, self.detectors): 

134 goodSrc = extract_detector_catalog_from_visit_catalog(table, 

135 catalogs[visit].sourceCat, 

136 detector, 

137 detColumn, 

138 ixxColumns, 

139 config.sourceFluxType, 

140 logger) 

141 dataId = {"detector": detector, "visit": visit} 

142 visitInfo = butler.get('calexp.visitInfo', dataId=dataId) 

143 detector = butler.get('calexp.detector', dataId=dataId) 

144 ccdId = detector.getId() 

145 wcs = butler.get('calexp.wcs', dataId=dataId) 

146 bbox = butler.get('calexp.bbox', dataId=dataId) 

147 filt = butler.get('calexp.filterLabel', dataId=dataId) 

148 filterName = filt.physicalLabel 

149 photoCalib = lsst.afw.image.PhotoCalib(100.0, 1.0) 

150 

151 self.catalogs.append(goodSrc) 

152 self.associations.createCcdImage(goodSrc, 

153 wcs, 

154 visitInfo, 

155 bbox, 

156 filterName, 

157 photoCalib, 

158 detector, 

159 visit, 

160 ccdId, 

161 jointcalControl) 

162 

163 # Have to set the common tangent point so projectionHandler can use skyToCTP. 

164 self.associations.computeCommonTangentPoint() 

165 

166 self.projectionHandler = lsst.jointcal.OneTPPerVisitHandler(self.associations.getCcdImageList()) 

167 

168 self.associations.associateCatalogs(matchCut) 

169 self.associations.prepareFittedStars(minMeasurements) 

170 self.associations.deprojectFittedStars() 

171 

172 def _prepModels(self): 

173 """Call this after model1 and model2 are created, to call assignIndices, 

174 and instantiate the fitters. 

175 """ 

176 posError = 0.02 # in pixels 

177 # have to call this once or offsetParams will fail because the transform indices aren't defined 

178 self.model1.assignIndices("Distortions", self.firstIndex) 

179 self.fitter1 = lsst.jointcal.AstrometryFit(self.associations, self.model1, posError) 

180 

181 # have to call this once or offsetParams will fail because the transform indices aren't defined 

182 self.model2.assignIndices("Distortions", self.firstIndex) 

183 self.fitter2 = lsst.jointcal.AstrometryFit(self.associations, self.model2, posError) 

184 

185 def testMakeSkyWcsModel1(self): 

186 self.checkMakeSkyWcsModel(self.model1, self.fitter1, self.inverseMaxDiff1) 

187 

188 def testMakeSkyWcsModel2(self): 

189 self.checkMakeSkyWcsModel(self.model2, self.fitter2, self.inverseMaxDiff2) 

190 

191 def checkMakeSkyWcsModel(self, model, fitter, inverseMaxDiff): 

192 """Test producing a SkyWcs on a model for every cdImage, 

193 both post-initialization and after one fitting step. 

194 

195 Parameters 

196 ---------- 

197 model : `lsst.jointcal.AstrometryModel` 

198 The model to test. 

199 fitter : `lsst.jointcal.FitterBase` 

200 The fitter to use to step the model to test with new (reasonable) parameters. 

201 inverseMaxDiff : `float` 

202 Required accuracy on inverse transform. 

203 See `lsst.afw.geom.utils.assertPairsAlmostEqual`. 

204 

205 """ 

206 # first test on as-initialized models 

207 for ccdImage in self.associations.getCcdImageList(): 

208 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff) 

209 

210 # now shift the models to non-default, but more reasonable, values by taking one fitting step. 

211 fitter.minimize("DistortionsVisit") 

212 fitter.minimize("Distortions") 

213 for ccdImage in self.associations.getCcdImageList(): 

214 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff) 

215 

216 def checkMakeSkyWcsOneCcdImage(self, model, ccdImage, inverseMaxDiff): 

217 """Test converting the model of one ccdImage to a SkyWcs by comparing 

218 to the original transform at the tangent plane. 

219 

220 Parameters 

221 ---------- 

222 model : `lsst.jointcal.AstrometryModel` 

223 The model to test. 

224 ccdImage : `lsst.jointcal.CcdImage` 

225 The ccdImage to extract from the model and test. 

226 inverseMaxDiff : `float` 

227 Required accuracy on inverse transform. 

228 See `lsst.afw.geom.utils.assertPairsAlmostEqual`. 

229 """ 

230 skyWcs = model.makeSkyWcs(ccdImage) 

231 skyToTangentPlane = model.getSkyToTangentPlane(ccdImage) 

232 mapping = model.getMapping(ccdImage) 

233 

234 bbox = ccdImage.getDetector().getBBox() 

235 num = 200 

236 xx = np.linspace(bbox.getMinX(), bbox.getMaxX(), num) 

237 yy = np.linspace(bbox.getMinY(), bbox.getMaxY(), num) 

238 points = [lsst.geom.Point2D(*xy) for xy in itertools.product(xx, yy)] 

239 

240 expects = [] 

241 forwards = [] 

242 inverses = [] 

243 spherePoints = skyWcs.pixelToSky(points) 

244 inverses = skyWcs.skyToPixel(skyWcs.pixelToSky(points)) 

245 for point, spherePoint in zip(points, spherePoints): 

246 # TODO: Fix these "Point"s once DM-4044 is done. 

247 

248 # jointcal's pixel->tangent-plane mapping 

249 star = lsst.jointcal.star.BaseStar(point.getX(), point.getY(), 0, 0) 

250 tpExpect = mapping.transformPosAndErrors(star) 

251 expects.append(lsst.geom.Point2D(tpExpect.x, tpExpect.y)) 

252 

253 # skywcs takes pixel->sky, and we then have to go sky->tangent-plane 

254 onSky = lsst.jointcal.star.BaseStar(spherePoint.getLongitude().asDegrees(), 

255 spherePoint.getLatitude().asDegrees(), 0, 0) 

256 result = skyToTangentPlane.apply(onSky) 

257 forwards.append(lsst.geom.Point2D(result.x, result.y)) 

258 

259 self.assertPairListsAlmostEqual(forwards, expects) 

260 # NOTE: assertPairListsAlmostEqual() compares absolute, not relative, 

261 # values so the points along the ccd edge may exceed maxDiff while still 

262 # being "close enough": set `inverseMaxDiff` accordingly. 

263 self.assertPairListsAlmostEqual(inverses, points, maxDiff=inverseMaxDiff) 

264 

265 

266class SimpleAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase): 

267 """Test the `SimpleAstrometryModel`, with one mapping per ccd per visit.""" 

268 def setUp(self): 

269 super().setUp() 

270 self.order1 = 3 

271 self.inverseMaxDiff1 = 2e-4 

272 self.model1 = astrometryModels.SimpleAstrometryModel(self.associations.getCcdImageList(), 

273 self.projectionHandler, 

274 True, 

275 order=self.order1) 

276 

277 self.order2 = 5 

278 # NOTE: because assertPairListsAlmostEqual tests an absolute 

279 # difference, we need this to be relatively high to avoid spurious 

280 # incorrect values. 

281 # Alternately, further increasing the order of the inverse polynomial 

282 # in astrometryTransform.toAstMap() can improve the quality of the 

283 # SkyWcs inverse, but that may not be wise for the more general use 

284 # case due to the inverse then having too many wiggles. 

285 self.inverseMaxDiff2 = 2e-2 

286 self.model2 = astrometryModels.SimpleAstrometryModel(self.associations.getCcdImageList(), 

287 self.projectionHandler, 

288 False, 

289 order=self.order2) 

290 self._prepModels() 

291 

292 def _testGetNpar(self, model, order): 

293 for ccdImage in self.associations.getCcdImageList(): 

294 result = model.getNpar(ccdImage) 

295 self.assertEqual(result, getNParametersPolynomial(order)) 

296 

297 def testGetNpar1(self): 

298 self._testGetNpar(self.model1, self.order1) 

299 

300 def testGetNpar2(self): 

301 self._testGetNpar(self.model2, self.order2) 

302 

303 def _testGetTotalParameters(self, model, order): 

304 result = model.getTotalParameters() 

305 expect = getNParametersPolynomial(order)*len(self.associations.getCcdImageList()) 

306 self.assertEqual(result, expect) 

307 

308 def testGetTotalParametersModel1(self): 

309 self._testGetTotalParameters(self.model1, self.order1) 

310 

311 def testGetTotalParametersModel2(self): 

312 self._testGetTotalParameters(self.model2, self.order2) 

313 

314 

315class ConstrainedAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase): 

316 """Test the `ConstrainedAstrometryModel`, with one mapping per ccd and one 

317 mapping per visit. 

318 """ 

319 def setUp(self): 

320 super().setUp() 

321 self.visitOrder1 = 3 

322 self.chipOrder1 = 1 

323 self.inverseMaxDiff1 = 1e-5 

324 self.model1 = astrometryModels.ConstrainedAstrometryModel(self.associations.getCcdImageList(), 

325 self.projectionHandler, 

326 chipOrder=self.chipOrder1, 

327 visitOrder=self.visitOrder1) 

328 

329 self.visitOrder2 = 5 

330 self.chipOrder2 = 2 

331 self.inverseMaxDiff2 = 8e-5 

332 self.model2 = astrometryModels.ConstrainedAstrometryModel(self.associations.getCcdImageList(), 

333 self.projectionHandler, 

334 chipOrder=self.chipOrder2, 

335 visitOrder=self.visitOrder2) 

336 self._prepModels() 

337 

338 # 22 is closest to the center of the focal plane in this data, so it is not fit. 

339 self.fixedCcd = 22 

340 

341 def _polyParams(self, chipOrder, visitOrder): 

342 """Number of parameters per polynomial is (d+1)(d+2)/2, summed over 

343 polynomials, times 2 polynomials per dimension. 

344 The chip transform is fixed for one chip, so only visitOrder matters 

345 if chipOrder is None. 

346 """ 

347 params = getNParametersPolynomial(visitOrder) 

348 if chipOrder is not None: 

349 params += getNParametersPolynomial(chipOrder) 

350 return params 

351 

352 def _testGetNpar(self, model, chipOrder, visitOrder): 

353 def checkParams(ccdImage, model, chipOrder, visitOrder): 

354 result = model.getNpar(ccdImage) 

355 failMsg = "ccdImage: %s, with chipOrder %s and visitOrder %s"%(ccdImage.getName(), 

356 chipOrder, 

357 visitOrder) 

358 self.assertEqual(result, self._polyParams(chipOrder, visitOrder), msg=failMsg) 

359 

360 for ccdImage in self.associations.getCcdImageList(): 

361 realChipOrder = None if ccdImage.getCcdId() == self.fixedCcd else chipOrder 

362 checkParams(ccdImage, model, realChipOrder, visitOrder) 

363 

364 def testGetNpar1(self): 

365 self._testGetNpar(self.model1, self.chipOrder1, self.visitOrder1) 

366 

367 def testGetNpar2(self): 

368 self._testGetNpar(self.model2, self.chipOrder2, self.visitOrder2) 

369 

370 def _testGetTotalParameters(self, model, chipOrder, visitOrder): 

371 result = model.getTotalParameters() 

372 # one sensor is held fixed, hence len(ccds)-1 

373 expect = getNParametersPolynomial(chipOrder)*(len(self.detectors) - 1) + \ 

374 getNParametersPolynomial(visitOrder)*len(self.visits) 

375 self.assertEqual(result, expect) 

376 

377 def testGetTotalParametersModel1(self): 

378 self._testGetTotalParameters(self.model1, self.chipOrder1, self.visitOrder1) 

379 

380 def testGetTotalParametersModel2(self): 

381 self._testGetTotalParameters(self.model2, self.chipOrder2, self.visitOrder2) 

382 

383 def checkGetChipTransform(self, model): 

384 # Check valid ccds 

385 for ccd in self.detectors: 

386 try: 

387 model.getChipTransform(ccd) 

388 except lsst.pex.exceptions.wrappers.InvalidParameterError: 

389 self.fail("model: {} raised on ccd: {}, but should not have.".format(model, ccd)) 

390 

391 # Check an invalid ccd 

392 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm: 

393 model.getChipTransform(self.badCcd) 

394 errMsg = f"No such chipId: {self.badCcd} among [{', '.join(str(d) for d in self.detectors)}]" 

395 self.assertIn(errMsg, str(cm.exception)) 

396 

397 def testGetChipTransform(self): 

398 """getChipTransform should get each known transform, and raise with an 

399 appropriate message otherwise. 

400 """ 

401 self.checkGetChipTransform(self.model1) 

402 self.checkGetChipTransform(self.model2) 

403 

404 def checkGetVisitTransform(self, model): 

405 # Check valid visits 

406 for visit in self.visits: 

407 try: 

408 model.getVisitTransform(visit) 

409 except lsst.pex.exceptions.wrappers.InvalidParameterError: 

410 self.fail("model: {} raised on visit: {}, but should not have.".format(model, visit)) 

411 

412 # Check an invalid visit 

413 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm: 

414 model.getVisitTransform(self.badVisit) 

415 errMsg = "No such visitId: {} among [{}]".format(self.badVisit, 

416 ", ".join(str(v) for v in self.visits)) 

417 self.assertIn(errMsg, str(cm.exception)) 

418 

419 def testGetVisitTransform(self): 

420 """getVisitTransform should get each known transform, and raise with an 

421 appropriate message otherwise. 

422 """ 

423 self.checkGetVisitTransform(self.model1) 

424 self.checkGetVisitTransform(self.model2) 

425 

426 def testValidate(self): 

427 """Test that invalid models fail validate(), and that valid ones pass. 

428 """ 

429 # We need at least 0 degrees of freedom (data - parameters) for the model to be valid. 

430 # Note: model1 has 70 total parameters (2 visits*20 params + (6-1) sensors*5 params) 

431 self.assertTrue(self.model1.validate(self.ccdImageList, 0)) 

432 self.assertFalse(self.model1.validate(self.ccdImageList, -1)) 

433 

434 

435class MemoryTester(lsst.utils.tests.MemoryTestCase): 

436 pass 

437 

438 

439def setup_module(module): 

440 lsst.utils.tests.init() 

441 

442 

443if __name__ == "__main__": 443 ↛ 444line 443 didn't jump to line 444, because the condition on line 443 was never true

444 lsst.utils.tests.init() 

445 unittest.main()