Coverage for tests/test_psfSelectTest.py: 9%

267 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-11 03:52 -0700

1# This file is part of meas_algorithms. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import math 

23import unittest 

24import time 

25 

26import numpy as np 

27 

28import lsst.daf.base as dafBase 

29import lsst.geom 

30import lsst.afw.image as afwImage 

31import lsst.afw.geom as afwGeom 

32import lsst.afw.table as afwTable 

33import lsst.meas.algorithms as measAlg 

34import lsst.meas.base as measBase 

35 

36import lsst.afw.cameraGeom as cameraGeom 

37from lsst.afw.cameraGeom.testUtils import DetectorWrapper 

38 

39import lsst.utils.tests 

40 

41try: 

42 display 

43except NameError: 

44 display = False 

45else: 

46 import lsst.afw.display as afwDisplay 

47 afwDisplay.setDefaultMaskTransparency(75) 

48 

49 

50def plantSources(x0, y0, nx, ny, sky, nObj, wid, detector, useRandom=False): 

51 

52 pixToTanPix = detector.getTransform(cameraGeom.PIXELS, cameraGeom.TAN_PIXELS) 

53 

54 img0 = afwImage.ImageF(lsst.geom.ExtentI(nx, ny)) 

55 img = afwImage.ImageF(lsst.geom.ExtentI(nx, ny)) 

56 

57 ixx0, iyy0, ixy0 = wid*wid, wid*wid, 0.0 

58 

59 edgeBuffer = 40.0*wid 

60 

61 flux = 1.0e4 

62 nkx, nky = int(10*wid) + 1, int(10*wid) + 1 

63 xhwid, yhwid = nkx//2, nky//2 

64 

65 nRow = int(math.sqrt(nObj)) 

66 xstep = (nx - 1 - 0.0*edgeBuffer)//(nRow+1) 

67 ystep = (ny - 1 - 0.0*edgeBuffer)//(nRow+1) 

68 

69 if useRandom: 

70 nObj = nRow*nRow 

71 

72 goodAdded0 = [] 

73 goodAdded = [] 

74 

75 for i in range(nObj): 

76 

77 # get our position 

78 if useRandom: 

79 xcen0, ycen0 = np.random.uniform(nx), np.random.uniform(ny) 

80 else: 

81 xcen0, ycen0 = xstep*((i % nRow) + 1), ystep*(int(i/nRow) + 1) 

82 ixcen0, iycen0 = int(xcen0), int(ycen0) 

83 

84 # distort position and shape 

85 pTan = lsst.geom.Point2D(xcen0, ycen0) 

86 p = pixToTanPix.applyInverse(pTan) 

87 linTransform = afwGeom.linearizeTransform(pixToTanPix, p).inverted().getLinear() 

88 m = afwGeom.Quadrupole(ixx0, iyy0, ixy0) 

89 m.transform(linTransform) 

90 

91 xcen, ycen = xcen0, ycen0 # p.getX(), p.getY() 

92 if (xcen < 1.0*edgeBuffer or (nx - xcen) < 1.0*edgeBuffer 

93 or ycen < 1.0*edgeBuffer or (ny - ycen) < 1.0*edgeBuffer): 

94 continue 

95 ixcen, iycen = int(xcen), int(ycen) 

96 ixx, iyy, ixy = m.getIxx(), m.getIyy(), m.getIxy() 

97 

98 # plant the object 

99 tmp = 0.25*(ixx-iyy)**2 + ixy**2 

100 a2 = 0.5*(ixx+iyy) + np.sqrt(tmp) 

101 b2 = 0.5*(ixx+iyy) - np.sqrt(tmp) 

102 

103 theta = 0.5*np.arctan2(2.0*ixy, ixx-iyy) 

104 a = np.sqrt(a2) 

105 b = np.sqrt(b2) 

106 

107 c, s = math.cos(theta), math.sin(theta) 

108 good0, good = True, True 

109 for y in range(nky): 

110 iy = iycen + y - yhwid 

111 iy0 = iycen0 + y - yhwid 

