Coverage for tests/test_register.py: 18%

130 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-21 11:31 +0000

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import unittest 

23 

24import numpy as np 

25 

26import lsst.utils.tests 

27import lsst.geom as geom 

28import lsst.afw.image as afwImage 

29import lsst.afw.table as afwTable 

30import lsst.afw.geom as afwGeom 

31from lsst.pipe.base import Struct 

32from lsst.pipe.tasks.registerImage import RegisterConfig, RegisterTask 

33 

34try: 

35 display 

36except NameError: 

37 display = False 

38else: 

39 import lsst.afw.display as afwDisplay 

40 afwDisplay.setDefaultMaskTransparency(75) 

41 

42 

43class RegisterTestCase(unittest.TestCase): 

44 

45 """A test case for RegisterTask.""" 

46 

47 def setUp(self): 

48 self.dx = -5 

49 self.dy = +3 

50 self.numSources = 123 

51 self.border = 10 # Must be larger than dx,dy 

52 self.width = 1000 

53 self.height = 1000 

54 self.pixelScale = 0.1 * geom.arcseconds # So dx,dy is not larger than RegisterConfig.matchRadius 

55 

56 def tearDown(self): 

57 del self.pixelScale 

58 

59 def create(self): 

60 """Create test images and sources 

61 

62 We will create two fake images with some 'sources', which are just single bright pixels. 

63 The images will have the same sources with a constant offset between them. The WCSes 

64 of the two images are identical, despite the offset; this simulates a small e.g., pointing 

65 error, or misalignment that the RegisterTask should rectify. 

66 """ 

67 np.random.seed(0) 

68 templateImage = afwImage.MaskedImageF(self.width, self.height) 

69 templateImage.set(0) 

70 inputImage = afwImage.MaskedImageF(self.width, self.height) 

71 inputImage.set(0) 

72 

73 templateArray = templateImage.getImage().getArray() 

74 inputArray = inputImage.getImage().getArray() 

75 

76 # Sources are at integer positions to ensure warped pixels have value of unity 

77 xTemplate = np.random.randint(self.border, self.width - self.border, self.numSources) 

78 yTemplate = np.random.randint(self.border, self.width - self.border, self.numSources) 

79 xInput = xTemplate + self.dx 

80 yInput = yTemplate + self.dy 

81 

82 # Note: numpy indices are backwards: [y,x] 

83 templateArray[(yTemplate).astype(int), (xTemplate).astype(int)] = 1 

84 inputArray[(yInput).astype(int), (xInput).astype(int)] = 1 

85 

86 # Create WCSes 

87 centerCoord = geom.SpherePoint(0, 0, geom.degrees) 

88 centerPixel = geom.Point2D(self.width/2, self.height/2) 

89 cdMatrix = afwGeom.makeCdMatrix(scale=self.pixelScale) 

90 wcs = afwGeom.makeSkyWcs(crpix=centerPixel, crval=centerCoord, cdMatrix=cdMatrix) 

91 

92 # Note that one of the WCSes must be "wrong", since they are the same, but the sources are offset. 

93 # It is the job of the RegisterTask to align the images, despite the "wrong" WCS. 

94 templateExp = afwImage.makeExposure(templateImage, wcs) 

95 inputExp = afwImage.makeExposure(inputImage, wcs) 

96 

97 # Generate catalogues 

98 schema = afwTable.SourceTable.makeMinimalSchema() 

99 centroidKey = afwTable.Point2DKey.addFields(schema, "center", "center", "pixel") 

100 

101 def newCatalog(): 

102 catalog = afwTable.SourceCatalog(schema) 

103 catalog.getTable().defineCentroid("center") 

104 return catalog 

105 

106 templateSources = newCatalog() 

107 inputSources = newCatalog() 

108 

109 coordKey = templateSources.getCoordKey() 

110 for xt, yt, xi, yi in zip(xTemplate, yTemplate, xInput, yInput): 

111 tRecord = templateSources.addNew() 

112 iRecord = inputSources.addNew() 

113 

114 tPoint = geom.Point2D(float(xt), float(yt)) 

115 iPoint = geom.Point2D(float(xi), float(yi)) 

116 

117 tRecord.set(centroidKey, tPoint) 

118 iRecord.set(centroidKey, iPoint) 

119 tRecord.set(coordKey, wcs.pixelToSky(tPoint)) 

120 iRecord.set(coordKey, wcs.pixelToSky(iPoint)) 

121 

122 self.showImage(inputExp, inputSources, "Input", 1) 

123 self.showImage(templateExp, templateSources, "Template", 2) 

124 

125 return Struct(xInput=xInput, yInput=yInput, xTemplate=xTemplate, yTemplate=yTemplate, wcs=wcs, 

126 inputExp=inputExp, inputSources=inputSources, 

127 templateExp=templateExp, templateSources=templateSources) 

128 

129 def runTask(self, inData, config=RegisterConfig()): 

130 """Run the task on the data""" 

131 config.sipOrder = 2 

