Coverage for python/lsst/ap/association/association.py: 26%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

81 statements  

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""A simple implementation of source association task for ap_verify. 

23""" 

24 

25__all__ = ["AssociationConfig", "AssociationTask"] 

26 

27import numpy as np 

28import pandas as pd 

29from scipy.spatial import cKDTree 

30 

31import lsst.geom as geom 

32import lsst.pex.config as pexConfig 

33import lsst.pipe.base as pipeBase 

34 

35# Enforce an error for unsafe column/array value setting in pandas. 

36pd.options.mode.chained_assignment = 'raise' 

37 

38 

39class AssociationConfig(pexConfig.Config): 

40 """Config class for AssociationTask. 

41 """ 

42 maxDistArcSeconds = pexConfig.Field( 

43 dtype=float, 

44 doc='Maximum distance in arcseconds to test for a DIASource to be a ' 

45 'match to a DIAObject.', 

46 default=1.0, 

47 ) 

48 

49 

50class AssociationTask(pipeBase.Task): 

51 """Associate DIAOSources into existing DIAObjects. 

52 

53 This task performs the association of detected DIASources in a visit 

54 with the previous DIAObjects detected over time. It also creates new 

55 DIAObjects out of DIASources that cannot be associated with previously 

56 detected DIAObjects. 

57 """ 

58 

59 ConfigClass = AssociationConfig 

60 _DefaultName = "association" 

61 

62 @pipeBase.timeMethod 

63 def run(self, 

64 diaSources, 

65 diaObjects): 

66 """Associate the new DiaSources with existing DiaObjects. 

67 

68 Parameters 

69 ---------- 

70 diaSources : `pandas.DataFrame` 

71 New DIASources to be associated with existing DIAObjects. 

72 diaObjects : `pandas.DataFrame` 

73 Existing diaObjects from the Apdb. 

74 

75 Returns 

76 ------- 

77 result : `lsst.pipe.base.Struct` 

78 Results struct with components. 

79 

80 - ``"diaSources"`` : Full set of diaSources after matching. Matched 

81 Sources have their diaObjectId updated and set to the id of the 

82 diaObject they were matched to. (`pandas.DataFrame`) 

83 - ``"nUpdatedDiaObjects"`` : Number of DiaObjects that were 

84 matched to new DiaSources. (`int`) 

85 - ``"nUnassociatedDiaObjects"`` : Number of DiaObjects that were 

86 not matched a new DiaSource. (`int`) 

87 """ 

88 diaSources = self.check_dia_source_radec(diaSources) 

89 if len(diaObjects) == 0: 

90 return pipeBase.Struct( 

91 diaSources=diaSources, 

92 nUpdatedDiaObjects=0, 

93 nUnassociatedDiaObjects=0) 

94 

95 matchResult = self.associate_sources(diaObjects, diaSources) 

96 

97 return pipeBase.Struct( 

98 diaSources=matchResult.diaSources, 

99 nUpdatedDiaObjects=matchResult.nUpdatedDiaObjects, 

100 nUnassociatedDiaObjects=matchResult.nUnassociatedDiaObjects) 

101 

102 def check_dia_source_radec(self, dia_sources): 

103 """Check that all DiaSources have non-NaN values for RA/DEC. 

104 

105 If one or more DiaSources are found to have NaN values, throw a 

106 warning to the log with the ids of the offending sources. Drop them 

107 from the table. 

108 

109 Parameters 

110 ---------- 

111 dia_sources : `pandas.DataFrame` 

112 Input DiaSources to check for NaN values. 

113 

114 Returns 

115 ------- 

116 trimmed_sources : `pandas.DataFrame` 

117 DataFrame of DiaSources trimmed of all entries with NaN values for 

118 RA/DEC. 

119 """ 

120 nan_mask = (dia_sources.loc[:, "ra"].isnull() 

121 | dia_sources.loc[:, "decl"].isnull()) 

122 if np.any(nan_mask): 

123 nan_idxs = np.argwhere(nan_mask.to_numpy()).flatten() 

124 for nan_idx in nan_idxs: 

125 self.log.warning( 

126 "DiaSource %i has NaN value for RA/DEC, " 

127 "dropping from association." % 

128 dia_sources.loc[nan_idx, "diaSourceId"]) 

129 dia_sources = dia_sources[~nan_mask] 

130 return dia_sources 

131 

132 @pipeBase.timeMethod 

133 def associate_sources(self, dia_objects, dia_sources): 

134 """Associate the input DIASources with the catalog of DIAObjects. 

