Coverage for tests/test_isrFunctions.py: 17%

184 statements  

« prev     ^ index     » next       coverage.py v6.4.1, created at 2022-07-03 01:44 -0700

1# 

2# LSST Data Management System 

3# Copyright 2008-2017 AURA/LSST. 

4# 

5# This product includes software developed by the 

6# LSST Project (http://www.lsst.org/). 

7# 

8# This program is free software: you can redistribute it and/or modify 

9# it under the terms of the GNU General Public License as published by 

10# the Free Software Foundation, either version 3 of the License, or 

11# (at your option) any later version. 

12# 

13# This program is distributed in the hope that it will be useful, 

14# but WITHOUT ANY WARRANTY; without even the implied warranty of 

15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

16# GNU General Public License for more details. 

17# 

18# You should have received a copy of the LSST License Statement and 

19# the GNU General Public License along with this program. If not, 

20# see <https://www.lsstcorp.org/LegalNotices/>. 

21# 

22 

23import unittest 

24import numpy as np 

25 

26import lsst.afw.image as afwImage 

27import lsst.utils.tests 

28import lsst.ip.isr as ipIsr 

29import lsst.ip.isr.isrMock as isrMock 

30import lsst.pipe.base as pipeBase 

31 

32 

33def countMaskedPixels(maskedImage, maskPlane): 

34 """Function to count the number of masked pixels of a given type. 

35 

36 Parameters 

37 ---------- 

38 maskedImage : `lsst.afw.image.MaskedImage` 

39 Image to measure the mask on. 

40 maskPlane : `str` 

41 Name of the mask plane to count 

42 

43 Returns 

44 ------- 

45 nMask : `int` 

46 Number of masked pixels. 

47 """ 

48 bitMask = maskedImage.getMask().getPlaneBitMask(maskPlane) 

49 isBit = maskedImage.getMask().getArray() & bitMask > 0 

50 numBit = np.sum(isBit) 

51 

52 return numBit 

53 

54 

55def computeImageMedianAndStd(image): 

56 """Function to calculate median and std of image data. 

57 

58 Parameters 

59 ---------- 

60 image : `lsst.afw.image.Image` 

61 Image to measure statistics on. 

62 

63 Returns 

64 ------- 

65 median : `float` 

66 Image median. 

67 std : `float` 

68 Image stddev. 

69 """ 

70 median = np.nanmedian(image.getArray()) 

71 std = np.nanstd(image.getArray()) 

72 

73 return (median, std) 

74 

75 

76class IsrFunctionsCases(lsst.utils.tests.TestCase): 

77 """Test that functions for ISR produce expected outputs. 

78 """ 

79 def setUp(self): 

80 self.inputExp = isrMock.TrimmedRawMock().run() 

81 self.mi = self.inputExp.getMaskedImage() 

82 

83 def test_transposeMaskedImage(self): 

84 """Expect height and width to be exchanged. 

85 """ 

86 transposed = ipIsr.transposeMaskedImage(self.mi) 

87 self.assertEqual(transposed.getImage().getBBox().getHeight(), 

88 self.mi.getImage().getBBox().getWidth()) 

89 self.assertEqual(transposed.getImage().getBBox().getWidth(), 

90 self.mi.getImage().getBBox().getHeight()) 

91 

92 def test_interpolateDefectList(self): 

93 """Expect number of interpolated pixels to be non-zero. 

94 """ 

95 defectList = isrMock.DefectMock().run() 

96 self.assertEqual(len(defectList), 1) 

97 

98 for fallbackValue in (None, -999.0): 

99 for haveMask in (True, False): 

100 with self.subTest(fallbackValue=fallbackValue, haveMask=haveMask): 

101 if haveMask is False: 

102 if 'INTRP' in self.mi.getMask().getMaskPlaneDict(): 

103 self.mi.getMask().removeAndClearMaskPlane('INTRP') 

104 else: 

