Coverage for tests/test_psfSelectTest.py: 10%

269 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-08-06 01:42 -0700

1# This file is part of meas_algorithms. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import math 

23import unittest 

24import time 

25 

26import numpy as np 

27 

28import lsst.daf.base as dafBase 

29import lsst.geom 

30import lsst.afw.image as afwImage 

31import lsst.afw.geom as afwGeom 

32import lsst.afw.table as afwTable 

33import lsst.meas.algorithms as measAlg 

34import lsst.meas.base as measBase 

35 

36import lsst.afw.cameraGeom as cameraGeom 

37from lsst.afw.cameraGeom.testUtils import DetectorWrapper 

38 

39import lsst.utils.tests 

40 

41try: 

42 display 

43except NameError: 

44 display = False 

45else: 

46 import lsst.afw.display as afwDisplay 

47 afwDisplay.setDefaultMaskTransparency(75) 

48 

49 

50def plantSources(x0, y0, nx, ny, sky, nObj, wid, detector, useRandom=False): 

51 

52 pixToTanPix = detector.getTransform(cameraGeom.PIXELS, cameraGeom.TAN_PIXELS) 

53 

54 img0 = afwImage.ImageF(lsst.geom.ExtentI(nx, ny)) 

55 img = afwImage.ImageF(lsst.geom.ExtentI(nx, ny)) 

56 

57 ixx0, iyy0, ixy0 = wid*wid, wid*wid, 0.0 

58 

59 edgeBuffer = 40.0*wid 

60 

61 flux = 1.0e4 

62 nkx, nky = int(10*wid) + 1, int(10*wid) + 1 

63 xhwid, yhwid = nkx//2, nky//2 

64 

65 nRow = int(math.sqrt(nObj)) 

66 xstep = (nx - 1 - 0.0*edgeBuffer)//(nRow+1) 

67 ystep = (ny - 1 - 0.0*edgeBuffer)//(nRow+1) 

68 

69 if useRandom: 

70 nObj = nRow*nRow 

71 

72 goodAdded0 = [] 

73 goodAdded = [] 

74 

75 for i in range(nObj): 

76 

77 # get our position 

78 if useRandom: 

79 xcen0, ycen0 = np.random.uniform(nx), np.random.uniform(ny) 

80 else: 

81 xcen0, ycen0 = xstep*((i % nRow) + 1), ystep*(int(i/nRow) + 1) 

82 ixcen0, iycen0 = int(xcen0), int(ycen0) 

83 

84 # distort position and shape 

85 pTan = lsst.geom.Point2D(xcen0, ycen0) 

86 p = pixToTanPix.applyInverse(pTan) 

87 linTransform = afwGeom.linearizeTransform(pixToTanPix, p).inverted().getLinear() 

88 m = afwGeom.Quadrupole(ixx0, iyy0, ixy0) 

89 m.transform(linTransform) 

90 

91 xcen, ycen = xcen0, ycen0 # p.getX(), p.getY() 

92 if (xcen < 1.0*edgeBuffer or (nx - xcen) < 1.0*edgeBuffer 

93 or ycen < 1.0*edgeBuffer or (ny - ycen) < 1.0*edgeBuffer): 

94 continue 

95 ixcen, iycen = int(xcen), int(ycen) 

96 ixx, iyy, ixy = m.getIxx(), m.getIyy(), m.getIxy() 

97 

98 # plant the object 

99 tmp = 0.25*(ixx-iyy)**2 + ixy**2 

100 a2 = 0.5*(ixx+iyy) + np.sqrt(tmp) 

101 b2 = 0.5*(ixx+iyy) - np.sqrt(tmp) 

102 

103 theta = 0.5*np.arctan2(2.0*ixy, ixx-iyy) 

104 a = np.sqrt(a2) 

105 b = np.sqrt(b2) 

106 

107 c, s = math.cos(theta), math.sin(theta) 

108 good0, good = True, True 

109 for y in range(nky): 

110 iy = iycen + y - yhwid 

111 iy0 = iycen0 + y - yhwid 

