Coverage for tests/test_astrometryModel.py: 22%
217 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-04 11:22 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-04 11:22 +0000
1# This file is part of jointcal.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""Tests of astrometryModels (simple, constrained).
24Includes tests of producing a Wcs from a model.
25"""
26import itertools
27import os
28import numpy as np
29import shutil
31import unittest
32import lsst.utils.tests
34import lsst.afw.cameraGeom
35import lsst.afw.geom
36import lsst.afw.table
37import lsst.afw.image
38import lsst.geom
39import lsst.log
41import lsst.jointcal
42from lsst.jointcal.jointcal import (make_schema_table, get_sourceTable_visit_columns,
43 extract_detector_catalog_from_visit_catalog)
44from lsst.jointcal import astrometryModels
45from jointcalTestBase import importRepository
48def getNParametersPolynomial(order):
49 """Number of parameters in an astrometry polynomial model is 2 * (d+1)(d+2)/2."""
50 return (order + 1)*(order + 2)
53class AstrometryModelTestBase:
54 """Test the jointcal AstrometryModel concrete classes, using CFHT data.
55 """
56 @classmethod
57 def setUpClass(cls):
58 try:
59 cls.dataDir = lsst.utils.getPackageDir('testdata_jointcal')
60 except LookupError:
61 raise unittest.SkipTest("testdata_jointcal not setup")
63 refcatPath = os.path.join(cls.dataDir, "cfht")
64 refcats = {"gaia_dr2_20200414": os.path.join(refcatPath, "gaia_dr2_20200414.ecsv"),
65 "ps1_pv3_3pi_20170110": os.path.join(refcatPath, "ps1_pv3_3pi_20170110.ecsv"),
66 "sdss_dr9_fink_v5b": os.path.join(refcatPath, "sdss-dr9-fink-v5b.ecsv")}
67 # Share one repo, since none of these tests write anything.
68 cls.repopath = importRepository("lsst.obs.cfht.MegaPrime",
69 os.path.join(cls.dataDir, 'cfht/repo/'),
70 os.path.join(cls.dataDir, 'cfht/exports.yaml'),
71 refcats=refcats,
72 refcatPath=refcatPath)
74 @classmethod
75 def tearDownClass(cls):
76 shutil.rmtree(cls.repopath, ignore_errors=True)
78 def setUp(self):
79 np.random.seed(200)
81 # DEBUG messages can help track down failures.
82 logger = lsst.log.Log.getLogger('lsst.jointcal')
83 logger.setLevel(lsst.log.DEBUG)
85 # Append `msg` arguments to assert failures.
86 self.longMessage = True
87 # absolute tolerance on positional errors of 10 micro-arcsecond
88 self.atol = 10.0 / (60 * 60 * 1e6)
90 # Maximum difference (see assertPairsAlmostEqual) for round-trip
91 # testing of the inverse for models 1 (simpler) and 2 (more.
92 # Replace either one for models that don't have as accurate an inverse.
93 self.inverseMaxDiff1 = 1e-5
94 self.inverseMaxDiff2 = 1e-5
96 self.firstIndex = 0 # for assignIndices
97 matchCut = 2.0 # arcseconds
98 minMeasurements = 2 # accept all star pairs.
100 jointcalControl = lsst.jointcal.JointcalControl("flux")
101 self.associations = lsst.jointcal.Associations()
102 config = lsst.jointcal.JointcalConfig()
103 config.load(os.path.join(os.path.dirname(__file__), "config/config.py"))
104 sourceSelector = config.sourceSelector.target(config=config.sourceSelector['science'])
106 # jointcal's cfht test data has 6 ccds and 2 visits.
107 self.visits = [849375, 850587]
108 self.detectors = [12, 13, 14, 21, 22, 23]
109 self.badVisit = -12345
110 self.badCcd = 888
112 butler = lsst.daf.butler.Butler(self.repopath, collections='singleFrame', instrument="MegaPrime")
114 self.catalogs = []
115 self.ccdImageList = []
116 table = make_schema_table()
117 inColumns = butler.get("sourceTable_visit", visit=self.visits[0])
118 columns, detColumn, ixxColumns = get_sourceTable_visit_columns(inColumns, config, sourceSelector)
119 catalogs = {v: sourceSelector.run(butler.get('sourceTable_visit',
120 visit=v,
121 parameters={'columns': columns})) for v in self.visits}
122 for (visit, detector) in itertools.product(self.visits, self.detectors):
123 goodSrc = extract_detector_catalog_from_visit_catalog(table,
124 catalogs[visit].sourceCat,
125 detector,
126 detColumn,
127 ixxColumns,
128 config.sourceFluxType,
129 logger)
130 dataId = {"detector": detector, "visit": visit}
131 visitInfo = butler.get('calexp.visitInfo', dataId=dataId)
132 detector = butler.get('calexp.detector', dataId=dataId)
133 ccdId = detector.getId()
134 wcs = butler.get('calexp.wcs', dataId=dataId)
135 bbox = butler.get('calexp.bbox', dataId=dataId)
136 filt = butler.get('calexp.filter', dataId=dataId)
137 filterName = filt.physicalLabel
138 photoCalib = lsst.afw.image.PhotoCalib(100.0, 1.0)
140 self.catalogs.append(goodSrc)
141 self.associations.createCcdImage(goodSrc,
142 wcs,
143 visitInfo,
144 bbox,
145 filterName,
146 photoCalib,
147 detector,
148 visit,
149 ccdId,
150 jointcalControl)
152 # Have to set the common tangent point so projectionHandler can use skyToCTP.
153 self.associations.computeCommonTangentPoint()
155 self.projectionHandler = lsst.jointcal.OneTPPerVisitHandler(self.associations.getCcdImageList())
157 self.associations.associateCatalogs(matchCut)
158 self.associations.prepareFittedStars(minMeasurements)
159 self.associations.deprojectFittedStars()
161 def _prepModels(self):
162 """Call this after model1 and model2 are created, to call assignIndices,
163 and instantiate the fitters.
164 """
165 posError = 0.02 # in pixels
166 # have to call this once or offsetParams will fail because the transform indices aren't defined
167 self.model1.assignIndices("Distortions", self.firstIndex)
168 self.fitter1 = lsst.jointcal.AstrometryFit(self.associations, self.model1, posError)
170 # have to call this once or offsetParams will fail because the transform indices aren't defined
171 self.model2.assignIndices("Distortions", self.firstIndex)
172 self.fitter2 = lsst.jointcal.AstrometryFit(self.associations, self.model2, posError)
174 def testMakeSkyWcsModel1(self):
175 self.checkMakeSkyWcsModel(self.model1, self.fitter1, self.inverseMaxDiff1)
177 def testMakeSkyWcsModel2(self):
178 self.checkMakeSkyWcsModel(self.model2, self.fitter2, self.inverseMaxDiff2)
180 def checkMakeSkyWcsModel(self, model, fitter, inverseMaxDiff):
181 """Test producing a SkyWcs on a model for every cdImage,
182 both post-initialization and after one fitting step.
184 Parameters
185 ----------
186 model : `lsst.jointcal.AstrometryModel`
187 The model to test.
188 fitter : `lsst.jointcal.FitterBase`
189 The fitter to use to step the model to test with new (reasonable) parameters.
190 inverseMaxDiff : `float`
191 Required accuracy on inverse transform.
192 See `lsst.afw.geom.utils.assertPairsAlmostEqual`.
194 """
195 # first test on as-initialized models
196 for ccdImage in self.associations.getCcdImageList():
197 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff)
199 # now shift the models to non-default, but more reasonable, values by taking one fitting step.
200 fitter.minimize("DistortionsVisit")
201 fitter.minimize("Distortions")
202 for ccdImage in self.associations.getCcdImageList():
203 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff)
205 def checkMakeSkyWcsOneCcdImage(self, model, ccdImage, inverseMaxDiff):
206 """Test converting the model of one ccdImage to a SkyWcs by comparing
207 to the original transform at the tangent plane.
209 Parameters
210 ----------
211 model : `lsst.jointcal.AstrometryModel`
212 The model to test.
213 ccdImage : `lsst.jointcal.CcdImage`
214 The ccdImage to extract from the model and test.
215 inverseMaxDiff : `float`
216 Required accuracy on inverse transform.
217 See `lsst.afw.geom.utils.assertPairsAlmostEqual`.
218 """
219 skyWcs = model.makeSkyWcs(ccdImage)
220 skyToTangentPlane = model.getSkyToTangentPlane(ccdImage)
221 mapping = model.getMapping(ccdImage)
223 bbox = ccdImage.getDetector().getBBox()
224 num = 200
225 xx = np.linspace(bbox.getMinX(), bbox.getMaxX(), num)
226 yy = np.linspace(bbox.getMinY(), bbox.getMaxY(), num)
227 points = [lsst.geom.Point2D(*xy) for xy in itertools.product(xx, yy)]
229 expects = []
230 forwards = []
231 inverses = []
232 spherePoints = skyWcs.pixelToSky(points)
233 inverses = skyWcs.skyToPixel(skyWcs.pixelToSky(points))
234 for point, spherePoint in zip(points, spherePoints):
235 # TODO: Fix these "Point"s once DM-4044 is done.
237 # jointcal's pixel->tangent-plane mapping
238 star = lsst.jointcal.star.BaseStar(point.getX(), point.getY(), 0, 0)
239 tpExpect = mapping.transformPosAndErrors(star)
240 expects.append(lsst.geom.Point2D(tpExpect.x, tpExpect.y))
242 # skywcs takes pixel->sky, and we then have to go sky->tangent-plane
243 onSky = lsst.jointcal.star.BaseStar(spherePoint.getLongitude().asDegrees(),
244 spherePoint.getLatitude().asDegrees(), 0, 0)
245 result = skyToTangentPlane.apply(onSky)
246 forwards.append(lsst.geom.Point2D(result.x, result.y))
248 self.assertPairListsAlmostEqual(forwards, expects)
249 # NOTE: assertPairListsAlmostEqual() compares absolute, not relative,
250 # values so the points along the ccd edge may exceed maxDiff while still
251 # being "close enough": set `inverseMaxDiff` accordingly.
252 self.assertPairListsAlmostEqual(inverses, points, maxDiff=inverseMaxDiff)
255class SimpleAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase):
256 """Test the `SimpleAstrometryModel`, with one mapping per ccd per visit."""
257 def setUp(self):
258 super().setUp()
259 self.order1 = 3
260 self.inverseMaxDiff1 = 2e-4
261 self.model1 = astrometryModels.SimpleAstrometryModel(self.associations.getCcdImageList(),
262 self.projectionHandler,
263 True,
264 order=self.order1)
266 self.order2 = 5
267 # NOTE: because assertPairListsAlmostEqual tests an absolute
268 # difference, we need this to be relatively high to avoid spurious
269 # incorrect values.
270 # Alternately, further increasing the order of the inverse polynomial
271 # in astrometryTransform.toAstMap() can improve the quality of the
272 # SkyWcs inverse, but that may not be wise for the more general use
273 # case due to the inverse then having too many wiggles.
274 self.inverseMaxDiff2 = 2e-2
275 self.model2 = astrometryModels.SimpleAstrometryModel(self.associations.getCcdImageList(),
276 self.projectionHandler,
277 False,
278 order=self.order2)
279 self._prepModels()
281 def _testGetNpar(self, model, order):
282 for ccdImage in self.associations.getCcdImageList():
283 result = model.getNpar(ccdImage)
284 self.assertEqual(result, getNParametersPolynomial(order))
286 def testGetNpar1(self):
287 self._testGetNpar(self.model1, self.order1)
289 def testGetNpar2(self):
290 self._testGetNpar(self.model2, self.order2)
292 def _testGetTotalParameters(self, model, order):
293 result = model.getTotalParameters()
294 expect = getNParametersPolynomial(order)*len(self.associations.getCcdImageList())
295 self.assertEqual(result, expect)
297 def testGetTotalParametersModel1(self):
298 self._testGetTotalParameters(self.model1, self.order1)
300 def testGetTotalParametersModel2(self):
301 self._testGetTotalParameters(self.model2, self.order2)
304class ConstrainedAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase):
305 """Test the `ConstrainedAstrometryModel`, with one mapping per ccd and one
306 mapping per visit.
307 """
308 def setUp(self):
309 super().setUp()
310 self.visitOrder1 = 3
311 self.chipOrder1 = 1
312 self.inverseMaxDiff1 = 1e-5
313 self.model1 = astrometryModels.ConstrainedAstrometryModel(self.associations.getCcdImageList(),
314 self.projectionHandler,
315 chipOrder=self.chipOrder1,
316 visitOrder=self.visitOrder1)
318 self.visitOrder2 = 5
319 self.chipOrder2 = 2
320 self.inverseMaxDiff2 = 8e-5
321 self.model2 = astrometryModels.ConstrainedAstrometryModel(self.associations.getCcdImageList(),
322 self.projectionHandler,
323 chipOrder=self.chipOrder2,
324 visitOrder=self.visitOrder2)
325 self._prepModels()
327 # 22 is closest to the center of the focal plane in this data, so it is not fit.
328 self.fixedCcd = 22
330 def _polyParams(self, chipOrder, visitOrder):
331 """Number of parameters per polynomial is (d+1)(d+2)/2, summed over
332 polynomials, times 2 polynomials per dimension.
333 The chip transform is fixed for one chip, so only visitOrder matters
334 if chipOrder is None.
335 """
336 params = getNParametersPolynomial(visitOrder)
337 if chipOrder is not None:
338 params += getNParametersPolynomial(chipOrder)
339 return params
341 def _testGetNpar(self, model, chipOrder, visitOrder):
342 def checkParams(ccdImage, model, chipOrder, visitOrder):
343 result = model.getNpar(ccdImage)
344 failMsg = "ccdImage: %s, with chipOrder %s and visitOrder %s"%(ccdImage.getName(),
345 chipOrder,
346 visitOrder)
347 self.assertEqual(result, self._polyParams(chipOrder, visitOrder), msg=failMsg)
349 for ccdImage in self.associations.getCcdImageList():
350 realChipOrder = None if ccdImage.getCcdId() == self.fixedCcd else chipOrder
351 checkParams(ccdImage, model, realChipOrder, visitOrder)
353 def testGetNpar1(self):
354 self._testGetNpar(self.model1, self.chipOrder1, self.visitOrder1)
356 def testGetNpar2(self):
357 self._testGetNpar(self.model2, self.chipOrder2, self.visitOrder2)
359 def _testGetTotalParameters(self, model, chipOrder, visitOrder):
360 result = model.getTotalParameters()
361 # one sensor is held fixed, hence len(ccds)-1
362 expect = getNParametersPolynomial(chipOrder)*(len(self.detectors) - 1) + \
363 getNParametersPolynomial(visitOrder)*len(self.visits)
364 self.assertEqual(result, expect)
366 def testGetTotalParametersModel1(self):
367 self._testGetTotalParameters(self.model1, self.chipOrder1, self.visitOrder1)
369 def testGetTotalParametersModel2(self):
370 self._testGetTotalParameters(self.model2, self.chipOrder2, self.visitOrder2)
372 def checkGetChipTransform(self, model):
373 # Check valid ccds
374 for ccd in self.detectors:
375 try:
376 model.getChipTransform(ccd)
377 except lsst.pex.exceptions.wrappers.InvalidParameterError:
378 self.fail("model: {} raised on ccd: {}, but should not have.".format(model, ccd))
380 # Check an invalid ccd
381 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm:
382 model.getChipTransform(self.badCcd)
383 errMsg = f"No such chipId: {self.badCcd} among [{', '.join(str(d) for d in self.detectors)}]"
384 self.assertIn(errMsg, str(cm.exception))
386 def testGetChipTransform(self):
387 """getChipTransform should get each known transform, and raise with an
388 appropriate message otherwise.
389 """
390 self.checkGetChipTransform(self.model1)
391 self.checkGetChipTransform(self.model2)
393 def checkGetVisitTransform(self, model):
394 # Check valid visits
395 for visit in self.visits:
396 try:
397 model.getVisitTransform(visit)
398 except lsst.pex.exceptions.wrappers.InvalidParameterError:
399 self.fail("model: {} raised on visit: {}, but should not have.".format(model, visit))
401 # Check an invalid visit
402 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm:
403 model.getVisitTransform(self.badVisit)
404 errMsg = "No such visitId: {} among [{}]".format(self.badVisit,
405 ", ".join(str(v) for v in self.visits))
406 self.assertIn(errMsg, str(cm.exception))
408 def testGetVisitTransform(self):
409 """getVisitTransform should get each known transform, and raise with an
410 appropriate message otherwise.
411 """
412 self.checkGetVisitTransform(self.model1)
413 self.checkGetVisitTransform(self.model2)
415 def testValidate(self):
416 """Test that invalid models fail validate(), and that valid ones pass.
417 """
418 # We need at least 0 degrees of freedom (data - parameters) for the model to be valid.
419 # Note: model1 has 70 total parameters (2 visits*20 params + (6-1) sensors*5 params)
420 self.assertTrue(self.model1.validate(self.ccdImageList, 0))
421 self.assertFalse(self.model1.validate(self.ccdImageList, -1))
424class MemoryTester(lsst.utils.tests.MemoryTestCase):
425 pass
428def setup_module(module):
429 lsst.utils.tests.init()
432if __name__ == "__main__": 432 ↛ 433line 432 didn't jump to line 433, because the condition on line 432 was never true
433 lsst.utils.tests.init()
434 unittest.main()