Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# 

2# LSST Data Management System 

3# Copyright 2008-2017 AURA/LSST. 

4# 

5# This product includes software developed by the 

6# LSST Project (http://www.lsst.org/). 

7# 

8# This program is free software: you can redistribute it and/or modify 

9# it under the terms of the GNU General Public License as published by 

10# the Free Software Foundation, either version 3 of the License, or 

11# (at your option) any later version. 

12# 

13# This program is distributed in the hope that it will be useful, 

14# but WITHOUT ANY WARRANTY; without even the implied warranty of 

15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

16# GNU General Public License for more details. 

17# 

18# You should have received a copy of the LSST License Statement and 

19# the GNU General Public License along with this program. If not, 

20# see <https://www.lsstcorp.org/LegalNotices/>. 

21# 

22 

23import unittest 

24import numpy as np 

25 

26import lsst.afw.image as afwImage 

27import lsst.utils.tests 

28import lsst.ip.isr as ipIsr 

29import lsst.pex.exceptions as pexExcept 

30import lsst.ip.isr.isrMock as isrMock 

31import lsst.pipe.base as pipeBase 

32 

33 

34def countMaskedPixels(maskedImage, maskPlane): 

35 """Function to count the number of masked pixels of a given type. 

36 

37 Parameters 

38 ---------- 

39 maskedImage : `lsst.afw.image.MaskedImage` 

40 Image to measure the mask on. 

41 maskPlane : `str` 

42 Name of the mask plane to count 

43 

44 Returns 

45 ------- 

46 nMask : `int` 

47 Number of masked pixels. 

48 """ 

49 bitMask = maskedImage.getMask().getPlaneBitMask(maskPlane) 

50 isBit = maskedImage.getMask().getArray() & bitMask > 0 

51 numBit = np.sum(isBit) 

52 

53 return numBit 

54 

55 

56def computeImageMedianAndStd(image): 

57 """Function to calculate median and std of image data. 

58 

59 Parameters 

60 ---------- 

61 image : `lsst.afw.image.Image` 

62 Image to measure statistics on. 

63 

64 Returns 

65 ------- 

66 median : `float` 

67 Image median. 

68 std : `float` 

69 Image stddev. 

70 """ 

71 median = np.nanmedian(image.getArray()) 

72 std = np.nanstd(image.getArray()) 

73 

74 return (median, std) 

75 

76 

77class IsrFunctionsCases(lsst.utils.tests.TestCase): 

78 """Test that functions for ISR produce expected outputs. 

79 """ 

80 def setUp(self): 

81 self.inputExp = isrMock.TrimmedRawMock().run() 

82 self.mi = self.inputExp.getMaskedImage() 

83 

84 def test_transposeMaskedImage(self): 

85 """Expect height and width to be exchanged. 

86 """ 

87 transposed = ipIsr.transposeMaskedImage(self.mi) 

88 self.assertEqual(transposed.getImage().getBBox().getHeight(), 

89 self.mi.getImage().getBBox().getWidth()) 

90 self.assertEqual(transposed.getImage().getBBox().getWidth(), 

91 self.mi.getImage().getBBox().getHeight()) 

92 

93 def test_interpolateDefectList(self): 

94 """Expect number of interpolated pixels to be non-zero. 

95 """ 

96 defectList = isrMock.DefectMock().run() 

97 self.assertEqual(len(defectList), 1) 

98 

99 for fallbackValue in (None, -999.0): 

100 for haveMask in (True, False): 

101 with self.subTest(fallbackValue=fallbackValue, haveMask=haveMask): 

102 if haveMask is False: 

103 if 'INTRP' in self.mi.getMask().getMaskPlaneDict(): 

104 self.mi.getMask().removeAndClearMaskPlane('INTRP') 

105 else: 

106 if 'INTRP' not in self.mi.getMask().getMaskPlaneDict(): 

107 self.mi.getMask().addMaskPlane('INTRP') 

108 numBit = countMaskedPixels(self.mi, "INTRP") 

