Coverage for tests/test_astrometryModel.py: 25%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of jointcal.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""Tests of astrometryModels (simple, constrained).
24Includes tests of producing a Wcs from a model.
25"""
26import itertools
27import os
28import numpy as np
29import shutil
31import unittest
32import lsst.utils.tests
34import lsst.afw.cameraGeom
35import lsst.afw.geom
36import lsst.afw.table
37import lsst.afw.image
38import lsst.afw.image.utils
39import lsst.geom
40import lsst.log
42import lsst.jointcal
43from lsst.jointcal.jointcal import (make_schema_table, get_sourceTable_visit_columns,
44 extract_detector_catalog_from_visit_catalog)
45from lsst.jointcal import astrometryModels
46from jointcalTestBase import importRepository
49def getNParametersPolynomial(order):
50 """Number of parameters in an astrometry polynomial model is 2 * (d+1)(d+2)/2."""
51 return (order + 1)*(order + 2)
54class AstrometryModelTestBase:
55 """Test the jointcal AstrometryModel concrete classes, using CFHT data.
56 """
57 @classmethod
58 def setUpClass(cls):
59 try:
60 cls.dataDir = lsst.utils.getPackageDir('testdata_jointcal')
61 # NOTE: the below is to facilitate testing with hsc test data,
62 # using chips far apart on the focal plane. See the note in setup()
63 # below for details. Using it requires having recently-processed
64 # singleFrame output in a rerun directory in validation_data_hsc.
65 # cls.dataDir = lsst.utils.getPackageDir('validation_data_hsc')
66 except LookupError:
67 raise unittest.SkipTest("testdata_jointcal not setup")
69 refcatPath = os.path.join(cls.dataDir, "cfht")
70 refcats = {"gaia_dr2_20200414": os.path.join(refcatPath, "gaia_dr2_20200414.ecsv"),
71 "ps1_pv3_3pi_20170110": os.path.join(refcatPath, "ps1_pv3_3pi_20170110.ecsv"),
72 "sdss_dr9_fink_v5b": os.path.join(refcatPath, "sdss-dr9-fink-v5b.ecsv")}
73 # Share one repo, since none of these tests write anything.
74 cls.repopath = importRepository("lsst.obs.cfht.MegaPrime",
75 os.path.join(cls.dataDir, 'cfht/repo/'),
76 os.path.join(cls.dataDir, 'cfht/exports.yaml'),
77 refcats=refcats,
78 refcatPath=refcatPath)
80 @classmethod
81 def tearDownClass(cls):
82 shutil.rmtree(cls.repopath, ignore_errors=True)
84 def setUp(self):
85 np.random.seed(200)
87 # DEBUG messages can help track down failures.
88 logger = lsst.log.Log.getLogger('lsst.jointcal')
89 logger.setLevel(lsst.log.DEBUG)
91 # Append `msg` arguments to assert failures.
92 self.longMessage = True
93 # absolute tolerance on positional errors of 10 micro-arcsecond
94 self.atol = 10.0 / (60 * 60 * 1e6)
96 # Maximum difference (see assertPairsAlmostEqual) for round-trip
97 # testing of the inverse for models 1 (simpler) and 2 (more.
98 # Replace either one for models that don't have as accurate an inverse.
99 self.inverseMaxDiff1 = 1e-5
100 self.inverseMaxDiff2 = 1e-5
102 self.firstIndex = 0 # for assignIndices
103 matchCut = 2.0 # arcseconds
104 minMeasurements = 2 # accept all star pairs.
106 jointcalControl = lsst.jointcal.JointcalControl("flux")
107 self.associations = lsst.jointcal.Associations()
108 config = lsst.jointcal.JointcalConfig()
109 config.load(os.path.join(os.path.dirname(__file__), "config/config-gen3.py"))
110 sourceSelector = config.sourceSelector.target(config=config.sourceSelector['science'])
112 # Ensure that the filter list is reset for each test so that we avoid
113 # confusion or contamination each time we create a cfht camera below.
114 lsst.afw.image.utils.resetFilters()
116 # jointcal's cfht test data has 6 ccds and 2 visits.
117 self.visits = [849375, 850587]
118 self.detectors = [12, 13, 14, 21, 22, 23]
119 self.badVisit = -12345
120 self.badCcd = 888
122 butler = lsst.daf.butler.Butler(self.repopath, collections='singleFrame', instrument="MegaPrime")
124 self.catalogs = []
125 self.ccdImageList = []
126 table = make_schema_table()
127 inColumns = butler.get("sourceTable_visit", visit=self.visits[0])
128 columns, detColumn, ixxColumns = get_sourceTable_visit_columns(inColumns, config, sourceSelector)
129 catalogs = {v: sourceSelector.run(butler.get('sourceTable_visit',
130 visit=v,
131 parameters={'columns': columns})) for v in self.visits}
132 for (visit, detector) in itertools.product(self.visits, self.detectors):
133 goodSrc = extract_detector_catalog_from_visit_catalog(table,
134 catalogs[visit].sourceCat,
135 detector,
136 detColumn,
137 ixxColumns,
138 config.sourceFluxType,
139 logger)
140 dataId = {"detector": detector, "visit": visit}
141 visitInfo = butler.get('calexp.visitInfo', dataId=dataId)
142 detector = butler.get('calexp.detector', dataId=dataId)
143 ccdId = detector.getId()
144 wcs = butler.get('calexp.wcs', dataId=dataId)
145 bbox = butler.get('calexp.bbox', dataId=dataId)
146 filt = butler.get('calexp.filterLabel', dataId=dataId)
147 filterName = filt.physicalLabel
148 photoCalib = lsst.afw.image.PhotoCalib(100.0, 1.0)
150 self.catalogs.append(goodSrc)
151 self.associations.createCcdImage(goodSrc,
152 wcs,
153 visitInfo,
154 bbox,
155 filterName,
156 photoCalib,
157 detector,
158 visit,
159 ccdId,
160 jointcalControl)
162 # Have to set the common tangent point so projectionHandler can use skyToCTP.
163 self.associations.computeCommonTangentPoint()
165 self.projectionHandler = lsst.jointcal.OneTPPerVisitHandler(self.associations.getCcdImageList())
167 self.associations.associateCatalogs(matchCut)
168 self.associations.prepareFittedStars(minMeasurements)
169 self.associations.deprojectFittedStars()
171 def _prepModels(self):
172 """Call this after model1 and model2 are created, to call assignIndices,
173 and instantiate the fitters.
174 """
175 posError = 0.02 # in pixels
176 # have to call this once or offsetParams will fail because the transform indices aren't defined
177 self.model1.assignIndices("Distortions", self.firstIndex)
178 self.fitter1 = lsst.jointcal.AstrometryFit(self.associations, self.model1, posError)
180 # have to call this once or offsetParams will fail because the transform indices aren't defined
181 self.model2.assignIndices("Distortions", self.firstIndex)
182 self.fitter2 = lsst.jointcal.AstrometryFit(self.associations, self.model2, posError)
184 def testMakeSkyWcsModel1(self):
185 self.checkMakeSkyWcsModel(self.model1, self.fitter1, self.inverseMaxDiff1)
187 def testMakeSkyWcsModel2(self):
188 self.checkMakeSkyWcsModel(self.model2, self.fitter2, self.inverseMaxDiff2)
190 def checkMakeSkyWcsModel(self, model, fitter, inverseMaxDiff):
191 """Test producing a SkyWcs on a model for every cdImage,
192 both post-initialization and after one fitting step.
194 Parameters
195 ----------
196 model : `lsst.jointcal.AstrometryModel`
197 The model to test.
198 fitter : `lsst.jointcal.FitterBase`
199 The fitter to use to step the model to test with new (reasonable) parameters.
200 inverseMaxDiff : `float`
201 Required accuracy on inverse transform.
202 See `lsst.afw.geom.utils.assertPairsAlmostEqual`.
204 """
205 # first test on as-initialized models
206 for ccdImage in self.associations.getCcdImageList():
207 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff)
209 # now shift the models to non-default, but more reasonable, values by taking one fitting step.
210 fitter.minimize("DistortionsVisit")
211 fitter.minimize("Distortions")
212 for ccdImage in self.associations.getCcdImageList():
213 self.checkMakeSkyWcsOneCcdImage(model, ccdImage, inverseMaxDiff)
215 def checkMakeSkyWcsOneCcdImage(self, model, ccdImage, inverseMaxDiff):
216 """Test converting the model of one ccdImage to a SkyWcs by comparing
217 to the original transform at the tangent plane.
219 Parameters
220 ----------
221 model : `lsst.jointcal.AstrometryModel`
222 The model to test.
223 ccdImage : `lsst.jointcal.CcdImage`
224 The ccdImage to extract from the model and test.
225 inverseMaxDiff : `float`
226 Required accuracy on inverse transform.
227 See `lsst.afw.geom.utils.assertPairsAlmostEqual`.
228 """
229 skyWcs = model.makeSkyWcs(ccdImage)
230 skyToTangentPlane = model.getSkyToTangentPlane(ccdImage)
231 mapping = model.getMapping(ccdImage)
233 bbox = ccdImage.getDetector().getBBox()
234 num = 200
235 xx = np.linspace(bbox.getMinX(), bbox.getMaxX(), num)
236 yy = np.linspace(bbox.getMinY(), bbox.getMaxY(), num)
237 points = [lsst.geom.Point2D(*xy) for xy in itertools.product(xx, yy)]
239 expects = []
240 forwards = []
241 inverses = []
242 spherePoints = skyWcs.pixelToSky(points)
243 inverses = skyWcs.skyToPixel(skyWcs.pixelToSky(points))
244 for point, spherePoint in zip(points, spherePoints):
245 # TODO: Fix these "Point"s once DM-4044 is done.
247 # jointcal's pixel->tangent-plane mapping
248 star = lsst.jointcal.star.BaseStar(point.getX(), point.getY(), 0, 0)
249 tpExpect = mapping.transformPosAndErrors(star)
250 expects.append(lsst.geom.Point2D(tpExpect.x, tpExpect.y))
252 # skywcs takes pixel->sky, and we then have to go sky->tangent-plane
253 onSky = lsst.jointcal.star.BaseStar(spherePoint.getLongitude().asDegrees(),
254 spherePoint.getLatitude().asDegrees(), 0, 0)
255 result = skyToTangentPlane.apply(onSky)
256 forwards.append(lsst.geom.Point2D(result.x, result.y))
258 self.assertPairListsAlmostEqual(forwards, expects)
259 # NOTE: assertPairListsAlmostEqual() compares absolute, not relative,
260 # values so the points along the ccd edge may exceed maxDiff while still
261 # being "close enough": set `inverseMaxDiff` accordingly.
262 self.assertPairListsAlmostEqual(inverses, points, maxDiff=inverseMaxDiff)
265class SimpleAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase):
266 """Test the `SimpleAstrometryModel`, with one mapping per ccd per visit."""
267 def setUp(self):
268 super().setUp()
269 self.order1 = 3
270 self.inverseMaxDiff1 = 2e-4
271 self.model1 = astrometryModels.SimpleAstrometryModel(self.associations.getCcdImageList(),
272 self.projectionHandler,
273 True,
274 order=self.order1)
276 self.order2 = 5
277 # NOTE: because assertPairListsAlmostEqual tests an absolute
278 # difference, we need this to be relatively high to avoid spurious
279 # incorrect values.
280 # Alternately, further increasing the order of the inverse polynomial
281 # in astrometryTransform.toAstMap() can improve the quality of the
282 # SkyWcs inverse, but that may not be wise for the more general use
283 # case due to the inverse then having too many wiggles.
284 self.inverseMaxDiff2 = 2e-2
285 self.model2 = astrometryModels.SimpleAstrometryModel(self.associations.getCcdImageList(),
286 self.projectionHandler,
287 False,
288 order=self.order2)
289 self._prepModels()
291 def _testGetNpar(self, model, order):
292 for ccdImage in self.associations.getCcdImageList():
293 result = model.getNpar(ccdImage)
294 self.assertEqual(result, getNParametersPolynomial(order))
296 def testGetNpar1(self):
297 self._testGetNpar(self.model1, self.order1)
299 def testGetNpar2(self):
300 self._testGetNpar(self.model2, self.order2)
302 def _testGetTotalParameters(self, model, order):
303 result = model.getTotalParameters()
304 expect = getNParametersPolynomial(order)*len(self.associations.getCcdImageList())
305 self.assertEqual(result, expect)
307 def testGetTotalParametersModel1(self):
308 self._testGetTotalParameters(self.model1, self.order1)
310 def testGetTotalParametersModel2(self):
311 self._testGetTotalParameters(self.model2, self.order2)
314class ConstrainedAstrometryModelTestCase(AstrometryModelTestBase, lsst.utils.tests.TestCase):
315 """Test the `ConstrainedAstrometryModel`, with one mapping per ccd and one
316 mapping per visit.
317 """
318 def setUp(self):
319 super().setUp()
320 self.visitOrder1 = 3
321 self.chipOrder1 = 1
322 self.inverseMaxDiff1 = 1e-5
323 self.model1 = astrometryModels.ConstrainedAstrometryModel(self.associations.getCcdImageList(),
324 self.projectionHandler,
325 chipOrder=self.chipOrder1,
326 visitOrder=self.visitOrder1)
328 self.visitOrder2 = 5
329 self.chipOrder2 = 2
330 self.inverseMaxDiff2 = 8e-5
331 self.model2 = astrometryModels.ConstrainedAstrometryModel(self.associations.getCcdImageList(),
332 self.projectionHandler,
333 chipOrder=self.chipOrder2,
334 visitOrder=self.visitOrder2)
335 self._prepModels()
337 # 22 is closest to the center of the focal plane in this data, so it is not fit.
338 self.fixedCcd = 22
340 def _polyParams(self, chipOrder, visitOrder):
341 """Number of parameters per polynomial is (d+1)(d+2)/2, summed over
342 polynomials, times 2 polynomials per dimension.
343 The chip transform is fixed for one chip, so only visitOrder matters
344 if chipOrder is None.
345 """
346 params = getNParametersPolynomial(visitOrder)
347 if chipOrder is not None:
348 params += getNParametersPolynomial(chipOrder)
349 return params
351 def _testGetNpar(self, model, chipOrder, visitOrder):
352 def checkParams(ccdImage, model, chipOrder, visitOrder):
353 result = model.getNpar(ccdImage)
354 failMsg = "ccdImage: %s, with chipOrder %s and visitOrder %s"%(ccdImage.getName(),
355 chipOrder,
356 visitOrder)
357 self.assertEqual(result, self._polyParams(chipOrder, visitOrder), msg=failMsg)
359 for ccdImage in self.associations.getCcdImageList():
360 realChipOrder = None if ccdImage.getCcdId() == self.fixedCcd else chipOrder
361 checkParams(ccdImage, model, realChipOrder, visitOrder)
363 def testGetNpar1(self):
364 self._testGetNpar(self.model1, self.chipOrder1, self.visitOrder1)
366 def testGetNpar2(self):
367 self._testGetNpar(self.model2, self.chipOrder2, self.visitOrder2)
369 def _testGetTotalParameters(self, model, chipOrder, visitOrder):
370 result = model.getTotalParameters()
371 # one sensor is held fixed, hence len(ccds)-1
372 expect = getNParametersPolynomial(chipOrder)*(len(self.detectors) - 1) + \
373 getNParametersPolynomial(visitOrder)*len(self.visits)
374 self.assertEqual(result, expect)
376 def testGetTotalParametersModel1(self):
377 self._testGetTotalParameters(self.model1, self.chipOrder1, self.visitOrder1)
379 def testGetTotalParametersModel2(self):
380 self._testGetTotalParameters(self.model2, self.chipOrder2, self.visitOrder2)
382 def checkGetChipTransform(self, model):
383 # Check valid ccds
384 for ccd in self.detectors:
385 try:
386 model.getChipTransform(ccd)
387 except lsst.pex.exceptions.wrappers.InvalidParameterError:
388 self.fail("model: {} raised on ccd: {}, but should not have.".format(model, ccd))
390 # Check an invalid ccd
391 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm:
392 model.getChipTransform(self.badCcd)
393 errMsg = f"No such chipId: {self.badCcd} among [{', '.join(str(d) for d in self.detectors)}]"
394 self.assertIn(errMsg, str(cm.exception))
396 def testGetChipTransform(self):
397 """getChipTransform should get each known transform, and raise with an
398 appropriate message otherwise.
399 """
400 self.checkGetChipTransform(self.model1)
401 self.checkGetChipTransform(self.model2)
403 def checkGetVisitTransform(self, model):
404 # Check valid visits
405 for visit in self.visits:
406 try:
407 model.getVisitTransform(visit)
408 except lsst.pex.exceptions.wrappers.InvalidParameterError:
409 self.fail("model: {} raised on visit: {}, but should not have.".format(model, visit))
411 # Check an invalid visit
412 with self.assertRaises(lsst.pex.exceptions.wrappers.InvalidParameterError) as cm:
413 model.getVisitTransform(self.badVisit)
414 errMsg = "No such visitId: {} among [{}]".format(self.badVisit,
415 ", ".join(str(v) for v in self.visits))
416 self.assertIn(errMsg, str(cm.exception))
418 def testGetVisitTransform(self):
419 """getVisitTransform should get each known transform, and raise with an
420 appropriate message otherwise.
421 """
422 self.checkGetVisitTransform(self.model1)
423 self.checkGetVisitTransform(self.model2)
425 def testValidate(self):
426 """Test that invalid models fail validate(), and that valid ones pass.
427 """
428 # We need at least 0 degrees of freedom (data - parameters) for the model to be valid.
429 # Note: model1 has 70 total parameters (2 visits*20 params + (6-1) sensors*5 params)
430 self.assertTrue(self.model1.validate(self.ccdImageList, 0))
431 self.assertFalse(self.model1.validate(self.ccdImageList, -1))
434class MemoryTester(lsst.utils.tests.MemoryTestCase):
435 pass
438def setup_module(module):
439 lsst.utils.tests.init()
442if __name__ == "__main__": 442 ↛ 443line 442 didn't jump to line 443, because the condition on line 442 was never true
443 lsst.utils.tests.init()
444 unittest.main()