Coverage for tests/test_diaCalculation.py: 41%

99 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-05-04 02:37 -0700

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import numpy as np 

23import pandas as pd 

24import unittest 

25 

26from lsst.meas.base import ( 

27 DiaObjectCalculationTask, 

28 DiaObjectCalculationConfig, 

29 DiaObjectCalculationPlugin) 

30from lsst.meas.base.pluginRegistry import register 

31import lsst.utils.tests 

32 

33 

34@register("testCount") 

35class CountDiaPlugin(DiaObjectCalculationPlugin): 

36 """Simple mean function. 

37 """ 

38 outputCols = ["count"] 

39 

40 @classmethod 

41 def getExecutionOrder(cls): 

42 return cls.DEFAULT_CATALOGCALCULATION 

43 

44 def calculate(self, 

45 diaObjects, 

46 diaObjectId, 

47 diaSources, 

48 filterDiaSources, 

49 band, 

50 **kwargs): 

51 """ 

52 """ 

53 diaObjects.at[diaObjectId, "count"] = len(diaSources["psfFlux"]) 

54 

55 

56@register("testDiaPlugin") 

57class DiaPlugin(DiaObjectCalculationPlugin): 

58 """Simple mean function. 

59 """ 

60 outputCols = ["MeanFlux", "StdFlux"] 

61 

62 plugType = "multi" 

63 

64 @classmethod 

65 def getExecutionOrder(cls): 

66 return cls.DEFAULT_CATALOGCALCULATION 

67 

68 def calculate(self, 

69 diaObjects, 

70 diaSources, 

71 filterDiaSources, 

72 band, 

73 **kwargs): 

74 """ 

75 """ 

76 diaObjects.loc[:, "%sMeanFlux" % band] = \ 

77 filterDiaSources.psfFlux.agg("mean") 

78 diaObjects.loc[:, "%sStdFlux" % band] = \ 

79 filterDiaSources.psfFlux.agg("std") 

80 

81 

82@register("testDependentDiaPlugin") 

83class DependentDiaPlugin(DiaObjectCalculationPlugin): 

84 """Simple calculation using the previously calculated mean. 

85 """ 

86 inputCols = ["MeanFlux"] 

87 outputCols = ["ChiFlux"] 

88 

89 @classmethod 

90 def getExecutionOrder(cls): 

91 return cls.FLUX_MOMENTS_CALCULATED 

92 

93 def calculate(self, 

94 diaObjects, 

95 diaObjectId, 

96 diaSources, 

97 filterDiaSources, 

98 band, 

99 **kwargs): 

100 diaObjects.at[diaObjectId, "%sChiFlux" % band] = np.sum( 

101 ((filterDiaSources["psfFlux"] 

102 - diaObjects.at[diaObjectId, "%sMeanFlux" % band]) 

103 / filterDiaSources["psfFluxErr"]) ** 2) 

104 

105 

106@register("testCollidingDiaPlugin") 

107class CollidingDiaPlugin(DiaObjectCalculationPlugin): 

108 """Simple calculation using the previously calculated mean. 

109 """ 

110 outputCols = ["MeanFlux"] 

111 

112 @classmethod 

113 def getExecutionOrder(cls): 

114 return cls.FLUX_MOMENTS_CALCULATED 

115 

116 def calculate(self, 

117 diaObjects, 

118 diaObjectId, 

119 diaSources, 

120 filterDiaSources, 

121 band, 

122 **kwargs): 

123 diaObjects.at[diaObjectId, "%sMeanFlux" % band] = 0.0 

124 

125 

126class TestDiaCalcluation(unittest.TestCase): 

127 

128 def setUp(self): 

129 # Create diaObjects 

130 self.newDiaObjectId = 13 

131 self.diaObjects = pd.DataFrame( 

132 data=[{"diaObjectId": objId} 

133 for objId in [0, 1, 2, 3, 4, 5, self.newDiaObjectId]]) 

134 

135 # Create diaSources from "previous runs" and newly created ones. 

136 diaSources = [{"diaSourceId": objId, "diaObjectId": objId, 

137 "psfFlux": 0., "psfFluxErr": 1., 

138 "scienceFlux": 0., "scienceFluxErr": 1., 

139 "midpointMjdTai": 0, "band": "g"} 

140 for objId in range(5)] 

141 diaSources.extend([{"diaSourceId": 5 + objId, "diaObjectId": objId, 

142 "psfFlux": 0., "psfFluxErr": 1., 

143 "scienceFlux": 0., "scienceFluxErr": 1., 

144 "midpointMjdTai": 0, "band": "r"} 

145 for objId in range(5)]) 

146 diaSources.extend([{"diaSourceId": 10, "diaObjectId": 0, 

147 "psfFlux": 1., "psfFluxErr": 1., 

148 "scienceFlux": 0., "scienceFluxErr": 0., 

149 "midpointMjdTai": 0, "band": "g"}, 

150 {"diaSourceId": 11, "diaObjectId": 1, 

151 "psfFlux": 1., "psfFluxErr": 1., 

152 "scienceFlux": 0., "scienceFluxErr": 0., 

153 "midpointMjdTai": 0, "band": "g"}, 

154 {"diaSourceId": 12, "diaObjectId": 2, 

155 "psfFlux": np.nan, "psfFluxErr": 1., 

156 "scienceFlux": 0., "scienceFluxErr": 0., 

157 "midpointMjdTai": 0, "band": "g"}, 

158 {"diaSourceId": self.newDiaObjectId, 

159 "diaObjectId": self.newDiaObjectId, 

160 "psfFlux": 1., "psfFluxErr": 1., 

161 "scienceFlux": 0., "scienceFluxErr": 0., 

162 "midpointMjdTai": 0, "band": "g"}]) 

