Coverage for tests/test_doubleShapeletPsfApprox.py: 16%

263 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-02-01 02:44 -0800

1# 

2# LSST Data Management System 

3# 

4# Copyright 2008-2016 AURA/LSST. 

5# 

6# This product includes software developed by the 

7# LSST Project (http://www.lsst.org/). 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the LSST License Statement and 

20# the GNU General Public License along with this program. If not, 

21# see <https://www.lsstcorp.org/LegalNotices/>. 

22# 

23import os 

24import unittest 

25import numpy 

26from io import StringIO 

27import warnings 

28 

29import lsst.utils.tests 

30import lsst.afw.detection 

31import lsst.afw.image 

32import lsst.geom 

33import lsst.afw.geom 

34import lsst.afw.geom.ellipses 

35import lsst.log 

36import lsst.utils.logging 

37import lsst.meas.modelfit 

38import lsst.meas.algorithms 

39 

40# Set trace to 0-5 to view debug messages. Level 5 enables all traces. 

41lsst.utils.logging.trace_set_at("lsst.meas.modelfit.optimizer.Optimizer", -1) 

42lsst.utils.logging.trace_set_at("lsst.meas.modelfit.optimizer.solveTrustRegion", -1) 

43 

44 

45class DoubleShapeletPsfApproxTestMixin: 

46 

47 Algorithm = lsst.meas.modelfit.DoubleShapeletPsfApproxAlgorithm 

48 

49 def initialize(self, psf, ctrl=None, atol=1E-4, **kwds): 

50 if not isinstance(psf, lsst.afw.detection.Psf): 

51 kernel = lsst.afw.math.FixedKernel(psf) 

52 psf = lsst.meas.algorithms.KernelPsf(kernel) 

53 self.psf = psf 

54 self.atol = atol 

55 if ctrl is None: 

56 ctrl = lsst.meas.modelfit.DoubleShapeletPsfApproxControl() 

57 self.ctrl = ctrl 

58 for name, value in kwds.items(): 

59 setattr(self.ctrl, name, value) 

60 self.exposure = lsst.afw.image.ExposureF(1, 1) 

61 scale = 5.0e-5 * lsst.geom.degrees 

62 wcs = lsst.afw.geom.makeSkyWcs(crpix=lsst.geom.Point2D(0.0, 0.0), 

63 crval=lsst.geom.SpherePoint(45, 45, lsst.geom.degrees), 

64 cdMatrix=lsst.afw.geom.makeCdMatrix(scale=scale)) 

65 self.exposure.setWcs(wcs) 

66 self.exposure.setPsf(self.psf) 

67 

68 def tearDown(self): 

69 del self.exposure 

70 del self.psf 

71 del self.ctrl 

72 del self.atol 

73 

74 def setupTaskConfig(self, config): 

75 config.slots.shape = None 

76 config.slots.psfFlux = None 

77 config.slots.apFlux = None 

78 config.slots.gaussianFlux = None 

79 config.slots.modelFlux = None 

80 config.slots.calibFlux = None 

81 config.doReplaceWithNoise = False 

82 config.plugins.names = ["modelfit_DoubleShapeletPsfApprox"] 

83 config.plugins["modelfit_DoubleShapeletPsfApprox"].readControl(self.ctrl) 

84 

85 def checkBounds(self, msf): 

86 """Check that the bounds specified in the control object are met by a MultiShapeletFunction. 

87 

88 These requirements must be true after a call to any fit method or measure(). 

89 """ 

90 pos = self.psf.getAveragePosition() 

91 self.assertEqual(len(msf.getComponents()), 2) 

92 self.assertEqual( 

93 lsst.shapelet.computeSize(self.ctrl.innerOrder), 

94 len(msf.getComponents()[0].getCoefficients()) 

95 ) 

96 self.assertEqual( 

97 lsst.shapelet.computeSize(self.ctrl.outerOrder), 

98 len(msf.getComponents()[1].getCoefficients()) 

99 ) 

100 self.assertGreater( 

101 self.ctrl.maxRadiusBoxFraction * (self.psf.computeKernelImage(pos).getBBox().getArea())**0.5, 

102 lsst.afw.geom.ellipses.Axes(msf.getComponents()[0].getEllipse().getCore()).getA() 

103 ) 

104 self.assertGreater( 

105 self.ctrl.maxRadiusBoxFraction * (self.psf.computeKernelImage(pos).getBBox().getArea())**0.5, 

106 lsst.afw.geom.ellipses.Axes(msf.getComponents()[1].getEllipse().getCore()).getA() 

107 ) 