112 

113 for x in range(nkx): 

114 ix = ixcen + x - xhwid 

115 ix0 = ixcen0 + x - xhwid 

116 

117 if ix >= 0 and ix < nx and iy >= 0 and iy < ny: 

118 dx, dy = ix - xcen, iy - ycen 

119 u = c*dx + s*dy 

120 v = -s*dx + c*dy 

121 I0 = flux/(2*math.pi*a*b) 

122 val = I0*math.exp(-0.5*((u/a)**2 + (v/b)**2)) 

123 if val < 0: 

124 val = 0 

125 prevVal = img[ix, iy, afwImage.LOCAL] 

126 img[ix, iy, afwImage.LOCAL] = val+prevVal 

127 else: 

128 good = False 

129 

130 if ix0 >= 0 and ix0 < nx and iy0 >= 0 and iy0 < ny: 

131 dx, dy = ix - xcen, iy - ycen 

132 I0 = flux/(2*math.pi*wid*wid) 

133 val = I0*math.exp(-0.5*((dx/wid)**2 + (dy/wid)**2)) 

134 if val < 0: 

135 val = 0 

136 prevVal = img0[ix0, iy0, afwImage.LOCAL] 

137 img0[ix0, iy0, afwImage.LOCAL] = val+prevVal 

138 else: 

139 good0 = False 

140 

141 if good0: 

142 goodAdded0.append([xcen, ycen]) 

143 if good: 

144 goodAdded.append([xcen, ycen]) 

145 

146 # add sky and noise 

147 img += sky 

148 img0 += sky 

149 noise = afwImage.ImageF(lsst.geom.ExtentI(nx, ny)) 

150 noise0 = afwImage.ImageF(lsst.geom.ExtentI(nx, ny)) 

151 for i in range(nx): 

152 for j in range(ny): 

153 noise[i, j, afwImage.LOCAL] = np.random.poisson(img[i, j, afwImage.LOCAL]) 

154 noise0[i, j, afwImage.LOCAL] = np.random.poisson(img0[i, j, afwImage.LOCAL]) 

155 

156 edgeWidth = int(0.5*edgeBuffer) 

157 mask = afwImage.Mask(lsst.geom.ExtentI(nx, ny)) 

158 left = lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.ExtentI(edgeWidth, ny)) 

159 right = lsst.geom.Box2I(lsst.geom.Point2I(nx - edgeWidth, 0), lsst.geom.ExtentI(edgeWidth, ny)) 

160 top = lsst.geom.Box2I(lsst.geom.Point2I(0, ny - edgeWidth), lsst.geom.ExtentI(nx, edgeWidth)) 

161 bottom = lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.ExtentI(nx, edgeWidth)) 

162 

163 for pos in [left, right, top, bottom]: 

164 msk = afwImage.Mask(mask, pos, deep=False) 

165 msk.set(msk.getPlaneBitMask('EDGE')) 

166 

167 expos = afwImage.makeExposure(afwImage.makeMaskedImage(noise, mask, afwImage.ImageF(noise, True))) 

168 expos0 = afwImage.makeExposure(afwImage.makeMaskedImage(noise0, mask, afwImage.ImageF(noise0, True))) 

169 

170 im = expos.getMaskedImage().getImage() 

171 im0 = expos0.getMaskedImage().getImage() 

172 im -= sky 

173 im0 -= sky 

174 

175 return expos, goodAdded, expos0, goodAdded0 

176 

177 

178class PsfSelectionTestCase(lsst.utils.tests.TestCase): 

179 """Test the aperture correction.""" 

180 

181 def setUp(self): 

182 np.random.seed(500) # make test repeatable 

183 self.x0, self.y0 = 0, 0 

184 self.nx, self.ny = 512, 512 # 2048, 4096 

185 self.sky = 100.0 

186 self.nObj = 100 

187 

188 # make a detector with distortion 

189 self.detector = DetectorWrapper( 

190 bbox=lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.Extent2I(self.nx, self.ny)), 