105 if 'INTRP' not in self.mi.getMask().getMaskPlaneDict(): 

106 self.mi.getMask().addMaskPlane('INTRP') 

107 numBit = countMaskedPixels(self.mi, "INTRP") 

108 self.assertEqual(numBit, 0) 

109 

110 def test_transposeDefectList(self): 

111 """Expect bbox dimension values to flip. 

112 """ 

113 defectList = isrMock.DefectMock().run() 

114 transposed = defectList.transpose() 

115 

116 for d, t in zip(defectList, transposed): 

117 self.assertEqual(d.getBBox().getDimensions().getX(), t.getBBox().getDimensions().getY()) 

118 self.assertEqual(d.getBBox().getDimensions().getY(), t.getBBox().getDimensions().getX()) 

119 

120 def test_makeThresholdMask(self): 

121 """Expect list of defects to have elements. 

122 """ 

123 defectList = ipIsr.makeThresholdMask(self.mi, 200, 

124 growFootprints=2, 

125 maskName='SAT') 

126 

127 self.assertEqual(len(defectList), 1) 

128 

129 def test_interpolateFromMask(self): 

130 """Expect number of interpolated pixels to be non-zero. 

131 """ 

132 ipIsr.makeThresholdMask(self.mi, 200, growFootprints=2, 

133 maskName='SAT') 

134 for growFootprints in range(0, 3): 

135 interpMaskedImage = ipIsr.interpolateFromMask(self.mi, 2.0, 

136 growSaturatedFootprints=growFootprints, 

137 maskNameList=['SAT']) 

138 numBit = countMaskedPixels(interpMaskedImage, "INTRP") 

139 self.assertEqual(numBit, 40800, msg=f"interpolateFromMask with growFootprints={growFootprints}") 

140 

141 def test_saturationCorrectionInterpolate(self): 

142 """Expect number of mask pixels with SAT marked to be non-zero. 

143 """ 

144 corrMaskedImage = ipIsr.saturationCorrection(self.mi, 200, 2.0, 

145 growFootprints=2, interpolate=True, 

146 maskName='SAT') 

147 numBit = countMaskedPixels(corrMaskedImage, "SAT") 

148 self.assertEqual(numBit, 40800) 

149 

150 def test_saturationCorrectionNoInterpolate(self): 

151 """Expect number of mask pixels with SAT marked to be non-zero. 

152 """ 

153 corrMaskedImage = ipIsr.saturationCorrection(self.mi, 200, 2.0, 

154 growFootprints=2, interpolate=False, 

155 maskName='SAT') 

156 numBit = countMaskedPixels(corrMaskedImage, "SAT") 

157 self.assertEqual(numBit, 40800) 

158 

159 def test_trimToMatchCalibBBox(self): 

160 """Expect bounding boxes to match. 

161 """ 

162 darkExp = isrMock.DarkMock().run() 

163 darkMi = darkExp.getMaskedImage() 

164 

165 nEdge = 2 

166 darkMi = darkMi[nEdge:-nEdge, nEdge:-nEdge, afwImage.LOCAL] 

167 newInput = ipIsr.trimToMatchCalibBBox(self.mi, darkMi) 

168 

169 self.assertEqual(newInput.getImage().getBBox(), darkMi.getImage().getBBox()) 

170 

171 def test_darkCorrection(self): 

172 """Expect round-trip application to be equal. 

173 Expect RuntimeError if sizes are different. 

174 """ 

175 darkExp = isrMock.DarkMock().run() 

176 darkMi = darkExp.getMaskedImage() 

177 

178 mi = self.mi.clone() 

179 

180 # The `invert` parameter controls the direction of the 

181 # application. This will apply, and un-apply the dark. 

182 ipIsr.darkCorrection(self.mi, darkMi, 1.0, 1.0, trimToFit=True) 

183 ipIsr.darkCorrection(self.mi, darkMi, 1.0, 1.0, trimToFit=True, invert=True) 