108 self.assertLess( 

109 self.ctrl.minRadius, 

110 lsst.afw.geom.ellipses.Axes(msf.getComponents()[0].getEllipse().getCore()).getB() 

111 ) 

112 self.assertLess( 

113 self.ctrl.minRadius, 

114 lsst.afw.geom.ellipses.Axes(msf.getComponents()[1].getEllipse().getCore()).getB() 

115 ) 

116 self.assertLess( 

117 self.ctrl.minRadiusDiff, 

118 (msf.getComponents()[1].getEllipse().getCore().getDeterminantRadius() 

119 - msf.getComponents()[0].getEllipse().getCore().getDeterminantRadius()) 

120 ) 

121 

122 def checkRatios(self, msf): 

123 """Check that the ratios specified in the control object are met by a MultiShapeletFunction. 

124 

125 These requirements must be true after initializeResult and fitMoments, but will are relaxed 

126 in later stages of the fit. 

127 """ 

128 inner = msf.getComponents()[0] 

129 outer = msf.getComponents()[1] 

130 position = msf.getComponents()[0].getEllipse().getCenter() 

131 self.assertFloatsAlmostEqual(position.getX(), msf.getComponents()[1].getEllipse().getCenter().getX()) 

132 self.assertFloatsAlmostEqual(position.getY(), msf.getComponents()[1].getEllipse().getCenter().getY()) 

133 self.assertFloatsAlmostEqual(outer.evaluate()(position), 

134 inner.evaluate()(position)*self.ctrl.peakRatio) 

135 self.assertFloatsAlmostEqual( 

136 outer.getEllipse().getCore().getDeterminantRadius(), 

137 inner.getEllipse().getCore().getDeterminantRadius() * self.ctrl.radiusRatio 

138 ) 

139 

140 def makeImages(self, msf): 

141 """Return an Image of the data and an Image of the model for comparison. 

142 """ 

143 pos = self.exposure.getPsf().getAveragePosition() 

144 dataImage = self.exposure.getPsf().computeKernelImage(pos) 

145 modelImage = dataImage.Factory(dataImage.getBBox()) 

146 msf.evaluate().addToImage(modelImage) 

147 return dataImage, modelImage 

148 

149 def checkFitQuality(self, msf): 

150 """Check the quality of the fit by comparing to the PSF image. 

151 """ 

152 dataImage, modelImage = self.makeImages(msf) 

153 self.assertFloatsAlmostEqual(dataImage.getArray(), modelImage.getArray(), atol=self.atol, 

154 plotOnFailure=True) 

155 

156 def testSingleFramePlugin(self): 

157 """Run the algorithm as a single-frame plugin and check the quality of the fit. 

158 """ 

159 with warnings.catch_warnings(): 

160 warnings.filterwarnings("ignore", message="ignoreSlotPluginChecks", category=FutureWarning) 

161 config = lsst.meas.base.SingleFrameMeasurementTask.ConfigClass(ignoreSlotPluginChecks=True) 

162 self.setupTaskConfig(config) 

163 config.slots.centroid = "centroid" 

164 schema = lsst.afw.table.SourceTable.makeMinimalSchema() 

165 centroidKey = lsst.afw.table.Point2DKey.addFields(schema, "centroid", "centroid", "pixel") 

166 task = lsst.meas.base.SingleFrameMeasurementTask(config=config, schema=schema) 

167 measCat = lsst.afw.table.SourceCatalog(schema) 

168 measRecord = measCat.addNew() 

169 measRecord.set(centroidKey, lsst.geom.Point2D(0.0, 0.0)) 

170 task.run(measCat, self.exposure) 

171 self.assertFalse(measRecord.get("modelfit_DoubleShapeletPsfApprox_flag")) 

172 key = lsst.shapelet.MultiShapeletFunctionKey(schema["modelfit"]["DoubleShapeletPsfApprox"]) 

173 msf = measRecord.get(key) 

174 self.checkBounds(msf) 

175 self.checkFitQuality(msf) 

176 

177 def testForcedPlugin(self): 

178 """Run the algorithm as a forced plugin and check the quality of the fit. 

179 """ 

180 config = lsst.meas.base.ForcedMeasurementTask.ConfigClass() 

181 config.copyColumns = {"id": "objectId", "parent": "parentObjectId"} 

182 self.setupTaskConfig(config) 

