Coverage for tests/test_astrometryModel.py: 22%
217 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-26 17:33 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-26 17:33 +0000
1# This file is part of jointcal.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""Tests of astrometryModels (simple, constrained).
24Includes tests of producing a Wcs from a model.
25"""
26import itertools
27import os
28import numpy as np
29import shutil
31import unittest
32import lsst.utils.tests
34import lsst.afw.cameraGeom
35import lsst.afw.geom
36import lsst.afw.table
37import lsst.afw.image
38import lsst.geom
39import lsst.log
41import lsst.jointcal
42from lsst.jointcal.jointcal import (make_schema_table, get_sourceTable_visit_columns,
43 extract_detector_catalog_from_visit_catalog)
44from lsst.jointcal import astrometryModels
45from jointcalTestBase import importRepository
48def getNParametersPolynomial(order):
49 """Number of parameters in an astrometry polynomial model is 2 * (d+1)(d+2)/2."""
50 return (order + 1)*(order + 2)
53class AstrometryModelTestBase:
54 """Test the jointcal AstrometryModel concrete classes, using CFHT data.
55 """
56 @classmethod
57 def setUpClass(cls):
58 try:
59 cls.dataDir = lsst.utils.getPackageDir('testdata_jointcal')
60 except LookupError:
61 raise unittest.SkipTest("testdata_jointcal not setup")
63 refcatPath = os.path.join(cls.dataDir, "cfht")
64 refcats = {"gaia_dr2_20200414": os.path.join(refcatPath, "gaia_dr2_20200414.ecsv"),
65 "ps1_pv3_3pi_20170110": os.path.join(refcatPath, "ps1_pv3_3pi_20170110.ecsv"),
66 "sdss_dr9_fink_v5b": os.path.join(refcatPath, "sdss-dr9-fink-v5b.ecsv")}
67 # Share one repo, since none of these tests write anything.
68 cls.repopath = importRepository("lsst.obs.cfht.MegaPrime",
69 os.path.join(cls.dataDir, 'cfht/repo/'),
70 os.path.join(cls.dataDir, 'cfht/exports.yaml'),
71 refcats=refcats,
72 refcatPath=refcatPath)
74 @classmethod
75 def tearDownClass(cls):
76 shutil.rmtree(cls.repopath, ignore_errors=True)
78 def setUp(self):
79 np.random.seed(200)
81 # DEBUG messages can help track down failures.
82 logger = lsst.log.Log.getLogger('lsst.jointcal')
83 logger.setLevel(lsst.log.DEBUG)
85 # Append `msg` arguments to assert failures.
86 self.longMessage = True
87 # absolute tolerance on positional errors of 10 micro-arcsecond
88 self.atol = 10.0 / (60 * 60 * 1e6)
90 # Maximum difference (see assertPairsAlmostEqual) for round-trip
91 # testing of the inverse for models 1 (simpler) and 2 (more.
92 # Replace either one for models that don't have as accurate an inverse.
93 self.inverseMaxDiff1 = 1e-5
94 self.inverseMaxDiff2 = 1e-5
96 self.firstIndex = 0 # for assignIndices
97 matchCut = 2.0 # arcseconds
98 minMeasurements = 2 # accept all star pairs.
100 jointcalControl = lsst.jointcal.JointcalControl("flux")
101 self.associations = lsst.jointcal.Associations()
102 config = lsst.jointcal.JointcalConfig()
103 config.load(os.path.join(os.path.dirname(__file__), "config/config.py"))
104 sourceSelector = config.sourceSelector.target(config=config.sourceSelector['science'])
106 # jointcal's cfht test data has 6 ccds and 2 visits.
107 self.visits = [849375, 850587]
108 self.detectors = [12, 13, 14, 21, 22, 23]
109 self.badVisit = -12345
110 self.badCcd = 888
112 butler = lsst.daf.butler.Butler(self.repopath, collections='singleFrame', instrument="MegaPrime")
114 self.catalogs = []
115 self.ccdImageList = []
116 table = make_schema_table()
117 inColumns = butler.get("sourceTable_visit", visit=self.visits[0])
118 columns, ixxColumns = get_sourceTable_visit_columns(inColumns, config, sourceSelector)
119 catalogs = {v: sourceSelector.run(butler.get('sourceTable_visit',
120 visit=v,
121 parameters={'columns': columns})) for v in self.visits}
122 for (visit, detector) in itertools.product(self.visits, self.detectors):
123 goodSrc = extract_detector_catalog_from_visit_catalog(table,
124 catalogs[visit].sourceCat,
125 detector,
126 ixxColumns,
127 config.sourceFluxType,
128 logger)
129 dataId = {"detector": detector, "visit": visit}
130 visitInfo = butler.get('calexp.visitInfo', dataId=dataId)
131 detector = butler.get('calexp.detector', dataId=dataId)
132 ccdId = detector.getId()
133 wcs = butler.get('calexp.wcs', dataId=dataId)
134 bbox = butler.get('calexp.bbox', dataId=dataId)
135 filt = butler.get('calexp.filter', dataId=dataId)
136 filterName = filt.physicalLabel
137 photoCalib = lsst.afw.image.PhotoCalib(100.0, 1.0)
139 self.catalogs.append(goodSrc)
140 self.associations.createCcdImage(goodSrc,
141 wcs,
142 visitInfo,
143 bbox,
144 filterName,
145 photoCalib,
146 detector,
147 visit,
148 ccdId,
149 jointcalControl)
151 # Have to set the common tangent point so projectionHandler can use skyToCTP.
152 self.associations.computeCommonTangentPoint()
154 self.projectionHandler = lsst.jointcal.OneTPPerVisitHandler(self.associations.getCcdImageList())
156 self.associations.associateCatalogs(matchCut)
157 self.associations.prepareFittedStars(minMeasurements)
158 self.associations.deprojectFittedStars()
160 def _prepModels(self):
161 """Call this after model1 and model2 are created, to call assignIndices,
162 and instantiate the fitters.
163 """
164 posError = 0.02 # in pixels
165 # have to call this once or offsetParams will fail because the transform indices aren't defined
166 self.model1.assignIndices("Distortions", self.firstIndex)
167 self.fitter1 = lsst.jointcal.AstrometryFit(self.associations, self.model1, posError)
169 # have to call this once or offsetParams will fail because the transform indices aren't defined
170 self.model2.assignIndices("Distortions", self.firstIndex)
171 self.fitter2 = lsst.jointcal.AstrometryFit(self.associations, self.model2, posError)
173 def testMakeSkyWcsModel1(self):
174 self.checkMakeSkyWcsModel(self.model1, self.fitter1, self.inverseMaxDiff1)
176 def testMakeSkyWcsModel2(self):
177 self.checkMakeSkyWcsModel(self.model2, self.fitter2, self.inverseMaxDiff2)
179 def checkMakeSkyWcsModel(self, model, fitter, inverseMaxDiff):
180 """Test producing a SkyWcs on a model for every cdImage,
181 both post-initialization and after one fitting step.
183 Parameters
184 ----------
185 model : `lsst.jointcal.AstrometryModel`
186 The model to test.
187 fitter : `lsst.jointcal.FitterBase`
188 The fitter to use to step the model to test with new (reasonable) parameters.
189 inverseMaxDiff : `float`
190 Required accuracy on inverse transform.
191 See `lsst.afw.geom.utils.assertPairsAlmostEqual`.
193 """
194 # first test on as-initialized models
195 for ccdImage in self.associations.getCcdImageList():
196 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff)
198 # now shift the models to non-default, but more reasonable, values by taking one fitting step.
199 fitter.minimize("DistortionsVisit")
200 fitter.minimize("Distortions")
201 for ccdImage in self.associations.getCcdImageList():
202 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff)
204 def checkMakeSkyWcsOneCcdImage(self, model, ccdImage, inverseMaxDiff):
205 """Test converting the model of one ccdImage to a SkyWcs by comparing
206 to the original transform at the tangent plane.
208 Parameters
209 ----------
210 model : `lsst.jointcal.AstrometryModel`
211 The model to test.
212 ccdImage : `lsst.jointcal.CcdImage`
213 The ccdImage to extract from the model and test.
214 inverseMaxDiff : `float`
215 Required accuracy on inverse transform.
216 See `lsst.afw.geom.utils.assertPairsAlmostEqual`.
217 """
218 skyWcs = model.makeSkyWcs(ccdImage)
219 skyToTangentPlane = model.getSkyToTangentPlane(ccdImage)
220 mapping = model.getMapping(ccdImage)
222 bbox = ccdImage.getDetector().getBBox()
223 num = 200
224 xx = np.linspace(bbox.getMinX(), bbox.getMaxX(), num)
225 yy = np.linspace(bbox.getMinY(), bbox.getMaxY(), num)
226 points = [lsst.geom.Point2D(*xy) for xy in itertools.product(xx, yy)]
228 expects = []
229 forwards = []
230 inverses = []
231 spherePoints = skyWcs.pixelToSky(points)
232 inverses = skyWcs.skyToPixel(skyWcs.pixelToSky(points))
233 for point, spherePoint in zip(points, spherePoints):
234 # TODO: Fix these "Point"s once DM-4044 is done.
236 # jointcal's pixel->tangent-plane mapping
237 star = lsst.jointcal.star.BaseStar(point.getX(), point.getY(), 0, 0)
238 tpExpect = mapping.transformPosAndErrors(star)
239 expects.append(lsst.geom.Point2D(tpExpect.x, tpExpect.y))
241 # skywcs takes pixel->sky, and we then have to go sky->tangent-plane
242 onSky = lsst.jointcal.star.BaseStar(spherePoint.getLongitude().asDegrees(),
243 spherePoint.getLatitude().asDegrees(), 0, 0)
244 result = skyToTangentPlane.apply(onSky)
245 forwards.append(lsst.geom.Point2D(result.x, result.y))
247 self.assertPairListsAlmostEqual(forwards, expects)
248 # NOTE: assertPairListsAlmostEqual() compares absolute, not relative,
249 # values so the points along the ccd edge may exceed maxDiff while still
250 # being "close enough": set `inverseMaxDiff` accordingly.
251 self.assertPairListsAlmostEqual(inverses, points, maxDiff=inverseMaxDiff)
254class SimpleAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase):
255 """Test the `SimpleAstrometryModel`, with one mapping per ccd per visit."""
256 def setUp(self):
257 super().setUp()
258 self.order1 = 3
259 self.inverseMaxDiff1 = 2e-4
260 self.model1 = astrometryModels.SimpleAstrometryModel(self.associations.getCcdImageList(),
261 self.projectionHandler,
262 True,
263 order=self.order1)
265 self.order2 = 5
266 # NOTE: because assertPairListsAlmostEqual tests an absolute
267 # difference, we need this to be relatively high to avoid spurious
268 # incorrect values.
269 # Alternately, further increasing the order of the inverse polynomial
270 # in astrometryTransform.toAstMap() can improve the quality of the
271 # SkyWcs inverse, but that may not be wise for the more general use
272 # case due to the inverse then having too many wiggles.
273 self.inverseMaxDiff2 = 2e-2
274 self.model2 = astrometryModels.SimpleAstrometryModel(self.associations.getCcdImageList(),
275 self.projectionHandler,
276 False,
277 order=self.order2)
278 self._prepModels()
280 def _testGetNpar(self, model, order):
281 for ccdImage in self.associations.getCcdImageList():
282 result = model.getNpar(ccdImage)
283 self.assertEqual(result, getNParametersPolynomial(order))
285 def testGetNpar1(self):
286 self._testGetNpar(self.model1, self.order1)
288 def testGetNpar2(self):
289 self._testGetNpar(self.model2, self.order2)
291 def _testGetTotalParameters(self, model, order):
292 result = model.getTotalParameters()
293 expect = getNParametersPolynomial(order)*len(self.associations.getCcdImageList())
294 self.assertEqual(result, expect)
296 def testGetTotalParametersModel1(self):
297 self._testGetTotalParameters(self.model1, self.order1)
299 def testGetTotalParametersModel2(self):
300 self._testGetTotalParameters(self.model2, self.order2)
303class ConstrainedAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase):
304 """Test the `ConstrainedAstrometryModel`, with one mapping per ccd and one
305 mapping per visit.
306 """
307 def setUp(self):
308 super().setUp()
309 self.visitOrder1 = 3
310 self.chipOrder1 = 1
311 self.inverseMaxDiff1 = 1e-5
312 self.model1 = astrometryModels.ConstrainedAstrometryModel(self.associations.getCcdImageList(),
313 self.projectionHandler,
314 chipOrder=self.chipOrder1,
315 visitOrder=self.visitOrder1)
317 self.visitOrder2 = 5
318 self.chipOrder2 = 2
319 self.inverseMaxDiff2 = 8e-5
320 self.model2 = astrometryModels.ConstrainedAstrometryModel(self.associations.getCcdImageList(),
321 self.projectionHandler,
322 chipOrder=self.chipOrder2,
323 visitOrder=self.visitOrder2)
324 self._prepModels()
326 # 22 is closest to the center of the focal plane in this data, so it is not fit.
327 self.fixedCcd = 22
329 def _polyParams(self, chipOrder, visitOrder):
330 """Number of parameters per polynomial is (d+1)(d+2)/2, summed over
331 polynomials, times 2 polynomials per dimension.
332 The chip transform is fixed for one chip, so only visitOrder matters
333 if chipOrder is None.
334 """
335 params = getNParametersPolynomial(visitOrder)
336 if chipOrder is not None:
337 params += getNParametersPolynomial(chipOrder)
338 return params
340 def _testGetNpar(self, model, chipOrder, visitOrder):
341 def checkParams(ccdImage, model, chipOrder, visitOrder):
342 result = model.getNpar(ccdImage)
343 failMsg = "ccdImage: %s, with chipOrder %s and visitOrder %s"%(ccdImage.getName(),
344 chipOrder,
345 visitOrder)
346 self.assertEqual(result, self._polyParams(chipOrder, visitOrder), msg=failMsg)
348 for ccdImage in self.associations.getCcdImageList():
349 realChipOrder = None if ccdImage.getCcdId() == self.fixedCcd else chipOrder
350 checkParams(ccdImage, model, realChipOrder, visitOrder)
352 def testGetNpar1(self):
353 self._testGetNpar(self.model1, self.chipOrder1, self.visitOrder1)
355 def testGetNpar2(self):
356 self._testGetNpar(self.model2, self.chipOrder2, self.visitOrder2)
358 def _testGetTotalParameters(self, model, chipOrder, visitOrder):
359 result = model.getTotalParameters()
360 # one sensor is held fixed, hence len(ccds)-1
361 expect = getNParametersPolynomial(chipOrder)*(len(self.detectors) - 1) + \
362 getNParametersPolynomial(visitOrder)*len(self.visits)
363 self.assertEqual(result, expect)
365 def testGetTotalParametersModel1(self):
366 self._testGetTotalParameters(self.model1, self.chipOrder1, self.visitOrder1)
368 def testGetTotalParametersModel2(self):
369 self._testGetTotalParameters(self.model2, self.chipOrder2, self.visitOrder2)
371 def checkGetChipTransform(self, model):
372 # Check valid ccds
373 for ccd in self.detectors:
374 try:
375 model.getChipTransform(ccd)
376 except lsst.pex.exceptions.wrappers.InvalidParameterError:
377 self.fail("model: {} raised on ccd: {}, but should not have.".format(model, ccd))
379 # Check an invalid ccd
380 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm:
381 model.getChipTransform(self.badCcd)
382 errMsg = f"No such chipId: {self.badCcd} among [{', '.join(str(d) for d in self.detectors)}]"
383 self.assertIn(errMsg, str(cm.exception))
385 def testGetChipTransform(self):
386 """getChipTransform should get each known transform, and raise with an
387 appropriate message otherwise.
388 """
389 self.checkGetChipTransform(self.model1)
390 self.checkGetChipTransform(self.model2)
392 def checkGetVisitTransform(self, model):
393 # Check valid visits
394 for visit in self.visits:
395 try:
396 model.getVisitTransform(visit)
397 except lsst.pex.exceptions.wrappers.InvalidParameterError:
398 self.fail("model: {} raised on visit: {}, but should not have.".format(model, visit))
400 # Check an invalid visit
401 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm:
402 model.getVisitTransform(self.badVisit)
403 errMsg = "No such visitId: {} among [{}]".format(self.badVisit,
404 ", ".join(str(v) for v in self.visits))
405 self.assertIn(errMsg, str(cm.exception))
407 def testGetVisitTransform(self):
408 """getVisitTransform should get each known transform, and raise with an
409 appropriate message otherwise.
410 """
411 self.checkGetVisitTransform(self.model1)
412 self.checkGetVisitTransform(self.model2)
414 def testValidate(self):
415 """Test that invalid models fail validate(), and that valid ones pass.
416 """
417 # We need at least 0 degrees of freedom (data - parameters) for the model to be valid.
418 # Note: model1 has 70 total parameters (2 visits*20 params + (6-1) sensors*5 params)
419 self.assertTrue(self.model1.validate(self.ccdImageList, 0))
420 self.assertFalse(self.model1.validate(self.ccdImageList, -1))
423class MemoryTester(lsst.utils.tests.MemoryTestCase):
424 pass
427def setup_module(module):
428 lsst.utils.tests.init()
431if __name__ == "__main__": 431 ↛ 432line 431 didn't jump to line 432, because the condition on line 431 was never true
432 lsst.utils.tests.init()
433 unittest.main()