Coverage for tests / test_crosstalk.py: 14%

205 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-06 08:42 +0000

1# This file is part of ip_isr. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import unittest 

23import itertools 

24import tempfile 

25 

26import numpy as np 

27 

28import lsst.geom 

29import lsst.utils.tests 

30import lsst.afw.image 

31import lsst.afw.table 

32import lsst.afw.cameraGeom as cameraGeom 

33 

34from lsst.ip.isr import IsrTask, CrosstalkCalib, NullCrosstalkTask, IsrTaskLSST 

35 

36try: 

37 display 

38except NameError: 

39 display = False 

40else: 

41 import lsst.afw.display as afwDisplay 

42 afwDisplay.setDefaultMaskTransparency(75) 

43 

44 

45outputName = None # specify a name (as a string) to save the output crosstalk coeffs. 

46 

47 

48class CrosstalkTestCase(lsst.utils.tests.TestCase): 

49 # Define a new set up function to be able to pass 

50 # NL-crosstalk correction boolean. 

51 def setUp_general(self, doSqrCrosstalk=False): 

52 width, height = 250, 500 

53 self.numAmps = 4 

54 numPixelsPerAmp = 1000 

55 # crosstalk[i][j] is the fraction of the j-th amp present on the i-th 

56 # amp. 

57 self.crosstalk = [[0.0, 1e-4, 2e-4, 3e-4], 

58 [3e-4, 0.0, 2e-4, 1e-4], 

59 [4e-4, 5e-4, 0.0, 6e-4], 

60 [7e-4, 8e-4, 9e-4, 0.0]] 

61 if doSqrCrosstalk: 

62 # Measured quadratic crosstalk from spots is O[-10], O[-11] 

63 self.crosstalk_sqr = [[0.0, 1e-10, 2e-10, 3e-10], 

64 [3e-10, 0.0, 2e-10, 1e-10], 

65 [4e-10, 5e-10, 0.0, 6e-10], 

66 [7e-10, 8e-10, 9e-10, 0.0]] 

67 else: 

68 self.crosstalk_sqr = np.zeros((self.numAmps, self.numAmps)) 

69 self.value = 12345 

70 self.crosstalkStr = "XTLK" 

71 

72 # A bit of noise is important, because otherwise the pixel 

73 # distributions are razor-thin and then rejection doesn't work. 

74 rng = np.random.RandomState(12345) 

75 self.noise = rng.normal(0.0, 0.1, (2*height, 2*width)) 

76 

77 # Create amp images 

78 withoutCrosstalk = [lsst.afw.image.ImageF(width, height) for _ in range(self.numAmps)] 

79 for image in withoutCrosstalk: 

80 image.set(0) 

81 xx = rng.randint(0, width, numPixelsPerAmp) 

82 yy = rng.randint(0, height, numPixelsPerAmp) 

83 image.getArray()[yy, xx] = self.value 

84 

85 # Add in crosstalk 

86 withCrosstalk = [image.Factory(image, True) for image in withoutCrosstalk] 

87 for ii, iImage in enumerate(withCrosstalk): 

88 for jj, jImage in enumerate(withoutCrosstalk): 

89 value = self.crosstalk[ii][jj] 

90 iImage.scaledPlus(value, jImage) 

91 # NL crosstalk will be added if boolean argument is True 

92 jImageSqr = jImage.clone() 

93 jImageSqr.scaledMultiplies(1.0, jImage) 

94 valueSqr = self.crosstalk_sqr[ii][jj] 

95 iImage.scaledPlus(valueSqr, jImageSqr) 

96 

97 # Put amp images together 

98 def construct(imageList): 

99 image = lsst.afw.image.ImageF(2*width, 2*height) 

100 image.getArray()[:height, :width] = imageList[0].getArray() 

101 image.getArray()[:height, width:] = imageList[1].getArray()[:, ::-1] # flip in x 

102 image.getArray()[height:, :width] = imageList[2].getArray()[::-1, :] # flip in y 

103 image.getArray()[height:, width:] = imageList[3].getArray()[::-1, ::-1] # flip in x and y 

104 image.getArray()[:] += self.noise 

105 return image 

106 

107 # Construct detector 

108 detName = 'detector 1' 

109 detId = 1 

110 detSerial = 'serial 1' 

