Coverage for python/lsst/ap/association/association.py: 32%

83 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-09 11:35 +0000

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""A simple implementation of source association task for ap_verify. 

23""" 

24 

25__all__ = ["AssociationConfig", "AssociationTask"] 

26 

27import numpy as np 

28import pandas as pd 

29from scipy.spatial import cKDTree 

30 

31import lsst.geom as geom 

32import lsst.pex.config as pexConfig 

33import lsst.pipe.base as pipeBase 

34from lsst.utils.timer import timeMethod 

35 

36# Enforce an error for unsafe column/array value setting in pandas. 

37pd.options.mode.chained_assignment = 'raise' 

38 

39 

40class AssociationConfig(pexConfig.Config): 

41 """Config class for AssociationTask. 

42 """ 

43 

44 maxDistArcSeconds = pexConfig.Field( 

45 dtype=float, 

46 doc="Maximum distance in arcseconds to test for a DIASource to be a " 

47 "match to a DIAObject.", 

48 default=1.0, 

49 ) 

50 

51 

52class AssociationTask(pipeBase.Task): 

53 """Associate DIAOSources into existing DIAObjects. 

54 

55 This task performs the association of detected DIASources in a visit 

56 with the previous DIAObjects detected over time. It also creates new 

57 DIAObjects out of DIASources that cannot be associated with previously 

58 detected DIAObjects. 

59 """ 

60 

61 ConfigClass = AssociationConfig 

62 _DefaultName = "association" 

63 

64 @timeMethod 

65 def run(self, 

66 diaSources, 

67 diaObjects): 

68 """Associate the new DiaSources with existing DiaObjects. 

69 

70 Parameters 

71 ---------- 

72 diaSources : `pandas.DataFrame` 

73 New DIASources to be associated with existing DIAObjects. 

74 diaObjects : `pandas.DataFrame` 

75 Existing diaObjects from the Apdb. 

76 

77 Returns 

78 ------- 

79 result : `lsst.pipe.base.Struct` 

80 Results struct with components. 

81 

82 - ``matchedDiaSources`` : DiaSources that were matched. Matched 

83 Sources have their diaObjectId updated and set to the id of the 

84 diaObject they were matched to. (`pandas.DataFrame`) 

85 - ``unAssocDiaSources`` : DiaSources that were not matched. 

86 Unassociated sources have their diaObject set to 0 as they 

87 were not associated with any existing DiaObjects. 

88 (`pandas.DataFrame`) 

89 - ``nUpdatedDiaObjects`` : Number of DiaObjects that were 

90 matched to new DiaSources. (`int`) 

91 - ``nUnassociatedDiaObjects`` : Number of DiaObjects that were 

92 not matched a new DiaSource. (`int`) 

93 """ 

94 diaSources = self.check_dia_source_radec(diaSources) 

95 

96 if len(diaObjects) == 0: 

97 return pipeBase.Struct( 

98 matchedDiaSources=pd.DataFrame(columns=diaSources.columns), 

99 unAssocDiaSources=diaSources, 

100 nUpdatedDiaObjects=0, 

101 nUnassociatedDiaObjects=0) 

102 

103 matchResult = self.associate_sources(diaObjects, diaSources) 

104 

105 mask = matchResult.diaSources["diaObjectId"] != 0 

106 

107 return pipeBase.Struct( 

108 matchedDiaSources=matchResult.diaSources[mask].reset_index(drop=True), 

109 unAssocDiaSources=matchResult.diaSources[~mask].reset_index(drop=True), 

110 nUpdatedDiaObjects=matchResult.nUpdatedDiaObjects, 

111 nUnassociatedDiaObjects=matchResult.nUnassociatedDiaObjects) 

112 

113 def check_dia_source_radec(self, dia_sources): 

114 """Check that all DiaSources have non-NaN values for RA/DEC. 

115 

116 If one or more DiaSources are found to have NaN values, throw a 

117 warning to the log with the ids of the offending sources. Drop them 

118 from the table. 

119 

120 Parameters 

121 ---------- 

122 dia_sources : `pandas.DataFrame` 

123 Input DiaSources to check for NaN values. 

124 

125 Returns 

126 ------- 

127 trimmed_sources : `pandas.DataFrame` 

128 DataFrame of DiaSources trimmed of all entries with NaN values for 

129 RA/DEC. 