109 self.assertEqual(numBit, 0) 

110 

111 def test_transposeDefectList(self): 

112 """Expect bbox dimension values to flip. 

113 """ 

114 defectList = isrMock.DefectMock().run() 

115 transposed = defectList.transpose() 

116 

117 for d, t in zip(defectList, transposed): 

118 self.assertEqual(d.getBBox().getDimensions().getX(), t.getBBox().getDimensions().getY()) 

119 self.assertEqual(d.getBBox().getDimensions().getY(), t.getBBox().getDimensions().getX()) 

120 

121 def test_makeThresholdMask(self): 

122 """Expect list of defects to have elements. 

123 """ 

124 defectList = ipIsr.makeThresholdMask(self.mi, 200, 

125 growFootprints=2, 

126 maskName='SAT') 

127 

128 self.assertEqual(len(defectList), 1) 

129 

130 def test_interpolateFromMask(self): 

131 """Expect number of interpolated pixels to be non-zero. 

132 """ 

133 ipIsr.makeThresholdMask(self.mi, 200, growFootprints=2, 

134 maskName='SAT') 

135 for growFootprints in range(0, 3): 

136 interpMaskedImage = ipIsr.interpolateFromMask(self.mi, 2.0, 

137 growSaturatedFootprints=growFootprints, 

138 maskNameList=['SAT']) 

139 numBit = countMaskedPixels(interpMaskedImage, "INTRP") 

140 self.assertEqual(numBit, 40800, msg=f"interpolateFromMask with growFootprints={growFootprints}") 

141 

142 def test_saturationCorrectionInterpolate(self): 

143 """Expect number of mask pixels with SAT marked to be non-zero. 

144 """ 

145 corrMaskedImage = ipIsr.saturationCorrection(self.mi, 200, 2.0, 

146 growFootprints=2, interpolate=True, 

147 maskName='SAT') 

148 numBit = countMaskedPixels(corrMaskedImage, "SAT") 

149 self.assertEqual(numBit, 40800) 

150 

151 def test_saturationCorrectionNoInterpolate(self): 

152 """Expect number of mask pixels with SAT marked to be non-zero. 

153 """ 

154 corrMaskedImage = ipIsr.saturationCorrection(self.mi, 200, 2.0, 

155 growFootprints=2, interpolate=False, 

156 maskName='SAT') 

157 numBit = countMaskedPixels(corrMaskedImage, "SAT") 

158 self.assertEqual(numBit, 40800) 

159 

160 def test_trimToMatchCalibBBox(self): 

161 """Expect bounding boxes to match. 

162 """ 

163 darkExp = isrMock.DarkMock().run() 

164 darkMi = darkExp.getMaskedImage() 

165 

166 nEdge = 2 

167 darkMi = darkMi[nEdge:-nEdge, nEdge:-nEdge, afwImage.LOCAL] 

168 newInput = ipIsr.trimToMatchCalibBBox(self.mi, darkMi) 

169 

170 self.assertEqual(newInput.getImage().getBBox(), darkMi.getImage().getBBox()) 

171 

172 def test_darkCorrection(self): 

173 """Expect round-trip application to be equal. 

174 Expect RuntimeError if sizes are different. 

175 """ 

176 darkExp = isrMock.DarkMock().run() 

177 darkMi = darkExp.getMaskedImage() 

178 

179 mi = self.mi.clone() 

180 

181 # The `invert` parameter controls the direction of the 

182 # application. This will apply, and un-apply the dark. 

183 ipIsr.darkCorrection(self.mi, darkMi, 1.0, 1.0, trimToFit=True) 

184 ipIsr.darkCorrection(self.mi, darkMi, 1.0, 1.0, trimToFit=True, invert=True) 

185 

186 self.assertMaskedImagesAlmostEqual(self.mi, mi, atol=1e-3) 

187 

188 darkMi = darkMi[1:-1, 1:-1, afwImage.LOCAL] 

189 with self.assertRaises(RuntimeError): 

