Coverage for python/lsst/ap/association/association.py: 31%

97 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-24 11:22 +0000

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""A simple implementation of source association task for ap_verify. 

23""" 

24 

25__all__ = ["AssociationConfig", "AssociationTask"] 

26 

27import numpy as np 

28import pandas as pd 

29from scipy.spatial import cKDTree 

30 

31import lsst.geom as geom 

32import lsst.pex.config as pexConfig 

33import lsst.pipe.base as pipeBase 

34from lsst.utils.timer import timeMethod 

35from .trailedSourceFilter import TrailedSourceFilterTask 

36 

37# Enforce an error for unsafe column/array value setting in pandas. 

38pd.options.mode.chained_assignment = 'raise' 

39 

40 

41class AssociationConfig(pexConfig.Config): 

42 """Config class for AssociationTask. 

43 """ 

44 

45 maxDistArcSeconds = pexConfig.Field( 

46 dtype=float, 

47 doc="Maximum distance in arcseconds to test for a DIASource to be a " 

48 "match to a DIAObject.", 

49 default=1.0, 

50 ) 

51 

52 trailedSourceFilter = pexConfig.ConfigurableField( 

53 target=TrailedSourceFilterTask, 

54 doc="Subtask to remove long trailed sources based on catalog source " 

55 "morphological measurements.", 

56 ) 

57 

58 doTrailedSourceFilter = pexConfig.Field( 

59 doc="Run traildeSourceFilter to remove long trailed sources from " 

60 "output catalog.", 

61 dtype=bool, 

62 default=True, 

63 ) 

64 

65 

66class AssociationTask(pipeBase.Task): 

67 """Associate DIAOSources into existing DIAObjects. 

68 

69 This task performs the association of detected DIASources in a visit 

70 with the previous DIAObjects detected over time. It also creates new 

71 DIAObjects out of DIASources that cannot be associated with previously 

72 detected DIAObjects. 

73 """ 

74 

75 ConfigClass = AssociationConfig 

76 _DefaultName = "association" 

77 

78 def __init__(self, *args, **kwargs): 

79 super().__init__(*args, **kwargs) 

80 if self.config.doTrailedSourceFilter: 

81 self.makeSubtask("trailedSourceFilter") 

82 

83 @timeMethod 

84 def run(self, 

85 diaSources, 

86 diaObjects, 

87 exposure_time=None): 

88 """Associate the new DiaSources with existing DiaObjects. 

89 

90 Parameters 

91 ---------- 

92 diaSources : `pandas.DataFrame` 

93 New DIASources to be associated with existing DIAObjects. 

94 diaObjects : `pandas.DataFrame` 

95 Existing diaObjects from the Apdb. 

96 exposure_time : `float`, optional 

97 Exposure time from difference image. 

98 

99 Returns 

100 ------- 

101 result : `lsst.pipe.base.Struct` 

102 Results struct with components. 

103 

104 - ``matchedDiaSources`` : DiaSources that were matched. Matched 

105 Sources have their diaObjectId updated and set to the id of the 

106 diaObject they were matched to. (`pandas.DataFrame`) 

107 - ``unAssocDiaSources`` : DiaSources that were not matched. 

108 Unassociated sources have their diaObject set to 0 as they 

109 were not associated with any existing DiaObjects. 

110 (`pandas.DataFrame`) 

111 - ``nUpdatedDiaObjects`` : Number of DiaObjects that were 

112 matched to new DiaSources. (`int`) 

113 - ``nUnassociatedDiaObjects`` : Number of DiaObjects that were 

114 not matched a new DiaSource. (`int`) 

115 - ``longTrailedSources`` : DiaSources which have trail lengths 

116 greater than max_trail_length/second*exposure_time. 