112 

113 for x in range(nkx): 

114 ix = ixcen + x - xhwid 

115 ix0 = ixcen0 + x - xhwid 

116 

117 if ix >= 0 and ix < nx and iy >= 0 and iy < ny: 

118 dx, dy = ix - xcen, iy - ycen 

119 u = c*dx + s*dy 

120 v = -s*dx + c*dy 

121 I0 = flux/(2*math.pi*a*b) 

122 val = I0*math.exp(-0.5*((u/a)**2 + (v/b)**2)) 

123 if val < 0: 

124 val = 0 

125 prevVal = img[ix, iy, afwImage.LOCAL] 

126 img[ix, iy, afwImage.LOCAL] = val+prevVal 

127 else: 

128 good = False 

129 

130 if ix0 >= 0 and ix0 < nx and iy0 >= 0 and iy0 < ny: 

131 dx, dy = ix - xcen, iy - ycen 

132 I0 = flux/(2*math.pi*wid*wid) 

133 val = I0*math.exp(-0.5*((dx/wid)**2 + (dy/wid)**2)) 

134 if val < 0: 

135 val = 0 

136 prevVal = img0[ix0, iy0, afwImage.LOCAL] 

137 img0[ix0, iy0, afwImage.LOCAL] = val+prevVal 

138 else: 

139 good0 = False 

140 

141 if good0: 

142 goodAdded0.append([xcen, ycen]) 

143 if good: 

144 goodAdded.append([xcen, ycen]) 

145 

146 # add sky and noise 

147 img += sky 

148 img0 += sky 

149 noise = afwImage.ImageF(lsst.geom.ExtentI(nx, ny)) 

150 noise0 = afwImage.ImageF(lsst.geom.ExtentI(nx, ny)) 

151 for i in range(nx): 

152 for j in range(ny): 

153 noise[i, j, afwImage.LOCAL] = np.random.poisson(img[i, j, afwImage.LOCAL]) 

154 noise0[i, j, afwImage.LOCAL] = np.random.poisson(img0[i, j, afwImage.LOCAL]) 

155 

156 edgeWidth = int(0.5*edgeBuffer) 

157 mask = afwImage.Mask(lsst.geom.ExtentI(nx, ny)) 

158 left = lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.ExtentI(edgeWidth, ny)) 

159 right = lsst.geom.Box2I(lsst.geom.Point2I(nx - edgeWidth, 0), lsst.geom.ExtentI(edgeWidth, ny)) 

160 top = lsst.geom.Box2I(lsst.geom.Point2I(0, ny - edgeWidth), lsst.geom.ExtentI(nx, edgeWidth)) 

161 bottom = lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.ExtentI(nx, edgeWidth)) 

162 

163 for pos in [left, right, top, bottom]: 

164 msk = afwImage.Mask(mask, pos, deep=False) 

165 msk.set(msk.getPlaneBitMask('EDGE')) 

166 

167 expos = afwImage.makeExposure(afwImage.makeMaskedImage(noise, mask, afwImage.ImageF(noise, True))) 

168 expos0 = afwImage.makeExposure(afwImage.makeMaskedImage(noise0, mask, afwImage.ImageF(noise0, True))) 

169 

170 im = expos.getMaskedImage().getImage() 

171 im0 = expos0.getMaskedImage().getImage() 

172 im -= sky 

173 im0 -= sky 

174 

175 return expos, goodAdded, expos0, goodAdded0 

176 

177 

178class PsfSelectionTestCase(lsst.utils.tests.TestCase): 

179 """Test the aperture correction.""" 

180 

181 def setUp(self): 

182 np.random.seed(500) # make test repeatable 

183 self.x0, self.y0 = 0, 0 

184 self.nx, self.ny = 512, 512 # 2048, 4096 

185 self.sky = 100.0 

186 self.nObj = 100 

187 

188 # make a detector with distortion 

189 self.detector = DetectorWrapper( 

190 bbox=lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.Extent2I(self.nx, self.ny)), 

