Coverage for tests/test_astrometryModel.py: 26%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of jointcal.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""Tests of astrometryModels (simple, constrained).
24Includes tests of producing a Wcs from a model.
25"""
26import itertools
27import os
28import numpy as np
29import shutil
31import unittest
32import lsst.utils.tests
34import lsst.afw.cameraGeom
35import lsst.afw.geom
36import lsst.afw.table
37import lsst.afw.image
38import lsst.afw.image.utils
39import lsst.daf.persistence
40import lsst.geom
41import lsst.log
43import lsst.jointcal
44from lsst.jointcal.jointcal import (make_schema_table, get_sourceTable_visit_columns,
45 extract_detector_catalog_from_visit_catalog)
46from lsst.jointcal import astrometryModels
47from jointcalTestBase import importRepository
50def getNParametersPolynomial(order):
51 """Number of parameters in an astrometry polynomial model is 2 * (d+1)(d+2)/2."""
52 return (order + 1)*(order + 2)
55class AstrometryModelTestBase:
56 """Test the jointcal AstrometryModel concrete classes, using CFHT data.
57 """
58 @classmethod
59 def setUpClass(cls):
60 try:
61 cls.dataDir = lsst.utils.getPackageDir('testdata_jointcal')
62 # NOTE: the below is to facilitate testing with hsc test data,
63 # using chips far apart on the focal plane. See the note in setup()
64 # below for details. Using it requires having recently-processed
65 # singleFrame output in a rerun directory in validation_data_hsc.
66 # cls.dataDir = lsst.utils.getPackageDir('validation_data_hsc')
67 except LookupError:
68 raise unittest.SkipTest("testdata_jointcal not setup")
70 refcatPath = os.path.join(cls.dataDir, "cfht")
71 refcats = {"gaia_dr2_20200414": os.path.join(refcatPath, "gaia_dr2_20200414.ecsv"),
72 "ps1_pv3_3pi_20170110": os.path.join(refcatPath, "ps1_pv3_3pi_20170110.ecsv"),
73 "sdss_dr9_fink_v5b": os.path.join(refcatPath, "sdss-dr9-fink-v5b.ecsv")}
74 # Share one repo, since none of these tests write anything.
75 cls.repopath = importRepository("lsst.obs.cfht.MegaPrime",
76 os.path.join(cls.dataDir, 'cfht/repo/'),
77 os.path.join(cls.dataDir, 'cfht/exports.yaml'),
78 refcats=refcats,
79 refcatPath=refcatPath)
81 @classmethod
82 def tearDownClass(cls):
83 shutil.rmtree(cls.repopath, ignore_errors=True)
85 def setUp(self):
86 np.random.seed(200)
88 # DEBUG messages can help track down failures.
89 logger = lsst.log.Log.getLogger('lsst.jointcal')
90 logger.setLevel(lsst.log.DEBUG)
92 # Append `msg` arguments to assert failures.
93 self.longMessage = True
94 # absolute tolerance on positional errors of 10 micro-arcsecond
95 self.atol = 10.0 / (60 * 60 * 1e6)
97 # Maximum difference (see assertPairsAlmostEqual) for round-trip
98 # testing of the inverse for models 1 (simpler) and 2 (more.
99 # Replace either one for models that don't have as accurate an inverse.
100 self.inverseMaxDiff1 = 1e-5
101 self.inverseMaxDiff2 = 1e-5
103 self.firstIndex = 0 # for assignIndices
104 matchCut = 2.0 # arcseconds
105 minMeasurements = 2 # accept all star pairs.
107 jointcalControl = lsst.jointcal.JointcalControl("flux")
108 self.associations = lsst.jointcal.Associations()
109 config = lsst.jointcal.JointcalConfig()
110 config.load(os.path.join(os.path.dirname(__file__), "config/config-gen3.py"))
111 sourceSelector = config.sourceSelector.target(config=config.sourceSelector['science'])
113 # Ensure that the filter list is reset for each test so that we avoid
114 # confusion or contamination each time we create a cfht camera below.
115 lsst.afw.image.utils.resetFilters()
117 # jointcal's cfht test data has 6 ccds and 2 visits.
118 self.visits = [849375, 850587]
119 self.detectors = [12, 13, 14, 21, 22, 23]
120 self.badVisit = -12345
121 self.badCcd = 888
123 butler = lsst.daf.butler.Butler(self.repopath, collections='singleFrame', instrument="MegaPrime")
125 self.catalogs = []
126 self.ccdImageList = []
127 table = make_schema_table()
128 inColumns = butler.get("sourceTable_visit", visit=self.visits[0])
129 columns, detColumn, ixxColumns = get_sourceTable_visit_columns(inColumns, config, sourceSelector)
130 catalogs = {v: sourceSelector.run(butler.get('sourceTable_visit',
131 visit=v,
132 parameters={'columns': columns})) for v in self.visits}
133 for (visit, detector) in itertools.product(self.visits, self.detectors):
134 goodSrc = extract_detector_catalog_from_visit_catalog(table,
135 catalogs[visit].sourceCat,
136 detector,
137 detColumn,
138 ixxColumns,
139 config.sourceFluxType,
140 logger)
141 dataId = {"detector": detector, "visit": visit}
142 visitInfo = butler.get('calexp.visitInfo', dataId=dataId)
143 detector = butler.get('calexp.detector', dataId=dataId)
144 ccdId = detector.getId()
145 wcs = butler.get('calexp.wcs', dataId=dataId)
146 bbox = butler.get('calexp.bbox', dataId=dataId)
147 filt = butler.get('calexp.filterLabel', dataId=dataId)
148 filterName = filt.physicalLabel
149 photoCalib = lsst.afw.image.PhotoCalib(100.0, 1.0)
151 self.catalogs.append(goodSrc)
152 self.associations.createCcdImage(goodSrc,
153 wcs,
154 visitInfo,
155 bbox,
156 filterName,
157 photoCalib,
158 detector,
159 visit,
160 ccdId,
161 jointcalControl)
163 # Have to set the common tangent point so projectionHandler can use skyToCTP.
164 self.associations.computeCommonTangentPoint()
166 self.projectionHandler = lsst.jointcal.OneTPPerVisitHandler(self.associations.getCcdImageList())
168 self.associations.associateCatalogs(matchCut)
169 self.associations.prepareFittedStars(minMeasurements)
170 self.associations.deprojectFittedStars()
172 def _prepModels(self):
173 """Call this after model1 and model2 are created, to call assignIndices,
174 and instantiate the fitters.
175 """
176 posError = 0.02 # in pixels
177 # have to call this once or offsetParams will fail because the transform indices aren't defined
178 self.model1.assignIndices("Distortions", self.firstIndex)
179 self.fitter1 = lsst.jointcal.AstrometryFit(self.associations, self.model1, posError)
181 # have to call this once or offsetParams will fail because the transform indices aren't defined
182 self.model2.assignIndices("Distortions", self.firstIndex)
183 self.fitter2 = lsst.jointcal.AstrometryFit(self.associations, self.model2, posError)
185 def testMakeSkyWcsModel1(self):
186 self.checkMakeSkyWcsModel(self.model1, self.fitter1, self.inverseMaxDiff1)
188 def testMakeSkyWcsModel2(self):
189 self.checkMakeSkyWcsModel(self.model2, self.fitter2, self.inverseMaxDiff2)
191 def checkMakeSkyWcsModel(self, model, fitter, inverseMaxDiff):
192 """Test producing a SkyWcs on a model for every cdImage,
193 both post-initialization and after one fitting step.
195 Parameters
196 ----------
197 model : `lsst.jointcal.AstrometryModel`
198 The model to test.
199 fitter : `lsst.jointcal.FitterBase`
200 The fitter to use to step the model to test with new (reasonable) parameters.
201 inverseMaxDiff : `float`
202 Required accuracy on inverse transform.
203 See `lsst.afw.geom.utils.assertPairsAlmostEqual`.
205 """
206 # first test on as-initialized models
207 for ccdImage in self.associations.getCcdImageList():
208 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff)
210 # now shift the models to non-default, but more reasonable, values by taking one fitting step.
211 fitter.minimize("DistortionsVisit")
212 fitter.minimize("Distortions")
213 for ccdImage in self.associations.getCcdImageList():
214 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff)
216 def checkMakeSkyWcsOneCcdImage(self, model, ccdImage, inverseMaxDiff):
217 """Test converting the model of one ccdImage to a SkyWcs by comparing
218 to the original transform at the tangent plane.
220 Parameters
221 ----------
222 model : `lsst.jointcal.AstrometryModel`
223 The model to test.
224 ccdImage : `lsst.jointcal.CcdImage`
225 The ccdImage to extract from the model and test.
226 inverseMaxDiff : `float`
227 Required accuracy on inverse transform.
228 See `lsst.afw.geom.utils.assertPairsAlmostEqual`.
229 """
230 skyWcs = model.makeSkyWcs(ccdImage)
231 skyToTangentPlane = model.getSkyToTangentPlane(ccdImage)
232 mapping = model.getMapping(ccdImage)
234 bbox = ccdImage.getDetector().getBBox()
235 num = 200
236 xx = np.linspace(bbox.getMinX(), bbox.getMaxX(), num)
237 yy = np.linspace(bbox.getMinY(), bbox.getMaxY(), num)
238 points = [lsst.geom.Point2D(*xy) for xy in itertools.product(xx, yy)]
240 expects = []
241 forwards = []
242 inverses = []
243 spherePoints = skyWcs.pixelToSky(points)
244 inverses = skyWcs.skyToPixel(skyWcs.pixelToSky(points))
245 for point, spherePoint in zip(points, spherePoints):
246 # TODO: Fix these "Point"s once DM-4044 is done.
248 # jointcal's pixel->tangent-plane mapping
249 star = lsst.jointcal.star.BaseStar(point.getX(), point.getY(), 0, 0)
250 tpExpect = mapping.transformPosAndErrors(star)
251 expects.append(lsst.geom.Point2D(tpExpect.x, tpExpect.y))
253 # skywcs takes pixel->sky, and we then have to go sky->tangent-plane
254 onSky = lsst.jointcal.star.BaseStar(spherePoint.getLongitude().asDegrees(),
255 spherePoint.getLatitude().asDegrees(), 0, 0)
256 result = skyToTangentPlane.apply(onSky)
257 forwards.append(lsst.geom.Point2D(result.x, result.y))
259 self.assertPairListsAlmostEqual(forwards, expects)
260 # NOTE: assertPairListsAlmostEqual() compares absolute, not relative,
261 # values so the points along the ccd edge may exceed maxDiff while still
262 # being "close enough": set `inverseMaxDiff` accordingly.
263 self.assertPairListsAlmostEqual(inverses, points, maxDiff=inverseMaxDiff)
266class SimpleAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase):
267 """Test the `SimpleAstrometryModel`, with one mapping per ccd per visit."""
268 def setUp(self):
269 super().setUp()
270 self.order1 = 3
271 self.inverseMaxDiff1 = 2e-4
272 self.model1 = astrometryModels.SimpleAstrometryModel(self.associations.getCcdImageList(),
273 self.projectionHandler,
274 True,
275 order=self.order1)
277 self.order2 = 5
278 # NOTE: because assertPairListsAlmostEqual tests an absolute
279 # difference, we need this to be relatively high to avoid spurious
280 # incorrect values.
281 # Alternately, further increasing the order of the inverse polynomial
282 # in astrometryTransform.toAstMap() can improve the quality of the
283 # SkyWcs inverse, but that may not be wise for the more general use
284 # case due to the inverse then having too many wiggles.
285 self.inverseMaxDiff2 = 2e-2
286 self.model2 = astrometryModels.SimpleAstrometryModel(self.associations.getCcdImageList(),
287 self.projectionHandler,
288 False,
289 order=self.order2)
290 self._prepModels()
292 def _testGetNpar(self, model, order):
293 for ccdImage in self.associations.getCcdImageList():
294 result = model.getNpar(ccdImage)
295 self.assertEqual(result, getNParametersPolynomial(order))
297 def testGetNpar1(self):
298 self._testGetNpar(self.model1, self.order1)
300 def testGetNpar2(self):
301 self._testGetNpar(self.model2, self.order2)
303 def _testGetTotalParameters(self, model, order):
304 result = model.getTotalParameters()
305 expect = getNParametersPolynomial(order)*len(self.associations.getCcdImageList())
306 self.assertEqual(result, expect)
308 def testGetTotalParametersModel1(self):
309 self._testGetTotalParameters(self.model1, self.order1)
311 def testGetTotalParametersModel2(self):
312 self._testGetTotalParameters(self.model2, self.order2)
315class ConstrainedAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase):
316 """Test the `ConstrainedAstrometryModel`, with one mapping per ccd and one
317 mapping per visit.
318 """
319 def setUp(self):
320 super().setUp()
321 self.visitOrder1 = 3
322 self.chipOrder1 = 1
323 self.inverseMaxDiff1 = 1e-5
324 self.model1 = astrometryModels.ConstrainedAstrometryModel(self.associations.getCcdImageList(),
325 self.projectionHandler,
326 chipOrder=self.chipOrder1,
327 visitOrder=self.visitOrder1)
329 self.visitOrder2 = 5
330 self.chipOrder2 = 2
331 self.inverseMaxDiff2 = 8e-5
332 self.model2 = astrometryModels.ConstrainedAstrometryModel(self.associations.getCcdImageList(),
333 self.projectionHandler,
334 chipOrder=self.chipOrder2,
335 visitOrder=self.visitOrder2)
336 self._prepModels()
338 # 22 is closest to the center of the focal plane in this data, so it is not fit.
339 self.fixedCcd = 22
341 def _polyParams(self, chipOrder, visitOrder):
342 """Number of parameters per polynomial is (d+1)(d+2)/2, summed over
343 polynomials, times 2 polynomials per dimension.
344 The chip transform is fixed for one chip, so only visitOrder matters
345 if chipOrder is None.
346 """
347 params = getNParametersPolynomial(visitOrder)
348 if chipOrder is not None:
349 params += getNParametersPolynomial(chipOrder)
350 return params
352 def _testGetNpar(self, model, chipOrder, visitOrder):
353 def checkParams(ccdImage, model, chipOrder, visitOrder):
354 result = model.getNpar(ccdImage)
355 failMsg = "ccdImage: %s, with chipOrder %s and visitOrder %s"%(ccdImage.getName(),
356 chipOrder,
357 visitOrder)
358 self.assertEqual(result, self._polyParams(chipOrder, visitOrder), msg=failMsg)
360 for ccdImage in self.associations.getCcdImageList():
361 realChipOrder = None if ccdImage.getCcdId() == self.fixedCcd else chipOrder
362 checkParams(ccdImage, model, realChipOrder, visitOrder)
364 def testGetNpar1(self):
365 self._testGetNpar(self.model1, self.chipOrder1, self.visitOrder1)
367 def testGetNpar2(self):
368 self._testGetNpar(self.model2, self.chipOrder2, self.visitOrder2)
370 def _testGetTotalParameters(self, model, chipOrder, visitOrder):
371 result = model.getTotalParameters()
372 # one sensor is held fixed, hence len(ccds)-1
373 expect = getNParametersPolynomial(chipOrder)*(len(self.detectors) - 1) + \
374 getNParametersPolynomial(visitOrder)*len(self.visits)
375 self.assertEqual(result, expect)
377 def testGetTotalParametersModel1(self):
378 self._testGetTotalParameters(self.model1, self.chipOrder1, self.visitOrder1)
380 def testGetTotalParametersModel2(self):
381 self._testGetTotalParameters(self.model2, self.chipOrder2, self.visitOrder2)
383 def checkGetChipTransform(self, model):
384 # Check valid ccds
385 for ccd in self.detectors:
386 try:
387 model.getChipTransform(ccd)
388 except lsst.pex.exceptions.wrappers.InvalidParameterError:
389 self.fail("model: {} raised on ccd: {}, but should not have.".format(model, ccd))
391 # Check an invalid ccd
392 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm:
393 model.getChipTransform(self.badCcd)
394 errMsg = f"No such chipId: {self.badCcd} among [{', '.join(str(d) for d in self.detectors)}]"
395 self.assertIn(errMsg, str(cm.exception))
397 def testGetChipTransform(self):
398 """getChipTransform should get each known transform, and raise with an
399 appropriate message otherwise.
400 """
401 self.checkGetChipTransform(self.model1)
402 self.checkGetChipTransform(self.model2)
404 def checkGetVisitTransform(self, model):
405 # Check valid visits
406 for visit in self.visits:
407 try:
408 model.getVisitTransform(visit)
409 except lsst.pex.exceptions.wrappers.InvalidParameterError:
410 self.fail("model: {} raised on visit: {}, but should not have.".format(model, visit))
412 # Check an invalid visit
413 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm:
414 model.getVisitTransform(self.badVisit)
415 errMsg = "No such visitId: {} among [{}]".format(self.badVisit,
416 ", ".join(str(v) for v in self.visits))
417 self.assertIn(errMsg, str(cm.exception))
419 def testGetVisitTransform(self):
420 """getVisitTransform should get each known transform, and raise with an
421 appropriate message otherwise.
422 """
423 self.checkGetVisitTransform(self.model1)
424 self.checkGetVisitTransform(self.model2)
426 def testValidate(self):
427 """Test that invalid models fail validate(), and that valid ones pass.
428 """
429 # We need at least 0 degrees of freedom (data - parameters) for the model to be valid.
430 # Note: model1 has 70 total parameters (2 visits*20 params + (6-1) sensors*5 params)
431 self.assertTrue(self.model1.validate(self.ccdImageList, 0))
432 self.assertFalse(self.model1.validate(self.ccdImageList, -1))
435class MemoryTester(lsst.utils.tests.MemoryTestCase):
436 pass
439def setup_module(module):
440 lsst.utils.tests.init()
443if __name__ == "__main__": 443 ↛ 444line 443 didn't jump to line 444, because the condition on line 443 was never true
444 lsst.utils.tests.init()
445 unittest.main()