111 orientation = cameraGeom.Orientation() 

112 pixelSize = lsst.geom.Extent2D(1, 1) 

113 bbox = lsst.geom.Box2I(lsst.geom.Point2I(0, 0), 

114 lsst.geom.Extent2I(2*width, 2*height)) 

115 crosstalk = np.array(self.crosstalk, dtype=np.float32) 

116 

117 camBuilder = cameraGeom.Camera.Builder("fakeCam") 

118 detBuilder = camBuilder.add(detName, detId) 

119 detBuilder.setSerial(detSerial) 

120 detBuilder.setBBox(bbox) 

121 detBuilder.setOrientation(orientation) 

122 detBuilder.setPixelSize(pixelSize) 

123 detBuilder.setCrosstalk(crosstalk) 

124 

125 # Construct second detector in this fake camera 

126 detName = 'detector 2' 

127 detId = 2 

128 detSerial = 'serial 2' 

129 orientation = cameraGeom.Orientation() 

130 pixelSize = lsst.geom.Extent2D(1, 1) 

131 bbox = lsst.geom.Box2I(lsst.geom.Point2I(0, 0), 

132 lsst.geom.Extent2I(2*width, 2*height)) 

133 crosstalk = np.array(self.crosstalk, dtype=np.float32) 

134 

135 detBuilder2 = camBuilder.add(detName, detId) 

136 detBuilder2.setSerial(detSerial) 

137 detBuilder2.setBBox(bbox) 

138 detBuilder2.setOrientation(orientation) 

139 detBuilder2.setPixelSize(pixelSize) 

140 detBuilder2.setCrosstalk(crosstalk) 

141 

142 # Create amp info 

143 for ii, (xx, yy, corner) in enumerate([(0, 0, lsst.afw.cameraGeom.ReadoutCorner.LL), 

144 (width, 0, lsst.afw.cameraGeom.ReadoutCorner.LR), 

145 (0, height, lsst.afw.cameraGeom.ReadoutCorner.UL), 

146 (width, height, lsst.afw.cameraGeom.ReadoutCorner.UR)]): 

147 

148 amp = cameraGeom.Amplifier.Builder() 

149 amp.setName("amp %d" % ii) 

150 amp.setBBox(lsst.geom.Box2I(lsst.geom.Point2I(xx, yy), 

151 lsst.geom.Extent2I(width, height))) 

152 amp.setRawDataBBox(lsst.geom.Box2I(lsst.geom.Point2I(xx, yy), 

153 lsst.geom.Extent2I(width, height))) 

154 amp.setReadoutCorner(corner) 

155 detBuilder.append(amp) 

156 detBuilder2.append(amp) 

157 

158 cam = camBuilder.finish() 

159 ccd1 = cam.get('detector 1') 

160 ccd2 = cam.get('detector 2') 

161 

162 self.exposure = lsst.afw.image.makeExposure(lsst.afw.image.makeMaskedImage(construct(withCrosstalk))) 

163 self.exposure.setDetector(ccd1) 

164 

165 # Add a saturated region so we can confirm it doesn't get duplicated. 

166 saturatedBit = self.exposure.mask.getPlaneBitMask("SAT") 

167 self.exposure.mask.array[0:5, 0:5] |= saturatedBit 

168 

169 # Create a single ctSource that will be used for interChip CT 

170 # correction. 

171 self.ctSource = lsst.afw.image.makeExposure(lsst.afw.image.makeMaskedImage(construct(withCrosstalk))) 

172 self.ctSource.setDetector(ccd2) 

173 

174 self.corrected = construct(withoutCrosstalk) 

175 

176 if display: 

177 disp = lsst.afw.display.Display(frame=1) 

178 disp.mtv(self.exposure, title="exposure") 

179 disp = lsst.afw.display.Display(frame=0) 

180 disp.mtv(self.corrected, title="corrected exposure") 

181 

182 def tearDown(self): 

183 del self.exposure 

184 del self.corrected 

185 

186 def checkCoefficients(self, coeff, coeffErr, coeffNum): 

187 """Check that coefficients are as expected 

188 

189 Parameters 

190 ---------- 

191 coeff : `numpy.ndarray` 

192 Crosstalk coefficients. 

193 coeffErr : `numpy.ndarray` 

194 Crosstalk coefficient errors. 

195 coeffNum : `numpy.ndarray` 

196 Number of pixels to produce each coefficient. 

197 """ 

