Coverage for tests/test_astrometryModel.py: 23%
211 statements
« prev ^ index » next coverage.py v7.2.1, created at 2023-03-12 03:55 -0700
« prev ^ index » next coverage.py v7.2.1, created at 2023-03-12 03:55 -0700
1# This file is part of jointcal.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""Tests of astrometryModels (simple, constrained).
24Includes tests of producing a Wcs from a model.
25"""
26import itertools
27import os
28import numpy as np
30import unittest
31import lsst.utils.tests
33import lsst.afw.cameraGeom
34import lsst.afw.geom
35import lsst.afw.table
36import lsst.afw.image
37import lsst.afw.image.utils
38import lsst.daf.persistence
39import lsst.geom
40import lsst.jointcal
41from lsst.jointcal import astrometryModels
42import lsst.log
43from lsst.meas.algorithms import astrometrySourceSelector
46def getNParametersPolynomial(order):
47 """Number of parameters in an astrometry polynomial model is 2 * (d+1)(d+2)/2."""
48 return (order + 1)*(order + 2)
51class AstrometryModelTestBase:
52 @classmethod
53 def setUpClass(cls):
54 try:
55 cls.dataDir = lsst.utils.getPackageDir('testdata_jointcal')
56 # NOTE: the below is to facilitate testing with hsc test data,
57 # using chips far apart on the focal plane. See the note in setup()
58 # below for details. Using it requires having recently-processed
59 # singleFrame output in a rerun directory in validation_data_hsc.
60 # cls.dataDir = lsst.utils.getPackageDir('validation_data_hsc')
61 except LookupError:
62 raise unittest.SkipTest("testdata_jointcal not setup")
64 def setUp(self):
65 np.random.seed(200)
67 # DEBUG messages can help track down failures.
68 logger = lsst.log.Log.getLogger('jointcal')
69 logger.setLevel(lsst.log.DEBUG)
71 # Append `msg` arguments to assert failures.
72 self.longMessage = True
73 # absolute tolerance on positional errors of 10 micro-arcsecond
74 self.atol = 10.0 / (60 * 60 * 1e6)
76 # Maximum difference (see assertPairsAlmostEqual) for round-trip
77 # testing of the inverse for models 1 (simpler) and 2 (more.
78 # Replace either one for models that don't have as accurate an inverse.
79 self.inverseMaxDiff1 = 1e-5
80 self.inverseMaxDiff2 = 1e-5
82 self.firstIndex = 0 # for assignIndices
83 matchCut = 2.0 # arcseconds
84 minMeasurements = 2 # accept all star pairs.
86 jointcalControl = lsst.jointcal.JointcalControl("slot_CalibFlux")
87 self.associations = lsst.jointcal.Associations()
88 # Work around the fact that the testdata_jointcal catalogs were produced
89 # before DM-13493, and so have a different definition of the interpolated flag.
90 sourceSelectorConfig = astrometrySourceSelector.AstrometrySourceSelectorConfig()
91 sourceSelectorConfig.badFlags.append("base_PixelFlags_flag_interpolated")
92 sourceSelector = astrometrySourceSelector.AstrometrySourceSelectorTask(config=sourceSelectorConfig)
94 # Ensure that the filter list is reset for each test so that we avoid
95 # confusion or contamination each time we create a cfht camera below.
96 lsst.afw.image.utils.resetFilters()
98 # jointcal's cfht test data has 6 ccds and 2 visits.
99 inputDir = os.path.join(self.dataDir, 'cfht')
100 self.visits = [849375, 850587]
101 self.ccds = [12, 13, 14, 21, 22, 23]
102 self.badVisit = -12345
103 self.badCcd = 888
105 self.butler = lsst.daf.persistence.Butler(inputDir)
107 self.catalogs = []
108 self.ccdImageList = []
109 for (visit, ccd) in itertools.product(self.visits, self.ccds):
110 dataRef = self.butler.dataRef('calexp', visit=visit, ccd=ccd)
112 src = dataRef.get("src", flags=lsst.afw.table.SOURCE_IO_NO_FOOTPRINTS, immediate=True)
113 goodSrc = sourceSelector.run(src)
114 # Need memory contiguity to do vector-like things on the sourceCat.
115 goodSrc = goodSrc.sourceCat.copy(deep=True)
117 visitInfo = dataRef.get('calexp_visitInfo')
118 detector = dataRef.get('calexp_detector')
119 ccdId = detector.getId()
120 wcs = dataRef.get('calexp_wcs')
121 bbox = dataRef.get('calexp_bbox')
122 filt = dataRef.get('calexp_filterLabel')
123 filterName = filt.physicalLabel
124 photoCalib = lsst.afw.image.PhotoCalib(100.0, 1.0)
126 self.catalogs.append(goodSrc)
127 self.associations.createCcdImage(goodSrc,
128 wcs,
129 visitInfo,
130 bbox,
131 filterName,
132 photoCalib,
133 detector,
134 visit,
135 ccdId,
136 jointcalControl)
138 # Have to set the common tangent point so projectionHandler can use skyToCTP.
139 self.associations.computeCommonTangentPoint()
141 self.projectionHandler = lsst.jointcal.OneTPPerVisitHandler(self.associations.getCcdImageList())
143 self.associations.associateCatalogs(matchCut)
144 self.associations.prepareFittedStars(minMeasurements)
145 self.associations.deprojectFittedStars()
147 def _prepModels(self):
148 """Call this after model1 and model2 are created, to call assignIndices,
149 and instantiate the fitters.
150 """
151 posError = 0.02 # in pixels
152 # have to call this once or offsetParams will fail because the transform indices aren't defined
153 self.model1.assignIndices("Distortions", self.firstIndex)
154 self.fitter1 = lsst.jointcal.AstrometryFit(self.associations, self.model1, posError)
156 # have to call this once or offsetParams will fail because the transform indices aren't defined
157 self.model2.assignIndices("Distortions", self.firstIndex)
158 self.fitter2 = lsst.jointcal.AstrometryFit(self.associations, self.model2, posError)
160 def testMakeSkyWcsModel1(self):
161 self.CheckMakeSkyWcsModel(self.model1, self.fitter1, self.inverseMaxDiff1)
163 def testMakeSkyWcsModel2(self):
164 self.CheckMakeSkyWcsModel(self.model2, self.fitter2, self.inverseMaxDiff2)
166 def CheckMakeSkyWcsModel(self, model, fitter, inverseMaxDiff):
167 """Test producing a SkyWcs on a model for every cdImage,
168 both post-initialization and after one fitting step.
170 Parameters
171 ----------
172 model : `lsst.jointcal.AstrometryModel`
173 The model to test.
174 fitter : `lsst.jointcal.FitterBase`
175 The fitter to use to step the model to test with new (reasonable) parameters.
176 inverseMaxDiff : `float`
177 Required accuracy on inverse transform.
178 See `lsst.afw.geom.utils.assertPairsAlmostEqual`.
180 """
181 # first test on as-initialized models
182 for ccdImage in self.associations.getCcdImageList():
183 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff)
185 # now shift the models to non-default, but more reasonable, values by taking one fitting step.
186 fitter.minimize("DistortionsVisit")
187 fitter.minimize("Distortions")
188 for ccdImage in self.associations.getCcdImageList():
189 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff)
191 def checkMakeSkyWcsOneCcdImage(self, model, ccdImage, inverseMaxDiff):
192 """Test converting the model of one ccdImage to a SkyWcs by comparing
193 to the original transform at the tangent plane.
195 Parameters
196 ----------
197 model : `lsst.jointcal.AstrometryModel`
198 The model to test.
199 ccdImage : `lsst.jointcal.CcdImage`
200 The ccdImage to extract from the model and test.
201 inverseMaxDiff : `float`
202 Required accuracy on inverse transform.
203 See `lsst.afw.geom.utils.assertPairsAlmostEqual`.
204 """
205 skyWcs = model.makeSkyWcs(ccdImage)
206 skyToTangentPlane = model.getSkyToTangentPlane(ccdImage)
207 mapping = model.getMapping(ccdImage)
209 bbox = ccdImage.getDetector().getBBox()
210 num = 200
211 xx = np.linspace(bbox.getMinX(), bbox.getMaxX(), num)
212 yy = np.linspace(bbox.getMinY(), bbox.getMaxY(), num)
213 points = [lsst.geom.Point2D(*xy) for xy in itertools.product(xx, yy)]
215 expects = []
216 forwards = []
217 inverses = []
218 spherePoints = skyWcs.pixelToSky(points)
219 inverses = skyWcs.skyToPixel(skyWcs.pixelToSky(points))
220 for point, spherePoint in zip(points, spherePoints):
221 # TODO: Fix these "Point"s once DM-4044 is done.
223 # jointcal's pixel->tangent-plane mapping
224 star = lsst.jointcal.star.BaseStar(point.getX(), point.getY(), 0, 0)
225 tpExpect = mapping.transformPosAndErrors(star)
226 expects.append(lsst.geom.Point2D(tpExpect.x, tpExpect.y))
228 # skywcs takes pixel->sky, and we then have to go sky->tangent-plane
229 onSky = lsst.jointcal.star.BaseStar(spherePoint.getLongitude().asDegrees(),
230 spherePoint.getLatitude().asDegrees(), 0, 0)
231 result = skyToTangentPlane.apply(onSky)
232 forwards.append(lsst.geom.Point2D(result.x, result.y))
234 self.assertPairListsAlmostEqual(forwards, expects)
235 # NOTE: assertPairListsAlmostEqual() compares absolute, not relative,
236 # values so the points along the ccd edge may exceed maxDiff while still
237 # being "close enough": set `inverseMaxDiff` accordingly.
238 self.assertPairListsAlmostEqual(inverses, points, maxDiff=inverseMaxDiff)
241class SimpleAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase):
242 """Test the `SimpleAstrometryModel`, with one mapping per ccd per visit."""
243 def setUp(self):
244 super().setUp()
245 self.order1 = 3
246 self.inverseMaxDiff1 = 2e-5
247 self.model1 = astrometryModels.SimpleAstrometryModel(self.associations.getCcdImageList(),
248 self.projectionHandler,
249 True,
250 order=self.order1)
252 self.order2 = 5
253 # NOTE: because assertPairListsAlmostEqual tests an absolute
254 # difference, we need this to be relatively high to avoid spurious
255 # incorrect values.
256 # Alternately, further increasing the order of the inverse polynomial
257 # in astrometryTransform.toAstMap() can improve the quality of the
258 # SkyWcs inverse, but that may not be wise for the more general use
259 # case due to the inverse then having too many wiggles.
260 self.inverseMaxDiff2 = 2e-3
261 self.model2 = astrometryModels.SimpleAstrometryModel(self.associations.getCcdImageList(),
262 self.projectionHandler,
263 False,
264 order=self.order2)
265 self._prepModels()
267 def _testGetNpar(self, model, order):
268 for ccdImage in self.associations.getCcdImageList():
269 result = model.getNpar(ccdImage)
270 self.assertEqual(result, getNParametersPolynomial(order))
272 def testGetNpar1(self):
273 self._testGetNpar(self.model1, self.order1)
275 def testGetNpar2(self):
276 self._testGetNpar(self.model2, self.order2)
278 def _testGetTotalParameters(self, model, order):
279 result = model.getTotalParameters()
280 expect = getNParametersPolynomial(order)*len(self.associations.getCcdImageList())
281 self.assertEqual(result, expect)
283 def testGetTotalParametersModel1(self):
284 self._testGetTotalParameters(self.model1, self.order1)
286 def testGetTotalParametersModel2(self):
287 self._testGetTotalParameters(self.model2, self.order2)
290class ConstrainedAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase):
291 """Test the `ConstrainedAstrometryModel`, with one mapping per ccd and one
292 mapping per visit.
293 """
294 def setUp(self):
295 super().setUp()
296 self.visitOrder1 = 3
297 self.chipOrder1 = 1
298 self.inverseMaxDiff1 = 1e-5
299 self.model1 = astrometryModels.ConstrainedAstrometryModel(self.associations.getCcdImageList(),
300 self.projectionHandler,
301 chipOrder=self.chipOrder1,
302 visitOrder=self.visitOrder1)
304 self.visitOrder2 = 5
305 self.chipOrder2 = 2
306 self.inverseMaxDiff2 = 5e-5
307 self.model2 = astrometryModels.ConstrainedAstrometryModel(self.associations.getCcdImageList(),
308 self.projectionHandler,
309 chipOrder=self.chipOrder2,
310 visitOrder=self.visitOrder2)
311 self._prepModels()
313 # 22 is closest to the center of the focal plane in this data, so it is not fit.
314 self.fixedCcd = 22
316 def _polyParams(self, chipOrder, visitOrder):
317 """Number of parameters per polynomial is (d+1)(d+2)/2, summed over
318 polynomials, times 2 polynomials per dimension.
319 The chip transform is fixed for one chip, so only visitOrder matters
320 if chipOrder is None.
321 """
322 params = getNParametersPolynomial(visitOrder)
323 if chipOrder is not None:
324 params += getNParametersPolynomial(chipOrder)
325 return params
327 def _testGetNpar(self, model, chipOrder, visitOrder):
328 def checkParams(ccdImage, model, chipOrder, visitOrder):
329 result = model.getNpar(ccdImage)
330 failMsg = "ccdImage: %s, with chipOrder %s and visitOrder %s"%(ccdImage.getName(),
331 chipOrder,
332 visitOrder)
333 self.assertEqual(result, self._polyParams(chipOrder, visitOrder), msg=failMsg)
335 for ccdImage in self.associations.getCcdImageList():
336 realChipOrder = None if ccdImage.getCcdId() == self.fixedCcd else chipOrder
337 checkParams(ccdImage, model, realChipOrder, visitOrder)
339 def testGetNpar1(self):
340 self._testGetNpar(self.model1, self.chipOrder1, self.visitOrder1)
342 def testGetNpar2(self):
343 self._testGetNpar(self.model2, self.chipOrder2, self.visitOrder2)
345 def _testGetTotalParameters(self, model, chipOrder, visitOrder):
346 result = model.getTotalParameters()
347 # one sensor is held fixed, hence len(ccds)-1
348 expect = getNParametersPolynomial(chipOrder)*(len(self.ccds) - 1) + \
349 getNParametersPolynomial(visitOrder)*len(self.visits)
350 self.assertEqual(result, expect)
352 def testGetTotalParametersModel1(self):
353 self._testGetTotalParameters(self.model1, self.chipOrder1, self.visitOrder1)
355 def testGetTotalParametersModel2(self):
356 self._testGetTotalParameters(self.model2, self.chipOrder2, self.visitOrder2)
358 def checkGetChipTransform(self, model):
359 # Check valid ccds
360 for ccd in self.ccds:
361 try:
362 model.getChipTransform(ccd)
363 except lsst.pex.exceptions.wrappers.InvalidParameterError:
364 self.fail("model: {} raised on ccd: {}, but should not have.".format(model, ccd))
366 # Check an invalid ccd
367 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm:
368 model.getChipTransform(self.badCcd)
369 errMsg = "No such chipId: {} among [{}]".format(self.badCcd, ", ".join(str(ccd) for ccd in self.ccds))
370 self.assertIn(errMsg, str(cm.exception))
372 def testGetChipTransform(self):
373 """getChipTransform should get each known transform, and raise with an
374 appropriate message otherwise.
375 """
376 self.checkGetChipTransform(self.model1)
377 self.checkGetChipTransform(self.model2)
379 def checkGetVisitTransform(self, model):
380 # Check valid visits
381 for visit in self.visits:
382 try:
383 model.getVisitTransform(visit)
384 except lsst.pex.exceptions.wrappers.InvalidParameterError:
385 self.fail("model: {} raised on visit: {}, but should not have.".format(model, visit))
387 # Check an invalid visit
388 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm:
389 model.getVisitTransform(self.badVisit)
390 errMsg = "No such visitId: {} among [{}]".format(self.badVisit,
391 ", ".join(str(v) for v in self.visits))
392 self.assertIn(errMsg, str(cm.exception))
394 def testGetVisitTransform(self):
395 """getVisitTransform should get each known transform, and raise with an
396 appropriate message otherwise.
397 """
398 self.checkGetVisitTransform(self.model1)
399 self.checkGetVisitTransform(self.model2)
401 def testValidate(self):
402 """Test that invalid models fail validate(), and that valid ones pass.
403 """
404 # We need at least 0 degrees of freedom (data - parameters) for the model to be valid.
405 # Note: model1 has 70 total parameters (2 visits*20 params + (6-1) sensors*5 params)
406 self.assertTrue(self.model1.validate(self.ccdImageList, 0))
407 self.assertFalse(self.model1.validate(self.ccdImageList, -1))
410class MemoryTester(lsst.utils.tests.MemoryTestCase):
411 pass
414def setup_module(module):
415 lsst.utils.tests.init()
418if __name__ == "__main__": 418 ↛ 419line 418 didn't jump to line 419, because the condition on line 418 was never true
419 lsst.utils.tests.init()
420 unittest.main()