183 config.slots.centroid = "base_TransformedCentroid" 

184 config.plugins.names |= ["base_TransformedCentroid"] 

185 refSchema = lsst.afw.table.SourceTable.makeMinimalSchema() 

186 refCentroidKey = lsst.afw.table.Point2DKey.addFields(refSchema, "centroid", "centroid", "pixel") 

187 refSchema.getAliasMap().set("slot_Centroid", "centroid") 

188 refCat = lsst.afw.table.SourceCatalog(refSchema) 

189 refRecord = refCat.addNew() 

190 refRecord.set(refCentroidKey, lsst.geom.Point2D(0.0, 0.0)) 

191 refWcs = self.exposure.getWcs() # same as measurement Wcs 

192 task = lsst.meas.base.ForcedMeasurementTask(config=config, refSchema=refSchema) 

193 measCat = task.generateMeasCat(self.exposure, refCat, refWcs) 

194 task.run(measCat, self.exposure, refCat, refWcs) 

195 measRecord = measCat[0] 

196 self.assertFalse(measRecord.get("modelfit_DoubleShapeletPsfApprox_flag")) 

197 measSchema = measCat.schema 

198 key = lsst.shapelet.MultiShapeletFunctionKey(measSchema["modelfit"]["DoubleShapeletPsfApprox"]) 

199 msf = measRecord.get(key) 

200 self.checkBounds(msf) 

201 self.checkFitQuality(msf) 

202 

203 def testInitializeResult(self): 

204 """Test that initializeResult() returns a unit-flux, unit-circle MultiShapeletFunction 

205 with the right peakRatio and radiusRatio. 

206 """ 

207 msf = self.Algorithm.initializeResult(self.ctrl) 

208 self.assertFloatsAlmostEqual(msf.evaluate().integrate(), 1.0) 

209 moments = msf.evaluate().computeMoments() 

210 axes = lsst.afw.geom.ellipses.Axes(moments.getCore()) 

211 self.assertFloatsAlmostEqual(moments.getCenter().getX(), 0.0) 

212 self.assertFloatsAlmostEqual(moments.getCenter().getY(), 0.0) 

213 self.assertFloatsAlmostEqual(axes.getA(), 1.0) 

214 self.assertFloatsAlmostEqual(axes.getB(), 1.0) 

215 self.assertEqual(len(msf.getComponents()), 2) 

216 self.checkRatios(msf) 

217 

218 def testFitMoments(self): 

219 """Test that fitMoments() preserves peakRatio and radiusRatio while setting moments 

220 correctly. 

221 """ 

222 MOMENTS_RTOL = 1E-13 

223 image = self.psf.computeKernelImage(self.psf.getAveragePosition()) 

224 array = image.getArray() 

225 bbox = image.getBBox() 

226 x, y = numpy.meshgrid( 

227 numpy.arange(bbox.getBeginX(), bbox.getEndX()), 

228 numpy.arange(bbox.getBeginY(), bbox.getEndY()) 

229 ) 

230 msf = self.Algorithm.initializeResult(self.ctrl) 

231 self.Algorithm.fitMoments(msf, self.ctrl, image) 

232 self.assertFloatsAlmostEqual(msf.evaluate().integrate(), array.sum(), rtol=MOMENTS_RTOL) 

233 moments = msf.evaluate().computeMoments() 

234 q = lsst.afw.geom.ellipses.Quadrupole(moments.getCore()) 

235 cx = (x*array).sum()/array.sum() 

236 cy = (y*array).sum()/array.sum() 

237 self.assertFloatsAlmostEqual(moments.getCenter().getX(), cx, rtol=MOMENTS_RTOL) 

238 self.assertFloatsAlmostEqual(moments.getCenter().getY(), cy, rtol=MOMENTS_RTOL) 

239 self.assertFloatsAlmostEqual(q.getIxx(), ((x - cx)**2 * array).sum()/array.sum(), rtol=MOMENTS_RTOL) 

240 self.assertFloatsAlmostEqual(q.getIyy(), ((y - cy)**2 * array).sum()/array.sum(), rtol=MOMENTS_RTOL) 

241 self.assertFloatsAlmostEqual(q.getIxy(), ((x - cx)*(y - cy)*array).sum()/array.sum(), 

242 rtol=MOMENTS_RTOL) 

243 self.assertEqual(len(msf.getComponents()), 2) 

244 self.checkRatios(msf) 

245 self.checkBounds(msf) 

246 

247 def testObjective(self): 