191 orientation=cameraGeom.Orientation(lsst.geom.Point2D(255.0, 255.0)), 

192 radialDistortion=0.925, 

193 ).detector 

194 

195 # make a detector with no distortion 

196 self.flatDetector = DetectorWrapper( 

197 bbox=lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.Extent2I(self.nx, self.ny)), 

198 orientation=cameraGeom.Orientation(lsst.geom.Point2D(255.0, 255.0)), 

199 radialDistortion=0.0, 

200 ).detector 

201 

202 # detection policies 

203 detConfig = measAlg.SourceDetectionConfig() 

204 # Cannot use default background approximation order (6) for such a small image. 

205 detConfig.background.approxOrderX = 4 

206 # This test depends on footprints grown with the old Manhattan metric. 

207 detConfig.isotropicGrow = False 

208 

209 # measurement policies 

210 measConfig = measBase.SingleFrameMeasurementConfig() 

211 measConfig.algorithms.names = [ 

212 "base_SdssCentroid", 

213 "base_SdssShape", 

214 "base_GaussianFlux", 

215 "base_PsfFlux", 

216 ] 

217 measConfig.slots.centroid = "base_SdssCentroid" 

218 measConfig.slots.shape = "base_SdssShape" 

219 measConfig.slots.psfFlux = "base_PsfFlux" 

220 measConfig.plugins["base_SdssCentroid"].doFootprintCheck = False 

221 measConfig.slots.apFlux = None 

222 measConfig.slots.modelFlux = None 

223 measConfig.slots.gaussianFlux = None 

224 measConfig.slots.calibFlux = None 

225 

226 self.schema = afwTable.SourceTable.makeMinimalSchema() 

227 self.detTask = measAlg.SourceDetectionTask(config=detConfig, schema=self.schema) 

228 self.measTask = measBase.SingleFrameMeasurementTask(config=measConfig, schema=self.schema) 

229 

230 # psf star selector 

231 starSelectorClass = measAlg.sourceSelectorRegistry["objectSize"] 

232 starSelectorConfig = starSelectorClass.ConfigClass() 

233 starSelectorConfig.fluxMin = 5000.0 

234 starSelectorConfig.badFlags = [] 

235 self.starSelector = starSelectorClass(config=starSelectorConfig) 

236 

237 self.makePsfCandidates = measAlg.MakePsfCandidatesTask() 

238 

239 # psf determiner 

240 psfDeterminerFactory = measAlg.psfDeterminerRegistry["pca"] 

241 psfDeterminerConfig = psfDeterminerFactory.ConfigClass() 

242 width, height = self.nx, self.ny 

243 nEigenComponents = 3 

244 psfDeterminerConfig.sizeCellX = width//3 

245 psfDeterminerConfig.sizeCellY = height//3 

246 psfDeterminerConfig.nEigenComponents = nEigenComponents 

247 psfDeterminerConfig.spatialOrder = 1 

248 psfDeterminerConfig.nStarPerCell = 0 

249 psfDeterminerConfig.nStarPerCellSpatialFit = 0 # unlimited 

250 self.psfDeterminer = psfDeterminerFactory(psfDeterminerConfig) 

251 

252 def tearDown(self): 

253 del self.detTask 

254 del self.measTask 

255 del self.schema 

256 del self.detector 

257 del self.flatDetector 

258 del self.starSelector 

259 del self.psfDeterminer 

260 

261 def detectAndMeasure(self, exposure): 

262 """Quick and dirty detection (note: we already subtracted background) 

263 """ 

264 table = afwTable.SourceTable.make(self.schema) 

265 # detect 

266 sources = self.detTask.run(table, exposure).sources 

267 # ... and measure 

268 self.measTask.run(sources, exposure) 

269 return sources 

270 

271 def testPsfCandidate(self): 

272 

273 detector = self.detector 

274 

275 # make an exposure 

276 print("Planting") 

277 psfSigma = 1.5 

278 exposDist, nGoodDist, expos0, nGood0 = plantSources(self.x0, self.y0, 

279 self.nx, self.ny, 

280 self.sky, self.nObj, psfSigma, detector) 