191 orientation=cameraGeom.Orientation(lsst.geom.Point2D(255.0, 255.0)), 

192 radialDistortion=0.925, 

193 ).detector 

194 

195 # make a detector with no distortion 

196 self.flatDetector = DetectorWrapper( 

197 bbox=lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.Extent2I(self.nx, self.ny)), 

198 orientation=cameraGeom.Orientation(lsst.geom.Point2D(255.0, 255.0)), 

199 radialDistortion=0.0, 

200 ).detector 

201 

202 # detection policies 

203 detConfig = measAlg.SourceDetectionConfig() 

204 # Cannot use default background approximation order (6) for such a small image. 

205 detConfig.background.approxOrderX = 4 

206 

207 # measurement policies 

208 measConfig = measBase.SingleFrameMeasurementConfig() 

209 measConfig.algorithms.names = [ 

210 "base_SdssCentroid", 

211 "base_SdssShape", 

212 "base_GaussianFlux", 

213 "base_PsfFlux", 

214 ] 

215 measConfig.slots.centroid = "base_SdssCentroid" 

216 measConfig.slots.shape = "base_SdssShape" 

217 measConfig.slots.psfFlux = "base_PsfFlux" 

218 measConfig.plugins["base_SdssCentroid"].doFootprintCheck = False 

219 measConfig.slots.apFlux = None 

220 measConfig.slots.modelFlux = None 

221 measConfig.slots.gaussianFlux = None 

222 measConfig.slots.calibFlux = None 

223 

224 self.schema = afwTable.SourceTable.makeMinimalSchema() 

225 detConfig.validate() 

226 measConfig.validate() 

227 self.detTask = measAlg.SourceDetectionTask(config=detConfig, schema=self.schema) 

228 self.measTask = measBase.SingleFrameMeasurementTask(config=measConfig, schema=self.schema) 

229 

230 # psf star selector 

231 starSelectorClass = measAlg.sourceSelectorRegistry["objectSize"] 

232 starSelectorConfig = starSelectorClass.ConfigClass() 

233 starSelectorConfig.fluxMin = 5000.0 

234 starSelectorConfig.badFlags = [] 

235 self.starSelector = starSelectorClass(config=starSelectorConfig) 

236 

237 self.makePsfCandidates = measAlg.MakePsfCandidatesTask() 

238 

239 # psf determiner 

240 psfDeterminerFactory = measAlg.psfDeterminerRegistry["pca"] 

241 psfDeterminerConfig = psfDeterminerFactory.ConfigClass() 

242 width, height = self.nx, self.ny 

243 nEigenComponents = 3 

244 psfDeterminerConfig.sizeCellX = width//3 

245 psfDeterminerConfig.sizeCellY = height//3 

246 psfDeterminerConfig.nEigenComponents = nEigenComponents 

247 psfDeterminerConfig.spatialOrder = 1 

248 psfDeterminerConfig.kernelSizeMin = 31 

249 psfDeterminerConfig.nStarPerCell = 0 

250 psfDeterminerConfig.nStarPerCellSpatialFit = 0 # unlimited 

251 self.psfDeterminer = psfDeterminerFactory(psfDeterminerConfig) 

252 

253 def tearDown(self): 

254 del self.detTask 

255 del self.measTask 

256 del self.schema 

257 del self.detector 

258 del self.flatDetector 

259 del self.starSelector 

260 del self.psfDeterminer 

261 

262 def detectAndMeasure(self, exposure): 

263 """Quick and dirty detection (note: we already subtracted background) 

264 """ 

265 table = afwTable.SourceTable.make(self.schema) 

266 # detect 

267 sources = self.detTask.run(table, exposure).sources 

268 # ... and measure 

269 self.measTask.run(sources, exposure) 

270 return sources 

271 

272 def testPsfCandidate(self): 

273 

274 detector = self.detector 

275 

276 # make an exposure 

277 print("Planting") 

278 psfSigma = 1.5 

279 exposDist, nGoodDist, expos0, nGood0 = plantSources(self.x0, self.y0, 

280 self.nx, self.ny, 

281 self.sky, self.nObj, psfSigma, detector) 

