Coverage for tests/test_matchFakes.py: 30%

82 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-07-17 08:57 +0000

1# 

2# This file is part of pipe_tasks. 

3# 

4# Developed for the LSST Data Management System. 

5# This product includes software developed by the LSST Project 

6# (http://www.lsst.org). 

7# See the COPYRIGHT file at the top-level directory of this distribution 

8# for details of code ownership. 

9# 

10# This program is free software: you can redistribute it and/or modify 

11# it under the terms of the GNU General Public License as published by 

12# the Free Software Foundation, either version 3 of the License, or 

13# (at your option) any later version. 

14# 

15# This program is distributed in the hope that it will be useful, 

16# but WITHOUT ANY WARRANTY; without even the implied warranty of 

17# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

18# GNU General Public License for more details. 

19# 

20# You should have received a copy of the GNU General Public License 

21# along with this program. If not, see <http://www.gnu.org/licenses/>. 

22# 

23 

24import numpy as np 

25import pandas as pd 

26import shutil 

27import tempfile 

28import unittest 

29import uuid 

30 

31import lsst.daf.butler.tests as butlerTests 

32import lsst.sphgeom as sphgeom 

33import lsst.geom as geom 

34import lsst.meas.base.tests as measTests 

35from lsst.pipe.base import testUtils 

36import lsst.skymap as skyMap 

37import lsst.utils.tests 

38 

39from lsst.pipe.tasks.matchFakes import MatchFakesTask, MatchFakesConfig 

40 

41 

42class TestMatchFakes(lsst.utils.tests.TestCase): 

43 

44 def setUp(self): 

45 """Create fake data to use in the tests. 

46 """ 

47 self.bbox = geom.Box2I(geom.Point2I(0, 0), 

48 geom.Extent2I(1024, 1153)) 

49 dataset = measTests.TestDataset(self.bbox) 

50 self.exposure = dataset.exposure 

51 

52 simpleMapConfig = skyMap.discreteSkyMap.DiscreteSkyMapConfig() 

53 simpleMapConfig.raList = [dataset.exposure.getWcs().getSkyOrigin().getRa().asDegrees()] 

54 simpleMapConfig.decList = [dataset.exposure.getWcs().getSkyOrigin().getDec().asDegrees()] 

55 simpleMapConfig.radiusList = [0.1] 

56 

57 self.simpleMap = skyMap.DiscreteSkyMap(simpleMapConfig) 

58 self.tractId = 0 

59 bCircle = self.simpleMap.generateTract(self.tractId).getInnerSkyPolygon().getBoundingCircle() 

60 bCenter = sphgeom.LonLat(bCircle.getCenter()) 

61 bRadius = bCircle.getOpeningAngle().asRadians() 

62 targetSources = 10000 

63 self.sourceDensity = (targetSources 

64 / (bCircle.getArea() * (180 / np.pi) ** 2)) 

65 self.rng = np.random.default_rng(1234) 

66 