130 """ 

131 nan_mask = (dia_sources.loc[:, "ra"].isnull() 

132 | dia_sources.loc[:, "dec"].isnull()) 

133 if np.any(nan_mask): 

134 nan_idxs = np.argwhere(nan_mask.to_numpy()).flatten() 

135 for nan_idx in nan_idxs: 

136 self.log.warning( 

137 "DiaSource %i has NaN value for RA/DEC, " 

138 "dropping from association." % 

139 dia_sources.loc[nan_idx, "diaSourceId"]) 

140 dia_sources = dia_sources[~nan_mask] 

141 return dia_sources 

142 

143 @timeMethod 

144 def associate_sources(self, dia_objects, dia_sources): 

145 """Associate the input DIASources with the catalog of DIAObjects. 

146 

147 DiaObject DataFrame must be indexed on ``diaObjectId``. 

148 

149 Parameters 

150 ---------- 

151 dia_objects : `pandas.DataFrame` 

152 Catalog of DIAObjects to attempt to associate the input 

153 DIASources into. 

154 dia_sources : `pandas.DataFrame` 

155 DIASources to associate into the DIAObjectCollection. 

156 

157 Returns 

158 ------- 

159 result : `lsst.pipe.base.Struct` 

160 Results struct with components. 

161 

162 - ``diaSources`` : Full set of diaSources both matched and not. 

163 (`pandas.DataFrame`) 

164 - ``nUpdatedDiaObjects`` : Number of DiaObjects that were 

165 associated. (`int`) 

166 - ``nUnassociatedDiaObjects`` : Number of DiaObjects that were 

167 not matched a new DiaSource. (`int`) 

168 """ 

169 scores = self.score( 

170 dia_objects, dia_sources, 

171 self.config.maxDistArcSeconds * geom.arcseconds) 

172 match_result = self.match(dia_objects, dia_sources, scores) 

173 

174 return match_result 

175 

176 @timeMethod 

177 def score(self, dia_objects, dia_sources, max_dist): 

178 """Compute a quality score for each dia_source/dia_object pair 

179 between this catalog of DIAObjects and the input DIASource catalog. 

180 

181 ``max_dist`` sets maximum separation in arcseconds to consider a 

182 dia_source a possible match to a dia_object. If the pair is 

183 beyond this distance no score is computed. 

184 

185 Parameters 

186 ---------- 

187 dia_objects : `pandas.DataFrame` 

188 A contiguous catalog of DIAObjects to score against dia_sources. 

189 dia_sources : `pandas.DataFrame` 

190 A contiguous catalog of dia_sources to "score" based on distance 

191 and (in the future) other metrics. 

192 max_dist : `lsst.geom.Angle` 

193 Maximum allowed distance to compute a score for a given DIAObject 

194 DIASource pair. 

195 

196 Returns 

197 ------- 

198 result : `lsst.pipe.base.Struct` 

199 Results struct with components: 

200 

201 - ``scores``: array of floats of match quality updated DIAObjects 

202 (array-like of `float`). 

203 - ``obj_idxs``: indexes of the matched DIAObjects in the catalog. 

204 (array-like of `int`) 

205 - ``obj_ids``: array of floats of match quality updated DIAObjects 

206 (array-like of `int`). 

207 

208 Default values for these arrays are 

209 INF, -1, and -1 respectively for unassociated sources. 

210 """ 

211 scores = np.full(len(dia_sources), np.inf, dtype=np.float64) 

212 obj_idxs = np.full(len(dia_sources), -1, dtype=np.int64) 

213 obj_ids = np.full(len(dia_sources), 0, dtype=np.int64) 

214 

215 if len(dia_objects) == 0: 

216 return pipeBase.Struct( 

217 scores=scores, 

218 obj_idxs=obj_idxs, 

219 obj_ids=obj_ids) 

220 

221 spatial_tree = self._make_spatial_tree(dia_objects) 

222 

223 max_dist_rad = max_dist.asRadians() 

224 

225 vectors = self._radec_to_xyz(dia_sources) 

226 

227 scores, obj_idxs = spatial_tree.query( 

228 vectors, 

229 distance_upper_bound=max_dist_rad) 

230 matched_src_idxs = np.argwhere(np.isfinite(scores)) 

231 obj_ids[matched_src_idxs] = dia_objects.index.to_numpy()[ 

232 obj_idxs[matched_src_idxs]] 

233 

234 return pipeBase.Struct( 

235 scores=scores, 

236 obj_idxs=obj_idxs, 

237 obj_ids=obj_ids) 

238 

239 def _make_spatial_tree(self, dia_objects): 

240 """Create a searchable kd-tree the input dia_object positions. 

241 

242 Parameters 

243 ---------- 

244 dia_objects : `pandas.DataFrame` 

245 A catalog of DIAObjects to create the tree from. 