135 

136 DiaObject DataFrame must be indexed on ``diaObjectId``. 

137 

138 Parameters 

139 ---------- 

140 dia_objects : `pandas.DataFrame` 

141 Catalog of DIAObjects to attempt to associate the input 

142 DIASources into. 

143 dia_sources : `pandas.DataFrame` 

144 DIASources to associate into the DIAObjectCollection. 

145 

146 Returns 

147 ------- 

148 result : `lsst.pipe.base.Struct` 

149 Results struct with components. 

150 

151 - ``"diaSources"`` : Full set of diaSources both matched and not. 

152 (`pandas.DataFrame`) 

153 - ``"nUpdatedDiaObjects"`` : Number of DiaObjects that were 

154 associated. (`int`) 

155 - ``"nUnassociatedDiaObjects"`` : Number of DiaObjects that were 

156 not matched a new DiaSource. (`int`) 

157 """ 

158 scores = self.score( 

159 dia_objects, dia_sources, 

160 self.config.maxDistArcSeconds * geom.arcseconds) 

161 match_result = self.match(dia_objects, dia_sources, scores) 

162 

163 return match_result 

164 

165 @pipeBase.timeMethod 

166 def score(self, dia_objects, dia_sources, max_dist): 

167 """Compute a quality score for each dia_source/dia_object pair 

168 between this catalog of DIAObjects and the input DIASource catalog. 

169 

170 ``max_dist`` sets maximum separation in arcseconds to consider a 

171 dia_source a possible match to a dia_object. If the pair is 

172 beyond this distance no score is computed. 

173 

174 Parameters 

175 ---------- 

176 dia_objects : `pandas.DataFrame` 

177 A contiguous catalog of DIAObjects to score against dia_sources. 

178 dia_sources : `pandas.DataFrame` 

179 A contiguous catalog of dia_sources to "score" based on distance 

180 and (in the future) other metrics. 

181 max_dist : `lsst.geom.Angle` 

182 Maximum allowed distance to compute a score for a given DIAObject 

183 DIASource pair. 

184 

185 Returns 

186 ------- 

187 result : `lsst.pipe.base.Struct` 

188 Results struct with components: 

189 

190 - ``"scores"``: array of floats of match quality updated DIAObjects 

191 (array-like of `float`). 

192 - ``"obj_idxs"``: indexes of the matched DIAObjects in the catalog. 

193 (array-like of `int`) 

194 - ``"obj_ids"``: array of floats of match quality updated DIAObjects 

195 (array-like of `int`). 

196 

197 Default values for these arrays are 

198 INF, -1, and -1 respectively for unassociated sources. 

199 """ 

200 scores = np.full(len(dia_sources), np.inf, dtype=np.float64) 

201 obj_idxs = np.full(len(dia_sources), -1, dtype=np.int64) 

202 obj_ids = np.full(len(dia_sources), 0, dtype=np.int64) 

203 

204 if len(dia_objects) == 0: 

205 return pipeBase.Struct( 

206 scores=scores, 

207 obj_idxs=obj_idxs, 

208 obj_ids=obj_ids) 

209 

210 spatial_tree = self._make_spatial_tree(dia_objects) 

211 

212 max_dist_rad = max_dist.asRadians() 

213 

214 vectors = self._radec_to_xyz(dia_sources) 

215 

216 scores, obj_idxs = spatial_tree.query( 

217 vectors, 

218 distance_upper_bound=max_dist_rad) 

219 matched_src_idxs = np.argwhere(np.isfinite(scores)) 

220 obj_ids[matched_src_idxs] = dia_objects.index.to_numpy()[ 

221 obj_idxs[matched_src_idxs]] 

222 

223 return pipeBase.Struct( 

224 scores=scores, 

225 obj_idxs=obj_idxs, 

226 obj_ids=obj_ids) 

227 

228 def _make_spatial_tree(self, dia_objects): 

229 """Create a searchable kd-tree the input dia_object positions. 

230 

231 Parameters 

232 ---------- 

233 dia_objects : `pandas.DataFrame` 

234 A catalog of DIAObjects to create the tree from. 

235 

236 Returns 

237 ------- 

238 kd_tree : `scipy.spatical.cKDTree` 

239 Searchable kd-tree created from the positions of the DIAObjects. 