184 

185 self.assertMaskedImagesAlmostEqual(self.mi, mi, atol=1e-3) 

186 

187 darkMi = darkMi[1:-1, 1:-1, afwImage.LOCAL] 

188 with self.assertRaises(RuntimeError): 

189 ipIsr.darkCorrection(self.mi, darkMi, 1.0, 1.0, trimToFit=False) 

190 

191 def test_biasCorrection(self): 

192 """Expect smaller median image value after. 

193 Expect RuntimeError if sizes are different. 

194 """ 

195 biasExp = isrMock.BiasMock().run() 

196 biasMi = biasExp.getMaskedImage() 

197 

198 mi = self.mi.clone() 

199 ipIsr.biasCorrection(self.mi, biasMi, trimToFit=True) 

200 self.assertLess(computeImageMedianAndStd(self.mi.getImage())[0], 

201 computeImageMedianAndStd(mi.getImage())[0]) 

202 

203 biasMi = biasMi[1:-1, 1:-1, afwImage.LOCAL] 

204 with self.assertRaises(RuntimeError): 

205 ipIsr.biasCorrection(self.mi, biasMi, trimToFit=False) 

206 

207 def test_flatCorrection(self): 

208 """Expect round-trip application to be equal. 

209 Expect RuntimeError if sizes are different. 

210 """ 

211 flatExp = isrMock.FlatMock().run() 

212 flatMi = flatExp.getMaskedImage() 

213 

214 mi = self.mi.clone() 

215 for scaling in ('USER', 'MEAN', 'MEDIAN'): 

216 # The `invert` parameter controls the direction of the 

217 # application. This will apply, and un-apply the flat. 

218 ipIsr.flatCorrection(self.mi, flatMi, scaling, userScale=1.0, trimToFit=True) 

219 ipIsr.flatCorrection(self.mi, flatMi, scaling, userScale=1.0, 

220 trimToFit=True, invert=True) 

221 

222 self.assertMaskedImagesAlmostEqual(self.mi, mi, atol=1e-3, 

223 msg=f"flatCorrection with scaling {scaling}") 

224 

225 flatMi = flatMi[1:-1, 1:-1, afwImage.LOCAL] 

226 with self.assertRaises(RuntimeError): 

227 ipIsr.flatCorrection(self.mi, flatMi, 'USER', userScale=1.0, trimToFit=False) 

228 

229 def test_flatCorrectionUnknown(self): 

230 """Raise if an unknown scaling is used. 

231 

232 The `scaling` parameter must be a known type. If not, the 

233 flat correction will raise a RuntimeError. 

234 """ 

235 flatExp = isrMock.FlatMock().run() 

236 flatMi = flatExp.getMaskedImage() 

237 

238 with self.assertRaises(RuntimeError): 

239 ipIsr.flatCorrection(self.mi, flatMi, "UNKNOWN", userScale=1.0, trimToFit=True) 

240 

241 def test_illumCorrection(self): 

242 """Expect larger median value after. 

243 Expect RuntimeError if sizes are different. 

244 """ 

245 flatExp = isrMock.FlatMock().run() 

246 flatMi = flatExp.getMaskedImage() 

247 

248 mi = self.mi.clone() 

249 ipIsr.illuminationCorrection(self.mi, flatMi, 1.0) 

250 self.assertGreater(computeImageMedianAndStd(self.mi.getImage())[0], 

251 computeImageMedianAndStd(mi.getImage())[0]) 

252 

253 flatMi = flatMi[1:-1, 1:-1, afwImage.LOCAL] 

254 with self.assertRaises(RuntimeError): 

255 ipIsr.illuminationCorrection(self.mi, flatMi, 1.0, trimToFit=False) 

256 

257 def test_overscanCorrection_isInt(self): 

258 """Expect smaller median/smaller std after. 

259 Expect exception if overscan fit type isn't known. 

260 """ 