198 for matrix in (coeff, coeffErr, coeffNum): 

199 self.assertEqual(matrix.shape, (self.numAmps, self.numAmps)) 

200 self.assertFloatsAlmostEqual(coeff, np.array(self.crosstalk), atol=1.0e-6) 

201 

202 for ii in range(self.numAmps): 

203 self.assertEqual(coeff[ii, ii], 0.0) 

204 self.assertTrue(np.isnan(coeffErr[ii, ii])) 

205 self.assertEqual(coeffNum[ii, ii], 1) 

206 

207 self.assertTrue(np.all(coeffErr[ii, jj] > 0 for ii, jj in 

208 itertools.product(range(self.numAmps), range(self.numAmps)) if ii != jj)) 

209 self.assertTrue(np.all(coeffNum[ii, jj] > 0 for ii, jj in 

210 itertools.product(range(self.numAmps), range(self.numAmps)) if ii != jj)) 

211 

212 def checkSubtracted(self, exposure): 

213 """Check that the subtracted image is as expected 

214 

215 Parameters 

216 ---------- 

217 exposure : `lsst.afw.image.Exposure` 

218 Crosstalk-subtracted exposure. 

219 """ 

220 image = exposure.getMaskedImage().getImage() 

221 mask = exposure.getMaskedImage().getMask() 

222 self.assertFloatsAlmostEqual(image.getArray(), self.corrected.getArray(), atol=2.0e-2) 

223 self.assertIn(self.crosstalkStr, mask.getMaskPlaneDict()) 

224 self.assertGreater((mask.getArray() & mask.getPlaneBitMask(self.crosstalkStr) > 0).sum(), 0) 

225 self.assertEqual((mask.getArray() & mask.getPlaneBitMask("SAT") > 0).sum(), 25) 

226 

227 def checkTaskAPI_NL(self, this_isr_task, doSubtrahendMasking=False, ignoreVariance=False): 

228 """Check the the crosstalk task under different ISR tasks. 

229 (e.g., IsrTask and IsrTaskLSST) 

230 

231 Parameters 

232 ---------- 

233 this_isr_task : `lsst.pipe.base.PipelineTask` 

234 The ISR Task instance to use. 

235 doSubtrahendMasking : `bool`, optional 

236 Enable subtrahend masking code. 

237 ignoreVariance : `bool`, optional 

238 Use optimized code to ignore variance. 

239 """ 

240 self.setUp_general(doSqrCrosstalk=True) 

241 coeff = np.array(self.crosstalk).transpose() 

242 coeffSqr = np.array(self.crosstalk_sqr).transpose() 

243 config = this_isr_task.ConfigClass() 

244 config.crosstalk.minPixelToMask = self.value - 1 

245 config.crosstalk.crosstalkMaskPlane = self.crosstalkStr 

246 if doSubtrahendMasking: 

247 config.crosstalk.doSubtrahendMasking = True 

248 config.crosstalk.minPixelToMask = 1.0 

249 # Turn on the NL correction 

250 config.crosstalk.doQuadraticCrosstalkCorrection = True 

251 isr = this_isr_task(config=config) 

252 calib = CrosstalkCalib().fromDetector(self.exposure.getDetector(), 

253 coeffVector=coeff, 

254 coeffSqrVector=coeffSqr) 

255 isr.crosstalk.run(self.exposure, crosstalk=calib, ignoreVariance=ignoreVariance) 

256 self.checkSubtracted(self.exposure) 

257 

258 def testDirectAPI(self): 

259 """Test that individual function calls work""" 

260 self.setUp_general() 

261 calib = CrosstalkCalib() 

262 calib.coeffs = np.array(self.crosstalk).transpose() 

263 calib.subtractCrosstalk(self.exposure, crosstalkCoeffs=calib.coeffs, 

264 minPixelToMask=self.value - 1, 

265 crosstalkStr=self.crosstalkStr) 

266 self.checkSubtracted(self.exposure) 

267 

268 outPath = tempfile.mktemp() if outputName is None else "{}-isrCrosstalk".format(outputName) 

269 outPath += '.yaml' 

