Coverage for python/lsst/ap/association/association.py: 32%

94 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-19 11:23 +0000

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""A simple implementation of source association task for ap_verify. 

23""" 

24 

25__all__ = ["AssociationConfig", "AssociationTask"] 

26 

27import numpy as np 

28import pandas as pd 

29from scipy.spatial import cKDTree 

30 

31import lsst.geom as geom 

32import lsst.pex.config as pexConfig 

33import lsst.pipe.base as pipeBase 

34from lsst.utils.timer import timeMethod 

35from .trailedSourceFilter import TrailedSourceFilterTask 

36 

37# Enforce an error for unsafe column/array value setting in pandas. 

38pd.options.mode.chained_assignment = 'raise' 

39 

40 

41class AssociationConfig(pexConfig.Config): 

42 """Config class for AssociationTask. 

43 """ 

44 

45 maxDistArcSeconds = pexConfig.Field( 

46 dtype=float, 

47 doc="Maximum distance in arcseconds to test for a DIASource to be a " 

48 "match to a DIAObject.", 

49 default=1.0, 

50 ) 

51 

52 trailedSourceFilter = pexConfig.ConfigurableField( 

53 target=TrailedSourceFilterTask, 

54 doc="Subtask to remove long trailed sources based on catalog source " 

55 "morphological measurements.", 

56 ) 

57 

58 doTrailedSourceFilter = pexConfig.Field( 

59 doc="Run traildeSourceFilter to remove long trailed sources from " 

60 "output catalog.", 

61 dtype=bool, 

62 default=True, 

63 ) 

64 

65 

66class AssociationTask(pipeBase.Task): 

67 """Associate DIAOSources into existing DIAObjects. 

68 

69 This task performs the association of detected DIASources in a visit 

70 with the previous DIAObjects detected over time. It also creates new 

71 DIAObjects out of DIASources that cannot be associated with previously 

72 detected DIAObjects. 

73 """ 

74 

75 ConfigClass = AssociationConfig 

76 _DefaultName = "association" 

77 

78 def __init__(self, *args, **kwargs): 

79 super().__init__(*args, **kwargs) 

80 if self.config.doTrailedSourceFilter: 

81 self.makeSubtask("trailedSourceFilter") 

82 

83 @timeMethod 

84 def run(self, 

85 diaSources, 

86 diaObjects, 

87 exposure_time=None): 

88 """Associate the new DiaSources with existing DiaObjects. 

89 

90 Parameters 

91 ---------- 

92 diaSources : `pandas.DataFrame` 

93 New DIASources to be associated with existing DIAObjects. 

94 diaObjects : `pandas.DataFrame` 

95 Existing diaObjects from the Apdb. 

96 exposure_time : `float`, optional 

97 Exposure time from difference image. 

98 

99 Returns 

100 ------- 

101 result : `lsst.pipe.base.Struct` 

102 Results struct with components. 

103 

104 - ``matchedDiaSources`` : DiaSources that were matched. Matched 

105 Sources have their diaObjectId updated and set to the id of the 

106 diaObject they were matched to. (`pandas.DataFrame`) 

107 - ``unAssocDiaSources`` : DiaSources that were not matched. 

108 Unassociated sources have their diaObject set to 0 as they 

109 were not associated with any existing DiaObjects. 

110 (`pandas.DataFrame`) 

111 - ``nUpdatedDiaObjects`` : Number of DiaObjects that were 

112 matched to new DiaSources. (`int`) 

113 - ``nUnassociatedDiaObjects`` : Number of DiaObjects that were 

114 not matched a new DiaSource. (`int`) 

115 """ 

116 diaSources = self.check_dia_source_radec(diaSources) 

117 if len(diaObjects) == 0: 

118 return pipeBase.Struct( 

119 matchedDiaSources=pd.DataFrame(columns=diaSources.columns), 

120 unAssocDiaSources=diaSources, 

121 nUpdatedDiaObjects=0, 

122 nUnassociatedDiaObjects=0) 

123 

124 if self.config.doTrailedSourceFilter: 

125 diaTrailedResult = self.trailedSourceFilter.run(diaSources, exposure_time) 

126 matchResult = self.associate_sources(diaObjects, diaTrailedResult.diaSources) 

127 

128 self.log.info("%i DIASources exceed max_trail_length, dropping " 

129 "from source catalog." % len(diaTrailedResult.trailedDiaSources)) 

130 

131 else: 

132 matchResult = self.associate_sources(diaObjects, diaSources) 

133 

134 mask = matchResult.diaSources["diaObjectId"] != 0 

135 

136 return pipeBase.Struct( 

137 matchedDiaSources=matchResult.diaSources[mask].reset_index(drop=True), 

138 unAssocDiaSources=matchResult.diaSources[~mask].reset_index(drop=True), 

139 nUpdatedDiaObjects=matchResult.nUpdatedDiaObjects, 

140 nUnassociatedDiaObjects=matchResult.nUnassociatedDiaObjects) 

141 

142 def check_dia_source_radec(self, dia_sources): 

143 """Check that all DiaSources have non-NaN values for RA/DEC. 

