Coverage for python / lsst / meas / algorithms / findGlintTrails.py: 20%

118 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-07 08:25 +0000

1# This file is part of meas_algorithms. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["FindGlintTrailsConfig", "FindGlintTrailsTask", "GlintTrailParameters"] 

23 

24import collections 

25import dataclasses 

26import math 

27 

28import numpy as np 

29import scipy.spatial 

30import sklearn.linear_model 

31 

32import lsst.afw.table 

33import lsst.pex.config 

34import lsst.pipe.base 

35 

36 

37class FindGlintTrailsConfig(lsst.pex.config.Config): 

38 radius = lsst.pex.config.Field( 

39 doc="Radius to search for glint trail candidates from each source (pixels).", 

40 dtype=float, 

41 default=500, 

42 ) 

43 min_points = lsst.pex.config.Field( 

44 doc="Minimum number of points to be considered a possible glint trail.", 

45 dtype=int, 

46 default=5, 

47 check=lambda x: x >= 3, 

48 ) 

49 threshold = lsst.pex.config.Field( 

50 doc="Maximum root mean squared deviation from a straight line (pixels).", 

51 dtype=float, 

52 default=15.0, 

53 ) 

54 seed = lsst.pex.config.Field( 

55 doc="Random seed for RANSAC fitter, to ensure stable fitting.", 

56 dtype=int, 

57 default=42, 

58 ) 

59 bad_flags = lsst.pex.config.ListField[str]( 

60 doc="Do not fit sources that have these flags set.", 

61 default=["ip_diffim_DipoleFit_classification", 

62 "is_negative", 

63 ], 

64 ) 

65 

66 

67@dataclasses.dataclass(frozen=True, kw_only=True) 

68class GlintTrailParameters: 

69 """Holds values from the line fit to a single glint trail.""" 

70 slope: float 

71 intercept: float 

72 stderr: float 

73 length: float # pixels 

74 angle: float # radians, from +X axis 

75 

76 

77class FindGlintTrailsTask(lsst.pipe.base.Task): 

78 """Find glint trails in a catalog by searching for sources that lie in a 

79 line. 

80 

81 Notes 

82 ----- 

83 For each source ("anchor") in the input catalog that was not included in 

84 an an earlier iteration as part of a trail: 

85 * Find all sources within a given radius. 

86 * For each pair of anchor and match, identify the other sources that 

87 could lie on the same line(s). 

88 * Take the longest set of such pairs as a candidate trail. 

89 * Fit a line to the identified pairs with the RANSAC algorithm. 

90 * Find all sources in the catalog that could lie on that line. 

91 * Refit a line to all of the matched sources. 

92 * If the error is below the threshold and the number of sources on the 

93 line is greater than the minimum, return the sources that were 

94 considered inliers during the fit, and the fit parameters. 

95 """ 

96 

97 ConfigClass = FindGlintTrailsConfig 

98 _DefaultName = "findGlintTrails" 

99 

100 def run(self, catalog): 

101 """Find glint trails in a catalog. 

102 

103 Parameters 

104 ---------- 

105 catalog : `lsst.afw.table.SourceCatalog` 

106 Catalog to search for glint trails. 

107 

108 Returns 

109 ------- 

110 result : `lsst.pipe.base.Struct` 

111 Results as a struct with attributes: 

112 

113 ``trails`` 

114 Catalog subsets containing sources in each trail that was found. 

115 (`list` [`lsst.afw.table.SourceCatalog`]) 

116 ``trailed_ids`` 

117 Ids of all the sources that were included in any fit trail. 

118 (`set` [`int`]) 

119 ``parameters`` 

120 Parameters of all the trails that were found. 

121 (`list` [`GlintTrailParameters`]) 

122 """ 

123 good_catalog = self._select_good_sources(catalog) 

124 

125 matches = lsst.afw.table.matchXy(good_catalog, self.config.radius) 

126 per_id = collections.defaultdict(list) 

127 for match in matches: 

128 per_id[match.first["id"]].append(match) 

129 counts = {id: len(value) for id, value in per_id.items()} 

130 

131 trails = [] 

132 parameters = [] 

133 trailed_ids = set() 

134 # Search starting with the source with the largest number of matches. 

135 for id in dict(sorted(counts.items(), key=lambda item: item[1], reverse=True)): 

