Coverage for tests/test_higher_moments.py: 24%

235 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-05-04 10:18 +0000

1# This file is part of meas_extensions_shapeHSM. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Unit tests for higher order moments measurement. 

23 

24These double up as initial estimates of how accurate the measurement is with 

25different configuration options. The various tolerance levels here are based 

26on experimentation with the specific datasets used here. 

27""" 

28 

29import unittest 

30 

31import galsim 

32import lsst.afw.geom 

33import lsst.meas.base.tests 

34import lsst.meas.extensions.shapeHSM # noqa: F401 

35import lsst.utils.tests as tests 

36import numpy as np 

37from lsst.meas.base import SingleFrameMeasurementConfig, SingleFrameMeasurementTask 

38from lsst.pex.config import FieldValidationError 

39 

40 

41class HigherMomentsBaseTestCase(tests.TestCase): 

42 """Base test case to test higher order moments.""" 

43 

44 def setUp(self): 

45 """Create an exposure and run measurement on the source and the PSF""" 

46 super().setUp() 

47 

48 # Initialize a config and activate the plugin 

49 sfmConfig = SingleFrameMeasurementConfig() 

50 sfmConfig.plugins.names |= [ 

51 "ext_shapeHSM_HsmSourceMoments", 

52 "ext_shapeHSM_HsmSourceMomentsRound", 

53 "ext_shapeHSM_HsmPsfMoments", 

54 "ext_shapeHSM_HigherOrderMomentsSource", 

55 "ext_shapeHSM_HigherOrderMomentsPSF", 

56 ] 

57 # The min and max order determine the schema and cannot be changed 

58 # after the Task is created. So we set it generously here. 

59 for plugin_name in ( 

60 "ext_shapeHSM_HigherOrderMomentsSource", 

61 "ext_shapeHSM_HigherOrderMomentsPSF", 

62 ): 

63 sfmConfig.plugins[plugin_name].max_order = 7 

64 sfmConfig.plugins[plugin_name].min_order = 0 

65 

66 # Create a minimal schema (columns) 

67 self.schema = lsst.meas.base.tests.TestDataset.makeMinimalSchema() 

68 

69 # Create a task 

70 sfmTask = SingleFrameMeasurementTask(config=sfmConfig, schema=self.schema) 

71 

72 dataset = self.create_dataset() 

73 

74 # Get the exposure and catalog. 

75 exposure, catalog = dataset.realize(0.0, sfmTask.schema, randomSeed=0) 

76 

77 self.catalog = catalog 

78 self.exposure = exposure 

79 self.task = sfmTask 

80 

81 self.add_mask_bits() 

82 

83 @staticmethod 

84 def add_mask_bits(): 

85 """Add mask bits to the exposure. 

86 

87 This must go along with the create_dataset method. This is a no-op for 

88 the base class and subclasses must set mask bits depending on the test. 