190 ipIsr.darkCorrection(self.mi, darkMi, 1.0, 1.0, trimToFit=False) 

191 

192 def test_biasCorrection(self): 

193 """Expect smaller median image value after. 

194 Expect RuntimeError if sizes are different. 

195 """ 

196 biasExp = isrMock.BiasMock().run() 

197 biasMi = biasExp.getMaskedImage() 

198 

199 mi = self.mi.clone() 

200 ipIsr.biasCorrection(self.mi, biasMi, trimToFit=True) 

201 self.assertLess(computeImageMedianAndStd(self.mi.getImage())[0], 

202 computeImageMedianAndStd(mi.getImage())[0]) 

203 

204 biasMi = biasMi[1:-1, 1:-1, afwImage.LOCAL] 

205 with self.assertRaises(RuntimeError): 

206 ipIsr.biasCorrection(self.mi, biasMi, trimToFit=False) 

207 

208 def test_flatCorrection(self): 

209 """Expect round-trip application to be equal. 

210 Expect RuntimeError if sizes are different. 

211 """ 

212 flatExp = isrMock.FlatMock().run() 

213 flatMi = flatExp.getMaskedImage() 

214 

215 mi = self.mi.clone() 

216 for scaling in ('USER', 'MEAN', 'MEDIAN'): 

217 # The `invert` parameter controls the direction of the 

218 # application. This will apply, and un-apply the flat. 

219 ipIsr.flatCorrection(self.mi, flatMi, scaling, userScale=1.0, trimToFit=True) 

220 ipIsr.flatCorrection(self.mi, flatMi, scaling, userScale=1.0, 

221 trimToFit=True, invert=True) 

222 

223 self.assertMaskedImagesAlmostEqual(self.mi, mi, atol=1e-3, 

224 msg=f"flatCorrection with scaling {scaling}") 

225 

226 flatMi = flatMi[1:-1, 1:-1, afwImage.LOCAL] 

227 with self.assertRaises(RuntimeError): 

228 ipIsr.flatCorrection(self.mi, flatMi, 'USER', userScale=1.0, trimToFit=False) 

229 

230 def test_flatCorrectionUnknown(self): 

231 """Raise if an unknown scaling is used. 

232 

233 The `scaling` parameter must be a known type. If not, the 

234 flat correction will raise a RuntimeError. 

235 """ 

236 flatExp = isrMock.FlatMock().run() 

237 flatMi = flatExp.getMaskedImage() 

238 

239 with self.assertRaises(RuntimeError): 

240 ipIsr.flatCorrection(self.mi, flatMi, "UNKNOWN", userScale=1.0, trimToFit=True) 

241 

242 def test_illumCorrection(self): 

243 """Expect larger median value after. 

244 Expect RuntimeError if sizes are different. 

245 """ 

246 flatExp = isrMock.FlatMock().run() 

247 flatMi = flatExp.getMaskedImage() 

248 

249 mi = self.mi.clone() 

250 ipIsr.illuminationCorrection(self.mi, flatMi, 1.0) 

251 self.assertGreater(computeImageMedianAndStd(self.mi.getImage())[0], 

252 computeImageMedianAndStd(mi.getImage())[0]) 

253 

254 flatMi = flatMi[1:-1, 1:-1, afwImage.LOCAL] 

255 with self.assertRaises(RuntimeError): 

256 ipIsr.illuminationCorrection(self.mi, flatMi, 1.0, trimToFit=False) 

257 

258 def test_overscanCorrection_isInt(self): 

259 """Expect smaller median/smaller std after. 

260 Expect exception if overscan fit type isn't known. 

261 """ 

262 inputExp = isrMock.RawMock().run() 

263 

264 amp = inputExp.getDetector()[0] 

265 ampI = inputExp.maskedImage[amp.getRawDataBBox()] 

266 overscanI = inputExp.maskedImage[amp.getRawHorizontalOverscanBBox()] 

267 

