Coverage for tests/test_astrometryModel.py: 25%
219 statements
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-03 02:04 -0700
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-03 02:04 -0700
1# This file is part of jointcal.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""Tests of astrometryModels (simple, constrained).
24Includes tests of producing a Wcs from a model.
25"""
26import itertools
27import os
28import numpy as np
29import shutil
31import unittest
32import lsst.utils.tests
34import lsst.afw.cameraGeom
35import lsst.afw.geom
36import lsst.afw.table
37import lsst.afw.image
38import lsst.afw.image.utils
39import lsst.geom
40import lsst.log
42import lsst.jointcal
43from lsst.jointcal.jointcal import (make_schema_table, get_sourceTable_visit_columns,
44 extract_detector_catalog_from_visit_catalog)
45from lsst.jointcal import astrometryModels
46from jointcalTestBase import importRepository
49def getNParametersPolynomial(order):
50 """Number of parameters in an astrometry polynomial model is 2 * (d+1)(d+2)/2."""
51 return (order + 1)*(order + 2)
54class AstrometryModelTestBase:
55 """Test the jointcal AstrometryModel concrete classes, using CFHT data.
56 """
57 @classmethod
58 def setUpClass(cls):
59 try:
60 cls.dataDir = lsst.utils.getPackageDir('testdata_jointcal')
61 except LookupError:
62 raise unittest.SkipTest("testdata_jointcal not setup")
64 refcatPath = os.path.join(cls.dataDir, "cfht")
65 refcats = {"gaia_dr2_20200414": os.path.join(refcatPath, "gaia_dr2_20200414.ecsv"),
66 "ps1_pv3_3pi_20170110": os.path.join(refcatPath, "ps1_pv3_3pi_20170110.ecsv"),
67 "sdss_dr9_fink_v5b": os.path.join(refcatPath, "sdss-dr9-fink-v5b.ecsv")}
68 # Share one repo, since none of these tests write anything.
69 cls.repopath = importRepository("lsst.obs.cfht.MegaPrime",
70 os.path.join(cls.dataDir, 'cfht/repo/'),
71 os.path.join(cls.dataDir, 'cfht/exports.yaml'),
72 refcats=refcats,
73 refcatPath=refcatPath)
75 @classmethod
76 def tearDownClass(cls):
77 shutil.rmtree(cls.repopath, ignore_errors=True)
79 def setUp(self):
80 np.random.seed(200)
82 # DEBUG messages can help track down failures.
83 logger = lsst.log.Log.getLogger('lsst.jointcal')
84 logger.setLevel(lsst.log.DEBUG)
86 # Append `msg` arguments to assert failures.
87 self.longMessage = True
88 # absolute tolerance on positional errors of 10 micro-arcsecond
89 self.atol = 10.0 / (60 * 60 * 1e6)
91 # Maximum difference (see assertPairsAlmostEqual) for round-trip
92 # testing of the inverse for models 1 (simpler) and 2 (more.
93 # Replace either one for models that don't have as accurate an inverse.
94 self.inverseMaxDiff1 = 1e-5
95 self.inverseMaxDiff2 = 1e-5
97 self.firstIndex = 0 # for assignIndices
98 matchCut = 2.0 # arcseconds
99 minMeasurements = 2 # accept all star pairs.
101 jointcalControl = lsst.jointcal.JointcalControl("flux")
102 self.associations = lsst.jointcal.Associations()
103 config = lsst.jointcal.JointcalConfig()
104 config.load(os.path.join(os.path.dirname(__file__), "config/config.py"))
105 sourceSelector = config.sourceSelector.target(config=config.sourceSelector['science'])
107 # Ensure that the filter list is reset for each test so that we avoid
108 # confusion or contamination each time we create a cfht camera below.
109 lsst.afw.image.utils.resetFilters()
111 # jointcal's cfht test data has 6 ccds and 2 visits.
112 self.visits = [849375, 850587]
113 self.detectors = [12, 13, 14, 21, 22, 23]
114 self.badVisit = -12345
115 self.badCcd = 888
117 butler = lsst.daf.butler.Butler(self.repopath, collections='singleFrame', instrument="MegaPrime")
119 self.catalogs = []
120 self.ccdImageList = []
121 table = make_schema_table()
122 inColumns = butler.get("sourceTable_visit", visit=self.visits[0])
123 columns, detColumn, ixxColumns = get_sourceTable_visit_columns(inColumns, config, sourceSelector)
124 catalogs = {v: sourceSelector.run(butler.get('sourceTable_visit',
125 visit=v,
126 parameters={'columns': columns})) for v in self.visits}
127 for (visit, detector) in itertools.product(self.visits, self.detectors):
128 goodSrc = extract_detector_catalog_from_visit_catalog(table,
129 catalogs[visit].sourceCat,
130 detector,
131 detColumn,
132 ixxColumns,
133 config.sourceFluxType,
134 logger)
135 dataId = {"detector": detector, "visit": visit}
136 visitInfo = butler.get('calexp.visitInfo', dataId=dataId)
137 detector = butler.get('calexp.detector', dataId=dataId)
138 ccdId = detector.getId()
139 wcs = butler.get('calexp.wcs', dataId=dataId)
140 bbox = butler.get('calexp.bbox', dataId=dataId)
141 filt = butler.get('calexp.filterLabel', dataId=dataId)
142 filterName = filt.physicalLabel
143 photoCalib = lsst.afw.image.PhotoCalib(100.0, 1.0)
145 self.catalogs.append(goodSrc)
146 self.associations.createCcdImage(goodSrc,
147 wcs,
148 visitInfo,
149 bbox,
150 filterName,
151 photoCalib,
152 detector,
153 visit,
154 ccdId,
155 jointcalControl)
157 # Have to set the common tangent point so projectionHandler can use skyToCTP.
158 self.associations.computeCommonTangentPoint()
160 self.projectionHandler = lsst.jointcal.OneTPPerVisitHandler(self.associations.getCcdImageList())
162 self.associations.associateCatalogs(matchCut)
163 self.associations.prepareFittedStars(minMeasurements)
164 self.associations.deprojectFittedStars()
166 def _prepModels(self):
167 """Call this after model1 and model2 are created, to call assignIndices,
168 and instantiate the fitters.
169 """
170 posError = 0.02 # in pixels
171 # have to call this once or offsetParams will fail because the transform indices aren't defined
172 self.model1.assignIndices("Distortions", self.firstIndex)
173 self.fitter1 = lsst.jointcal.AstrometryFit(self.associations, self.model1, posError)
175 # have to call this once or offsetParams will fail because the transform indices aren't defined
176 self.model2.assignIndices("Distortions", self.firstIndex)
177 self.fitter2 = lsst.jointcal.AstrometryFit(self.associations, self.model2, posError)
179 def testMakeSkyWcsModel1(self):
180 self.checkMakeSkyWcsModel(self.model1, self.fitter1, self.inverseMaxDiff1)
182 def testMakeSkyWcsModel2(self):
183 self.checkMakeSkyWcsModel(self.model2, self.fitter2, self.inverseMaxDiff2)
185 def checkMakeSkyWcsModel(self, model, fitter, inverseMaxDiff):
186 """Test producing a SkyWcs on a model for every cdImage,
187 both post-initialization and after one fitting step.
189 Parameters
190 ----------
191 model : `lsst.jointcal.AstrometryModel`
192 The model to test.
193 fitter : `lsst.jointcal.FitterBase`
194 The fitter to use to step the model to test with new (reasonable) parameters.
195 inverseMaxDiff : `float`
196 Required accuracy on inverse transform.
197 See `lsst.afw.geom.utils.assertPairsAlmostEqual`.
199 """
200 # first test on as-initialized models
201 for ccdImage in self.associations.getCcdImageList():
202 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff)
204 # now shift the models to non-default, but more reasonable, values by taking one fitting step.
205 fitter.minimize("DistortionsVisit")
206 fitter.minimize("Distortions")
207 for ccdImage in self.associations.getCcdImageList():
208 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff)
210 def checkMakeSkyWcsOneCcdImage(self, model, ccdImage, inverseMaxDiff):
211 """Test converting the model of one ccdImage to a SkyWcs by comparing
212 to the original transform at the tangent plane.
214 Parameters
215 ----------
216 model : `lsst.jointcal.AstrometryModel`
217 The model to test.
218 ccdImage : `lsst.jointcal.CcdImage`
219 The ccdImage to extract from the model and test.
220 inverseMaxDiff : `float`
221 Required accuracy on inverse transform.
222 See `lsst.afw.geom.utils.assertPairsAlmostEqual`.
223 """
224 skyWcs = model.makeSkyWcs(ccdImage)
225 skyToTangentPlane = model.getSkyToTangentPlane(ccdImage)
226 mapping = model.getMapping(ccdImage)
228 bbox = ccdImage.getDetector().getBBox()
229 num = 200
230 xx = np.linspace(bbox.getMinX(), bbox.getMaxX(), num)
231 yy = np.linspace(bbox.getMinY(), bbox.getMaxY(), num)
232 points = [lsst.geom.Point2D(*xy) for xy in itertools.product(xx, yy)]
234 expects = []
235 forwards = []
236 inverses = []
237 spherePoints = skyWcs.pixelToSky(points)
238 inverses = skyWcs.skyToPixel(skyWcs.pixelToSky(points))
239 for point, spherePoint in zip(points, spherePoints):
240 # TODO: Fix these "Point"s once DM-4044 is done.
242 # jointcal's pixel->tangent-plane mapping
243 star = lsst.jointcal.star.BaseStar(point.getX(), point.getY(), 0, 0)
244 tpExpect = mapping.transformPosAndErrors(star)
245 expects.append(lsst.geom.Point2D(tpExpect.x, tpExpect.y))
247 # skywcs takes pixel->sky, and we then have to go sky->tangent-plane
248 onSky = lsst.jointcal.star.BaseStar(spherePoint.getLongitude().asDegrees(),
249 spherePoint.getLatitude().asDegrees(), 0, 0)
250 result = skyToTangentPlane.apply(onSky)
251 forwards.append(lsst.geom.Point2D(result.x, result.y))
253 self.assertPairListsAlmostEqual(forwards, expects)
254 # NOTE: assertPairListsAlmostEqual() compares absolute, not relative,
255 # values so the points along the ccd edge may exceed maxDiff while still
256 # being "close enough": set `inverseMaxDiff` accordingly.
257 self.assertPairListsAlmostEqual(inverses, points, maxDiff=inverseMaxDiff)
260class SimpleAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase):
261 """Test the `SimpleAstrometryModel`, with one mapping per ccd per visit."""
262 def setUp(self):
263 super().setUp()
264 self.order1 = 3
265 self.inverseMaxDiff1 = 2e-4
266 self.model1 = astrometryModels.SimpleAstrometryModel(self.associations.getCcdImageList(),
267 self.projectionHandler,
268 True,
269 order=self.order1)
271 self.order2 = 5
272 # NOTE: because assertPairListsAlmostEqual tests an absolute
273 # difference, we need this to be relatively high to avoid spurious
274 # incorrect values.
275 # Alternately, further increasing the order of the inverse polynomial
276 # in astrometryTransform.toAstMap() can improve the quality of the
277 # SkyWcs inverse, but that may not be wise for the more general use
278 # case due to the inverse then having too many wiggles.
279 self.inverseMaxDiff2 = 2e-2
280 self.model2 = astrometryModels.SimpleAstrometryModel(self.associations.getCcdImageList(),
281 self.projectionHandler,
282 False,
283 order=self.order2)
284 self._prepModels()
286 def _testGetNpar(self, model, order):
287 for ccdImage in self.associations.getCcdImageList():
288 result = model.getNpar(ccdImage)
289 self.assertEqual(result, getNParametersPolynomial(order))
291 def testGetNpar1(self):
292 self._testGetNpar(self.model1, self.order1)
294 def testGetNpar2(self):
295 self._testGetNpar(self.model2, self.order2)
297 def _testGetTotalParameters(self, model, order):
298 result = model.getTotalParameters()
299 expect = getNParametersPolynomial(order)*len(self.associations.getCcdImageList())
300 self.assertEqual(result, expect)
302 def testGetTotalParametersModel1(self):
303 self._testGetTotalParameters(self.model1, self.order1)
305 def testGetTotalParametersModel2(self):
306 self._testGetTotalParameters(self.model2, self.order2)
309class ConstrainedAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase):
310 """Test the `ConstrainedAstrometryModel`, with one mapping per ccd and one
311 mapping per visit.
312 """
313 def setUp(self):
314 super().setUp()
315 self.visitOrder1 = 3
316 self.chipOrder1 = 1
317 self.inverseMaxDiff1 = 1e-5
318 self.model1 = astrometryModels.ConstrainedAstrometryModel(self.associations.getCcdImageList(),
319 self.projectionHandler,
320 chipOrder=self.chipOrder1,
321 visitOrder=self.visitOrder1)
323 self.visitOrder2 = 5
324 self.chipOrder2 = 2
325 self.inverseMaxDiff2 = 8e-5
326 self.model2 = astrometryModels.ConstrainedAstrometryModel(self.associations.getCcdImageList(),
327 self.projectionHandler,
328 chipOrder=self.chipOrder2,
329 visitOrder=self.visitOrder2)
330 self._prepModels()
332 # 22 is closest to the center of the focal plane in this data, so it is not fit.
333 self.fixedCcd = 22
335 def _polyParams(self, chipOrder, visitOrder):
336 """Number of parameters per polynomial is (d+1)(d+2)/2, summed over
337 polynomials, times 2 polynomials per dimension.
338 The chip transform is fixed for one chip, so only visitOrder matters
339 if chipOrder is None.
340 """
341 params = getNParametersPolynomial(visitOrder)
342 if chipOrder is not None:
343 params += getNParametersPolynomial(chipOrder)
344 return params
346 def _testGetNpar(self, model, chipOrder, visitOrder):
347 def checkParams(ccdImage, model, chipOrder, visitOrder):
348 result = model.getNpar(ccdImage)
349 failMsg = "ccdImage: %s, with chipOrder %s and visitOrder %s"%(ccdImage.getName(),
350 chipOrder,
351 visitOrder)
352 self.assertEqual(result, self._polyParams(chipOrder, visitOrder), msg=failMsg)
354 for ccdImage in self.associations.getCcdImageList():
355 realChipOrder = None if ccdImage.getCcdId() == self.fixedCcd else chipOrder
356 checkParams(ccdImage, model, realChipOrder, visitOrder)
358 def testGetNpar1(self):
359 self._testGetNpar(self.model1, self.chipOrder1, self.visitOrder1)
361 def testGetNpar2(self):
362 self._testGetNpar(self.model2, self.chipOrder2, self.visitOrder2)
364 def _testGetTotalParameters(self, model, chipOrder, visitOrder):
365 result = model.getTotalParameters()
366 # one sensor is held fixed, hence len(ccds)-1
367 expect = getNParametersPolynomial(chipOrder)*(len(self.detectors) - 1) + \
368 getNParametersPolynomial(visitOrder)*len(self.visits)
369 self.assertEqual(result, expect)
371 def testGetTotalParametersModel1(self):
372 self._testGetTotalParameters(self.model1, self.chipOrder1, self.visitOrder1)
374 def testGetTotalParametersModel2(self):
375 self._testGetTotalParameters(self.model2, self.chipOrder2, self.visitOrder2)
377 def checkGetChipTransform(self, model):
378 # Check valid ccds
379 for ccd in self.detectors:
380 try:
381 model.getChipTransform(ccd)
382 except lsst.pex.exceptions.wrappers.InvalidParameterError:
383 self.fail("model: {} raised on ccd: {}, but should not have.".format(model, ccd))
385 # Check an invalid ccd
386 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm:
387 model.getChipTransform(self.badCcd)
388 errMsg = f"No such chipId: {self.badCcd} among [{', '.join(str(d) for d in self.detectors)}]"
389 self.assertIn(errMsg, str(cm.exception))
391 def testGetChipTransform(self):
392 """getChipTransform should get each known transform, and raise with an
393 appropriate message otherwise.
394 """
395 self.checkGetChipTransform(self.model1)
396 self.checkGetChipTransform(self.model2)
398 def checkGetVisitTransform(self, model):
399 # Check valid visits
400 for visit in self.visits:
401 try:
402 model.getVisitTransform(visit)
403 except lsst.pex.exceptions.wrappers.InvalidParameterError:
404 self.fail("model: {} raised on visit: {}, but should not have.".format(model, visit))
406 # Check an invalid visit
407 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm:
408 model.getVisitTransform(self.badVisit)
409 errMsg = "No such visitId: {} among [{}]".format(self.badVisit,
410 ", ".join(str(v) for v in self.visits))
411 self.assertIn(errMsg, str(cm.exception))
413 def testGetVisitTransform(self):
414 """getVisitTransform should get each known transform, and raise with an
415 appropriate message otherwise.
416 """
417 self.checkGetVisitTransform(self.model1)
418 self.checkGetVisitTransform(self.model2)
420 def testValidate(self):
421 """Test that invalid models fail validate(), and that valid ones pass.
422 """
423 # We need at least 0 degrees of freedom (data - parameters) for the model to be valid.
424 # Note: model1 has 70 total parameters (2 visits*20 params + (6-1) sensors*5 params)
425 self.assertTrue(self.model1.validate(self.ccdImageList, 0))
426 self.assertFalse(self.model1.validate(self.ccdImageList, -1))
429class MemoryTester(lsst.utils.tests.MemoryTestCase):
430 pass
433def setup_module(module):
434 lsst.utils.tests.init()
437if __name__ == "__main__": 437 ↛ 438line 437 didn't jump to line 438, because the condition on line 437 was never true
438 lsst.utils.tests.init()
439 unittest.main()