Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of jointcal. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Tests of astrometryModels (simple, constrained). 

23 

24Includes tests of producing a Wcs from a model. 

25""" 

26import itertools 

27import os 

28import numpy as np 

29 

30import unittest 

31import lsst.utils.tests 

32 

33import lsst.afw.cameraGeom 

34import lsst.afw.geom 

35import lsst.afw.table 

36import lsst.afw.image 

37import lsst.afw.image.utils 

38import lsst.daf.persistence 

39import lsst.geom 

40import lsst.jointcal 

41from lsst.jointcal import astrometryModels 

42from lsst.meas.algorithms import astrometrySourceSelector 

43 

44 

45def getNParametersPolynomial(order): 

46 """Number of parameters in an astrometry polynomial model is 2 * (d+1)(d+2)/2.""" 

47 return (order + 1)*(order + 2) 

48 

49 

50class AstrometryModelTestBase: 

51 @classmethod 

52 def setUpClass(cls): 

53 try: 

54 cls.dataDir = lsst.utils.getPackageDir('testdata_jointcal') 

55 # NOTE: the below is to facilitate testing with hsc test data, 

56 # using chips far apart on the focal plane. See the note in setup() 

57 # below for details. Using it requires having recently-processed 

58 # singleFrame output in a rerun directory in validation_data_hsc. 

59 # cls.dataDir = lsst.utils.getPackageDir('validation_data_hsc') 

60 except LookupError: 

61 raise unittest.SkipTest("testdata_jointcal not setup") 

62 

63 def setUp(self): 

64 np.random.seed(200) 

65 

66 # Append `msg` arguments to assert failures. 

67 self.longMessage = True 

68 # absolute tolerance on positional errors of 10 micro-arcsecond 

69 self.atol = 10.0 / (60 * 60 * 1e6) 

70 

71 # Maximum difference (see assertPairsAlmostEqual) for round-trip 

72 # testing of the inverse for models 1 (simpler) and 2 (more. 

73 # Replace either one for models that don't have as accurate an inverse. 

74 self.inverseMaxDiff1 = 1e-5 

75 self.inverseMaxDiff2 = 1e-5 

76 

77 self.firstIndex = 0 # for assignIndices 

78 matchCut = 2.0 # arcseconds 

79 minMeasurements = 2 # accept all star pairs. 

80 

81 jointcalControl = lsst.jointcal.JointcalControl("slot_CalibFlux") 

82 self.associations = lsst.jointcal.Associations() 

83 # Work around the fact that the testdata_jointcal catalogs were produced 

84 # before DM-13493, and so have a different definition of the interpolated flag. 

85 sourceSelectorConfig = astrometrySourceSelector.AstrometrySourceSelectorConfig() 

86 sourceSelectorConfig.badFlags.append("base_PixelFlags_flag_interpolated") 

87 sourceSelector = astrometrySourceSelector.AstrometrySourceSelectorTask(config=sourceSelectorConfig) 

88 

89 # Ensure that the filter list is reset for each test so that we avoid 

90 # confusion or contamination each time we create a cfht camera below. 

91 lsst.afw.image.utils.resetFilters() 

92 

93 # jointcal's cfht test data has 6 ccds and 2 visits. 

94 inputDir = os.path.join(self.dataDir, 'cfht') 

95 self.visits = [849375, 850587] 

96 self.ccds = [12, 13, 14, 21, 22, 23] 

97 self.badVisit = -12345 

98 self.badCcd = 888 

99 

100 # NOTE: the below block is to facilitate testing with validation_data_hsc. 

101 # NOTE: You need to have recently-processed singleFrame output available in inputDir. 

102 # ccd 50 is near the center, 3 is in the SW corner, 62 is on the East side. 

103 # ccds 101 and 103 are rotated by 90deg compared with the above. 

104 # inputDir = os.path.join(self.dataDir, 'DATA/rerun/20160805') 

105 # self.visits = [903982, 904828] # Only need two visits for this test. 

106 # self.ccds = [50, 3, 62, 101, 103] 

107 

108 self.butler = lsst.daf.persistence.Butler(inputDir) 

109 

110 self.catalogs = [] 

111 self.ccdImageList = [] 

112 for (visit, ccd) in itertools.product(self.visits, self.ccds): 

113 dataRef = self.butler.dataRef('calexp', visit=visit, ccd=ccd) 

114 

115 src = dataRef.get("src", flags=lsst.afw.table.SOURCE_IO_NO_FOOTPRINTS, immediate=True) 

116 goodSrc = sourceSelector.run(src) 

117 # Need memory contiguity to do vector-like things on the sourceCat. 

118 goodSrc = goodSrc.sourceCat.copy(deep=True) 

119 

120 visitInfo = dataRef.get('calexp_visitInfo') 

121 detector = dataRef.get('calexp_detector') 

122 ccdId = detector.getId() 

123 wcs = dataRef.get('calexp_wcs') 

124 bbox = dataRef.get('calexp_bbox') 

125 filt = dataRef.get('calexp_filter') 

126 filterName = filt.getName() 

127 photoCalib = lsst.afw.image.PhotoCalib(100.0, 1.0) 

128 

129 self.catalogs.append(goodSrc) 

130 self.associations.createCcdImage(goodSrc, 

131 wcs, 

132 visitInfo, 

133 bbox, 

134 filterName, 

135 photoCalib, 

136 detector, 

137 visit, 

138 ccdId, 

139 jointcalControl) 

140 

141 # Have to set the common tangent point so projectionHandler can use skyToCTP. 

142 self.associations.computeCommonTangentPoint() 

143 

144 self.projectionHandler = lsst.jointcal.OneTPPerVisitHandler(self.associations.getCcdImageList()) 

145 

146 self.associations.associateCatalogs(matchCut) 

147 self.associations.prepareFittedStars(minMeasurements) 

148 self.associations.deprojectFittedStars() 

149 

150 def _prepModels(self): 

151 """Call this after model1 and model2 are created, to call assignIndices, 