89 """ 

90 pass 

91 

92 @staticmethod 

93 def create_dataset(): 

94 # Create a simple, fake dataset 

95 bbox = lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.Extent2I(100, 100)) 

96 dataset = lsst.meas.base.tests.TestDataset(bbox) 

97 # Create a point source with Gaussian PSF 

98 dataset.addSource(100000.0, lsst.geom.Point2D(49.5, 49.5)) 

99 

100 # Create a galaxy with Gaussian PSF 

101 dataset.addSource(300000.0, lsst.geom.Point2D(76.3, 79.2), lsst.afw.geom.Quadrupole(2.0, 3.0, 0.5)) 

102 return dataset 

103 

104 def run_measurement(self, **kwargs): 

105 """Run measurement on the source and the PSF""" 

106 self.task.run(self.catalog, self.exposure, **kwargs) 

107 

108 def check_odd_moments(self, row, plugin_name, atol, orders=(3, 5)): 

109 for n in orders: 

110 for p in range(n + 1): 

111 with self.subTest((p, n - p)): 

112 self.assertFloatsAlmostEqual(row[f"{plugin_name}_{p}{n-p}"], 0.0, atol=atol) 

113 

114 def check_even_moments(self, row, plugin_name, atol): 

115 M_source_40 = row[f"{plugin_name}_40"] 

116 M_source_31 = row[f"{plugin_name}_31"] 

117 M_source_22 = row[f"{plugin_name}_22"] 

118 M_source_13 = row[f"{plugin_name}_13"] 

119 M_source_04 = row[f"{plugin_name}_04"] 

120 

121 M_source_60 = row[f"{plugin_name}_60"] 

122 M_source_51 = row[f"{plugin_name}_51"] 

123 M_source_42 = row[f"{plugin_name}_42"] 

124 M_source_33 = row[f"{plugin_name}_33"] 

125 M_source_24 = row[f"{plugin_name}_24"] 

126 M_source_15 = row[f"{plugin_name}_15"] 

127 M_source_06 = row[f"{plugin_name}_06"] 

128 

129 self.assertFloatsAlmostEqual(M_source_40, 0.75, atol=atol) 

130 self.assertFloatsAlmostEqual(M_source_31, 0.0, atol=atol) 

131 self.assertFloatsAlmostEqual(M_source_22, 0.25, atol=atol) 

132 self.assertFloatsAlmostEqual(M_source_13, 0.0, atol=atol) 

133 self.assertFloatsAlmostEqual(M_source_04, 0.75, atol=atol) 

134 

135 self.assertFloatsAlmostEqual(M_source_60, 1.875, atol=atol) 

136 self.assertFloatsAlmostEqual(M_source_51, 0.0, atol=atol) 

137 self.assertFloatsAlmostEqual(M_source_42, 0.375, atol=atol) 

138 self.assertFloatsAlmostEqual(M_source_33, 0.0, atol=atol) 

139 self.assertFloatsAlmostEqual(M_source_24, 0.375, atol=atol) 

140 self.assertFloatsAlmostEqual(M_source_15, 0.0, atol=atol) 

141 self.assertFloatsAlmostEqual(M_source_06, 1.875, atol=atol) 

142 

143 def check(self, row, plugin_name, atol): 

144 self.check_odd_moments(row, plugin_name, atol) 

145 self.check_even_moments(row, plugin_name, atol) 

146 

147 @lsst.utils.tests.methodParameters( 

148 plugin_name=( 

149 "ext_shapeHSM_HigherOrderMomentsSource", 

150 "ext_shapeHSM_HigherOrderMomentsPSF", 

151 ) 

152 ) 

153 def test_validate_config(self, plugin_name): 

154 """Test that the validation of the configs works as expected.""" 

155 config = self.task.config.plugins[plugin_name] 

156 config.validate() # This should not raise any error. 

157 

158 # Test that the validation fails when the max_order is smaller than the 

159 # min_order. 

160 config.max_order = 3 

161 config.min_order = 4 

162 with self.assertRaises(FieldValidationError): 

163 config.validate() 

164 

165 @lsst.utils.tests.methodParameters( 

166 plugin_name=( 

167 "ext_shapeHSM_HigherOrderMomentsSource", 

168 "ext_shapeHSM_HigherOrderMomentsPSF", 

169 ) 

170 ) 

171 def test_calculate_higher_order_moments(self, plugin_name): 

172 """Test that the _calculate_higher_order_moments results in the same 

173 outputs whether or not we take the linear algebra code path. 