136 # Don't search this point if it was already included in a trail. 

137 if counts[id] < self.config.min_points or id in trailed_ids: 

138 continue 

139 

140 self.log.debug("id=%d at %.1f,%.1f has %d matches within %d pixels.", 

141 id, 

142 per_id[id][0].first.getX(), 

143 per_id[id][0].first.getY(), 

144 counts[id], 

145 self.config.radius) 

146 if (trail := self._search_one(per_id[id], good_catalog)) is not None: 

147 trail, result = trail 

148 # Check that we didn't already find this trail. 

149 n_new = len(set(trail["id"]).difference(trailed_ids)) 

150 if n_new > 0: 

151 self.log.info("Found %.1f pixel length trail with %d points, " 

152 "%d not in any other trail (slope=%.4f, intercept=%.2f)", 

153 result.length, len(trail), n_new, result.slope, result.intercept) 

154 trails.append(trail) 

155 trailed_ids.update(trail["id"]) 

156 parameters.append(result) 

157 

158 self.log.info("Found %d glint trails containing %d total sources.", 

159 len(trails), len(trailed_ids)) 

160 return lsst.pipe.base.Struct(trails=trails, 

161 trailed_ids=trailed_ids, 

162 parameters=parameters) 

163 

164 def _select_good_sources(self, catalog): 

165 """Return sources that could possibly be in a glint trail, i.e. ones 

166 that do not have bad flags set. 

167 

168 Parameters 

169 ---------- 

170 catalog : `lsst.afw.table.SourceCatalog` 

171 Original catalog to be selected from. 

172 

173 Returns 

174 ------- 

175 good_catalog : `lsst.afw.table.SourceCatalog` 

176 Catalog that has had bad sources removed. 

177 """ 

178 bad = np.zeros(len(catalog), dtype=bool) 

179 for flag in self.config.bad_flags: 

180 bad |= catalog[flag] 

181 return catalog[~bad] 

182 

183 def _search_one(self, matches, catalog): 

184 """Search one set of matches for a possible trail. 

185 

186 Parameters 

187 ---------- 

188 matches : `list` [`lsst.afw.table.Match`] 

189 Matches for one anchor source to search for lines. 

190 catalog : `lsst.afw.SourceCatalog` 

191 Catalog of all sources, to refit lines to. 

192 

193 Returns 

194 ------- 

195 trail, result : `tuple` or None 

196 If the no trails matching the criteria are found, return None, 

197 otherwise return a tuple of the sources in the trail and the 

198 trail parameters. 

199 """ 

200 components = collections.defaultdict(list) 

201 # Normalized distances from the first record to all the others. 

202 xy_deltas = {pair.second["id"]: (pair.second.getX() - pair.first.getX(), 

203 pair.second.getY() - pair.first.getY()) for pair in matches} 

204 

205 # Find all sets of pairs from this anchor that could lie on a line. 

206 for i, (id1, pair1) in enumerate(xy_deltas.items()): 

207 distance = math.sqrt(pair1[0]**2 + pair1[1]**2) 

208 for j, (id2, pair2) in enumerate(xy_deltas.items()): 

209 if i == j: 

210 continue 

211 delta = abs(pair1[0] * pair2[1] - pair1[1] * pair2[0]) 

212 # 2x threshold to search more broadly; will be refined later. 

213 if delta / distance < 2 * self.config.threshold: 

214 components[i].append(j) 

215 

216 # There are no lines with at least 3 components. 

217 if len(components) == 0: 

218 return None 

219 

220 longest, value = max(components.items(), key=lambda x: len(x[1])) 

221 n_points = len(value) 

222 n_points += 2 # to account for the base source and the first pair 

223 if n_points < self.config.min_points: 

224 return None 

225 

226 candidate = [longest] + components[longest] 

227 trail, result = self._other_points(n_points, candidate, matches, catalog) 

228 

229 if trail is None or len(trail) < self.config.min_points: 

230 return None 

231 if result.stderr > self.config.threshold: 

232 self.log.info("Candidate trail with %d sources rejected with stderr %.6f > %.3f", 

233 len(trail), result.stderr, self.config.threshold) 

234 return None 

235 else: 

236 return trail, result 

237 

238 def _other_points(self, n_points, indexes, matches, catalog): 