282 

283 # set the psf 

284 kwid = 21 

285 psf = measAlg.SingleGaussianPsf(kwid, kwid, psfSigma) 

286 exposDist.setPsf(psf) 

287 exposDist.setDetector(detector) 

288 

289 # detect 

290 print("detection") 

291 sourceList = self.detectAndMeasure(exposDist) 

292 

293 # select psf stars 

294 print("PSF selection") 

295 stars = self.starSelector.run(sourceList, exposure=exposDist) 

296 psfCandidateList = self.makePsfCandidates.run(stars.sourceCat, exposDist).psfCandidates 

297 

298 # determine the PSF 

299 print("PSF determination") 

300 metadata = dafBase.PropertyList() 

301 t0 = time.time() 

302 psf, cellSet = self.psfDeterminer.determinePsf(exposDist, psfCandidateList, metadata) 

303 print("... determination time: ", time.time() - t0) 

304 print("PSF kernel width: ", psf.getKernel().getWidth()) 

305 

306 ####################################################################### 

307 # try to subtract off the stars and check the residuals 

308 

309 imgOrig = exposDist.getMaskedImage().getImage().getArray() 

310 maxFlux = imgOrig.max() 

311 

312 ############ 

313 # first try it with no distortion in the psf 

314 exposDist.setDetector(self.flatDetector) 

315 

316 print("uncorrected subtraction") 

317 subImg = afwImage.MaskedImageF(exposDist.getMaskedImage(), True) 

318 for s in sourceList: 

319 x, y = s.getX(), s.getY() 

320 measAlg.subtractPsf(psf, subImg, x, y) 

321 

322 if display: 

323 afwDisplay.Display(frame=1).mtv(exposDist, title=self._testMethodName + ": full") 

324 afwDisplay.Display(frame=0).mtv(subImg, title=self._testMethodName + ": subtracted") 

325 

326 img = subImg.getImage().getArray() 

327 norm = img/math.sqrt(maxFlux) 

328 

329 smin0, smax0, srms0 = norm.min(), norm.max(), norm.std() 

330 

331 print("min:", smin0, "max: ", smax0, "rms: ", srms0) 

332 

333 if False: 

334 # This section has been disabled as distortion was removed from PsfCandidate and Psf; 

335 # it will be reintroduced in the future with a different API, at which point this 

336 # test code should be re-enabled. 

337 

338 ############## 

339 # try it with the correct distortion in the psf 

340 exposDist.setDetector(self.detector) 

341 

342 print("corrected subtraction") 

343 subImg = afwImage.MaskedImageF(exposDist.getMaskedImage(), True) 

344 for s in sourceList: 

345 x, y = s.getX(), s.getY() 

346 measAlg.subtractPsf(psf, subImg, x, y) 

347 

348 if display: 

349 afwDisplay.Display(frame=2).mtv(exposDist, title=self._testMethodName + ": full") 

350 afwDisplay.Display(frame=3).mtv(subImg, title=self._testMethodName + ": subtracted") 

351 

352 img = subImg.getImage().getArray() 

353 norm = img/math.sqrt(maxFlux) 

354 

355 smin, smax, srms = norm.min(), norm.max(), norm.std() 

356 

357 # with proper distortion, residuals should be < 4sigma (even for 512x512 pixels) 

358 print("min:", smin, "max: ", smax, "rms: ", srms) 

359 

360 # the distrib of residuals should be tighter 

361 self.assertLess(smin0, smin) 

362 self.assertGreater(smax0, smax) 

363 self.assertGreater(srms0, srms) 

364 

365 def testDistortedImage(self): 

366 

367 detector = self.detector 

368 

369 psfSigma = 1.5 

370 stars = plantSources(self.x0, self.y0, self.nx, self.ny, self.sky, self.nObj, psfSigma, detector) 

371 expos, starXy = stars[0], stars[1] 

372 

373 # add some faint round galaxies ... only slightly bigger than the psf 

