Coverage for tests/test_higher_moments.py: 24%
235 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 11:35 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 11:35 +0000
1# This file is part of meas_extensions_shapeHSM.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""Unit tests for higher order moments measurement.
24These double up as initial estimates of how accurate the measurement is with
25different configuration options. The various tolerance levels here are based
26on experimentation with the specific datasets used here.
27"""
29import unittest
31import galsim
32import lsst.afw.geom
33import lsst.meas.base.tests
34import lsst.meas.extensions.shapeHSM # noqa: F401
35import lsst.utils.tests as tests
36import numpy as np
37from lsst.meas.base import SingleFrameMeasurementConfig, SingleFrameMeasurementTask
38from lsst.pex.config import FieldValidationError
41class HigherMomentsBaseTestCase(tests.TestCase):
42 """Base test case to test higher order moments."""
44 def setUp(self):
45 """Create an exposure and run measurement on the source and the PSF"""
46 super().setUp()
48 # Initialize a config and activate the plugin
49 sfmConfig = SingleFrameMeasurementConfig()
50 sfmConfig.plugins.names |= [
51 "ext_shapeHSM_HsmSourceMoments",
52 "ext_shapeHSM_HsmSourceMomentsRound",
53 "ext_shapeHSM_HsmPsfMoments",
54 "ext_shapeHSM_HigherOrderMomentsSource",
55 "ext_shapeHSM_HigherOrderMomentsPSF",
56 ]
57 # The min and max order determine the schema and cannot be changed
58 # after the Task is created. So we set it generously here.
59 for plugin_name in (
60 "ext_shapeHSM_HigherOrderMomentsSource",
61 "ext_shapeHSM_HigherOrderMomentsPSF",
62 ):
63 sfmConfig.plugins[plugin_name].max_order = 7
64 sfmConfig.plugins[plugin_name].min_order = 0
66 # Create a minimal schema (columns)
67 self.schema = lsst.meas.base.tests.TestDataset.makeMinimalSchema()
69 # Create a task
70 sfmTask = SingleFrameMeasurementTask(config=sfmConfig, schema=self.schema)
72 dataset = self.create_dataset()
74 # Get the exposure and catalog.
75 exposure, catalog = dataset.realize(0.0, sfmTask.schema, randomSeed=0)
77 self.catalog = catalog
78 self.exposure = exposure
79 self.task = sfmTask
81 self.add_mask_bits()
83 @staticmethod
84 def add_mask_bits():
85 """Add mask bits to the exposure.
87 This must go along with the create_dataset method. This is a no-op for
88 the base class and subclasses must set mask bits depending on the test.
89 """
90 pass
92 @staticmethod
93 def create_dataset():
94 # Create a simple, fake dataset
95 bbox = lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.Extent2I(100, 100))
96 dataset = lsst.meas.base.tests.TestDataset(bbox)
97 # Create a point source with Gaussian PSF
98 dataset.addSource(100000.0, lsst.geom.Point2D(49.5, 49.5))
100 # Create a galaxy with Gaussian PSF
101 dataset.addSource(300000.0, lsst.geom.Point2D(76.3, 79.2), lsst.afw.geom.Quadrupole(2.0, 3.0, 0.5))
102 return dataset
104 def run_measurement(self, **kwargs):
105 """Run measurement on the source and the PSF"""
106 self.task.run(self.catalog, self.exposure, **kwargs)
108 def check_odd_moments(self, row, plugin_name, atol, orders=(3, 5)):
109 for n in orders:
110 for p in range(n + 1):
111 with self.subTest((p, n - p)):
112 self.assertFloatsAlmostEqual(row[f"{plugin_name}_{p}{n-p}"], 0.0, atol=atol)
114 def check_even_moments(self, row, plugin_name, atol):
115 M_source_40 = row[f"{plugin_name}_40"]
116 M_source_31 = row[f"{plugin_name}_31"]
117 M_source_22 = row[f"{plugin_name}_22"]
118 M_source_13 = row[f"{plugin_name}_13"]
119 M_source_04 = row[f"{plugin_name}_04"]
121 M_source_60 = row[f"{plugin_name}_60"]
122 M_source_51 = row[f"{plugin_name}_51"]
123 M_source_42 = row[f"{plugin_name}_42"]
124 M_source_33 = row[f"{plugin_name}_33"]
125 M_source_24 = row[f"{plugin_name}_24"]
126 M_source_15 = row[f"{plugin_name}_15"]
127 M_source_06 = row[f"{plugin_name}_06"]
129 self.assertFloatsAlmostEqual(M_source_40, 0.75, atol=atol)
130 self.assertFloatsAlmostEqual(M_source_31, 0.0, atol=atol)
131 self.assertFloatsAlmostEqual(M_source_22, 0.25, atol=atol)
132 self.assertFloatsAlmostEqual(M_source_13, 0.0, atol=atol)
133 self.assertFloatsAlmostEqual(M_source_04, 0.75, atol=atol)
135 self.assertFloatsAlmostEqual(M_source_60, 1.875, atol=atol)
136 self.assertFloatsAlmostEqual(M_source_51, 0.0, atol=atol)
137 self.assertFloatsAlmostEqual(M_source_42, 0.375, atol=atol)
138 self.assertFloatsAlmostEqual(M_source_33, 0.0, atol=atol)
139 self.assertFloatsAlmostEqual(M_source_24, 0.375, atol=atol)
140 self.assertFloatsAlmostEqual(M_source_15, 0.0, atol=atol)
141 self.assertFloatsAlmostEqual(M_source_06, 1.875, atol=atol)
143 def check(self, row, plugin_name, atol):
144 self.check_odd_moments(row, plugin_name, atol)
145 self.check_even_moments(row, plugin_name, atol)
147 @lsst.utils.tests.methodParameters(
148 plugin_name=(
149 "ext_shapeHSM_HigherOrderMomentsSource",
150 "ext_shapeHSM_HigherOrderMomentsPSF",
151 )
152 )
153 def test_validate_config(self, plugin_name):
154 """Test that the validation of the configs works as expected."""
155 config = self.task.config.plugins[plugin_name]
156 config.validate() # This should not raise any error.
158 # Test that the validation fails when the max_order is smaller than the
159 # min_order.
160 config.max_order = 3
161 config.min_order = 4
162 with self.assertRaises(FieldValidationError):
163 config.validate()
165 @lsst.utils.tests.methodParameters(
166 plugin_name=(
167 "ext_shapeHSM_HigherOrderMomentsSource",
168 "ext_shapeHSM_HigherOrderMomentsPSF",
169 )
170 )
171 def test_calculate_higher_order_moments(self, plugin_name):
172 """Test that the _calculate_higher_order_moments results in the same
173 outputs whether or not we take the linear algebra code path.
174 """
176 # We do not run any of the measurement plugins, but use a rough
177 # centroid and an arbitrary 2x2 matrix to test that the code paths
178 # result in consistent outputs.
180 for row in self.catalog:
181 bbox = row.getFootprint().getBBox()
182 center = bbox.getCenter()
184 # Asymmetric matrix is not realistic, but we don't expect it to
185 # break the consistency. It just needs to have determinant > 0.
186 # This can be considered as a stress test for any small asymmetry
187 # that may arise because of rounding errors in the off-diagonal
188 # terms.
189 M = np.array([[2.0, 1.0], [0.5, 3.0]])
190 plugin = self.task.plugins[plugin_name]
191 image = self.exposure.image[bbox]
193 hm1 = plugin._calculate_higher_order_moments(image, center, M, use_linear_algebra=False)
194 hm2 = plugin._calculate_higher_order_moments(image, center, M, use_linear_algebra=True)
195 for key in hm1:
196 with self.subTest():
197 self.assertFloatsAlmostEqual(hm1[key], hm2[key], atol=1e-14)
200class HigherOrderMomentsTestCase(HigherMomentsBaseTestCase):
201 @lsst.utils.tests.methodParameters(
202 plugin_name=(
203 "ext_shapeHSM_HigherOrderMomentsSource",
204 "ext_shapeHSM_HigherOrderMomentsPSF",
205 )
206 )
207 def test_hsm_source_moments(self, plugin_name):
208 """Test that we can instantiate and play with a measureShape"""
210 self.run_measurement()
212 atol = 8e-6
213 for row in self.catalog:
214 self.check(row, plugin_name, atol=atol)
216 @lsst.utils.tests.methodParameters(useSourceCentroidOffset=(False, True))
217 def test_hsm_psf_lower_moments(self, useSourceCentroidOffset):
218 """Test that we can instantiate and play with a measureShape"""
219 plugin_name = "ext_shapeHSM_HigherOrderMomentsPSF"
220 self.task.config.plugins[
221 "ext_shapeHSM_HsmPsfMoments"
222 ].useSourceCentroidOffset = useSourceCentroidOffset
223 self.task.config.plugins[
224 "ext_shapeHSM_HigherOrderMomentsPSF"
225 ].useSourceCentroidOffset = useSourceCentroidOffset
227 self.run_measurement()
229 # Results are accurate for either values of useSourceCentroidOffset
230 # when looking at lower order moments.
231 atol = 2e-8
233 for i, row in enumerate(self.catalog):
234 with self.subTest(i=i):
235 self.assertFloatsAlmostEqual(row[f"{plugin_name}_00"], 1.0, atol=atol)
237 self.assertFloatsAlmostEqual(row[f"{plugin_name}_01"], 0.0, atol=atol)
238 self.assertFloatsAlmostEqual(row[f"{plugin_name}_10"], 0.0, atol=atol)
240 self.assertFloatsAlmostEqual(row[f"{plugin_name}_20"], 0.5, atol=atol)
241 self.assertFloatsAlmostEqual(row[f"{plugin_name}_11"], 0.0, atol=atol)
242 self.assertFloatsAlmostEqual(row[f"{plugin_name}_02"], 0.5, atol=atol)
244 @lsst.utils.tests.methodParameters(useSourceCentroidOffset=(False, True))
245 def test_hsm_psf_higher_moments(self, useSourceCentroidOffset):
246 """Test that we can instantiate and play with a measureShape"""
248 self.task.config.plugins[
249 "ext_shapeHSM_HsmPsfMoments"
250 ].useSourceCentroidOffset = useSourceCentroidOffset
251 self.task.config.plugins[
252 "ext_shapeHSM_HigherOrderMomentsPSF"
253 ].useSourceCentroidOffset = useSourceCentroidOffset
255 self.run_measurement()
257 # useSourceCentroidOffset = False results in more accurate results.
258 # Adjust the absolute tolerance accordingly.
259 atol = 7e-3 if useSourceCentroidOffset else 8e-6
261 for i, row in enumerate(self.catalog):
262 with self.subTest(i=i):
263 self.check(row, "ext_shapeHSM_HigherOrderMomentsPSF", atol=atol)
265 @lsst.utils.tests.methodParameters(
266 target_plugin_name=(
267 "base_SdssShape",
268 "ext_shapeHSM_HsmSourceMomentsRound",
269 "truth",
270 )
271 )
272 def test_source_consistent_weight(self, target_plugin_name):
273 """Test that when we get expected results when use a different set of
274 consistent weights to measure the higher order moments of sources.
275 """
276 # Pause the execution of the measurement task before the higher order
277 # moments plugins.
279 pause_order = self.task.plugins["ext_shapeHSM_HigherOrderMomentsSource"].getExecutionOrder()
280 self.run_measurement(endOrder=pause_order)
282 for suffix in (
283 "x",
284 "y",
285 "xx",
286 "yy",
287 "xy",
288 ):
289 self.catalog[f"ext_shapeHSM_HsmSourceMoments_{suffix}"] = self.catalog[
290 f"{target_plugin_name}_{suffix}"
291 ]
293 # Resume the execution of the measurement task.
294 self.run_measurement(beginOrder=pause_order)
296 # ext_shapeHSM_HsmSourceMomentsRound appears to have lower accuracy.
297 # Adjust the absolute tolerance accordingly.
298 atol = 1.2e-7 if target_plugin_name == "ext_shapeHSM_HsmSourceMomentsRound" else 6e-4
299 plugin_name = "ext_shapeHSM_HigherOrderMomentsSource"
301 for i, row in enumerate(self.catalog):
302 with self.subTest((plugin_name, i)):
303 self.check(row, plugin_name, atol=atol)
304 # The round moments are only accurate for the round sources,
305 # which is the first one in the catalog.
306 if target_plugin_name == "ext_shapeHSM_HsmSourceMomentsRound":
307 break
309 @lsst.utils.tests.methodParametersProduct(
310 target_plugin_name=(
311 "base_SdssShape_psf",
312 "truth",
313 ),
314 useSourceCentroidOffset=(False, True),
315 )
316 def test_psf_consistent_weight(self, target_plugin_name, useSourceCentroidOffset):
317 """Test that when we get expected results when use a different set of
318 consistent weights to measure the higher order moments of PSFs.
319 """
320 self.task.config.plugins[
321 "ext_shapeHSM_HigherOrderMomentsPSF"
322 ].useSourceCentroidOffset = useSourceCentroidOffset
324 # Pause the execution of the measurement task before the higher order
325 # moments plugins.
326 pause_order = self.task.plugins["ext_shapeHSM_HigherOrderMomentsPSF"].getExecutionOrder()
327 self.run_measurement(endOrder=pause_order)
329 # Create a dictionary of PSF moments corresponding to the truth.
330 # These are hardcoded in dataset.realize.
331 truth_psf = {"xx": 4.0, "yy": 4.0, "xy": 0.0}
333 for suffix in (
334 "xx",
335 "yy",
336 "xy",
337 ):
338 if target_plugin_name == "truth":
339 self.catalog[f"ext_shapeHSM_HsmPsfMoments_{suffix}"] = truth_psf[suffix]
340 else:
341 self.catalog[f"ext_shapeHSM_HsmPsfMoments_{suffix}"] = self.catalog[
342 f"{target_plugin_name}_{suffix}"
343 ]
345 # Resume the execution of the measurement task.
346 self.run_measurement(beginOrder=pause_order)
348 # useSourceCentroidOffset = False results in more accurate results.
349 # Adjust the absolute tolerance accordingly.
350 atol = 1.2e-2 if useSourceCentroidOffset else 8e-6
351 plugin_name = "ext_shapeHSM_HigherOrderMomentsPSF"
353 for i, row in enumerate(self.catalog):
354 with self.subTest((plugin_name, i)):
355 self.check(row, plugin_name, atol=atol)
358class HigherMomentTestCaseWithMask(HigherMomentsBaseTestCase):
359 """A test case to measure higher order moments in the presence of masks.
361 The tests serve checking the validity the algorithm on non-Gaussian
362 profiles.
363 """
365 def add_mask_bits(self):
366 # Docstring inherited.
367 for position in (
368 lsst.geom.Point2I(48, 47),
369 lsst.geom.Point2I(76, 79),
370 ):
371 self.exposure.mask[position] |= self.exposure.mask.getPlaneBitMask("BAD")
372 for position in (
373 lsst.geom.Point2D(49, 49),
374 lsst.geom.Point2D(76, 79),
375 ):
376 self.exposure.mask[position] |= self.exposure.mask.getPlaneBitMask("SAT")
378 def test_lower_order_moments(self, plugin_name="ext_shapeHSM_HigherOrderMomentsSource"):
379 """Test that the lower order moments (2nd order or lower) is consistent
380 even in the presence of masks.
381 """
382 self.task.config.plugins["ext_shapeHSM_HigherOrderMomentsSource"].setMaskedPixelsToZero = True
384 self.run_measurement()
386 atol = 2e-8
387 for row in self.catalog:
388 self.assertFloatsAlmostEqual(row[f"{plugin_name}_00"], 1.0, atol=atol)
390 self.assertFloatsAlmostEqual(row[f"{plugin_name}_01"], 0.0, atol=atol)
391 self.assertFloatsAlmostEqual(row[f"{plugin_name}_10"], 0.0, atol=atol)
393 self.assertFloatsAlmostEqual(row[f"{plugin_name}_20"], 0.5, atol=atol)
394 self.assertFloatsAlmostEqual(row[f"{plugin_name}_11"], 0.0, atol=atol)
395 self.assertFloatsAlmostEqual(row[f"{plugin_name}_02"], 0.5, atol=atol)
397 def test_kurtosis(self):
398 """Test the the kurtosis measurement against GalSim HSM implementation."""
399 # GalSim does not set masked pixels to zero.
400 # So we set them to zero as well for the comparison.
401 self.task.config.plugins["ext_shapeHSM_HigherOrderMomentsSource"].setMaskedPixelsToZero = True
403 self.run_measurement()
405 delta_rho4s = []
406 for i, row in enumerate(self.catalog):
407 bbox = row.getFootprint().getBBox()
408 im = galsim.Image(self.exposure[bbox].image.array)
409 badpix = self.exposure.mask[bbox].array.copy()
410 bitValue = self.exposure.mask.getPlaneBitMask(["BAD", "SAT"])
411 badpix &= bitValue
412 badpix = galsim.Image(badpix, copy=False)
413 shape = galsim.hsm.FindAdaptiveMom(im, badpix=badpix, strict=False)
414 # r^4 = (x^2+y^2)^2 = x^4 + y^4 + 2x^2y^2
415 rho4 = sum(
416 (
417 row["ext_shapeHSM_HigherOrderMomentsSource_40"],
418 row["ext_shapeHSM_HigherOrderMomentsSource_04"],
419 row["ext_shapeHSM_HigherOrderMomentsSource_22"] * 2,
420 )
421 )
422 delta_rho4s.append(abs(rho4 - 2.0))
423 with self.subTest(i=i):
424 self.assertFloatsAlmostEqual(shape.moments_rho4, rho4, atol=4e-7)
426 # Check that at least one rho4 moment is non-trivial and differs from
427 # the fiducial value of 2, by an amount much larger than the precision.
428 self.assertTrue((np.array(delta_rho4s) > 1e-2).any(), "Unit test is too weak.")
430 def test_hsm_source_higher_moments(self, plugin_name="ext_shapeHSM_HigherOrderMomentsSource"):
431 """Test that we can instantiate and play with a measureShape"""
433 self.task.config.plugins["ext_shapeHSM_HigherOrderMomentsSource"].badMaskPlanes = ["BAD", "SAT"]
434 self.task.config.plugins["ext_shapeHSM_HigherOrderMomentsSource"].setMaskedPixelsToZero = False
436 self.run_measurement()
438 atol = 3e-1
439 for row in self.catalog:
440 self.assertFloatsAlmostEqual(row[f"{plugin_name}_00"], 1.0, atol=atol)
442 self.assertFloatsAlmostEqual(row[f"{plugin_name}_01"], 0.0, atol=atol)
443 self.assertFloatsAlmostEqual(row[f"{plugin_name}_10"], 0.0, atol=atol)
445 self.assertFloatsAlmostEqual(row[f"{plugin_name}_20"], 0.5, atol=atol)
446 self.assertFloatsAlmostEqual(row[f"{plugin_name}_11"], 0.0, atol=atol)
447 self.assertFloatsAlmostEqual(row[f"{plugin_name}_02"], 0.5, atol=atol)
449 self.check(row, plugin_name, atol=atol)
452class HigherMomentTestCaseWithSymmetricMask(HigherMomentTestCaseWithMask):
453 @staticmethod
454 def create_dataset():
455 # Create a simple, fake dataset with centroids at integer or
456 # half-integer positions to have a definite symmetry.
457 bbox = lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.Extent2I(100, 100))
458 dataset = lsst.meas.base.tests.TestDataset(bbox)
459 # Create a point source with Gaussian PSF
460 dataset.addSource(100000.0, lsst.geom.Point2D(49.5, 49.5))
462 # Create a galaxy with Gaussian PSF
463 dataset.addSource(300000.0, lsst.geom.Point2D(76, 79), lsst.afw.geom.Quadrupole(2.0, 3.0, 0.5))
464 return dataset
466 def add_mask_bits(self):
467 # Docstring inherited.
468 for position in (
469 lsst.geom.Point2I(48, 48),
470 lsst.geom.Point2I(73, 79),
471 ):
472 self.exposure.mask[position] |= self.exposure.mask.getPlaneBitMask("BAD")
473 for position in (
474 lsst.geom.Point2D(51, 51),
475 lsst.geom.Point2D(79, 79),
476 ):
477 self.exposure.mask[position] |= self.exposure.mask.getPlaneBitMask("SAT")
479 @lsst.utils.tests.methodParameters(plugin_name=("ext_shapeHSM_HigherOrderMomentsSource",))
480 def test_odd_moments(self, plugin_name):
481 """Test that the odd order moments are close to expect values."""
483 self.run_measurement()
485 for row in self.catalog:
486 self.check_odd_moments(row, plugin_name, atol=1e-16)
487 self.check_even_moments(row, plugin_name, atol=3e-1)
489 @lsst.utils.tests.methodParameters(useSourceCentroidOffset=(False, True))
490 def test_hsm_psf_higher_moments(self, useSourceCentroidOffset):
491 """Test that the higher order PSF moments are closer to expect values,
492 when the masks are symmetric.
493 """
495 self.task.config.plugins[
496 "ext_shapeHSM_HsmPsfMoments"
497 ].useSourceCentroidOffset = useSourceCentroidOffset
498 self.task.config.plugins[
499 "ext_shapeHSM_HigherOrderMomentsPSF"
500 ].useSourceCentroidOffset = useSourceCentroidOffset
502 self.run_measurement()
504 # useSourceCentroidOffset = False results in more accurate results.
505 # Adjust the absolute tolerance accordingly.
506 atol = 4e-3 if useSourceCentroidOffset else 8e-6
508 for i, row in enumerate(self.catalog):
509 with self.subTest(i=i):
510 self.check(row, "ext_shapeHSM_HigherOrderMomentsPSF", atol=atol)
513class TestMemory(lsst.utils.tests.MemoryTestCase):
514 pass
517def setup_module(module):
518 lsst.utils.tests.init()
521if __name__ == "__main__": 521 ↛ 522line 521 didn't jump to line 522, because the condition on line 521 was never true
522 lsst.utils.tests.init()
523 unittest.main()