268 for fitType in ('MEAN', 'MEDIAN', 'MEDIAN_PER_ROW', 'MEANCLIP', 'POLY', 'CHEB', 

269 'NATURAL_SPLINE', 'CUBIC_SPLINE', 'UNKNOWN'): 

270 if fitType in ('NATURAL_SPLINE', 'CUBIC_SPLINE'): 

271 order = 3 

272 else: 

273 order = 1 

274 

275 if fitType == 'UNKNOWN': 

276 with self.assertRaises(pexExcept.Exception, 

277 msg=f"overscanCorrection overscanIsInt fitType: {fitType}"): 

278 ipIsr.overscanCorrection(ampI, overscanI, fitType=fitType, 

279 order=order, collapseRej=3.0, 

280 statControl=None, overscanIsInt=True) 

281 else: 

282 response = ipIsr.overscanCorrection(ampI, overscanI, fitType=fitType, 

283 order=order, collapseRej=3.0, 

284 statControl=None, overscanIsInt=True) 

285 self.assertIsInstance(response, pipeBase.Struct, 

286 msg=f"overscanCorrection overscanIsInt Bad response: {fitType}") 

287 self.assertIsNotNone(response.imageFit, 

288 msg=f"overscanCorrection overscanIsInt Bad imageFit: {fitType}") 

289 self.assertIsNotNone(response.overscanFit, 

290 msg=f"overscanCorrection overscanIsInt Bad overscanFit: {fitType}") 

291 self.assertIsInstance(response.overscanImage, afwImage.MaskedImageF, 

292 msg=f"overscanCorrection overscanIsInt Bad overscanImage: {fitType}") 

293 

294 def test_overscanCorrection_isNotInt(self): 

295 """Expect smaller median/smaller std after. 

296 Expect exception if overscan fit type isn't known. 

297 """ 

298 inputExp = isrMock.RawMock().run() 

299 

300 amp = inputExp.getDetector()[0] 

301 ampI = inputExp.maskedImage[amp.getRawDataBBox()] 

302 overscanI = inputExp.maskedImage[amp.getRawHorizontalOverscanBBox()] 

303 

304 for fitType in ('MEAN', 'MEDIAN', 'MEDIAN_PER_ROW', 'MEANCLIP', 'POLY', 'CHEB', 

305 'NATURAL_SPLINE', 'CUBIC_SPLINE', 'UNKNOWN'): 

306 if fitType in ('NATURAL_SPLINE', 'CUBIC_SPLINE'): 

307 order = 3 

308 else: 

309 order = 1 

310 

311 if fitType == 'UNKNOWN': 

312 with self.assertRaises(pexExcept.Exception, 

313 msg=f"overscanCorrection overscanIsNotInt fitType: {fitType}"): 

314 ipIsr.overscanCorrection(ampI, overscanI, fitType=fitType, 

315 order=order, collapseRej=3.0, 

316 statControl=None, overscanIsInt=False) 

317 else: 

318 response = ipIsr.overscanCorrection(ampI, overscanI, fitType=fitType, 

319 order=order, collapseRej=3.0, 

320 statControl=None, overscanIsInt=False) 

321 self.assertIsInstance(response, pipeBase.Struct, 

322 msg=f"overscanCorrection overscanIsNotInt Bad response: {fitType}") 

323 self.assertIsNotNone(response.imageFit, 

324 msg=f"overscanCorrection overscanIsNotInt Bad imageFit: {fitType}") 

325 self.assertIsNotNone(response.overscanFit, 

326 msg=f"overscanCorrection overscanIsNotInt Bad overscanFit: {fitType}") 

327 self.assertIsInstance(response.overscanImage, afwImage.MaskedImageF, 

328 msg=f"overscanCorrection overscanIsNotInt Bad overscanImage: {fitType}") 

329 

330 def test_brighterFatterCorrection(self): 

331 """Expect smoother image/smaller std before. 

332 """ 

333 bfKern = isrMock.BfKernelMock().run() 

334 

335 before = computeImageMedianAndStd(self.inputExp.getImage()) 