174 """ 

175 

176 # We do not run any of the measurement plugins, but use a rough 

177 # centroid and an arbitrary 2x2 matrix to test that the code paths 

178 # result in consistent outputs. 

179 

180 for row in self.catalog: 

181 bbox = row.getFootprint().getBBox() 

182 center = bbox.getCenter() 

183 

184 # Asymmetric matrix is not realistic, but we don't expect it to 

185 # break the consistency. It just needs to have determinant > 0. 

186 # This can be considered as a stress test for any small asymmetry 

187 # that may arise because of rounding errors in the off-diagonal 

188 # terms. 

189 M = np.array([[2.0, 1.0], [0.5, 3.0]]) 

190 plugin = self.task.plugins[plugin_name] 

191 image = self.exposure.image[bbox] 

192 

193 hm1 = plugin._calculate_higher_order_moments(image, center, M, use_linear_algebra=False) 

194 hm2 = plugin._calculate_higher_order_moments(image, center, M, use_linear_algebra=True) 

195 for key in hm1: 

196 with self.subTest(): 

197 self.assertFloatsAlmostEqual(hm1[key], hm2[key], atol=1e-14) 

198 

199 

200class HigherOrderMomentsTestCase(HigherMomentsBaseTestCase): 

201 @lsst.utils.tests.methodParameters( 

202 plugin_name=( 

203 "ext_shapeHSM_HigherOrderMomentsSource", 

204 "ext_shapeHSM_HigherOrderMomentsPSF", 

205 ) 

206 ) 

207 def test_hsm_source_moments(self, plugin_name): 

208 """Test that we can instantiate and play with a measureShape""" 

209 

210 self.run_measurement() 

211 

212 atol = 8e-6 

213 for row in self.catalog: 

214 self.check(row, plugin_name, atol=atol) 

215 

216 @lsst.utils.tests.methodParameters(useSourceCentroidOffset=(False, True)) 

217 def test_hsm_psf_lower_moments(self, useSourceCentroidOffset): 

218 """Test that we can instantiate and play with a measureShape""" 

219 plugin_name = "ext_shapeHSM_HigherOrderMomentsPSF" 

220 self.task.config.plugins["ext_shapeHSM_HsmPsfMoments"].useSourceCentroidOffset = ( 

221 useSourceCentroidOffset 

222 ) 

223 self.task.config.plugins["ext_shapeHSM_HigherOrderMomentsPSF"].useSourceCentroidOffset = ( 

224 useSourceCentroidOffset 

225 ) 

226 

227 self.run_measurement() 

228 

229 # Results are accurate for either values of useSourceCentroidOffset 

230 # when looking at lower order moments. 

231 atol = 2e-8 

232 

233 for i, row in enumerate(self.catalog): 

234 with self.subTest(i=i): 

235 self.assertFloatsAlmostEqual(row[f"{plugin_name}_00"], 1.0, atol=atol) 

236 

237 self.assertFloatsAlmostEqual(row[f"{plugin_name}_01"], 0.0, atol=atol) 

238 self.assertFloatsAlmostEqual(row[f"{plugin_name}_10"], 0.0, atol=atol) 

239 

240 self.assertFloatsAlmostEqual(row[f"{plugin_name}_20"], 0.5, atol=atol) 

241 self.assertFloatsAlmostEqual(row[f"{plugin_name}_11"], 0.0, atol=atol) 

242 self.assertFloatsAlmostEqual(row[f"{plugin_name}_02"], 0.5, atol=atol) 

243 

244 @lsst.utils.tests.methodParameters(useSourceCentroidOffset=(False, True)) 

245 def test_hsm_psf_higher_moments(self, useSourceCentroidOffset): 

246 """Test that we can instantiate and play with a measureShape""" 

247 

248 self.task.config.plugins["ext_shapeHSM_HsmPsfMoments"].useSourceCentroidOffset = ( 

249 useSourceCentroidOffset 

250 ) 

251 self.task.config.plugins["ext_shapeHSM_HigherOrderMomentsPSF"].useSourceCentroidOffset = ( 

252 useSourceCentroidOffset 

253 ) 

254 

255 self.run_measurement() 

256 

257 # useSourceCentroidOffset = False results in more accurate results. 

258 # Adjust the absolute tolerance accordingly. 

259 atol = 7e-3 if useSourceCentroidOffset else 8e-6 

260 

261 for i, row in enumerate(self.catalog): 

262 with self.subTest(i=i): 

263 self.check(row, "ext_shapeHSM_HigherOrderMomentsPSF", atol=atol) 

264 

265 @lsst.utils.tests.methodParameters( 

266 target_plugin_name=( 

267 "base_SdssShape", 

268 "ext_shapeHSM_HsmSourceMomentsRound", 

269 "truth", 

270 ) 

271 ) 

272 def test_source_consistent_weight(self, target_plugin_name): 

273 """Test that when we get expected results when use a different set of 

274 consistent weights to measure the higher order moments of sources. 