239 """Find all catalog records that could lie on this line. 

240 

241 Parameters 

242 ---------- 

243 n_points : `int` 

244 Number of sources in this candidate trail. 

245 indexes : `list` [`int`] 

246 Indexes into matches on this candidate trail. 

247 matches : `list` [`lsst.afw.table.Match`] 

248 Matches for one anchor sources to search for lines. 

249 catalog : `lsst.afw.SourceCatalog` 

250 Catalog of all sources, to refit lines to. 

251 

252 Returns 

253 ------- 

254 trail : `lsst.afw.table.SourceCatalog` 

255 Sources that are in the fitted trail. 

256 result : `GlintTrailParameters` 

257 Parameters of the fitted trail. 

258 """ 

259 

260 def extract(fitter, x, y, prefix=""): 

261 """Extract values from the fit and log and return them.""" 

262 x = x[fitter.inlier_mask_] 

263 y = y[fitter.inlier_mask_] 

264 predicted = fitter.predict(x).flatten() 

265 stderr = math.sqrt(((predicted - y.flatten())**2).sum()) 

266 m, b = fitter.estimator_.coef_[0][0], fitter.estimator_.intercept_[0] 

267 self.log.debug("%s fit: score=%.6f, stderr=%.6f, inliers/total=%d/%d", 

268 prefix, fitter.score(x, y), stderr, sum(fitter.inlier_mask_), len(x)) 

269 # Simple O(N^2) search for longest distance; there will never be 

270 # enough points in a trail a for "faster" approach to be worth it. 

271 length = max(scipy.spatial.distance.pdist(np.hstack((x, y))), default=0) 

272 angle = math.atan(m) 

273 return GlintTrailParameters(slope=m, intercept=b, stderr=stderr, length=length, angle=angle) 

274 

275 # min_samples=2 is necessary here for some sets of only 5 matches, 

276 # otherwise we sometimes get "UndefinedMetricWarning: R^2 score is not 

277 # well-defined with less than two samples" from RANSAC. 

278 fitter = sklearn.linear_model.RANSACRegressor(residual_threshold=self.config.threshold, 

279 loss="squared_error", 

280 random_state=self.config.seed, 

281 min_samples=2) 

282 

283 # The (-1,1) shape is to keep sklearn happy. 

284 x = np.empty(n_points).reshape(-1, 1) 

285 x[0] = matches[0].first.getX() 

286 x[1:, 0] = [matches[i].second.getX() for i in indexes] 

287 y = np.empty(n_points).reshape(-1, 1) 

288 y[0] = matches[0].first.getY() 

289 y[1:, 0] = [matches[i].second.getY() for i in indexes] 

290 

291 try: 

292 fitter.fit(x, y) 

293 except ValueError: 

294 self.log.info("Glint trail interpolation could not find a valid fit.") 

295 return None, None 

296 else: 

297 result = extract(fitter, x, y, prefix="preliminary") 

298 # Reject trails that have too many outliers after the first fit. 

299 if (n_inliers := sum(fitter.inlier_mask_)) < self.config.min_points: 

300 self.log.debug("Candidate trail rejected with %d < %d points.", 

301 n_inliers, self.config.min_points) 

302 return None, None 

303 

304 # Find all points that are close to this line and refit with them. 

305 x = catalog["slot_Centroid_x"] 

306 y = catalog["slot_Centroid_y"] 

307 dist = abs(result.intercept + result.slope * x - y) / math.sqrt(1 + result.slope**2) 

308 # 2x threshold to search more broadly: outlier rejection may change 

309 # the line parameters some and we want to grab all candidates here. 

310 candidates = (dist < 2 * self.config.threshold).flatten() 

311 # min_samples>2 should make the fit more stable. 

312 fitter = sklearn.linear_model.RANSACRegressor(residual_threshold=self.config.threshold, 

313 loss="squared_error", 

314 random_state=self.config.seed, 

315 min_samples=3) 

316 # The (-1,1) shape is to keep sklearn happy. 

317 x = x[candidates].reshape(-1, 1) 

318 y = y[candidates].reshape(-1, 1) 

319 fitter.fit(x, y) 

320 result = extract(fitter, x, y, prefix="final") 

321 

322 return catalog[candidates][fitter.inlier_mask_], result