281 

282 # set the psf 

283 kwid = 21 

284 psf = measAlg.SingleGaussianPsf(kwid, kwid, psfSigma) 

285 exposDist.setPsf(psf) 

286 exposDist.setDetector(detector) 

287 

288 # detect 

289 print("detection") 

290 sourceList = self.detectAndMeasure(exposDist) 

291 

292 # select psf stars 

293 print("PSF selection") 

294 stars = self.starSelector.run(sourceList, exposure=exposDist) 

295 psfCandidateList = self.makePsfCandidates.run(stars.sourceCat, exposDist).psfCandidates 

296 

297 # determine the PSF 

298 print("PSF determination") 

299 metadata = dafBase.PropertyList() 

300 t0 = time.time() 

301 psf, cellSet = self.psfDeterminer.determinePsf(exposDist, psfCandidateList, metadata) 

302 print("... determination time: ", time.time() - t0) 

303 print("PSF kernel width: ", psf.getKernel().getWidth()) 

304 

305 ####################################################################### 

306 # try to subtract off the stars and check the residuals 

307 

308 imgOrig = exposDist.getMaskedImage().getImage().getArray() 

309 maxFlux = imgOrig.max() 

310 

311 ############ 

312 # first try it with no distortion in the psf 

313 exposDist.setDetector(self.flatDetector) 

314 

315 print("uncorrected subtraction") 

316 subImg = afwImage.MaskedImageF(exposDist.getMaskedImage(), True) 

317 for s in sourceList: 

318 x, y = s.getX(), s.getY() 

319 measAlg.subtractPsf(psf, subImg, x, y) 

320 

321 if display: 

322 afwDisplay.Display(frame=1).mtv(exposDist, title=self._testMethodName + ": full") 

323 afwDisplay.Display(frame=0).mtv(subImg, title=self._testMethodName + ": subtracted") 

324 

325 img = subImg.getImage().getArray() 

326 norm = img/math.sqrt(maxFlux) 

327 

328 smin0, smax0, srms0 = norm.min(), norm.max(), norm.std() 

329 

330 print("min:", smin0, "max: ", smax0, "rms: ", srms0) 

331 

332 if False: 

333 # This section has been disabled as distortion was removed from PsfCandidate and Psf; 

334 # it will be reintroduced in the future with a different API, at which point this 

335 # test code should be re-enabled. 

336 

337 ############## 

338 # try it with the correct distortion in the psf 

339 exposDist.setDetector(self.detector) 

340 

341 print("corrected subtraction") 

342 subImg = afwImage.MaskedImageF(exposDist.getMaskedImage(), True) 

343 for s in sourceList: 

344 x, y = s.getX(), s.getY() 

345 measAlg.subtractPsf(psf, subImg, x, y) 

346 

347 if display: 

348 afwDisplay.Display(frame=2).mtv(exposDist, title=self._testMethodName + ": full") 

349 afwDisplay.Display(frame=3).mtv(subImg, title=self._testMethodName + ": subtracted") 

350 

351 img = subImg.getImage().getArray() 

352 norm = img/math.sqrt(maxFlux) 

353 

354 smin, smax, srms = norm.min(), norm.max(), norm.std() 

355 

356 # with proper distortion, residuals should be < 4sigma (even for 512x512 pixels) 

357 print("min:", smin, "max: ", smax, "rms: ", srms) 

358 

359 # the distrib of residuals should be tighter 

360 self.assertLess(smin0, smin) 

361 self.assertGreater(smax0, smax) 

362 self.assertGreater(srms0, srms) 

363 

364 def testDistortedImage(self): 

365 

366 detector = self.detector 

367 

368 psfSigma = 1.5 

369 stars = plantSources(self.x0, self.y0, self.nx, self.ny, self.sky, self.nObj, psfSigma, detector) 

370 expos, starXy = stars[0], stars[1] 

371 

372 # add some faint round galaxies ... only slightly bigger than the psf 