275 """ 

276 # Pause the execution of the measurement task before the higher order 

277 # moments plugins. 

278 

279 pause_order = self.task.plugins["ext_shapeHSM_HigherOrderMomentsSource"].getExecutionOrder() 

280 self.run_measurement(endOrder=pause_order) 

281 

282 for suffix in ( 

283 "x", 

284 "y", 

285 "xx", 

286 "yy", 

287 "xy", 

288 ): 

289 self.catalog[f"ext_shapeHSM_HsmSourceMoments_{suffix}"] = self.catalog[ 

290 f"{target_plugin_name}_{suffix}" 

291 ] 

292 

293 # Resume the execution of the measurement task. 

294 self.run_measurement(beginOrder=pause_order) 

295 

296 # ext_shapeHSM_HsmSourceMomentsRound appears to have lower accuracy. 

297 # Adjust the absolute tolerance accordingly. 

298 atol = 1.2e-7 if target_plugin_name == "ext_shapeHSM_HsmSourceMomentsRound" else 6e-4 

299 plugin_name = "ext_shapeHSM_HigherOrderMomentsSource" 

300 

301 for i, row in enumerate(self.catalog): 

302 with self.subTest((plugin_name, i)): 

303 self.check(row, plugin_name, atol=atol) 

304 # The round moments are only accurate for the round sources, 

305 # which is the first one in the catalog. 

306 if target_plugin_name == "ext_shapeHSM_HsmSourceMomentsRound": 

307 break 

308 

309 @lsst.utils.tests.methodParametersProduct( 

310 target_plugin_name=( 

311 "base_SdssShape_psf", 

312 "truth", 

313 ), 

314 useSourceCentroidOffset=(False, True), 

315 ) 

316 def test_psf_consistent_weight(self, target_plugin_name, useSourceCentroidOffset): 

317 """Test that when we get expected results when use a different set of 

318 consistent weights to measure the higher order moments of PSFs. 

319 """ 

320 self.task.config.plugins["ext_shapeHSM_HigherOrderMomentsPSF"].useSourceCentroidOffset = ( 

321 useSourceCentroidOffset 

322 ) 

323 

324 # Pause the execution of the measurement task before the higher order 

325 # moments plugins. 

326 pause_order = self.task.plugins["ext_shapeHSM_HigherOrderMomentsPSF"].getExecutionOrder() 

327 self.run_measurement(endOrder=pause_order) 

328 

329 # Create a dictionary of PSF moments corresponding to the truth. 

330 # These are hardcoded in dataset.realize. 

331 truth_psf = {"xx": 4.0, "yy": 4.0, "xy": 0.0} 

332 

333 for suffix in ( 

334 "xx", 

335 "yy", 

336 "xy", 

337 ): 

338 if target_plugin_name == "truth": 

339 self.catalog[f"ext_shapeHSM_HsmPsfMoments_{suffix}"] = truth_psf[suffix] 

340 else: 

341 self.catalog[f"ext_shapeHSM_HsmPsfMoments_{suffix}"] = self.catalog[ 

342 f"{target_plugin_name}_{suffix}" 

343 ] 

344 

345 # Resume the execution of the measurement task. 

346 self.run_measurement(beginOrder=pause_order) 

347 

348 # useSourceCentroidOffset = False results in more accurate results. 

349 # Adjust the absolute tolerance accordingly. 

350 atol = 1.2e-2 if useSourceCentroidOffset else 8e-6 

351 plugin_name = "ext_shapeHSM_HigherOrderMomentsPSF" 

352 

353 for i, row in enumerate(self.catalog): 

354 with self.subTest((plugin_name, i)): 

355 self.check(row, plugin_name, atol=atol) 

356 

357 

358class HigherMomentTestCaseWithMask(HigherMomentsBaseTestCase): 

359 """A test case to measure higher order moments in the presence of masks. 

360 

361 The tests serve checking the validity the algorithm on non-Gaussian 

362 profiles. 