261 inputExp = isrMock.RawMock().run() 

262 

263 amp = inputExp.getDetector()[0] 

264 ampI = inputExp.maskedImage[amp.getRawDataBBox()] 

265 overscanI = inputExp.maskedImage[amp.getRawHorizontalOverscanBBox()] 

266 

267 for fitType in ('MEAN', 'MEDIAN', 'MEDIAN_PER_ROW', 'MEANCLIP', 'POLY', 'CHEB', 

268 'NATURAL_SPLINE', 'CUBIC_SPLINE', 'UNKNOWN'): 

269 if fitType in ('NATURAL_SPLINE', 'CUBIC_SPLINE'): 

270 order = 3 

271 else: 

272 order = 1 

273 

274 if fitType == 'UNKNOWN': 

275 with self.assertRaises(ValueError, 

276 msg=f"overscanCorrection overscanIsInt fitType: {fitType}"): 

277 ipIsr.overscanCorrection(ampI, overscanI, fitType=fitType, 

278 order=order, collapseRej=3.0, 

279 statControl=None, overscanIsInt=True) 

280 else: 

281 response = ipIsr.overscanCorrection(ampI, overscanI, fitType=fitType, 

282 order=order, collapseRej=3.0, 

283 statControl=None, overscanIsInt=True) 

284 self.assertIsInstance(response, pipeBase.Struct, 

285 msg=f"overscanCorrection overscanIsInt Bad response: {fitType}") 

286 self.assertIsNotNone(response.imageFit, 

287 msg=f"overscanCorrection overscanIsInt Bad imageFit: {fitType}") 

288 self.assertIsNotNone(response.overscanFit, 

289 msg=f"overscanCorrection overscanIsInt Bad overscanFit: {fitType}") 

290 self.assertIsInstance(response.overscanImage, afwImage.MaskedImageF, 

291 msg=f"overscanCorrection overscanIsInt Bad overscanImage: {fitType}") 

292 

293 def test_overscanCorrection_isNotInt(self): 

294 """Expect smaller median/smaller std after. 

295 Expect exception if overscan fit type isn't known. 

296 """ 

297 inputExp = isrMock.RawMock().run() 

298 

299 amp = inputExp.getDetector()[0] 

300 ampI = inputExp.maskedImage[amp.getRawDataBBox()] 

301 overscanI = inputExp.maskedImage[amp.getRawHorizontalOverscanBBox()] 

302 

303 for fitType in ('MEAN', 'MEDIAN', 'MEDIAN_PER_ROW', 'MEANCLIP', 'POLY', 'CHEB', 

304 'NATURAL_SPLINE', 'CUBIC_SPLINE', 'UNKNOWN'): 

305 if fitType in ('NATURAL_SPLINE', 'CUBIC_SPLINE'): 

306 order = 3 

307 else: 

308 order = 1 

309 

310 if fitType == 'UNKNOWN': 

311 with self.assertRaises(ValueError, 

312 msg=f"overscanCorrection overscanIsNotInt fitType: {fitType}"): 

313 ipIsr.overscanCorrection(ampI, overscanI, fitType=fitType, 

314 order=order, collapseRej=3.0, 

315 statControl=None, overscanIsInt=False) 

316 else: 

317 response = ipIsr.overscanCorrection(ampI, overscanI, fitType=fitType, 

318 order=order, collapseRej=3.0, 

319 statControl=None, overscanIsInt=False) 

320 self.assertIsInstance(response, pipeBase.Struct, 

321 msg=f"overscanCorrection overscanIsNotInt Bad response: {fitType}") 

322 self.assertIsNotNone(response.imageFit, 

323 msg=f"overscanCorrection overscanIsNotInt Bad imageFit: {fitType}") 

324 self.assertIsNotNone(response.overscanFit, 

325 msg=f"overscanCorrection overscanIsNotInt Bad overscanFit: {fitType}") 