336 ipIsr.brighterFatterCorrection(self.inputExp, bfKern, 10, 1e-2, False) 

337 after = computeImageMedianAndStd(self.inputExp.getImage()) 

338 

339 self.assertLess(before[1], after[1]) 

340 

341 def test_gainContext(self): 

342 """Expect image to be unmodified before and after 

343 """ 

344 mi = self.inputExp.getMaskedImage().clone() 

345 with ipIsr.gainContext(self.inputExp, self.inputExp.getImage(), apply=True): 

346 pass 

347 

348 self.assertIsNotNone(mi) 

349 self.assertMaskedImagesAlmostEqual(self.inputExp.getMaskedImage(), mi) 

350 

351 def test_addDistortionModel(self): 

352 """Expect RuntimeError if no model supplied, or incomplete exposure information. 

353 """ 

354 camera = isrMock.IsrMock().getCamera() 

355 ipIsr.addDistortionModel(self.inputExp, camera) 

356 

357 with self.assertRaises(RuntimeError): 

358 ipIsr.addDistortionModel(self.inputExp, None) 

359 

360 self.inputExp.setDetector(None) 

361 ipIsr.addDistortionModel(self.inputExp, camera) 

362 

363 self.inputExp.setWcs(None) 

364 ipIsr.addDistortionModel(self.inputExp, camera) 

365 

366 def test_widenSaturationTrails(self): 

367 """Expect more mask pixels with SAT set after. 

368 """ 

369 numBitBefore = countMaskedPixels(self.mi, "SAT") 

370 

371 ipIsr.widenSaturationTrails(self.mi.getMask()) 

372 numBitAfter = countMaskedPixels(self.mi, "SAT") 

373 

374 self.assertGreaterEqual(numBitAfter, numBitBefore) 

375 

376 def test_setBadRegions(self): 

377 """Expect RuntimeError if improper statistic given. 

378 Expect a float value otherwise. 

379 """ 

380 for badStatistic in ('MEDIAN', 'MEANCLIP', 'UNKNOWN'): 

381 if badStatistic == 'UNKNOWN': 

382 with self.assertRaises(RuntimeError, 

383 msg=f"setBadRegions did not fail for stat {badStatistic}"): 

384 nBad, value = ipIsr.setBadRegions(self.inputExp, badStatistic=badStatistic) 

385 else: 

386 nBad, value = ipIsr.setBadRegions(self.inputExp, badStatistic=badStatistic) 

387 self.assertGreaterEqual(abs(value), 0.0, 

388 msg=f"setBadRegions did not find valid value for stat {badStatistic}") 

389 

390 def test_attachTransmissionCurve(self): 

391 """Expect no failure and non-None output from attachTransmissionCurve. 

392 """ 

393 curve = isrMock.TransmissionMock().run() 

394 combined = ipIsr.attachTransmissionCurve(self.inputExp, 

395 opticsTransmission=curve, 

396 filterTransmission=curve, 

397 sensorTransmission=curve, 

398 atmosphereTransmission=curve) 

399 # DM-19707: ip_isr functionality not fully tested by unit tests 

400 self.assertIsNotNone(combined) 

401 

402 def test_attachTransmissionCurve_None(self): 

403 """Expect no failure and non-None output from attachTransmissionCurve. 

404 """ 

405 curve = None 

406 combined = ipIsr.attachTransmissionCurve(self.inputExp, 

407 opticsTransmission=curve, 

408 filterTransmission=curve, 

409 sensorTransmission=curve, 

410 atmosphereTransmission=curve) 

411 # DM-19707: ip_isr functionality not fully tested by unit tests 

412 self.assertIsNotNone(combined) 

413 

414 

415class MemoryTester(lsst.utils.tests.MemoryTestCase): 

416 pass 

417 

418 

419def setup_module(module): 

420 lsst.utils.tests.init() 

421 

422 

423if __name__ == "__main__": 423 ↛ 424line 423 didn't jump to line 424, because the condition on line 423 was never true

424 lsst.utils.tests.init() 

425 unittest.main()