152 and instantiate the fitters. 

153 """ 

154 posError = 0.02 # in pixels 

155 # have to call this once or offsetParams will fail because the transform indices aren't defined 

156 self.model1.assignIndices("Distortions", self.firstIndex) 

157 self.fitter1 = lsst.jointcal.AstrometryFit(self.associations, self.model1, posError) 

158 

159 # have to call this once or offsetParams will fail because the transform indices aren't defined 

160 self.model2.assignIndices("Distortions", self.firstIndex) 

161 self.fitter2 = lsst.jointcal.AstrometryFit(self.associations, self.model2, posError) 

162 

163 def testMakeSkyWcsModel1(self): 

164 self.CheckMakeSkyWcsModel(self.model1, self.fitter1, self.inverseMaxDiff1) 

165 

166 def testMakeSkyWcsModel2(self): 

167 self.CheckMakeSkyWcsModel(self.model2, self.fitter2, self.inverseMaxDiff2) 

168 

169 def CheckMakeSkyWcsModel(self, model, fitter, inverseMaxDiff): 

170 """Test producing a SkyWcs on a model for every cdImage, 

171 both post-initialization and after one fitting step. 

172 

173 Parameters 

174 ---------- 

175 model : `lsst.jointcal.AstrometryModel` 

176 The model to test. 

177 fitter : `lsst.jointcal.FitterBase` 

178 The fitter to use to step the model to test with new (reasonable) parameters. 

179 inverseMaxDiff : `float` 

180 Required accuracy on inverse transform. 

181 See `lsst.afw.geom.utils.assertPairsAlmostEqual`. 

182 

183 """ 

184 # first test on as-initialized models 

185 for ccdImage in self.associations.getCcdImageList(): 

186 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff) 

187 

188 # now shift the models to non-default, but more reasonable, values by taking one fitting step. 

189 fitter.minimize("DistortionsVisit") 

190 fitter.minimize("Distortions") 

191 for ccdImage in self.associations.getCcdImageList(): 

192 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff) 

193 

194 def checkMakeSkyWcsOneCcdImage(self, model, ccdImage, inverseMaxDiff): 

195 """Test converting the model of one ccdImage to a SkyWcs by comparing 

196 to the original transform at the tangent plane. 

197 

198 Parameters 

199 ---------- 

200 model : `lsst.jointcal.AstrometryModel` 

201 The model to test. 

202 ccdImage : `lsst.jointcal.CcdImage` 

203 The ccdImage to extract from the model and test. 

204 inverseMaxDiff : `float` 

205 Required accuracy on inverse transform. 

206 See `lsst.afw.geom.utils.assertPairsAlmostEqual`. 