248 """Test that model evaluation agrees with derivative evaluation in the objective object. 

249 """ 

250 image = self.psf.computeKernelImage(self.psf.getAveragePosition()) 

251 msf = self.Algorithm.initializeResult(self.ctrl) 

252 self.Algorithm.fitMoments(msf, self.ctrl, image) 

253 moments = msf.evaluate().computeMoments() 

254 r0 = moments.getCore().getDeterminantRadius() 

255 objective = self.Algorithm.makeObjective(moments, self.ctrl, image) 

256 image, model = self.makeImages(msf) 

257 parameters = numpy.zeros(4, dtype=float) 

258 parameters[0] = msf.getComponents()[0].getCoefficients()[0] 

259 parameters[1] = msf.getComponents()[1].getCoefficients()[0] 

260 parameters[2] = msf.getComponents()[0].getEllipse().getCore().getDeterminantRadius() / r0 

261 parameters[3] = msf.getComponents()[1].getEllipse().getCore().getDeterminantRadius() / r0 

262 residuals = numpy.zeros(image.getArray().size, dtype=float) 

263 objective.computeResiduals(parameters, residuals) 

264 self.assertFloatsAlmostEqual( 

265 residuals.reshape(image.getHeight(), image.getWidth()), 

266 image.getArray() - model.getArray() 

267 ) 

268 step = 1E-6 

269 derivatives = numpy.zeros((parameters.size, residuals.size), dtype=float).transpose() 

270 objective.differentiateResiduals(parameters, derivatives) 

271 for i in range(parameters.size): 

272 original = parameters[i] 

273 r1 = numpy.zeros(residuals.size, dtype=float) 

274 r2 = numpy.zeros(residuals.size, dtype=float) 

275 parameters[i] = original + step 

276 objective.computeResiduals(parameters, r1) 

277 parameters[i] = original - step 

278 objective.computeResiduals(parameters, r2) 

279 parameters[i] = original 

280 d = (r1 - r2)/(2.0*step) 

281 self.assertFloatsAlmostEqual( 

282 d.reshape(image.getHeight(), image.getWidth()), 

283 derivatives[:, i].reshape(image.getHeight(), image.getWidth()), 

284 atol=1E-11 

285 ) 

286 

287 def testFitProfile(self): 

288 """Test that fitProfile() does not modify the ellipticity, that it improves the fit, and 

289 that small perturbations to the zeroth-order amplitudes and radii do not improve the fit. 

290 """ 

291 image = self.psf.computeKernelImage(self.psf.getAveragePosition()) 

292 msf = self.Algorithm.initializeResult(self.ctrl) 

293 self.Algorithm.fitMoments(msf, self.ctrl, image) 

294 prev = lsst.shapelet.MultiShapeletFunction(msf) 

295 self.Algorithm.fitProfile(msf, self.ctrl, image) 

296 

297 def getEllipticity(m, c): 

298 s = lsst.afw.geom.ellipses.SeparableDistortionDeterminantRadius( 

299 m.getComponents()[c].getEllipse().getCore() 

300 ) 

301 return numpy.array([s.getE1(), s.getE2()]) 

302 self.assertFloatsAlmostEqual(getEllipticity(prev, 0), getEllipticity(msf, 0), rtol=1E-13) 

303 self.assertFloatsAlmostEqual(getEllipticity(prev, 1), getEllipticity(msf, 1), rtol=1E-13) 

304 

305 def computeChiSq(m): 

306 data, model = self.makeImages(m) 

307 return numpy.sum((data.getArray() - model.getArray())**2) 

308 bestChiSq = computeChiSq(msf) 

309 self.assertLessEqual(bestChiSq, computeChiSq(prev)) 

310 step = 1E-4 

311 for component in msf.getComponents(): 

312 # 0th-order amplitude perturbation 

313 original = component.getCoefficients()[0] 

314 component.getCoefficients()[0] = original + step 

315 self.assertLessEqual(bestChiSq, computeChiSq(msf)) 

316 component.getCoefficients()[0] = original - step 

317 self.assertLessEqual(bestChiSq, computeChiSq(msf)) 

318 component.getCoefficients()[0] = original 

319 # Radius perturbation 

320 original = component.getEllipse() 

321 component.getEllipse().getCore().scale(1.0 + step) 

322 self.assertLessEqual(bestChiSq, computeChiSq(msf)) 

323 component.setEllipse(original) 

324 component.getEllipse().getCore().scale(1.0 - step) 