163 self.diaSources = pd.DataFrame(data=diaSources) 

164 

165 self.updatedDiaObjectIds = np.array([0, 1, 2, self.newDiaObjectId], 

166 dtype=np.int64) 

167 

168 conf = DiaObjectCalculationConfig() 

169 conf.plugins = ["testDiaPlugin", 

170 "testDependentDiaPlugin"] 

171 self.diaObjCalTask = DiaObjectCalculationTask(config=conf) 

172 

173 def testRun(self): 

174 """Test the run method and that diaObjects are updated correctly. 

175 """ 

176 results = self.diaObjCalTask.run(self.diaObjects, 

177 self.diaSources, 

178 self.updatedDiaObjectIds, 

179 ["g"]) 

180 diaObjectCat = results.diaObjectCat 

181 updatedDiaObjects = results.updatedDiaObjects 

182 updatedDiaObjects.set_index("diaObjectId", inplace=True) 

183 # Test the lengths of the output dataframes. 

184 self.assertEqual(len(diaObjectCat), len(self.diaObjects)) 

185 self.assertEqual(len(updatedDiaObjects), 

186 len(self.updatedDiaObjectIds)) 

187 

188 # Test values stored computed in the task. 

189 for objId, diaObject in updatedDiaObjects.iterrows(): 

190 if objId == self.newDiaObjectId: 

191 self.assertEqual(diaObject["gMeanFlux"], 1.) 

192 self.assertTrue(np.isnan(diaObject["gStdFlux"])) 

193 self.assertAlmostEqual(diaObject["gChiFlux"], 0.0) 

194 elif objId == 2: 

195 self.assertAlmostEqual(diaObject["gMeanFlux"], 0.0) 

196 self.assertTrue(np.isnan(diaObject["gStdFlux"])) 

197 self.assertAlmostEqual(diaObject["gChiFlux"], 0.0) 

198 else: 

199 self.assertAlmostEqual(diaObject["gMeanFlux"], 0.5) 

200 self.assertAlmostEqual(diaObject["gStdFlux"], 

201 0.7071067811865476) 

202 self.assertAlmostEqual(diaObject["gChiFlux"], 0.5) 

203 

204 def testRunUnindexed(self): 

205 """Test inputing un-indexed catalogs. 

206 """ 

207 unindexedDiaSources = pd.DataFrame(data=[ 

208 {"diaSourceId": objId, "diaObjectId": 0, 

209 "psfFlux": 0., "psfFluxErr": 1., 

210 "scienceFlux": 0., "scienceFluxErr": 1., 

211 "midpointMjdTai": 0, "band": "g"} 

212 for objId in range(1000)]) 

213 unindexedDiaSources = pd.concat( 

214 ( 

215 unindexedDiaSources, 

216 pd.DataFrame( 

217 data=[ 

218 { 

219 "diaSourceId": objId + 1000, 

220 "diaObjectId": 0, 

221 "psfFlux": 0., "psfFluxErr": 1., 

222 "scienceFlux": 0., "scienceFluxErr": 1., 

223 "midpointMjdTai": 0, "band": "g", 

224 } 

225 for objId in range(10) 

226 ] 

227 ) 

228 ) 

229 ) 

230 

231 conf = DiaObjectCalculationConfig() 

232 conf.plugins = ["testCount"] 

233 diaObjectCalTask = DiaObjectCalculationTask(config=conf) 

234 self.diaObjects.reset_index() 

235 results = diaObjectCalTask.run(self.diaObjects, 

236 unindexedDiaSources, 

237 np.array([0], dtype=np.int64), 

238 ["g"]) 

239 updatedDiaObjects = results.updatedDiaObjects 

240 self.assertEqual(updatedDiaObjects.at[0, "count"], 

241 len(unindexedDiaSources)) 

242 

243 def testConflictingPlugins(self): 

244 """Test that code properly exits upon plugin collision. 

245 """ 

246 with self.assertRaises(ValueError): 

247 conf = DiaObjectCalculationConfig() 

248 conf.plugins = ["testDependentDiaPlugin"] 

249 DiaObjectCalculationTask(config=conf) 

250 

251 with self.assertRaises(ValueError): 

252 conf = DiaObjectCalculationConfig() 

253 conf.plugins = ["testDiaPlugin", 

254 "testCollidingDiaPlugin", 

255 "testDependentDiaPlugin"] 

256 DiaObjectCalculationTask(config=conf) 

257 

258 # Test that ordering in the config does not matter and dependent 

259 # plugin is instantiated after independent plugin. Would raise 

260 # ValueError on failure. 

261 conf = DiaObjectCalculationConfig() 

262 conf.plugins = ["testDependentDiaPlugin", 

263 "testDiaPlugin"] 

264 DiaObjectCalculationTask(config=conf) 

265 

266 

267def setup_module(module): 

268 lsst.utils.tests.init() 

269 

270 

271if __name__ == "__main__": 271 ↛ 272line 271 didn't jump to line 272, because the condition on line 271 was never true

272 lsst.utils.tests.init() 

273 unittest.main()