132 task = RegisterTask(name="register", config=config) 

133 results = task.run(inData.inputSources, inData.inputExp.getWcs(), 

134 inData.inputExp.getBBox(afwImage.LOCAL), inData.templateSources) 

135 warpedExp = task.warpExposure(inData.inputExp, results.wcs, inData.templateExp.getWcs(), 

136 inData.templateExp.getBBox(afwImage.LOCAL)) 

137 warpedSources = task.warpSources(inData.inputSources, results.wcs, inData.templateExp.getWcs(), 

138 inData.templateExp.getBBox(afwImage.LOCAL)) 

139 

140 self.showImage(warpedExp, warpedSources, "Aligned", 3) 

141 return Struct(warpedExp=warpedExp, warpedSources=warpedSources, matches=results.matches, 

142 wcs=results.wcs, task=task) 

143 

144 def assertRegistered(self, inData, outData, bad=set()): 

145 """Assert that the registration task is registering images""" 

146 xTemplate = np.array([x for i, x in enumerate(inData.xTemplate) if i not in bad]) 

147 yTemplate = np.array([y for i, y in enumerate(inData.yTemplate) if i not in bad]) 

148 alignedArray = outData.warpedExp.getMaskedImage().getImage().getArray() 

149 self.assertTrue((alignedArray[yTemplate, xTemplate] == 1.0).all()) 

150 for dx in (-1, 0, +1): 

151 for dy in range(-1, 0, +1): 

152 # The density of points is such that I can assume that no point is next to another. 

153 # The values are not quite zero because the "image" is undersampled, so we get ringing. 

154 self.assertTrue((alignedArray[yTemplate+dy, xTemplate+dx] < 0.1).all()) 

155 

156 xAligned = np.array([x for i, x in enumerate(outData.warpedSources["center_x"]) if i not in bad]) 

157 yAligned = np.array([y for i, y in enumerate(outData.warpedSources["center_y"]) if i not in bad]) 

158 self.assertAlmostEqual((xAligned - xTemplate).mean(), 0, 8) 

159 self.assertAlmostEqual((xAligned - xTemplate).std(), 0, 8) 

160 self.assertAlmostEqual((yAligned - yTemplate).mean(), 0, 8) 

161 self.assertAlmostEqual((yAligned - yTemplate).std(), 0, 8) 

162 

163 def assertMetadata(self, outData, numRejected=0): 

164 """Assert that the registration task is populating the metadata""" 

165 metadata = outData.task.metadata 

166 self.assertEqual(metadata.getScalar("MATCH_NUM"), self.numSources) 

167 self.assertAlmostEqual(metadata.getScalar("SIP_RMS"), 0.0) 

168 self.assertEqual(metadata.getScalar("SIP_GOOD"), self.numSources-numRejected) 

169 self.assertEqual(metadata.getScalar("SIP_REJECTED"), numRejected) 

170 

171 def testRegister(self): 

172 """Test image registration""" 

173 inData = self.create() 

174 outData = self.runTask(inData) 

175 self.assertRegistered(inData, outData) 

176 self.assertMetadata(outData) 

177 

178 def testRejection(self): 

179 """Test image registration with rejection""" 

180 inData = self.create() 

181 

182 # Tweak a source to have a bad offset 

183 badIndex = 111 

184 

185 coordKey = inData.inputSources[badIndex].getTable().getCoordKey() 

186 centroidKey = inData.inputSources[badIndex].getTable().getCentroidSlot().getMeasKey() 

187 x, y = float(inData.xInput[badIndex] + 0.01), float(inData.yInput[badIndex] - 0.01) 

188 point = geom.Point2D(x, y) 

189 inData.inputSources[badIndex].set(centroidKey, point) 

190 inData.inputSources[badIndex].set(coordKey, inData.wcs.pixelToSky(point)) 

191 

192 config = RegisterConfig() 

193 config.sipRej = 10.0 

194 

195 outData = self.runTask(inData) 

196 self.assertRegistered(inData, outData, bad=set([badIndex])) 

197 self.assertMetadata(outData, numRejected=1) 

198 

199 def showImage(self, image, sources, title, frame): 

200 """Display an image 

201 

202 Images are only displayed if 'display' is turned on. 

203 

204 @param image: Image to display 

205 @param sources: Sources to mark on the display 

206 @param title: Title to give frame 

207 @param frame: Frame on which to display 

208 """ 

209 if not display: 

210 return 

211 disp = afwDisplay.Display(frame=frame) 

212 disp.mtv(image, title=title) 

213 with disp.Buffering(): 

214 for s in sources: 

215 center = s.getCentroid() 

216 disp.dot("o", center.getX(), center.getY()) 

217 

218 

219class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

220 pass 

221 

222 

223def setup_module(module): 

224 lsst.utils.tests.init() 

225 

226 

227if __name__ == "__main__": 227 ↛ 228line 227 didn't jump to line 228, because the condition on line 227 was never true

228 lsst.utils.tests.init() 

229 unittest.main()