325 self.assertLessEqual(bestChiSq, computeChiSq(msf)) 

326 component.setEllipse(original) 

327 

328 def testFitShapelets(self): 

329 """Test that fitShapelets() does not modify the zeroth order coefficients or ellipse, 

330 that it improves the fit, and that small perturbations to the higher-order coefficients 

331 do not improve the fit. 

332 """ 

333 image = self.psf.computeKernelImage(self.psf.getAveragePosition()) 

334 msf = self.Algorithm.initializeResult(self.ctrl) 

335 self.Algorithm.fitMoments(msf, self.ctrl, image) 

336 self.Algorithm.fitProfile(msf, self.ctrl, image) 

337 prev = lsst.shapelet.MultiShapeletFunction(msf) 

338 self.Algorithm.fitShapelets(msf, self.ctrl, image) 

339 self.assertFloatsAlmostEqual( 

340 prev.getComponents()[0].getEllipse().getParameterVector(), 

341 msf.getComponents()[0].getEllipse().getParameterVector() 

342 ) 

343 self.assertFloatsAlmostEqual( 

344 prev.getComponents()[1].getEllipse().getParameterVector(), 

345 msf.getComponents()[1].getEllipse().getParameterVector() 

346 ) 

347 

348 def computeChiSq(m): 

349 data, model = self.makeImages(m) 

350 return numpy.sum((data.getArray() - model.getArray())**2) 

351 bestChiSq = computeChiSq(msf) 

352 self.assertLessEqual(bestChiSq, computeChiSq(prev)) 

353 step = 1E-4 

354 for component in msf.getComponents(): 

355 for i in range(1, len(component.getCoefficients())): 

356 original = component.getCoefficients()[i] 

357 component.getCoefficients()[i] = original + step 

358 self.assertLessEqual(bestChiSq, computeChiSq(msf)) 

359 component.getCoefficients()[i] = original - step 

360 self.assertLessEqual(bestChiSq, computeChiSq(msf)) 

361 component.getCoefficients()[i] = original 

362 

363 def testSingleFrameConfigIO(self): 

364 config1 = lsst.meas.base.SingleFrameMeasurementTask.ConfigClass() 

365 config2 = lsst.meas.base.SingleFrameMeasurementTask.ConfigClass() 

366 self.setupTaskConfig(config1) 

367 stream = StringIO() 

368 config1.saveToStream(stream) 

369 config2.loadFromStream(stream.getvalue()) 

370 self.assertEqual(config1, config2) 

371 

372 

373class SingleGaussianTestCase(DoubleShapeletPsfApproxTestMixin, lsst.utils.tests.TestCase): 

374 

375 def setUp(self): 

376 numpy.random.seed(500) 

377 DoubleShapeletPsfApproxTestMixin.initialize( 

378 self, psf=lsst.afw.detection.GaussianPsf(25, 25, 2.0), 

379 innerOrder=0, outerOrder=0, peakRatio=0.0 

380 ) 

381 

382 

383class HigherOrderTestCase0(DoubleShapeletPsfApproxTestMixin, lsst.utils.tests.TestCase): 

384 

385 def setUp(self): 

386 numpy.random.seed(500) 

387 image = lsst.afw.image.ImageD(os.path.join(os.path.dirname(os.path.realpath(__file__)), 

388 "data", "psfs/great3-0.fits")) 

389 DoubleShapeletPsfApproxTestMixin.initialize( 

390 self, psf=image, 

391 innerOrder=3, outerOrder=2, 

392 atol=0.0005 

393 ) 

394 

395 

396class HigherOrderTestCase1(DoubleShapeletPsfApproxTestMixin, lsst.utils.tests.TestCase): 

397 

398 def setUp(self): 

399 numpy.random.seed(500) 

400 image = lsst.afw.image.ImageD(os.path.join(os.path.dirname(os.path.realpath(__file__)), 

401 "data", "psfs/great3-1.fits")) 

402 DoubleShapeletPsfApproxTestMixin.initialize( 

403 self, psf=image, 

404 innerOrder=2, outerOrder=1, 

405 atol=0.002 

406 ) 

407 

408 

409class TestMemory(lsst.utils.tests.MemoryTestCase): 

410 pass 

411 

412 

413def setup_module(module): 

414 lsst.utils.tests.init() 

415 

416 

417if __name__ == "__main__": 417 ↛ 418line 417 didn't jump to line 418, because the condition on line 417 was never true

418 lsst.utils.tests.init() 

419 unittest.main()