144 

145 If one or more DiaSources are found to have NaN values, throw a 

146 warning to the log with the ids of the offending sources. Drop them 

147 from the table. 

148 

149 Parameters 

150 ---------- 

151 dia_sources : `pandas.DataFrame` 

152 Input DiaSources to check for NaN values. 

153 

154 Returns 

155 ------- 

156 trimmed_sources : `pandas.DataFrame` 

157 DataFrame of DiaSources trimmed of all entries with NaN values for 

158 RA/DEC. 

159 """ 

160 nan_mask = (dia_sources.loc[:, "ra"].isnull() 

161 | dia_sources.loc[:, "dec"].isnull()) 

162 if np.any(nan_mask): 

163 nan_idxs = np.argwhere(nan_mask.to_numpy()).flatten() 

164 for nan_idx in nan_idxs: 

165 self.log.warning( 

166 "DiaSource %i has NaN value for RA/DEC, " 

167 "dropping from association." % 

168 dia_sources.loc[nan_idx, "diaSourceId"]) 

169 dia_sources = dia_sources[~nan_mask] 

170 return dia_sources 

171 

172 @timeMethod 

173 def associate_sources(self, dia_objects, dia_sources): 

174 """Associate the input DIASources with the catalog of DIAObjects. 

175 

176 DiaObject DataFrame must be indexed on ``diaObjectId``. 

177 

178 Parameters 

179 ---------- 

180 dia_objects : `pandas.DataFrame` 

181 Catalog of DIAObjects to attempt to associate the input 

182 DIASources into. 

183 dia_sources : `pandas.DataFrame` 

184 DIASources to associate into the DIAObjectCollection. 

185 

186 Returns 

187 ------- 

188 result : `lsst.pipe.base.Struct` 

189 Results struct with components. 

190 

191 - ``diaSources`` : Full set of diaSources both matched and not. 

192 (`pandas.DataFrame`) 

193 - ``nUpdatedDiaObjects`` : Number of DiaObjects that were 

194 associated. (`int`) 

195 - ``nUnassociatedDiaObjects`` : Number of DiaObjects that were 

196 not matched a new DiaSource. (`int`) 

197 """ 

198 scores = self.score( 

199 dia_objects, dia_sources, 

200 self.config.maxDistArcSeconds * geom.arcseconds) 

201 match_result = self.match(dia_objects, dia_sources, scores) 

202 

203 return match_result 

204 

205 @timeMethod 

206 def score(self, dia_objects, dia_sources, max_dist): 

207 """Compute a quality score for each dia_source/dia_object pair 

208 between this catalog of DIAObjects and the input DIASource catalog. 

209 

210 ``max_dist`` sets maximum separation in arcseconds to consider a 

211 dia_source a possible match to a dia_object. If the pair is 

212 beyond this distance no score is computed. 

213 

214 Parameters 

215 ---------- 

216 dia_objects : `pandas.DataFrame` 

217 A contiguous catalog of DIAObjects to score against dia_sources. 

218 dia_sources : `pandas.DataFrame` 

219 A contiguous catalog of dia_sources to "score" based on distance 

220 and (in the future) other metrics. 

221 max_dist : `lsst.geom.Angle` 

222 Maximum allowed distance to compute a score for a given DIAObject 

223 DIASource pair. 

224 

225 Returns 

226 ------- 

227 result : `lsst.pipe.base.Struct` 

228 Results struct with components: 

229 

230 - ``scores``: array of floats of match quality updated DIAObjects 

231 (array-like of `float`). 

232 - ``obj_idxs``: indexes of the matched DIAObjects in the catalog. 

233 (array-like of `int`) 

234 - ``obj_ids``: array of floats of match quality updated DIAObjects 

235 (array-like of `int`). 

236 

237 Default values for these arrays are 

238 INF, -1, and -1 respectively for unassociated sources. 

239 """ 

240 scores = np.full(len(dia_sources), np.inf, dtype=np.float64) 

241 obj_idxs = np.full(len(dia_sources), -1, dtype=np.int64) 

242 obj_ids = np.full(len(dia_sources), 0, dtype=np.int64) 

243 

244 if len(dia_objects) == 0: 

245 return pipeBase.Struct( 

246 scores=scores, 

247 obj_idxs=obj_idxs, 

248 obj_ids=obj_ids) 

249 

250 spatial_tree = self._make_spatial_tree(dia_objects) 

251 

252 max_dist_rad = max_dist.asRadians() 

253 

254 vectors = self._radec_to_xyz(dia_sources) 

255 

256 scores, obj_idxs = spatial_tree.query( 

257 vectors, 

258 distance_upper_bound=max_dist_rad) 

259 matched_src_idxs = np.argwhere(np.isfinite(scores)) 

260 obj_ids[matched_src_idxs] = dia_objects.index.to_numpy()[ 

261 obj_idxs[matched_src_idxs]] 

262 

263 return pipeBase.Struct( 

264 scores=scores, 

265 obj_idxs=obj_idxs, 

266 obj_ids=obj_ids) 

267 

268 def _make_spatial_tree(self, dia_objects): 

269 """Create a searchable kd-tree the input dia_object positions. 