270 calib.writeText(outPath) 

271 

272 def testTaskAPI(self): 

273 """Test that the Tasks work 

274 

275 Checks both MeasureCrosstalkTask and the CrosstalkTask. 

276 """ 

277 self.setUp_general() 

278 coeff = np.array(self.crosstalk).transpose() 

279 coeffSqr = np.array(self.crosstalk_sqr).transpose() 

280 config = IsrTask.ConfigClass() 

281 config.crosstalk.minPixelToMask = self.value - 1 

282 config.crosstalk.crosstalkMaskPlane = self.crosstalkStr 

283 isr = IsrTask(config=config) 

284 calib = CrosstalkCalib().fromDetector(self.exposure.getDetector(), 

285 coeffVector=coeff, 

286 coeffSqrVector=coeffSqr) 

287 isr.crosstalk.run(self.exposure, crosstalk=calib) 

288 self.checkSubtracted(self.exposure) 

289 

290 def testTaskAPI_NL(self): 

291 """Test that the Tasks work 

292 

293 Checks both MeasureCrosstalkTask and the CrosstalkTask. 

294 This test is for the quadratic (non-linear) corsstalk 

295 correction. 

296 """ 

297 for this_isr_task in [IsrTask, IsrTaskLSST]: 

298 for subtrahendMasking in [False, True]: 

299 for ignoreVariance in [False, True]: 

300 self.checkTaskAPI_NL(this_isr_task, subtrahendMasking, ignoreVariance) 

301 

302 def test_nullCrosstalkTask(self): 

303 """Test that the null crosstalk task does not create an error. 

304 """ 

305 self.setUp_general() 

306 exposure = self.exposure 

307 task = NullCrosstalkTask() 

308 result = task.run(exposure, crosstalkSources=None) 

309 self.assertIsNone(result) 

310 

311 def test_interChip(self): 

312 """Test that passing an external exposure as the crosstalk source 

313 works. 

314 """ 

315 self.setUp_general() 

316 exposure = self.exposure 

317 ctSources = [self.ctSource] 

318 

319 coeff = np.array(self.crosstalk).transpose() 

320 calib = CrosstalkCalib().fromDetector(exposure.getDetector(), coeffVector=coeff) 

321 # Now convert this into zero intra-chip, full inter-chip: 

322 calib.interChip['detector 2'] = coeff 

323 calib.coeffs = np.zeros_like(coeff) 

324 

325 # Process and check as above 

326 config = IsrTask.ConfigClass() 

327 config.crosstalk.minPixelToMask = self.value - 1 

328 config.crosstalk.crosstalkMaskPlane = self.crosstalkStr 

329 isr = IsrTask(config=config) 

330 isr.crosstalk.run(exposure, crosstalk=calib, crosstalkSources=ctSources) 

331 self.checkSubtracted(exposure) 

332 

333 def test_crosstalkIO(self): 

334 """Test that crosstalk doesn't change on being converted to persistable 

335 formats. 

336 """ 

337 self.setUp_general() 

338 # Add the interchip crosstalk as in the previous test. 

339 exposure = self.exposure 

340 

341 coeff = np.array(self.crosstalk).transpose() 

342 calib = CrosstalkCalib().fromDetector(exposure.getDetector(), coeffVector=coeff) 

343 # Now convert this into zero intra-chip, full inter-chip: 

344 calib.interChip['detector 2'] = coeff 

345 

346 outPath = tempfile.mktemp() + '.yaml' 

347 calib.writeText(outPath) 

348 newCrosstalk = CrosstalkCalib().readText(outPath) 

349 self.assertEqual(calib, newCrosstalk) 

350 

351 outPath = tempfile.mktemp() + '.fits' 

352 calib.writeFits(outPath) 

353 newCrosstalk = CrosstalkCalib().readFits(outPath) 

354 self.assertEqual(calib, newCrosstalk) 

355 

356 

357class MemoryTester(lsst.utils.tests.MemoryTestCase): 

358 pass 

359 

360 

361def setup_module(module): 

362 lsst.utils.tests.init() 

363 

364 

365if __name__ == "__main__": 365 ↛ 366line 365 didn't jump to line 366 because the condition on line 365 was never true

366 import sys 

367 setup_module(sys.modules[__name__]) 

368 unittest.main()