246 

247 Returns 

248 ------- 

249 kd_tree : `scipy.spatical.cKDTree` 

250 Searchable kd-tree created from the positions of the DIAObjects. 

251 """ 

252 vectors = self._radec_to_xyz(dia_objects) 

253 return cKDTree(vectors) 

254 

255 def _radec_to_xyz(self, catalog): 

256 """Convert input ra/dec coordinates to spherical unit-vectors. 

257 

258 Parameters 

259 ---------- 

260 catalog : `pandas.DataFrame` 

261 Catalog to produce spherical unit-vector from. 

262 

263 Returns 

264 ------- 

265 vectors : `numpy.ndarray`, (N, 3) 

266 Output unit-vectors 

267 """ 

268 ras = np.radians(catalog["ra"]) 

269 decs = np.radians(catalog["dec"]) 

270 vectors = np.empty((len(ras), 3)) 

271 

272 sin_dec = np.sin(np.pi / 2 - decs) 

273 vectors[:, 0] = sin_dec * np.cos(ras) 

274 vectors[:, 1] = sin_dec * np.sin(ras) 

275 vectors[:, 2] = np.cos(np.pi / 2 - decs) 

276 

277 return vectors 

278 

279 @timeMethod 

280 def match(self, dia_objects, dia_sources, score_struct): 

281 """Match DIAsources to DiaObjects given a score. 

282 

283 Parameters 

284 ---------- 

285 dia_objects : `pandas.DataFrame` 

286 A SourceCatalog of DIAObjects to associate to DIASources. 

287 dia_sources : `pandas.DataFrame` 

288 A contiguous catalog of dia_sources for which the set of scores 

289 has been computed on with DIAObjectCollection.score. 

290 score_struct : `lsst.pipe.base.Struct` 

291 Results struct with components: 

292 

293 - ``"scores"``: array of floats of match quality 

294 updated DIAObjects (array-like of `float`). 

295 - ``"obj_ids"``: array of floats of match quality 

296 updated DIAObjects (array-like of `int`). 

297 - ``"obj_idxs"``: indexes of the matched DIAObjects in the catalog. 

298 (array-like of `int`) 

299 

300 Default values for these arrays are 

301 INF, -1 and -1 respectively for unassociated sources. 

302 

303 Returns 

304 ------- 

305 result : `lsst.pipe.base.Struct` 

306 Results struct with components. 

307 

308 - ``"diaSources"`` : Full set of diaSources both matched and not. 

309 (`pandas.DataFrame`) 

310 - ``"nUpdatedDiaObjects"`` : Number of DiaObjects that were 

311 associated. (`int`) 

312 - ``"nUnassociatedDiaObjects"`` : Number of DiaObjects that were 

313 not matched a new DiaSource. (`int`) 

314 """ 

315 n_previous_dia_objects = len(dia_objects) 

316 used_dia_object = np.zeros(n_previous_dia_objects, dtype=bool) 

317 used_dia_source = np.zeros(len(dia_sources), dtype=bool) 

318 associated_dia_object_ids = np.zeros(len(dia_sources), 

319 dtype=np.uint64) 

320 n_updated_dia_objects = 0 

321 

322 # We sort from best match to worst to effectively perform a 

323 # "handshake" match where both the DIASources and DIAObjects agree 

324 # their the best match. By sorting this way, scores with NaN (those 

325 # sources that have no match and will create new DIAObjects) will be 

326 # placed at the end of the array. 

327 score_args = score_struct.scores.argsort(axis=None) 

328 for score_idx in score_args: 

329 if not np.isfinite(score_struct.scores[score_idx]): 

330 # Thanks to the sorting the rest of the sources will be 

331 # NaN for their score. We therefore exit the loop to append 

332 # sources to a existing DIAObject, leaving these for 

333 # the loop creating new objects. 

334 break 

335 dia_obj_idx = score_struct.obj_idxs[score_idx] 

336 if used_dia_object[dia_obj_idx]: 

337 continue 

338 used_dia_object[dia_obj_idx] = True 

339 used_dia_source[score_idx] = True 

340 obj_id = score_struct.obj_ids[score_idx] 

341 associated_dia_object_ids[score_idx] = obj_id 

342 dia_sources.loc[score_idx, "diaObjectId"] = obj_id 

343 n_updated_dia_objects += 1 

344 

345 return pipeBase.Struct( 

346 diaSources=dia_sources, 

347 nUpdatedDiaObjects=n_updated_dia_objects, 

348 nUnassociatedDiaObjects=(n_previous_dia_objects 

349 - n_updated_dia_objects))