Coverage for tests/test_jointcal.py : 19%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of jointcal.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import itertools
23import os.path
24import unittest
25from unittest import mock
27import numpy as np
28import pyarrow.parquet
30import lsst.log
31import lsst.utils
33import lsst.afw.table
34import lsst.daf.persistence
35from lsst.daf.base import DateTime
36import lsst.geom
37from lsst.meas.algorithms import getRefFluxField, LoadIndexedReferenceObjectsTask, DatasetConfig
38import lsst.obs.base
39import lsst.pipe.base
40import lsst.jointcal
41from lsst.jointcal import MinimizeResult
42import lsst.jointcal.chi2
43import lsst.jointcal.testUtils
46# for MemoryTestCase
47def setup_module(module):
48 lsst.utils.tests.init()
51def make_fake_refcat(center, flux, filterName):
52 """Make a fake reference catalog."""
53 schema = LoadIndexedReferenceObjectsTask.makeMinimalSchema([filterName],
54 addProperMotion=True)
55 catalog = lsst.afw.table.SimpleCatalog(schema)
56 record = catalog.addNew()
57 record.setCoord(center)
58 record[filterName + '_flux'] = flux
59 record[filterName + '_fluxErr'] = flux*0.1
60 record['pm_ra'] = lsst.geom.Angle(1)
61 record['pm_dec'] = lsst.geom.Angle(2)
62 record['epoch'] = 65432.1
63 return catalog
66def make_fake_wcs():
67 """Return two simple SkyWcs objects, with slightly different sky positions.
69 Use the same pixel origins as the cfht_minimal data, but put the sky origin
70 at RA=0
71 """
72 crpix = lsst.geom.Point2D(931.517869, 2438.572109)
73 cd = np.array([[5.19513851e-05, -2.81124812e-07],
74 [-3.25186974e-07, -5.19112119e-05]])
75 crval1 = lsst.geom.SpherePoint(0.01, -0.01, lsst.geom.degrees)
76 crval2 = lsst.geom.SpherePoint(-0.01, 0.01, lsst.geom.degrees)
77 wcs1 = lsst.afw.geom.makeSkyWcs(crpix, crval1, cd)
78 wcs2 = lsst.afw.geom.makeSkyWcs(crpix, crval2, cd)
79 return wcs1, wcs2
82class TestJointcalVisitCatalog(lsst.utils.tests.TestCase):
83 """Tests of jointcal's sourceTable_visit parquet ->single detector afw
84 table catalog unrolling.
85 """
86 def setUp(self):
87 filename = os.path.join(os.path.dirname(__file__),
88 "data/subselected-sourceTable-0034690.parq")
89 file = pyarrow.parquet.ParquetFile(filename)
90 self.data = file.read(use_pandas_metadata=True).to_pandas()
91 config = lsst.jointcal.jointcal.JointcalConfig()
92 # TODO DM-29008: Remove this (to use the new gen3 default) before gen2 removal.
93 config.sourceFluxType = "ApFlux_12_0"
94 # we don't actually need either fitter to run for these tests
95 config.doAstrometry = False
96 config.doPhotometry = False
97 self.jointcal = lsst.jointcal.JointcalTask(config=config)
99 def test_make_catalog_schema(self):
100 """Check that the slot fields required by CcdImage::loadCatalog are in
101 the schema returned by _make_catalog_schema().
102 """
103 table = self.jointcal._make_schema_table()
104 self.assertTrue(table.getCentroidSlot().getMeasKey().isValid())
105 self.assertTrue(table.getCentroidSlot().getErrKey().isValid())
106 self.assertTrue(table.getShapeSlot().getMeasKey().isValid())
108 def test_extract_detector_catalog_from_visit_catalog(self):
109 """Spot check a value output by the script that generated the test
110 parquet catalog and check that the size of the returned catalog
111 is correct for each detectior.
112 """
113 detectorId = 56
114 table = self.jointcal._make_schema_table()
115 catalog = self.jointcal._extract_detector_catalog_from_visit_catalog(table, self.data, detectorId)
117 # The test catalog has a number of elements for each detector equal to the detector id.
118 self.assertEqual(len(catalog), detectorId)
119 self.assertIn(29798723617816629, catalog['id'])
120 matched = catalog[29798723617816629 == catalog['id']]
121 self.assertEqual(1715.734359473175, matched['slot_Centroid_x'])
122 self.assertEqual(89.06076509964362, matched['slot_Centroid_y'])
125class JointcalTestBase:
126 def setUp(self):
127 # Ensure that the filter list is reset for each test so that we avoid
128 # confusion or contamination each time we create a cfht camera below.
129 lsst.obs.base.FilterDefinitionCollection.reset()
131 struct = lsst.jointcal.testUtils.createTwoFakeCcdImages(100, 100)
132 self.ccdImageList = struct.ccdImageList
133 # so that countStars() returns nonzero results
134 for ccdImage in self.ccdImageList:
135 ccdImage.resetCatalogForFit()
137 self.goodChi2 = lsst.jointcal.chi2.Chi2Statistic()
138 # chi2/ndof == 2.0 should be non-bad
139 self.goodChi2.chi2 = 200.0
140 self.goodChi2.ndof = 100
142 self.badChi2 = lsst.jointcal.chi2.Chi2Statistic()
143 self.badChi2.chi2 = 600.0
144 self.badChi2.ndof = 100
146 self.nanChi2 = lsst.jointcal.chi2.Chi2Statistic()
147 self.nanChi2.chi2 = np.nan
148 self.nanChi2.ndof = 100
150 self.maxSteps = 20
151 self.name = "testing"
152 self.dataName = "fake"
153 self.whatToFit = "" # unneeded, since we're mocking the fitter
155 # Mock a Butler so the refObjLoaders have something to call `get()` on.
156 self.butler = unittest.mock.Mock(spec=lsst.daf.persistence.Butler)
157 self.butler.get.return_value.indexer = DatasetConfig().indexer
159 # Mock the association manager and give it access to the ccd list above.
160 self.associations = mock.Mock(spec=lsst.jointcal.Associations)
161 self.associations.getCcdImageList.return_value = self.ccdImageList
163 # a default config to be modified by individual tests
164 self.config = lsst.jointcal.jointcal.JointcalConfig()
167class TestJointcalIterateFit(JointcalTestBase, lsst.utils.tests.TestCase):
168 def setUp(self):
169 super().setUp()
170 # Mock the fitter and model, so we can force particular
171 # return values/exceptions. Default to "good" return values.
172 self.fitter = mock.Mock(spec=lsst.jointcal.PhotometryFit)
173 self.fitter.computeChi2.return_value = self.goodChi2
174 self.fitter.minimize.return_value = MinimizeResult.Converged
175 self.model = mock.Mock(spec=lsst.jointcal.SimpleFluxModel)
177 self.jointcal = lsst.jointcal.JointcalTask(config=self.config, butler=self.butler)
179 def test_iterateFit_success(self):
180 chi2 = self.jointcal._iterate_fit(self.associations, self.fitter,
181 self.maxSteps, self.name, self.whatToFit)
182 self.assertEqual(chi2, self.goodChi2)
183 # Once for the for loop, the second time for the rank update.
184 self.assertEqual(self.fitter.minimize.call_count, 2)
186 def test_iterateFit_writeChi2Outer(self):
187 chi2 = self.jointcal._iterate_fit(self.associations, self.fitter,
188 self.maxSteps, self.name, self.whatToFit,
189 dataName=self.dataName)
190 self.assertEqual(chi2, self.goodChi2)
191 # Once for the for loop, the second time for the rank update.
192 self.assertEqual(self.fitter.minimize.call_count, 2)
193 # Default config should not call saveChi2Contributions
194 self.fitter.saveChi2Contributions.assert_not_called()
196 def test_iterateFit_failed(self):
197 self.fitter.minimize.return_value = MinimizeResult.Failed
199 with self.assertRaises(RuntimeError):
200 self.jointcal._iterate_fit(self.associations, self.fitter,
201 self.maxSteps, self.name, self.whatToFit)
202 self.assertEqual(self.fitter.minimize.call_count, 1)
204 def test_iterateFit_badFinalChi2(self):
205 log = mock.Mock(spec=lsst.log.Log)
206 self.jointcal.log = log
207 self.fitter.computeChi2.return_value = self.badChi2
209 chi2 = self.jointcal._iterate_fit(self.associations, self.fitter,
210 self.maxSteps, self.name, self.whatToFit)
211 self.assertEqual(chi2, self.badChi2)
212 log.info.assert_called_with("%s %s", "Fit completed", self.badChi2)
213 log.error.assert_called_with("Potentially bad fit: High chi-squared/ndof.")
215 def test_iterateFit_exceedMaxSteps(self):
216 log = mock.Mock(spec=lsst.log.Log)
217 self.jointcal.log = log
218 self.fitter.minimize.return_value = MinimizeResult.Chi2Increased
219 maxSteps = 3
221 chi2 = self.jointcal._iterate_fit(self.associations, self.fitter,
222 maxSteps, self.name, self.whatToFit)
223 self.assertEqual(chi2, self.goodChi2)
224 self.assertEqual(self.fitter.minimize.call_count, maxSteps)
225 log.error.assert_called_with("testing failed to converge after %s steps" % maxSteps)
227 def test_moderate_chi2_increase(self):
228 """DM-25159: warn, but don't fail, on moderate chi2 increases between
229 steps.
230 """
231 chi2_1 = lsst.jointcal.chi2.Chi2Statistic()
232 chi2_1.chi2 = 100.0
233 chi2_1.ndof = 100
234 chi2_2 = lsst.jointcal.chi2.Chi2Statistic()
235 chi2_2.chi2 = 300.0
236 chi2_2.ndof = 100
238 chi2s = [self.goodChi2, chi2_1, chi2_2, self.goodChi2, self.goodChi2]
239 self.fitter.computeChi2.side_effect = chi2s
240 self.fitter.minimize.side_effect = [MinimizeResult.Chi2Increased,
241 MinimizeResult.Chi2Increased,
242 MinimizeResult.Chi2Increased,
243 MinimizeResult.Converged,
244 MinimizeResult.Converged]
245 with lsst.log.UsePythonLogging(): # so that assertLogs works with lsst.log
246 with self.assertLogs("jointcal", level="WARN") as logger:
247 self.jointcal._iterate_fit(self.associations, self.fitter,
248 self.maxSteps, self.name, self.whatToFit)
249 msg = "WARNING:jointcal:Significant chi2 increase by a factor of 300 / 100 = 3"
250 self.assertIn(msg, logger.output)
252 def test_large_chi2_increase_fails(self):
253 """DM-25159: fail on large chi2 increases between steps."""
254 chi2_1 = lsst.jointcal.chi2.Chi2Statistic()
255 chi2_1.chi2 = 1e11
256 chi2_1.ndof = 100
257 chi2_2 = lsst.jointcal.chi2.Chi2Statistic()
258 chi2_2.chi2 = 1.123456e13 # to check floating point formatting
259 chi2_2.ndof = 100
261 chi2s = [chi2_1, chi2_1, chi2_2]
262 self.fitter.computeChi2.side_effect = chi2s
263 self.fitter.minimize.return_value = MinimizeResult.Chi2Increased
264 with lsst.log.UsePythonLogging(): # so that assertLogs works with lsst.log
265 with self.assertLogs("jointcal", level="WARN") as logger:
266 with(self.assertRaisesRegex(RuntimeError, "Large chi2 increase")):
267 self.jointcal._iterate_fit(self.associations, self.fitter,
268 self.maxSteps, self.name, self.whatToFit)
269 msg = "WARNING:jointcal:Significant chi2 increase by a factor of 1.123e+13 / 1e+11 = 112.3"
270 self.assertIn(msg, logger.output)
272 def test_invalid_model(self):
273 self.model.validate.return_value = False
274 with(self.assertRaises(ValueError)):
275 self.jointcal._logChi2AndValidate(self.associations, self.fitter, self.model, "invalid")
277 def test_nonfinite_chi2(self):
278 self.fitter.computeChi2.return_value = self.nanChi2
279 with(self.assertRaises(FloatingPointError)):
280 self.jointcal._logChi2AndValidate(self.associations, self.fitter, self.model, "nonfinite")
282 def test_writeChi2(self):
283 filename = "somefile"
284 self.jointcal._logChi2AndValidate(self.associations, self.fitter, self.model, "writeCh2",
285 writeChi2Name=filename)
286 # logChi2AndValidate prepends `config.debugOutputPath` to the filename
287 self.fitter.saveChi2Contributions.assert_called_with("./"+filename+"{type}")
290class TestJointcalLoadRefCat(JointcalTestBase, lsst.utils.tests.TestCase):
292 def _make_fake_refcat(self):
293 """Mock a fake reference catalog and the bits necessary to use it."""
294 center = lsst.geom.SpherePoint(30, -30, lsst.geom.degrees)
295 flux = 10
296 radius = 1 * lsst.geom.degrees
297 filter = lsst.afw.image.FilterLabel(band='fake', physical="fake-filter")
299 fakeRefCat = make_fake_refcat(center, flux, filter.bandLabel)
300 fluxField = getRefFluxField(fakeRefCat.schema, filter.bandLabel)
301 returnStruct = lsst.pipe.base.Struct(refCat=fakeRefCat, fluxField=fluxField)
302 refObjLoader = mock.Mock(spec=LoadIndexedReferenceObjectsTask)
303 refObjLoader.loadSkyCircle.return_value = returnStruct
305 return refObjLoader, center, radius, filter, fakeRefCat
307 def test_load_reference_catalog(self):
308 refObjLoader, center, radius, filterLabel, fakeRefCat = self._make_fake_refcat()
310 config = lsst.jointcal.jointcal.JointcalConfig()
311 config.astrometryReferenceErr = 0.1 # our test refcats don't have coord errors
312 jointcal = lsst.jointcal.JointcalTask(config=config, butler=self.butler)
314 # NOTE: we cannot test application of proper motion here, because we
315 # mock the refObjLoader, so the real loader is never called.
316 refCat, fluxField = jointcal._load_reference_catalog(refObjLoader,
317 jointcal.astrometryReferenceSelector,
318 center,
319 radius,
320 filterLabel)
321 # operator== isn't implemented for Catalogs, so we have to check like
322 # this, in case the records are copied during load.
323 self.assertEqual(len(refCat), len(fakeRefCat))
324 for r1, r2 in zip(refCat, fakeRefCat):
325 self.assertEqual(r1, r2)
327 def test_load_reference_catalog_subselect(self):
328 """Test that we can select out the one source in the fake refcat
329 with a ridiculous S/N cut.
330 """
331 refObjLoader, center, radius, filterLabel, fakeRefCat = self._make_fake_refcat()
333 config = lsst.jointcal.jointcal.JointcalConfig()
334 config.astrometryReferenceErr = 0.1 # our test refcats don't have coord errors
335 config.astrometryReferenceSelector.doSignalToNoise = True
336 config.astrometryReferenceSelector.signalToNoise.minimum = 1e10
337 config.astrometryReferenceSelector.signalToNoise.fluxField = "fake_flux"
338 config.astrometryReferenceSelector.signalToNoise.errField = "fake_fluxErr"
339 jointcal = lsst.jointcal.JointcalTask(config=config, butler=self.butler)
341 refCat, fluxField = jointcal._load_reference_catalog(refObjLoader,
342 jointcal.astrometryReferenceSelector,
343 center,
344 radius,
345 filterLabel)
346 self.assertEqual(len(refCat), 0)
349class TestJointcalFitModel(JointcalTestBase, lsst.utils.tests.TestCase):
350 def test_fit_photometry_writeChi2(self):
351 """Test that we are calling saveChi2 with appropriate file prefixes."""
352 self.config.photometryModel = "constrainedFlux"
353 self.config.writeChi2FilesOuterLoop = True
354 jointcal = lsst.jointcal.JointcalTask(config=self.config, butler=self.butler)
355 jointcal.focalPlaneBBox = lsst.geom.Box2D()
357 # Mock the fitter, so we can pretend it found a good fit
358 with mock.patch("lsst.jointcal.PhotometryFit", autospect=True) as fitPatch:
359 fitPatch.return_value.computeChi2.return_value = self.goodChi2
360 fitPatch.return_value.minimize.return_value = MinimizeResult.Converged
362 # config.debugOutputPath is prepended to the filenames that go into saveChi2Contributions
363 expected = ["./photometry_init-ModelVisit_chi2", "./photometry_init-Model_chi2",
364 "./photometry_init-Fluxes_chi2", "./photometry_init-ModelFluxes_chi2"]
365 expected = [mock.call(x+"-fake{type}") for x in expected]
366 jointcal._fit_photometry(self.associations, dataName=self.dataName)
367 fitPatch.return_value.saveChi2Contributions.assert_has_calls(expected)
369 def test_fit_astrometry_writeChi2(self):
370 """Test that we are calling saveChi2 with appropriate file prefixes."""
371 self.config.astrometryModel = "constrained"
372 self.config.writeChi2FilesOuterLoop = True
373 jointcal = lsst.jointcal.JointcalTask(config=self.config, butler=self.butler)
374 jointcal.focalPlaneBBox = lsst.geom.Box2D()
376 # Mock the fitter, so we can pretend it found a good fit
377 fitPatch = mock.patch("lsst.jointcal.AstrometryFit")
378 # Mock the projection handler so we don't segfault due to not-fully initialized ccdImages
379 projectorPatch = mock.patch("lsst.jointcal.OneTPPerVisitHandler")
380 with fitPatch as fit, projectorPatch as projector:
381 fit.return_value.computeChi2.return_value = self.goodChi2
382 fit.return_value.minimize.return_value = MinimizeResult.Converged
383 # return a real ProjectionHandler to keep ConstrainedAstrometryModel() happy
384 projector.return_value = lsst.jointcal.IdentityProjectionHandler()
386 # config.debugOutputPath is prepended to the filenames that go into saveChi2Contributions
387 expected = ["./astrometry_init-DistortionsVisit_chi2", "./astrometry_init-Distortions_chi2",
388 "./astrometry_init-Positions_chi2", "./astrometry_init-DistortionsPositions_chi2"]
389 expected = [mock.call(x+"-fake{type}") for x in expected]
390 jointcal._fit_astrometry(self.associations, dataName=self.dataName)
391 fit.return_value.saveChi2Contributions.assert_has_calls(expected)
394class TestComputeBoundingCircle(lsst.utils.tests.TestCase):
395 """Tests of Associations.computeBoundingCircle()"""
396 def _checkPointsInCircle(self, points, center, radius):
397 """Check that all points are within the (center, radius) circle.
399 The test is whether the max(points - center) separation is equal to
400 (or slightly less than) radius.
401 """
402 maxSeparation = 0*lsst.geom.degrees
403 for point in points:
404 maxSeparation = max(maxSeparation, center.separation(point))
405 self.assertAnglesAlmostEqual(maxSeparation, radius, maxDiff=3*lsst.geom.arcseconds)
406 self.assertLess(maxSeparation, radius)
408 def _testPoints(self, ccdImage1, ccdImage2, skyWcs1, skyWcs2, bbox):
409 """Fill an Associations object and test that it computes the correct
410 bounding circle for the input data.
412 Parameters
413 ----------
414 ccdImage1, ccdImage2 : `lsst.jointcal.CcdImage`
415 The CcdImages to add to the Associations object.
416 skyWcs1, skyWcs2 : `lsst.afw.geom.SkyWcs`
417 The WCS of each of the above images.
418 bbox : `lsst.geom.Box2D`
419 The ccd bounding box of both images.
420 """
421 lsst.log.setLevel('jointcal', lsst.log.DEBUG)
422 associations = lsst.jointcal.Associations()
423 associations.addCcdImage(ccdImage1)
424 associations.addCcdImage(ccdImage2)
425 associations.computeCommonTangentPoint()
427 circle = associations.computeBoundingCircle()
428 center = lsst.geom.SpherePoint(circle.getCenter())
429 radius = lsst.geom.Angle(circle.getOpeningAngle().asRadians(), lsst.geom.radians)
430 points = [lsst.geom.SpherePoint(skyWcs1.pixelToSky(lsst.geom.Point2D(x)))
431 for x in bbox.getCorners()]
432 points.extend([lsst.geom.SpherePoint(skyWcs2.pixelToSky(lsst.geom.Point2D(x)))
433 for x in bbox.getCorners()])
434 self._checkPointsInCircle(points, center, radius)
436 def testPoints(self):
437 """Test for points in an "easy" area, far from RA=0 or the poles."""
438 struct = lsst.jointcal.testUtils.createTwoFakeCcdImages()
439 self._testPoints(struct.ccdImageList[0], struct.ccdImageList[1],
440 struct.skyWcs[0], struct.skyWcs[1], struct.bbox)
442 def testPointsRA0(self):
443 """Test for CcdImages crossing RA=0; this demonstrates a fix for
444 the bug described in DM-19802.
445 """
446 wcs1, wcs2 = make_fake_wcs()
448 # Put the visit boresights at the WCS origin, for consistency
449 visitInfo1 = lsst.afw.image.VisitInfo(exposureId=30577512,
450 date=DateTime(65321.1),
451 boresightRaDec=wcs1.getSkyOrigin())
452 visitInfo2 = lsst.afw.image.VisitInfo(exposureId=30621144,
453 date=DateTime(65322.1),
454 boresightRaDec=wcs1.getSkyOrigin())
456 struct = lsst.jointcal.testUtils.createTwoFakeCcdImages(fakeWcses=[wcs1, wcs2],
457 fakeVisitInfos=[visitInfo1, visitInfo2])
458 self._testPoints(struct.ccdImageList[0], struct.ccdImageList[1],
459 struct.skyWcs[0], struct.skyWcs[1], struct.bbox)
462class TestJointcalComputePMDate(JointcalTestBase, lsst.utils.tests.TestCase):
463 """Tests of jointcal._compute_proper_motion_epoch()"""
464 def test_compute_proper_motion_epoch(self):
465 mjds = np.array((65432.1, 66666, 65555, 64322.2))
467 wcs1, wcs2 = make_fake_wcs()
468 visitInfo1 = lsst.afw.image.VisitInfo(exposureId=30577512,
469 date=DateTime(mjds[0]),
470 boresightRaDec=wcs1.getSkyOrigin())
471 visitInfo2 = lsst.afw.image.VisitInfo(exposureId=30621144,
472 date=DateTime(mjds[1]),
473 boresightRaDec=wcs2.getSkyOrigin())
474 visitInfo3 = lsst.afw.image.VisitInfo(exposureId=30577513,
475 date=DateTime(mjds[2]),
476 boresightRaDec=wcs1.getSkyOrigin())
477 visitInfo4 = lsst.afw.image.VisitInfo(exposureId=30621145,
478 date=DateTime(mjds[3]),
479 boresightRaDec=wcs2.getSkyOrigin())
481 struct1 = lsst.jointcal.testUtils.createTwoFakeCcdImages(fakeWcses=[wcs1, wcs2],
482 fakeVisitInfos=[visitInfo1, visitInfo2])
483 struct2 = lsst.jointcal.testUtils.createTwoFakeCcdImages(fakeWcses=[wcs1, wcs2],
484 fakeVisitInfos=[visitInfo3, visitInfo4])
485 ccdImageList = list(itertools.chain(struct1.ccdImageList, struct2.ccdImageList))
486 associations = lsst.jointcal.Associations()
487 for ccdImage in ccdImageList:
488 associations.addCcdImage(ccdImage)
489 associations.computeCommonTangentPoint()
491 jointcal = lsst.jointcal.JointcalTask(config=self.config, butler=self.butler)
492 result = jointcal._compute_proper_motion_epoch(ccdImageList)
493 self.assertEqual(result.mjd, mjds.mean())
496class MemoryTester(lsst.utils.tests.MemoryTestCase):
497 pass
500if __name__ == "__main__": 500 ↛ 501line 500 didn't jump to line 501, because the condition on line 500 was never true
501 lsst.utils.tests.init()
502 unittest.main()