326 self.assertIsInstance(response.overscanImage, afwImage.MaskedImageF, 

327 msg=f"overscanCorrection overscanIsNotInt Bad overscanImage: {fitType}") 

328 

329 def test_brighterFatterCorrection(self): 

330 """Expect smoother image/smaller std before. 

331 """ 

332 bfKern = isrMock.BfKernelMock().run() 

333 

334 before = computeImageMedianAndStd(self.inputExp.getImage()) 

335 ipIsr.brighterFatterCorrection(self.inputExp, bfKern, 10, 1e-2, False) 

336 after = computeImageMedianAndStd(self.inputExp.getImage()) 

337 

338 self.assertLess(before[1], after[1]) 

339 

340 def test_gainContext(self): 

341 """Expect image to be unmodified before and after 

342 """ 

343 mi = self.inputExp.getMaskedImage().clone() 

344 with ipIsr.gainContext(self.inputExp, self.inputExp.getImage(), apply=True): 

345 pass 

346 

347 self.assertIsNotNone(mi) 

348 self.assertMaskedImagesAlmostEqual(self.inputExp.getMaskedImage(), mi) 

349 

350 def test_widenSaturationTrails(self): 

351 """Expect more mask pixels with SAT set after. 

352 """ 

353 numBitBefore = countMaskedPixels(self.mi, "SAT") 

354 

355 ipIsr.widenSaturationTrails(self.mi.getMask()) 

356 numBitAfter = countMaskedPixels(self.mi, "SAT") 

357 

358 self.assertGreaterEqual(numBitAfter, numBitBefore) 

359 

360 def test_setBadRegions(self): 

361 """Expect RuntimeError if improper statistic given. 

362 Expect a float value otherwise. 

363 """ 

364 for badStatistic in ('MEDIAN', 'MEANCLIP', 'UNKNOWN'): 

365 if badStatistic == 'UNKNOWN': 

366 with self.assertRaises(RuntimeError, 

367 msg=f"setBadRegions did not fail for stat {badStatistic}"): 

368 nBad, value = ipIsr.setBadRegions(self.inputExp, badStatistic=badStatistic) 

369 else: 

370 nBad, value = ipIsr.setBadRegions(self.inputExp, badStatistic=badStatistic) 

371 self.assertGreaterEqual(abs(value), 0.0, 

372 msg=f"setBadRegions did not find valid value for stat {badStatistic}") 

373 

374 def test_attachTransmissionCurve(self): 

375 """Expect no failure and non-None output from attachTransmissionCurve. 

376 """ 

377 curve = isrMock.TransmissionMock().run() 

378 combined = ipIsr.attachTransmissionCurve(self.inputExp, 

379 opticsTransmission=curve, 

380 filterTransmission=curve, 

381 sensorTransmission=curve, 

382 atmosphereTransmission=curve) 

383 # DM-19707: ip_isr functionality not fully tested by unit tests 

384 self.assertIsNotNone(combined) 

385 

386 def test_attachTransmissionCurve_None(self): 

387 """Expect no failure and non-None output from attachTransmissionCurve. 

388 """ 

389 curve = None 

390 combined = ipIsr.attachTransmissionCurve(self.inputExp, 

391 opticsTransmission=curve, 

392 filterTransmission=curve, 

393 sensorTransmission=curve, 

394 atmosphereTransmission=curve) 

395 # DM-19707: ip_isr functionality not fully tested by unit tests 

396 self.assertIsNotNone(combined) 

397 

398 

399class MemoryTester(lsst.utils.tests.MemoryTestCase): 

400 pass 

401 

402 

403def setup_module(module): 

404 lsst.utils.tests.init() 

405 

406 

407if __name__ == "__main__": 407 ↛ 408line 407 didn't jump to line 408, because the condition on line 407 was never true

408 lsst.utils.tests.init() 

409 unittest.main()