373 gxy = plantSources(self.x0, self.y0, self.nx, self.ny, self.sky, 10, 1.07*psfSigma, detector) 

374 mi = expos.getMaskedImage() 

375 mi += gxy[0].getMaskedImage() 

376 gxyXy = gxy[1] 

377 

378 kwid = 15 # int(10*psfSigma) + 1 

379 psf = measAlg.SingleGaussianPsf(kwid, kwid, psfSigma) 

380 expos.setPsf(psf) 

381 

382 expos.setDetector(detector) 

383 

384 ######################## 

385 # try without distorter 

386 expos.setDetector(self.flatDetector) 

387 print("Testing PSF selection *without* distortion") 

388 sourceList = self.detectAndMeasure(expos) 

389 stars = self.starSelector.run(sourceList, exposure=expos) 

390 psfCandidateList = self.makePsfCandidates.run(stars.sourceCat, expos).psfCandidates 

391 

392 ######################## 

393 # try with distorter 

394 expos.setDetector(self.detector) 

395 print("Testing PSF selection *with* distortion") 

396 sourceList = self.detectAndMeasure(expos) 

397 stars = self.starSelector.run(sourceList, exposure=expos) 

398 psfCandidateListCorrected = self.makePsfCandidates.run(stars.sourceCat, expos).psfCandidates 

399 

400 def countObjects(candList): 

401 nStar, nGxy = 0, 0 

402 for c in candList: 

403 s = c.getSource() 

404 x, y = s.getX(), s.getY() 

405 for xs, ys in starXy: 

406 if abs(x-xs) < 2.0 and abs(y-ys) < 2.0: 

407 nStar += 1 

408 for xg, yg in gxyXy: 

409 if abs(x-xg) < 2.0 and abs(y-yg) < 2.0: 

410 nGxy += 1 

411 return nStar, nGxy 

412 

413 nstar, ngxy = countObjects(psfCandidateList) 

414 nstarC, ngxyC = countObjects(psfCandidateListCorrected) 

415 

416 print("uncorrected nStar, nGxy: ", nstar, "/", len(starXy), " ", ngxy, '/', len(gxyXy)) 

417 print("dist-corrected nStar, nGxy: ", nstarC, '/', len(starXy), " ", ngxyC, '/', len(gxyXy)) 

418 

419 ######################## 

420 # display 

421 if display: 

422 iDisp = 1 

423 disp = afwDisplay.Display(frame=iDisp) 

424 disp.mtv(expos, title=self._testMethodName + ": image") 

425 size = 40 

426 for c in psfCandidateList: 

427 s = c.getSource() 

428 ixx, iyy, ixy = size*s.getIxx(), size*s.getIyy(), size*s.getIxy() 

429 disp.dot("@:%g,%g,%g" % (ixx, ixy, iyy), s.getX(), s.getY(), ctype=afwDisplay.RED) 

430 size *= 2.0 

431 for c in psfCandidateListCorrected: 

432 s = c.getSource() 

433 ixx, iyy, ixy = size*s.getIxx(), size*s.getIyy(), size*s.getIxy() 

434 disp.dot("@:%g,%g,%g" % (ixx, ixy, iyy), s.getX(), s.getY(), ctype=afwDisplay.GREEN) 

435 

436 # we shouldn't expect to get all available stars without distortion correcting 

437 self.assertLess(nstar, len(starXy)) 

438 

439 # here we should get all of them, occassionally 1 or 2 might get missed 

440 self.assertGreaterEqual(nstarC, 0.95*len(starXy)) 

441 

442 # no contamination by small gxys 

443 self.assertEqual(ngxyC, 0) 

444 

445 

446class TestMemory(lsst.utils.tests.MemoryTestCase): 

447 pass 

448 

449 

450def setup_module(module): 

451 lsst.utils.tests.init() 

452 

453 

454if __name__ == "__main__": 454 ↛ 455line 454 didn't jump to line 455, because the condition on line 454 was never true

455 lsst.utils.tests.init() 

456 unittest.main()