363 """ 

364 

365 def add_mask_bits(self): 

366 # Docstring inherited. 

367 for position in ( 

368 lsst.geom.Point2I(48, 47), 

369 lsst.geom.Point2I(76, 79), 

370 ): 

371 self.exposure.mask[position] |= self.exposure.mask.getPlaneBitMask("BAD") 

372 for position in ( 

373 lsst.geom.Point2D(49, 49), 

374 lsst.geom.Point2D(76, 79), 

375 ): 

376 self.exposure.mask[position] |= self.exposure.mask.getPlaneBitMask("SAT") 

377 

378 def test_lower_order_moments(self, plugin_name="ext_shapeHSM_HigherOrderMomentsSource"): 

379 """Test that the lower order moments (2nd order or lower) is consistent 

380 even in the presence of masks. 

381 """ 

382 self.task.config.plugins["ext_shapeHSM_HigherOrderMomentsSource"].setMaskedPixelsToZero = True 

383 

384 self.run_measurement() 

385 

386 atol = 2e-8 

387 for row in self.catalog: 

388 self.assertFloatsAlmostEqual(row[f"{plugin_name}_00"], 1.0, atol=atol) 

389 

390 self.assertFloatsAlmostEqual(row[f"{plugin_name}_01"], 0.0, atol=atol) 

391 self.assertFloatsAlmostEqual(row[f"{plugin_name}_10"], 0.0, atol=atol) 

392 

393 self.assertFloatsAlmostEqual(row[f"{plugin_name}_20"], 0.5, atol=atol) 

394 self.assertFloatsAlmostEqual(row[f"{plugin_name}_11"], 0.0, atol=atol) 

395 self.assertFloatsAlmostEqual(row[f"{plugin_name}_02"], 0.5, atol=atol) 

396 

397 def test_kurtosis(self): 

398 """Test the the kurtosis measurement against GalSim HSM implementation.""" 

399 # GalSim does not set masked pixels to zero. 

400 # So we set them to zero as well for the comparison. 

401 self.task.config.plugins["ext_shapeHSM_HigherOrderMomentsSource"].setMaskedPixelsToZero = True 

402 

403 self.run_measurement() 

404 

405 delta_rho4s = [] 

406 for i, row in enumerate(self.catalog): 

407 bbox = row.getFootprint().getBBox() 

408 im = galsim.Image(self.exposure[bbox].image.array) 

409 badpix = self.exposure.mask[bbox].array.copy() 

410 bitValue = self.exposure.mask.getPlaneBitMask(["BAD", "SAT"]) 

411 badpix &= bitValue 

412 badpix = galsim.Image(badpix, copy=False) 

413 shape = galsim.hsm.FindAdaptiveMom(im, badpix=badpix, strict=False) 

414 # r^4 = (x^2+y^2)^2 = x^4 + y^4 + 2x^2y^2 

415 rho4 = sum( 

416 ( 

417 row["ext_shapeHSM_HigherOrderMomentsSource_40"], 

418 row["ext_shapeHSM_HigherOrderMomentsSource_04"], 

419 row["ext_shapeHSM_HigherOrderMomentsSource_22"] * 2, 

420 ) 

421 ) 

422 delta_rho4s.append(abs(rho4 - 2.0)) 

423 with self.subTest(i=i): 

424 self.assertFloatsAlmostEqual(shape.moments_rho4, rho4, atol=4e-7) 

425 

426 # Check that at least one rho4 moment is non-trivial and differs from 

427 # the fiducial value of 2, by an amount much larger than the precision. 

428 self.assertTrue((np.array(delta_rho4s) > 1e-2).any(), "Unit test is too weak.") 

429 

430 def test_hsm_source_higher_moments(self, plugin_name="ext_shapeHSM_HigherOrderMomentsSource"): 

431 """Test that we can instantiate and play with a measureShape""" 

432 

433 self.task.config.plugins["ext_shapeHSM_HigherOrderMomentsSource"].badMaskPlanes = ["BAD", "SAT"] 

434 self.task.config.plugins["ext_shapeHSM_HigherOrderMomentsSource"].setMaskedPixelsToZero = False 

435 

436 self.run_measurement() 

437 

438 atol = 3e-1 

439 for row in self.catalog: 

440 self.assertFloatsAlmostEqual(row[f"{plugin_name}_00"], 1.0, atol=atol) 

441 

442 self.assertFloatsAlmostEqual(row[f"{plugin_name}_01"], 0.0, atol=atol) 

443 self.assertFloatsAlmostEqual(row[f"{plugin_name}_10"], 0.0, atol=atol) 

444 

445 self.assertFloatsAlmostEqual(row[f"{plugin_name}_20"], 0.5, atol=atol) 

446 self.assertFloatsAlmostEqual(row[f"{plugin_name}_11"], 0.0, atol=atol) 

447 self.assertFloatsAlmostEqual(row[f"{plugin_name}_02"], 0.5, atol=atol) 

448 

449 self.check(row, plugin_name, atol=atol) 

450 

451 

452class HigherMomentTestCaseWithSymmetricMask(HigherMomentTestCaseWithMask): 

453 @staticmethod 

454 def create_dataset(): 

455 # Create a simple, fake dataset with centroids at integer or 

456 # half-integer positions to have a definite symmetry. 

457 bbox = lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.Extent2I(100, 100)) 

458 dataset = lsst.meas.base.tests.TestDataset(bbox) 

459 # Create a point source with Gaussian PSF 

460 dataset.addSource(100000.0, lsst.geom.Point2D(49.5, 49.5)) 

461 

462 # Create a galaxy with Gaussian PSF 

463 dataset.addSource(300000.0, lsst.geom.Point2D(76, 79), lsst.afw.geom.Quadrupole(2.0, 3.0, 0.5)) 

464 return dataset 

465 

466 def add_mask_bits(self): 

467 # Docstring inherited. 

468 for position in ( 

469 lsst.geom.Point2I(48, 48), 

470 lsst.geom.Point2I(73, 79), 

471 ): 

472 self.exposure.mask[position] |= self.exposure.mask.getPlaneBitMask("BAD") 

473 for position in ( 

474 lsst.geom.Point2D(51, 51), 

475 lsst.geom.Point2D(79, 79), 

476 ): 

477 self.exposure.mask[position] |= self.exposure.mask.getPlaneBitMask("SAT") 

478 

479 @lsst.utils.tests.methodParameters(plugin_name=("ext_shapeHSM_HigherOrderMomentsSource",)) 

480 def test_odd_moments(self, plugin_name): 

481 """Test that the odd order moments are close to expect values.""" 

482 

483 self.run_measurement() 

484 

485 for row in self.catalog: 

486 self.check_odd_moments(row, plugin_name, atol=1e-16) 

487 self.check_even_moments(row, plugin_name, atol=3e-1) 

488 

489 @lsst.utils.tests.methodParameters(useSourceCentroidOffset=(False, True)) 

490 def test_hsm_psf_higher_moments(self, useSourceCentroidOffset): 

491 """Test that the higher order PSF moments are closer to expect values, 

492 when the masks are symmetric. 

493 """ 

494 

495 self.task.config.plugins["ext_shapeHSM_HsmPsfMoments"].useSourceCentroidOffset = ( 

496 useSourceCentroidOffset 

497 ) 

498 self.task.config.plugins["ext_shapeHSM_HigherOrderMomentsPSF"].useSourceCentroidOffset = ( 

499 useSourceCentroidOffset 

500 ) 

501 

502 self.run_measurement() 

503 

504 # useSourceCentroidOffset = False results in more accurate results. 

505 # Adjust the absolute tolerance accordingly. 

506 atol = 4e-3 if useSourceCentroidOffset else 8e-6 

507 

508 for i, row in enumerate(self.catalog): 

509 with self.subTest(i=i): 

510 self.check(row, "ext_shapeHSM_HigherOrderMomentsPSF", atol=atol) 

511 

512 

513class TestMemory(lsst.utils.tests.MemoryTestCase): 

514 pass 

515 

516 

517def setup_module(module): 

518 lsst.utils.tests.init() 

519 

520 

521if __name__ == "__main__": 521 ↛ 522line 521 didn't jump to line 522, because the condition on line 521 was never true

522 lsst.utils.tests.init() 

523 unittest.main()