240 """ 

241 vectors = self._radec_to_xyz(dia_objects) 

242 return cKDTree(vectors) 

243 

244 def _radec_to_xyz(self, catalog): 

245 """Convert input ra/dec coordinates to spherical unit-vectors. 

246 

247 Parameters 

248 ---------- 

249 catalog : `pandas.DataFrame` 

250 Catalog to produce spherical unit-vector from. 

251 

252 Returns 

253 ------- 

254 vectors : `numpy.ndarray`, (N, 3) 

255 Output unit-vectors 

256 """ 

257 ras = np.radians(catalog["ra"]) 

258 decs = np.radians(catalog["decl"]) 

259 vectors = np.empty((len(ras), 3)) 

260 

261 sin_dec = np.sin(np.pi / 2 - decs) 

262 vectors[:, 0] = sin_dec * np.cos(ras) 

263 vectors[:, 1] = sin_dec * np.sin(ras) 

264 vectors[:, 2] = np.cos(np.pi / 2 - decs) 

265 

266 return vectors 

267 

268 @pipeBase.timeMethod 

269 def match(self, dia_objects, dia_sources, score_struct): 

270 """Match DIAsources to DiaObjects given a score. 

271 

272 Parameters 

273 ---------- 

274 dia_objects : `pandas.DataFrame` 

275 A SourceCatalog of DIAObjects to associate to DIASources. 

276 dia_sources : `pandas.DataFrame` 

277 A contiguous catalog of dia_sources for which the set of scores 

278 has been computed on with DIAObjectCollection.score. 

279 score_struct : `lsst.pipe.base.Struct` 

280 Results struct with components: 

281 

282 - ``"scores"``: array of floats of match quality 

283 updated DIAObjects (array-like of `float`). 

284 - ``"obj_ids"``: array of floats of match quality 

285 updated DIAObjects (array-like of `int`). 

286 - ``"obj_idxs"``: indexes of the matched DIAObjects in the catalog. 

287 (array-like of `int`) 

288 

289 Default values for these arrays are 

290 INF, -1 and -1 respectively for unassociated sources. 

291 

292 Returns 

293 ------- 

294 result : `lsst.pipe.base.Struct` 

295 Results struct with components. 

296 

297 - ``"diaSources"`` : Full set of diaSources both matched and not. 

298 (`pandas.DataFrame`) 

299 - ``"nUpdatedDiaObjects"`` : Number of DiaObjects that were 

300 associated. (`int`) 

301 - ``"nUnassociatedDiaObjects"`` : Number of DiaObjects that were 

302 not matched a new DiaSource. (`int`) 

303 """ 

304 n_previous_dia_objects = len(dia_objects) 

305 used_dia_object = np.zeros(n_previous_dia_objects, dtype=bool) 

306 used_dia_source = np.zeros(len(dia_sources), dtype=bool) 

307 associated_dia_object_ids = np.zeros(len(dia_sources), 

308 dtype=np.uint64) 

309 n_updated_dia_objects = 0 

310 

311 # We sort from best match to worst to effectively perform a 

312 # "handshake" match where both the DIASources and DIAObjects agree 

313 # their the best match. By sorting this way, scores with NaN (those 

314 # sources that have no match and will create new DIAObjects) will be 

315 # placed at the end of the array. 

316 score_args = score_struct.scores.argsort(axis=None) 

317 for score_idx in score_args: 

318 if not np.isfinite(score_struct.scores[score_idx]): 

319 # Thanks to the sorting the rest of the sources will be 

320 # NaN for their score. We therefore exit the loop to append 

321 # sources to a existing DIAObject, leaving these for 

322 # the loop creating new objects. 

323 break 

324 dia_obj_idx = score_struct.obj_idxs[score_idx] 

325 if used_dia_object[dia_obj_idx]: 

326 continue 

327 used_dia_object[dia_obj_idx] = True 

328 used_dia_source[score_idx] = True 

329 obj_id = score_struct.obj_ids[score_idx] 

330 associated_dia_object_ids[score_idx] = obj_id 

331 dia_sources.loc[score_idx, "diaObjectId"] = obj_id 

332 n_updated_dia_objects += 1 

333 

334 return pipeBase.Struct( 

335 diaSources=dia_sources, 

336 nUpdatedDiaObjects=n_updated_dia_objects, 

337 nUnassociatedDiaObjects=(n_previous_dia_objects 

338 - n_updated_dia_objects))