207 """ 

208 skyWcs = model.makeSkyWcs(ccdImage) 

209 skyToTangentPlane = model.getSkyToTangentPlane(ccdImage) 

210 mapping = model.getMapping(ccdImage) 

211 

212 bbox = ccdImage.getDetector().getBBox() 

213 num = 200 

214 xx = np.linspace(bbox.getMinX(), bbox.getMaxX(), num) 

215 yy = np.linspace(bbox.getMinY(), bbox.getMaxY(), num) 

216 points = [lsst.geom.Point2D(*xy) for xy in itertools.product(xx, yy)] 

217 

218 expects = [] 

219 forwards = [] 

220 inverses = [] 

221 spherePoints = skyWcs.pixelToSky(points) 

222 inverses = skyWcs.skyToPixel(skyWcs.pixelToSky(points)) 

223 for point, spherePoint in zip(points, spherePoints): 

224 # TODO: Fix these "Point"s once DM-4044 is done. 

225 

226 # jointcal's pixel->tangent-plane mapping 

227 star = lsst.jointcal.star.BaseStar(point.getX(), point.getY(), 0, 0) 

228 tpExpect = mapping.transformPosAndErrors(star) 

229 expects.append(lsst.geom.Point2D(tpExpect.x, tpExpect.y)) 

230 

231 # skywcs takes pixel->sky, and we then have to go sky->tangent-plane 

232 onSky = lsst.jointcal.star.BaseStar(spherePoint.getLongitude().asDegrees(), 

233 spherePoint.getLatitude().asDegrees(), 0, 0) 

234 result = skyToTangentPlane.apply(onSky) 

235 forwards.append(lsst.geom.Point2D(result.x, result.y)) 

236 

237 self.assertPairListsAlmostEqual(forwards, expects) 

238 self.assertPairListsAlmostEqual(inverses, points, maxDiff=inverseMaxDiff) 

239 

240 

241class SimpleAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase): 

242 """Test the `SimpleAstrometryModel`, with one mapping per ccd per visit.""" 

243 def setUp(self): 

244 super().setUp() 

245 self.order1 = 3 

246 self.inverseMaxDiff1 = 2e-5 

247 self.model1 = astrometryModels.SimpleAstrometryModel(self.associations.getCcdImageList(), 

248 self.projectionHandler, 

249 True, 

250 order=self.order1) 

251 

252 self.order2 = 5 

253 self.inverseMaxDiff2 = 5e-4 

254 self.model2 = astrometryModels.SimpleAstrometryModel(self.associations.getCcdImageList(), 

255 self.projectionHandler, 

256 False, 

257 order=self.order2) 

258 self._prepModels() 

259 

260 def _testGetNpar(self, model, order): 

261 for ccdImage in self.associations.getCcdImageList(): 

262 result = model.getNpar(ccdImage) 

263 self.assertEqual(result, getNParametersPolynomial(order)) 

264 

265 def testGetNpar1(self): 

266 self._testGetNpar(self.model1, self.order1) 

267 

268 def testGetNpar2(self): 

269 self._testGetNpar(self.model2, self.order2) 

270 

271 def _testGetTotalParameters(self, model, order): 

272 result = model.getTotalParameters() 

273 expect = getNParametersPolynomial(order)*len(self.associations.getCcdImageList()) 

274 self.assertEqual(result, expect) 

275 

276 def testGetTotalParametersModel1(self): 

277 self._testGetTotalParameters(self.model1, self.order1) 

278 

279 def testGetTotalParametersModel2(self): 

280 self._testGetTotalParameters(self.model2, self.order2) 

281 

282 

283class ConstrainedAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase): 

284 """Test the `ConstrainedAstrometryModel`, with one mapping per ccd and one 

285 mapping per visit. 

286 """ 

287 def setUp(self): 

288 super().setUp() 

289 self.visitOrder1 = 3 

290 self.chipOrder1 = 1 

291 self.inverseMaxDiff1 = 1e-5 

292 self.model1 = astrometryModels.ConstrainedAstrometryModel(self.associations.getCcdImageList(), 

293 self.projectionHandler, 

294 chipOrder=self.chipOrder1, 

295 visitOrder=self.visitOrder1) 

296 

297 self.visitOrder2 = 5 

298 self.chipOrder2 = 2 

299 self.inverseMaxDiff2 = 5e-5 

300 self.model2 = astrometryModels.ConstrainedAstrometryModel(self.associations.getCcdImageList(), 

301 self.projectionHandler, 

302 chipOrder=self.chipOrder2, 

303 visitOrder=self.visitOrder2) 

304 self._prepModels() 

305 

306 # 22 is closest to the center of the focal plane in this data, so it is not fit. 

307 self.fixedCcd = 22 

308 

309 def _polyParams(self, chipOrder, visitOrder): 

310 """Number of parameters per polynomial is (d+1)(d+2)/2, summed over 

311 polynomials, times 2 polynomials per dimension. 

312 The chip transform is fixed for one chip, so only visitOrder matters 

313 if chipOrder is None. 