374 gxy = plantSources(self.x0, self.y0, self.nx, self.ny, self.sky, 10, 1.07*psfSigma, detector) 

375 mi = expos.getMaskedImage() 

376 mi += gxy[0].getMaskedImage() 

377 gxyXy = gxy[1] 

378 

379 kwid = 15 # int(10*psfSigma) + 1 

380 psf = measAlg.SingleGaussianPsf(kwid, kwid, psfSigma) 

381 expos.setPsf(psf) 

382 

383 expos.setDetector(detector) 

384 

385 ######################## 

386 # try without distorter 

387 expos.setDetector(self.flatDetector) 

388 print("Testing PSF selection *without* distortion") 

389 sourceList = self.detectAndMeasure(expos) 

390 stars = self.starSelector.run(sourceList, exposure=expos) 

391 psfCandidateList = self.makePsfCandidates.run(stars.sourceCat, expos).psfCandidates 

392 

393 ######################## 

394 # try with distorter 

395 expos.setDetector(self.detector) 

396 print("Testing PSF selection *with* distortion") 

397 sourceList = self.detectAndMeasure(expos) 

398 stars = self.starSelector.run(sourceList, exposure=expos) 

399 psfCandidateListCorrected = self.makePsfCandidates.run(stars.sourceCat, expos).psfCandidates 

400 

401 def countObjects(candList): 

402 nStar, nGxy = 0, 0 

403 for c in candList: 

404 s = c.getSource() 

405 x, y = s.getX(), s.getY() 

406 for xs, ys in starXy: 

407 if abs(x-xs) < 2.0 and abs(y-ys) < 2.0: 

408 nStar += 1 

409 for xg, yg in gxyXy: 

410 if abs(x-xg) < 2.0 and abs(y-yg) < 2.0: 

411 nGxy += 1 

412 return nStar, nGxy 

413 

414 nstar, ngxy = countObjects(psfCandidateList) 

415 nstarC, ngxyC = countObjects(psfCandidateListCorrected) 

416 

417 print("uncorrected nStar, nGxy: ", nstar, "/", len(starXy), " ", ngxy, '/', len(gxyXy)) 

418 print("dist-corrected nStar, nGxy: ", nstarC, '/', len(starXy), " ", ngxyC, '/', len(gxyXy)) 

419 

420 ######################## 

421 # display 

422 if display: 

423 iDisp = 1 

424 disp = afwDisplay.Display(frame=iDisp) 

425 disp.mtv(expos, title=self._testMethodName + ": image") 

426 size = 40 

427 for c in psfCandidateList: 

428 s = c.getSource() 

429 ixx, iyy, ixy = size*s.getIxx(), size*s.getIyy(), size*s.getIxy() 

430 disp.dot("@:%g,%g,%g" % (ixx, ixy, iyy), s.getX(), s.getY(), ctype=afwDisplay.RED) 

431 size *= 2.0 

432 for c in psfCandidateListCorrected: 

433 s = c.getSource() 

434 ixx, iyy, ixy = size*s.getIxx(), size*s.getIyy(), size*s.getIxy() 

435 disp.dot("@:%g,%g,%g" % (ixx, ixy, iyy), s.getX(), s.getY(), ctype=afwDisplay.GREEN) 

436 

437 # we shouldn't expect to get all available stars without distortion correcting 

438 self.assertLess(nstar, len(starXy)) 

439 

440 # here we should get all of them, occassionally 1 or 2 might get missed 

441 self.assertGreaterEqual(nstarC, 0.95*len(starXy)) 

442 

443 # no contamination by small gxys 

444 self.assertEqual(ngxyC, 0) 

445 

446 

447class TestMemory(lsst.utils.tests.MemoryTestCase): 

448 pass 

449 

450 

451def setup_module(module): 

452 lsst.utils.tests.init() 

453 

454 

455if __name__ == "__main__": 455 ↛ 456line 455 didn't jump to line 456, because the condition on line 455 was never true

456 lsst.utils.tests.init() 

457 unittest.main()