270 

271 Parameters 

272 ---------- 

273 dia_objects : `pandas.DataFrame` 

274 A catalog of DIAObjects to create the tree from. 

275 

276 Returns 

277 ------- 

278 kd_tree : `scipy.spatical.cKDTree` 

279 Searchable kd-tree created from the positions of the DIAObjects. 

280 """ 

281 vectors = self._radec_to_xyz(dia_objects) 

282 return cKDTree(vectors) 

283 

284 def _radec_to_xyz(self, catalog): 

285 """Convert input ra/dec coordinates to spherical unit-vectors. 

286 

287 Parameters 

288 ---------- 

289 catalog : `pandas.DataFrame` 

290 Catalog to produce spherical unit-vector from. 

291 

292 Returns 

293 ------- 

294 vectors : `numpy.ndarray`, (N, 3) 

295 Output unit-vectors 

296 """ 

297 ras = np.radians(catalog["ra"]) 

298 decs = np.radians(catalog["dec"]) 

299 vectors = np.empty((len(ras), 3)) 

300 

301 sin_dec = np.sin(np.pi / 2 - decs) 

302 vectors[:, 0] = sin_dec * np.cos(ras) 

303 vectors[:, 1] = sin_dec * np.sin(ras) 

304 vectors[:, 2] = np.cos(np.pi / 2 - decs) 

305 

306 return vectors 

307 

308 @timeMethod 

309 def match(self, dia_objects, dia_sources, score_struct): 

310 """Match DIAsources to DiaObjects given a score. 

311 

312 Parameters 

313 ---------- 

314 dia_objects : `pandas.DataFrame` 

315 A SourceCatalog of DIAObjects to associate to DIASources. 

316 dia_sources : `pandas.DataFrame` 

317 A contiguous catalog of dia_sources for which the set of scores 

318 has been computed on with DIAObjectCollection.score. 

319 score_struct : `lsst.pipe.base.Struct` 

320 Results struct with components: 

321 

322 - ``"scores"``: array of floats of match quality 

323 updated DIAObjects (array-like of `float`). 

324 - ``"obj_ids"``: array of floats of match quality 

325 updated DIAObjects (array-like of `int`). 

326 - ``"obj_idxs"``: indexes of the matched DIAObjects in the catalog. 

327 (array-like of `int`) 

328 

329 Default values for these arrays are 

330 INF, -1 and -1 respectively for unassociated sources. 

331 

332 Returns 

333 ------- 

334 result : `lsst.pipe.base.Struct` 

335 Results struct with components. 

336 

337 - ``"diaSources"`` : Full set of diaSources both matched and not. 

338 (`pandas.DataFrame`) 

339 - ``"nUpdatedDiaObjects"`` : Number of DiaObjects that were 

340 associated. (`int`) 

341 - ``"nUnassociatedDiaObjects"`` : Number of DiaObjects that were 

342 not matched a new DiaSource. (`int`) 

343 """ 

344 n_previous_dia_objects = len(dia_objects) 

345 used_dia_object = np.zeros(n_previous_dia_objects, dtype=bool) 

346 used_dia_source = np.zeros(len(dia_sources), dtype=bool) 

347 associated_dia_object_ids = np.zeros(len(dia_sources), 

348 dtype=np.uint64) 

349 n_updated_dia_objects = 0 

350 

351 # We sort from best match to worst to effectively perform a 

352 # "handshake" match where both the DIASources and DIAObjects agree 

353 # their the best match. By sorting this way, scores with NaN (those 

354 # sources that have no match and will create new DIAObjects) will be 

355 # placed at the end of the array. 

356 score_args = score_struct.scores.argsort(axis=None) 

357 for score_idx in score_args: 

358 if not np.isfinite(score_struct.scores[score_idx]): 

359 # Thanks to the sorting the rest of the sources will be 

360 # NaN for their score. We therefore exit the loop to append 

361 # sources to a existing DIAObject, leaving these for 

362 # the loop creating new objects. 

363 break 

364 dia_obj_idx = score_struct.obj_idxs[score_idx] 

365 if used_dia_object[dia_obj_idx]: 

366 continue 

367 used_dia_object[dia_obj_idx] = True 

368 used_dia_source[score_idx] = True 

369 obj_id = score_struct.obj_ids[score_idx] 

370 associated_dia_object_ids[score_idx] = obj_id 

371 dia_sources.loc[score_idx, "diaObjectId"] = obj_id 

372 n_updated_dia_objects += 1 

373 

374 return pipeBase.Struct( 

375 diaSources=dia_sources, 

376 nUpdatedDiaObjects=n_updated_dia_objects, 

377 nUnassociatedDiaObjects=(n_previous_dia_objects 

378 - n_updated_dia_objects))