67 self.fakeCat = pd.DataFrame({ 

68 "fakeId": [uuid.uuid4().int & (1 << 64) - 1 for n in range(targetSources)], 

69 # Quick-and-dirty values for testing 

70 "raJ2000": bCenter.getLon().asRadians() + bRadius * (2.0 * self.rng.random(targetSources) - 1.0), 

71 "decJ2000": bCenter.getLat().asRadians() + bRadius * (2.0 * self.rng.random(targetSources) - 1.0), 

72 "isVisitSource": np.concatenate([np.ones(targetSources//2, dtype="bool"), 

73 np.zeros(targetSources - targetSources//2, dtype="bool")]), 

74 "isTemplateSource": np.concatenate([np.zeros(targetSources//2, dtype="bool"), 

75 np.ones(targetSources - targetSources//2, dtype="bool")]), 

76 **{band: self.rng.uniform(20, 30, size=targetSources) 

77 for band in {"u", "g", "r", "i", "z", "y"}}, 

78 "DiskHalfLightRadius": np.ones(targetSources, dtype="float"), 

79 "BulgeHalfLightRadius": np.ones(targetSources, dtype="float"), 

80 "disk_n": np.ones(targetSources, dtype="float"), 

81 "bulge_n": np.ones(targetSources, dtype="float"), 

82 "a_d": np.ones(targetSources, dtype="float"), 

83 "a_b": np.ones(targetSources, dtype="float"), 

84 "b_d": np.ones(targetSources, dtype="float"), 

85 "b_b": np.ones(targetSources, dtype="float"), 

86 "pa_disk": np.ones(targetSources, dtype="float"), 

87 "pa_bulge": np.ones(targetSources, dtype="float"), 

88 "sourceType": targetSources * ["star"], 

89 }) 

90 

91 self.inExp = np.zeros(len(self.fakeCat), dtype=bool) 

92 bbox = geom.Box2D(self.exposure.getBBox()) 

93 for idx, row in self.fakeCat.iterrows(): 

94 coord = geom.SpherePoint(row["raJ2000"], 

95 row["decJ2000"], 

96 geom.radians) 

97 cent = self.exposure.getWcs().skyToPixel(coord) 

98 self.inExp[idx] = bbox.contains(cent) 

99 

100 tmpCat = self.fakeCat[self.inExp].iloc[:int(self.inExp.sum() / 2)] 

101 extraColumnData = self.rng.integers(0, 100, size=len(tmpCat)) 

102 self.sourceCat = pd.DataFrame( 

103 data={"ra": np.degrees(tmpCat["raJ2000"]), 

104 "decl": np.degrees(tmpCat["decJ2000"]), 

105 "diaObjectId": np.arange(1, len(tmpCat) + 1, dtype=int), 

106 "filterName": "g", 

107 "diaSourceId": np.arange(1, len(tmpCat) + 1, dtype=int), 

108 "extraColumn": extraColumnData}) 

109 self.sourceCat.set_index(["diaObjectId", "filterName", "extraColumn"], 

110 drop=False, 

111 inplace=True) 

112 

113 def testRunQuantum(self): 

114 """Test the run quantum method with a gen3 butler. 

115 """ 

116 root = tempfile.mkdtemp() 

117 dimensions = {"instrument": ["notACam"], 

118 "skymap": ["deepCoadd_skyMap"], 

119 "tract": [0, 42], 

120 "visit": [1234, 4321], 

121 "detector": [25, 26]} 

122 testRepo = butlerTests.makeTestRepo(root, dimensions) 

123 matchTask = MatchFakesTask() 

124 connections = matchTask.config.ConnectionsClass( 

125 config=matchTask.config) 

126 

127 fakesDataId = {"skymap": "deepCoadd_skyMap", 

128 "tract": 0} 

129 imgDataId = {"instrument": "notACam", 

130 "visit": 1234, 

131 "detector": 25} 

132 butlerTests.addDatasetType( 

133 testRepo, 

134 connections.fakeCat.name, 

135 connections.fakeCat.dimensions, 

136 connections.fakeCat.storageClass) 

137 butlerTests.addDatasetType( 

138 testRepo, 

139 connections.diffIm.name, 

140 connections.diffIm.dimensions, 

141 connections.diffIm.storageClass) 

142 butlerTests.addDatasetType( 

143 testRepo, 

144 connections.associatedDiaSources.name, 

145 connections.associatedDiaSources.dimensions, 

146 connections.associatedDiaSources.storageClass) 

147 butlerTests.addDatasetType( 

148 testRepo, 

149 connections.matchedDiaSources.name, 

150 connections.matchedDiaSources.dimensions, 

151 connections.matchedDiaSources.storageClass) 

152 butler = butlerTests.makeTestCollection(testRepo) 

153 

154 butler.put(self.fakeCat, 

155 connections.fakeCat.name, 

156 {"tract": fakesDataId["tract"], 

157 "skymap": fakesDataId["skymap"]}) 

158 butler.put(self.exposure, 

159 connections.diffIm.name, 

160 {"instrument": imgDataId["instrument"], 

161 "visit": imgDataId["visit"], 

162 "detector": imgDataId["detector"]}) 

163 butler.put(self.sourceCat, 

164 connections.associatedDiaSources.name, 

165 {"instrument": imgDataId["instrument"], 

166 "visit": imgDataId["visit"], 

167 "detector": imgDataId["detector"]}) 

168 

169 quantumDataId = imgDataId.copy() 

170 quantumDataId.update(fakesDataId) 

171 quantum = testUtils.makeQuantum( 

172 matchTask, butler, quantumDataId, 

173 {"fakeCat": fakesDataId, 

174 "diffIm": imgDataId, 

175 "associatedDiaSources": imgDataId, 

176 "matchedDiaSources": imgDataId}) 

177 run = testUtils.runTestQuantum(matchTask, butler, quantum) 

178 # Actual input dataset omitted for simplicity 

179 run.assert_called_once() 

180 shutil.rmtree(root, ignore_errors=True) 

181 

182 def testRun(self): 

183 """Test the run method. 

184 """ 

185 matchFakesConfig = MatchFakesConfig() 

186 matchFakesConfig.matchDistanceArcseconds = 0.1 

187 matchFakes = MatchFakesTask(config=matchFakesConfig) 

188 result = matchFakes.run(self.fakeCat, 

189 self.exposure, 

190 self.sourceCat) 

191 self.assertEqual(self.inExp.sum(), len(result.matchedDiaSources)) 

192 self.assertEqual( 

193 len(self.sourceCat), 

194 np.sum(np.isfinite(result.matchedDiaSources["extraColumn"]))) 

195 

196 def testTrimCat(self): 

197 """Test that the correct number of sources are in the ccd area. 

198 """ 

199 matchTask = MatchFakesTask() 

200 result = matchTask._trimFakeCat(self.fakeCat, self.exposure) 

201 self.assertEqual(len(result), self.inExp.sum()) 

202 

203 

204class MemoryTester(lsst.utils.tests.MemoryTestCase): 

205 pass 

206 

207 

208def setup_module(module): 

209 lsst.utils.tests.init() 

210 

211 

212if __name__ == "__main__": 212 ↛ 213line 212 didn't jump to line 213, because the condition on line 212 was never true

213 lsst.utils.tests.init() 

214 unittest.main()