314 """ 

315 params = getNParametersPolynomial(visitOrder) 

316 if chipOrder is not None: 

317 params += getNParametersPolynomial(chipOrder) 

318 return params 

319 

320 def _testGetNpar(self, model, chipOrder, visitOrder): 

321 def checkParams(ccdImage, model, chipOrder, visitOrder): 

322 result = model.getNpar(ccdImage) 

323 failMsg = "ccdImage: %s, with chipOrder %s and visitOrder %s"%(ccdImage.getName(), 

324 chipOrder, 

325 visitOrder) 

326 self.assertEqual(result, self._polyParams(chipOrder, visitOrder), msg=failMsg) 

327 

328 for ccdImage in self.associations.getCcdImageList(): 

329 realChipOrder = None if ccdImage.getCcdId() == self.fixedCcd else chipOrder 

330 checkParams(ccdImage, model, realChipOrder, visitOrder) 

331 

332 def testGetNpar1(self): 

333 self._testGetNpar(self.model1, self.chipOrder1, self.visitOrder1) 

334 

335 def testGetNpar2(self): 

336 self._testGetNpar(self.model2, self.chipOrder2, self.visitOrder2) 

337 

338 def _testGetTotalParameters(self, model, chipOrder, visitOrder): 

339 result = model.getTotalParameters() 

340 # one sensor is held fixed, hence len(ccds)-1 

341 expect = getNParametersPolynomial(chipOrder)*(len(self.ccds) - 1) + \ 

342 getNParametersPolynomial(visitOrder)*len(self.visits) 

343 self.assertEqual(result, expect) 

344 

345 def testGetTotalParametersModel1(self): 

346 self._testGetTotalParameters(self.model1, self.chipOrder1, self.visitOrder1) 

347 

348 def testGetTotalParametersModel2(self): 

349 self._testGetTotalParameters(self.model2, self.chipOrder2, self.visitOrder2) 

350 

351 def checkGetChipTransform(self, model): 

352 # Check valid ccds 

353 for ccd in self.ccds: 

354 try: 

355 model.getChipTransform(ccd) 

356 except lsst.pex.exceptions.wrappers.InvalidParameterError: 

357 self.fail("model: {} raised on ccd: {}, but should not have.".format(model, ccd)) 

358 

359 # Check an invalid ccd 

360 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm: 

361 model.getChipTransform(self.badCcd) 

362 errMsg = "No such chipId: {} among [{}]".format(self.badCcd, ", ".join(str(ccd) for ccd in self.ccds)) 

363 self.assertIn(errMsg, str(cm.exception)) 

364 

365 def testGetChipTransform(self): 

366 """getChipTransform should get each known transform, and raise with an 

367 appropriate message otherwise. 

368 """ 

369 self.checkGetChipTransform(self.model1) 

370 self.checkGetChipTransform(self.model2) 

371 

372 def checkGetVisitTransform(self, model): 

373 # Check valid visits 

374 for visit in self.visits: 

375 try: 

376 model.getVisitTransform(visit) 

377 except lsst.pex.exceptions.wrappers.InvalidParameterError: 

378 self.fail("model: {} raised on visit: {}, but should not have.".format(model, visit)) 

379 

380 # Check an invalid visit 

381 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm: 

382 model.getVisitTransform(self.badVisit) 

383 errMsg = "No such visitId: {} among [{}]".format(self.badVisit, 

384 ", ".join(str(v) for v in self.visits)) 

385 self.assertIn(errMsg, str(cm.exception)) 

386 

387 def testGetVisitTransform(self): 

388 """getVisitTransform should get each known transform, and raise with an 

389 appropriate message otherwise. 

390 """ 

391 self.checkGetVisitTransform(self.model1) 

392 self.checkGetVisitTransform(self.model2) 

393 

394 def testValidate(self): 

395 """Test that invalid models fail validate(), and that valid ones pass. 

396 """ 

397 # We need at least 0 degrees of freedom (data - parameters) for the model to be valid. 

398 # Note: model1 has 70 total parameters (2 visits*20 params + (6-1) sensors*5 params) 

399 self.assertTrue(self.model1.validate(self.ccdImageList, 0)) 

400 self.assertFalse(self.model1.validate(self.ccdImageList, -1)) 

401 

402 

403class MemoryTester(lsst.utils.tests.MemoryTestCase): 

404 pass 

405 

406 

407def setup_module(module): 

408 lsst.utils.tests.init() 

409 

410 

411if __name__ == "__main__": 411 ↛ 412line 411 didn't jump to line 412, because the condition on line 411 was never true

412 lsst.utils.tests.init() 

413 unittest.main()