117 (`pandas.DataFrame``) 

118 """ 

119 diaSources = self.check_dia_source_radec(diaSources) 

120 

121 if self.config.doTrailedSourceFilter: 

122 diaTrailedResult = self.trailedSourceFilter.run(diaSources, exposure_time) 

123 diaSources = diaTrailedResult.diaSources 

124 longTrailedSources = diaTrailedResult.longTrailedDiaSources 

125 

126 self.log.info("%i DiaSources exceed max_trail_length, dropping from source " 

127 "catalog." % len(diaTrailedResult.longTrailedDiaSources)) 

128 self.metadata.add("num_filtered", len(diaTrailedResult.longTrailedDiaSources)) 

129 else: 

130 longTrailedSources = pd.DataFrame(columns=diaSources.columns) 

131 

132 if len(diaObjects) == 0: 

133 return pipeBase.Struct( 

134 matchedDiaSources=pd.DataFrame(columns=diaSources.columns), 

135 unAssocDiaSources=diaSources, 

136 nUpdatedDiaObjects=0, 

137 nUnassociatedDiaObjects=0, 

138 longTrailedSources=longTrailedSources) 

139 

140 matchResult = self.associate_sources(diaObjects, diaSources) 

141 

142 mask = matchResult.diaSources["diaObjectId"] != 0 

143 

144 return pipeBase.Struct( 

145 matchedDiaSources=matchResult.diaSources[mask].reset_index(drop=True), 

146 unAssocDiaSources=matchResult.diaSources[~mask].reset_index(drop=True), 

147 nUpdatedDiaObjects=matchResult.nUpdatedDiaObjects, 

148 nUnassociatedDiaObjects=matchResult.nUnassociatedDiaObjects, 

149 longTrailedSources=longTrailedSources) 

150 

151 def check_dia_source_radec(self, dia_sources): 

152 """Check that all DiaSources have non-NaN values for RA/DEC. 

153 

154 If one or more DiaSources are found to have NaN values, throw a 

155 warning to the log with the ids of the offending sources. Drop them 

156 from the table. 

157 

158 Parameters 

159 ---------- 

160 dia_sources : `pandas.DataFrame` 

161 Input DiaSources to check for NaN values. 

162 

163 Returns 

164 ------- 

165 trimmed_sources : `pandas.DataFrame` 

166 DataFrame of DiaSources trimmed of all entries with NaN values for 

167 RA/DEC. 

168 """ 

169 nan_mask = (dia_sources.loc[:, "ra"].isnull() 

170 | dia_sources.loc[:, "dec"].isnull()) 

171 if np.any(nan_mask): 

172 nan_idxs = np.argwhere(nan_mask.to_numpy()).flatten() 

173 for nan_idx in nan_idxs: 

174 self.log.warning( 

175 "DiaSource %i has NaN value for RA/DEC, " 

176 "dropping from association." % 

177 dia_sources.loc[nan_idx, "diaSourceId"]) 

178 dia_sources = dia_sources[~nan_mask] 

179 return dia_sources 

180 

181 @timeMethod 

182 def associate_sources(self, dia_objects, dia_sources): 

183 """Associate the input DIASources with the catalog of DIAObjects. 

184 

185 DiaObject DataFrame must be indexed on ``diaObjectId``. 

186 

187 Parameters 

188 ---------- 

189 dia_objects : `pandas.DataFrame` 

190 Catalog of DIAObjects to attempt to associate the input 

191 DIASources into. 

192 dia_sources : `pandas.DataFrame` 

193 DIASources to associate into the DIAObjectCollection. 

194 

195 Returns 

196 ------- 

197 result : `lsst.pipe.base.Struct` 

198 Results struct with components. 

199 

200 - ``diaSources`` : Full set of diaSources both matched and not. 

201 (`pandas.DataFrame`) 

202 - ``nUpdatedDiaObjects`` : Number of DiaObjects that were 

203 associated. (`int`) 

204 - ``nUnassociatedDiaObjects`` : Number of DiaObjects that were 

205 not matched a new DiaSource. (`int`) 

206 """ 

207 scores = self.score( 

208 dia_objects, dia_sources, 

209 self.config.maxDistArcSeconds * geom.arcseconds) 

210 match_result = self.match(dia_objects, dia_sources, scores) 

211 

212 return match_result 

213 

214 @timeMethod 

215 def score(self, dia_objects, dia_sources, max_dist): 

216 """Compute a quality score for each dia_source/dia_object pair 

217 between this catalog of DIAObjects and the input DIASource catalog. 

218 

219 ``max_dist`` sets maximum separation in arcseconds to consider a 

220 dia_source a possible match to a dia_object. If the pair is 

221 beyond this distance no score is computed. 

222 

223 Parameters 

224 ---------- 

225 dia_objects : `pandas.DataFrame` 

226 A contiguous catalog of DIAObjects to score against dia_sources. 

227 dia_sources : `pandas.DataFrame` 

228 A contiguous catalog of dia_sources to "score" based on distance 

229 and (in the future) other metrics. 

230 max_dist : `lsst.geom.Angle` 

231 Maximum allowed distance to compute a score for a given DIAObject 

232 DIASource pair. 

233 

234 Returns 

235 ------- 

236 result : `lsst.pipe.base.Struct` 

237 Results struct with components: 

238 

239 - ``scores``: array of floats of match quality updated DIAObjects 

240 (array-like of `float`). 

241 - ``obj_idxs``: indexes of the matched DIAObjects in the catalog. 

242 (array-like of `int`) 

243 - ``obj_ids``: array of floats of match quality updated DIAObjects 

244 (array-like of `int`). 

245 

246 Default values for these arrays are 

247 INF, -1, and -1 respectively for unassociated sources. 

248 """ 

249 scores = np.full(len(dia_sources), np.inf, dtype=np.float64) 

250 obj_idxs = np.full(len(dia_sources), -1, dtype=np.int64) 

251 obj_ids = np.full(len(dia_sources), 0, dtype=np.int64) 

252 

253 if len(dia_objects) == 0: 

254 return pipeBase.Struct( 

255 scores=scores, 

256 obj_idxs=obj_idxs, 

257 obj_ids=obj_ids) 

258 

259 spatial_tree = self._make_spatial_tree(dia_objects) 

260 

261 max_dist_rad = max_dist.asRadians() 

262 

263 vectors = self._radec_to_xyz(dia_sources) 

264 

265 scores, obj_idxs = spatial_tree.query( 

266 vectors, 

267 distance_upper_bound=max_dist_rad) 

268 matched_src_idxs = np.argwhere(np.isfinite(scores)) 

269 obj_ids[matched_src_idxs] = dia_objects.index.to_numpy()[ 

270 obj_idxs[matched_src_idxs]] 

271 

272 return pipeBase.Struct( 

273 scores=scores, 

274 obj_idxs=obj_idxs, 

275 obj_ids=obj_ids) 

276 

277 def _make_spatial_tree(self, dia_objects): 

278 """Create a searchable kd-tree the input dia_object positions. 

279 

280 Parameters 

281 ---------- 

282 dia_objects : `pandas.DataFrame` 

283 A catalog of DIAObjects to create the tree from. 

284 

285 Returns 

286 ------- 

287 kd_tree : `scipy.spatical.cKDTree` 

288 Searchable kd-tree created from the positions of the DIAObjects. 

289 """ 

290 vectors = self._radec_to_xyz(dia_objects) 

291 return cKDTree(vectors) 

292 

293 def _radec_to_xyz(self, catalog): 

294 """Convert input ra/dec coordinates to spherical unit-vectors. 

295 

296 Parameters 

297 ---------- 

298 catalog : `pandas.DataFrame` 

299 Catalog to produce spherical unit-vector from. 

300 

301 Returns 

302 ------- 

303 vectors : `numpy.ndarray`, (N, 3) 

304 Output unit-vectors 

305 """ 

306 ras = np.radians(catalog["ra"]) 

307 decs = np.radians(catalog["dec"]) 

308 vectors = np.empty((len(ras), 3)) 

309 

310 sin_dec = np.sin(np.pi / 2 - decs) 

311 vectors[:, 0] = sin_dec * np.cos(ras) 

312 vectors[:, 1] = sin_dec * np.sin(ras) 

313 vectors[:, 2] = np.cos(np.pi / 2 - decs) 

314 

315 return vectors 

316 

317 @timeMethod 

318 def match(self, dia_objects, dia_sources, score_struct): 

319 """Match DIAsources to DiaObjects given a score. 

320 

321 Parameters 

322 ---------- 

323 dia_objects : `pandas.DataFrame` 

324 A SourceCatalog of DIAObjects to associate to DIASources. 

325 dia_sources : `pandas.DataFrame` 

326 A contiguous catalog of dia_sources for which the set of scores 

327 has been computed on with DIAObjectCollection.score. 

328 score_struct : `lsst.pipe.base.Struct` 

329 Results struct with components: 

330 

331 - ``"scores"``: array of floats of match quality 

332 updated DIAObjects (array-like of `float`). 

333 - ``"obj_ids"``: array of floats of match quality 

334 updated DIAObjects (array-like of `int`). 

335 - ``"obj_idxs"``: indexes of the matched DIAObjects in the catalog. 

336 (array-like of `int`) 

337 

338 Default values for these arrays are 

339 INF, -1 and -1 respectively for unassociated sources. 

340 

341 Returns 

342 ------- 

343 result : `lsst.pipe.base.Struct` 

344 Results struct with components. 

345 

346 - ``"diaSources"`` : Full set of diaSources both matched and not. 

347 (`pandas.DataFrame`) 

348 - ``"nUpdatedDiaObjects"`` : Number of DiaObjects that were 

349 associated. (`int`) 

350 - ``"nUnassociatedDiaObjects"`` : Number of DiaObjects that were 

351 not matched a new DiaSource. (`int`) 

352 """ 

353 n_previous_dia_objects = len(dia_objects) 

354 used_dia_object = np.zeros(n_previous_dia_objects, dtype=bool) 

355 used_dia_source = np.zeros(len(dia_sources), dtype=bool) 

356 associated_dia_object_ids = np.zeros(len(dia_sources), 

357 dtype=np.uint64) 

358 n_updated_dia_objects = 0 

359 

360 # We sort from best match to worst to effectively perform a 

361 # "handshake" match where both the DIASources and DIAObjects agree 

362 # their the best match. By sorting this way, scores with NaN (those 

363 # sources that have no match and will create new DIAObjects) will be 

364 # placed at the end of the array. 

365 score_args = score_struct.scores.argsort(axis=None) 

366 for score_idx in score_args: 

367 if not np.isfinite(score_struct.scores[score_idx]): 

368 # Thanks to the sorting the rest of the sources will be 

369 # NaN for their score. We therefore exit the loop to append 

370 # sources to a existing DIAObject, leaving these for 

371 # the loop creating new objects. 

372 break 

373 dia_obj_idx = score_struct.obj_idxs[score_idx] 

374 if used_dia_object[dia_obj_idx]: 

375 continue 

376 used_dia_object[dia_obj_idx] = True 

377 used_dia_source[score_idx] = True 

378 obj_id = score_struct.obj_ids[score_idx] 

379 associated_dia_object_ids[score_idx] = obj_id 

380 dia_sources.loc[score_idx, "diaObjectId"] = obj_id 

381 n_updated_dia_objects += 1 

382 

383 return pipeBase.Struct( 

384 diaSources=dia_sources, 

385 nUpdatedDiaObjects=n_updated_dia_objects, 

386 nUnassociatedDiaObjects=(n_previous_dia_objects 

387 - n_updated_dia_objects))