Coverage for python/lsst/meas/astrom/matcher_probabilistic.py: 29%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

149 statements  

1# This file is part of meas_astrom. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ['MatchProbabilisticConfig', 'MatcherProbabilistic'] 

23 

24import lsst.pex.config as pexConfig 

25 

26from dataclasses import dataclass 

27import logging 

28import numpy as np 

29import pandas as pd 

30from scipy.spatial import cKDTree 

31import time 

32from typing import Set 

33 

34logger_default = logging.getLogger(__name__) 

35 

36 

37def _mul_column(column: np.array, value: float): 

38 if value is not None and value != 1: 

39 column *= value 

40 return column 

41 

42 

43def _radec_to_xyz(ra, dec): 

44 """Convert input ra/dec coordinates to spherical unit vectors. 

45 

46 Parameters 

47 ---------- 

48 ra, dec: `numpy.ndarray` 

49 Arrays of right ascension/declination in degrees. 

50 

51 Returns 

52 ------- 

53 vectors : `numpy.ndarray`, (N, 3) 

54 Output unit vectors. 

55 """ 

56 if ra.size != dec.size: 

57 raise ValueError('ra and dec must be same size') 

58 ras = np.radians(ra) 

59 decs = np.radians(dec) 

60 vectors = np.empty((ras.size, 3)) 

61 

62 sin_dec = np.sin(np.pi / 2 - decs) 

63 vectors[:, 0] = sin_dec * np.cos(ras) 

64 vectors[:, 1] = sin_dec * np.sin(ras) 

65 vectors[:, 2] = np.cos(np.pi / 2 - decs) 

66 

67 return vectors 

68 

69 

70class MatchProbabilisticConfig(pexConfig.Config): 

71 """Configuration for the MatchProbabilistic matcher. 

72 """ 

73 column_order = pexConfig.Field( 

74 dtype=str, 

75 default=None, 

76 optional=True, 

77 doc="Column to sort fit by. Derived from columns_flux if not set.", 

78 ) 

79 column_ref_coord1 = pexConfig.Field( 

80 dtype=str, 

81 default='ra', 

82 doc='The reference table column for the first spatial coordinate (usually x or ra).', 

83 ) 

84 column_ref_coord2 = pexConfig.Field( 

85 dtype=str, 

86 default='dec', 

87 doc='The reference table column for the second spatial coordinate (usually y or dec).' 

88 'Units must match column_ref_coord1.', 

89 ) 

90 column_target_coord1 = pexConfig.Field( 

91 dtype=str, 

92 default='coord_ra', 

93 doc='The target table column for the first spatial coordinate (usually x or ra).' 

94 'Units must match column_ref_coord1.', 

95 ) 

96 column_target_coord2 = pexConfig.Field( 

97 dtype=str, 

98 default='coord_dec', 

99 doc='The target table column for the second spatial coordinate (usually y or dec).' 

100 'Units must match column_ref_coord2.', 

101 ) 

102 

103 @property 

104 def columns_in_ref(self) -> Set[str]: 

105 columns_all = [x for x in self.columns_ref_flux] 

106 for columns in ( 

107 [self.column_ref_coord1, self.column_ref_coord2], 

108 self.columns_ref_meas, 

109 ): 

110 columns_all.extend(columns) 

111 

112 return set(columns_all) 

113 

114 @property 

115 def columns_in_target(self) -> Set[str]: 

116 columns_all = [self.column_target_coord1, self.column_target_coord2] 

117 for columns in ( 

118 self.columns_target_meas, 

119 self.columns_target_err, 

120 self.columns_target_select_false, 

121 self.columns_target_select_true, 

122 ): 

123 columns_all.extend(columns) 

124 return set(columns_all) 

125 

126 columns_ref_flux = pexConfig.ListField( 126 ↛ exitline 126 didn't jump to the function exit

127 dtype=str, 

128 default=[], 

129 listCheck=lambda x: len(set(x)) == len(x), 

130 optional=True, 

131 doc="List of reference flux columns to nansum total magnitudes from if column_order is None", 

132 ) 

133 columns_ref_meas = pexConfig.ListField( 

134 dtype=str, 

135 doc='The reference table columns to compute match likelihoods from ' 

136 '(usually centroids and fluxes/magnitudes)', 

137 ) 

138 columns_target_meas = pexConfig.ListField( 

139 dtype=str, 

140 doc='Target table columns with measurements corresponding to columns_ref_meas', 

141 ) 

142 columns_target_err = pexConfig.ListField( 

143 dtype=str, 

144 doc='Target table columns with standard errors (sigma) corresponding to columns_ref_meas', 

145 ) 

146 columns_target_select_true = pexConfig.ListField( 

147 dtype=str, 

148 default=('detect_isPrimary',), 

149 doc='Target table columns to require to be True for selecting match candidates', 

150 ) 

151 columns_target_select_false = pexConfig.ListField( 

152 dtype=str, 

153 default=('merge_peak_sky',), 

154 doc='Target table columns to require to be False for selecting match candidates', 

155 ) 

156 

157 coords_spherical = pexConfig.Field( 

158 dtype=bool, 

159 default=True, 

160 doc='Whether column_*_coord[12] are spherical coordinates (ra/dec) or not (pixel x/y)', 

161 ) 

162 coords_ref_factor = pexConfig.Field( 

163 dtype=float, 

164 default=1.0, 

165 doc='Multiplicative factor for reference catalog coordinates.' 

166 'If coords_spherical is true, this must be the number of degrees per unit increment of ' 

167 'column_ref_coord[12]. Otherwise, it must convert the coordinate to the same units' 

168 ' as the target coordinates.', 

169 ) 

170 coords_target_factor = pexConfig.Field( 

171 dtype=float, 

172 default=1.0, 

173 doc='Multiplicative factor for target catalog coordinates.' 

174 'If coords_spherical is true, this must be the number of degrees per unit increment of ' 

175 'column_target_coord[12]. Otherwise, it must convert the coordinate to the same units' 

176 ' as the reference coordinates.', 

177 ) 

178 coords_ref_to_convert = pexConfig.DictField( 178 ↛ exitline 178 didn't jump to the function exit

179 default=None, 

180 keytype=str, 

181 itemtype=str, 

182 dictCheck=lambda x: len(x) == 2, 

183 doc='Dict mapping sky coordinate columns to be converted to tract pixel columns', 

184 ) 

185 mag_brightest_ref = pexConfig.Field( 

186 dtype=float, 

187 default=-np.inf, 

188 doc='Bright magnitude cutoff for selecting reference sources to match.' 

189 ' Ignored if column_order is None.' 

190 ) 

191 mag_faintest_ref = pexConfig.Field( 

192 dtype=float, 

193 default=np.Inf, 

194 doc='Faint magnitude cutoff for selecting reference sources to match.' 

195 ' Ignored if column_order is None.' 

196 ) 

197 mag_zeropoint_ref = pexConfig.Field( 

198 dtype=float, 

199 default=31.4, 

200 doc='Magnitude zeropoint for computing reference source magnitudes for selection.' 

201 ' Ignored if column_order is None.' 

202 ) 

203 match_dist_max = pexConfig.Field( 

204 dtype=float, 

205 default=0.5, 

206 doc='Maximum match distance. Units must be arcseconds if coords_spherical, ' 

207 'or else match those of column_*_coord[12] multiplied by coords_*_factor.', 

208 ) 

209 match_n_max = pexConfig.Field( 

210 dtype=int, 

211 default=10, 

212 optional=True, 

213 doc='Maximum number of spatial matches to consider (in ascending distance order).', 

214 ) 

215 match_n_finite_min = pexConfig.Field( 

216 dtype=int, 

217 default=3, 

218 optional=True, 

219 doc='Minimum number of columns with a finite value to measure match likelihood', 

220 ) 

221 

222 order_ascending = pexConfig.Field( 

223 dtype=bool, 

224 default=False, 

225 optional=True, 

226 doc='Whether to order reference match candidates in ascending order of column_order ' 

227 '(should be False if the column is a flux and True if it is a magnitude.', 

228 ) 

229 

230 

231@dataclass 

232class CatalogExtras: 

233 """Store frequently-reference (meta)data revelant for matching a catalog. 

234 

235 Parameters 

236 ---------- 

237 catalog : `pandas.DataFrame` 

238 A pandas catalog to store extra information for. 

239 select : `numpy.array` 

240 A numpy boolean array of the same length as catalog to be used for 

241 target selection. 

242 """ 

243 n: int 

244 indices: np.array 

245 select: np.array 

246 

247 coordinate_factor: float = None 

248 

249 def __init__(self, catalog: pd.DataFrame, select: np.array = None, coordinate_factor: float = None): 

250 self.n = len(catalog) 

251 self.select = select 

252 self.indices = np.flatnonzero(select) if select is not None else None 

253 self.coordinate_factor = coordinate_factor 

254 

255 

256class MatcherProbabilistic: 

257 """A probabilistic, greedy catalog matcher. 

258 

259 Parameters 

260 ---------- 

261 config: `MatchProbabilisticConfig` 

262 A configuration instance. 

263 """ 

264 config: MatchProbabilisticConfig 

265 

266 def __init__( 

267 self, 

268 config: MatchProbabilisticConfig, 

269 ): 

270 self.config = config 

271 

272 def match( 

273 self, 

274 catalog_ref: pd.DataFrame, 

275 catalog_target: pd.DataFrame, 

276 select_ref: np.array = None, 

277 select_target: np.array = None, 

278 logger: logging.Logger = None, 

279 logging_n_rows: int = None, 

280 ): 

281 """Match catalogs. 

282 

283 Parameters 

284 ---------- 

285 catalog_ref : `pandas.DataFrame` 

286 A reference catalog to match in order of a given column (i.e. greedily). 

287 catalog_target : `pandas.DataFrame` 

288 A target catalog for matching sources from `catalog_ref`. Must contain measurements with errors. 

289 select_ref : `numpy.array` 

290 A boolean array of the same length as `catalog_ref` selecting the sources that can be matched. 

291 select_target : `numpy.array` 

292 A boolean array of the same length as `catalog_target` selecting the sources that can be matched. 

293 logger : `logging.Logger` 

294 A Logger for logging. 

295 logging_n_rows : `int` 

296 The number of sources to match before printing a log message. 

297 

298 Returns 

299 ------- 

300 catalog_out_ref : `pandas.DataFrame` 

301 A catalog of identical length to `catalog_ref`, containing match information for rows selected by 

302 `select_ref` (including the matching row index in `catalog_target`). 

303 catalog_out_target : `pandas.DataFrame` 

304 A catalog of identical length to `catalog_target`, containing the indices of matching rows in 

305 `catalog_ref`. 

306 exceptions : `dict` [`int`, `Exception`] 

307 A dictionary keyed by `catalog_target` row number of the first exception caught when matching. 

308 """ 

309 if logger is None: 

310 logger = logger_default 

311 

312 config = self.config 

313 # Set up objects with frequently-used attributes like selection bool array 

314 extras_ref, extras_target = ( 

315 CatalogExtras(catalog, select=select, coordinate_factor=coord_factor) 

316 for catalog, select, coord_factor in zip( 

317 (catalog_ref, catalog_target), 

318 (select_ref, select_target), 

319 (config.coords_ref_factor, config.coords_target_factor), 

320 ) 

321 ) 

322 n_ref_match, n_target_match = (len(x) for x in (extras_ref.indices, extras_target.indices)) 

323 

324 match_dist_max = config.match_dist_max 

325 if config.coords_spherical: 

326 match_dist_max = np.radians(match_dist_max/3600.) 

327 

328 # Retrieve coordinates and multiply them by scaling factors 

329 (coord1_ref, coord2_ref), (coord1_target, coord2_target) = ( 

330 # Confused as to why this needs to be a list to work properly 

331 [ 

332 _mul_column(catalog.loc[extras.select, column].values, extras.coordinate_factor) 

333 for column in columns 

334 ] 

335 for catalog, extras, columns in ( 

336 (catalog_ref, extras_ref, (config.column_ref_coord1, config.column_ref_coord2)), 

337 (catalog_target, extras_target, (config.column_target_coord1, config.column_target_coord2)), 

338 ) 

339 ) 

340 

341 # Convert ra/dec sky coordinates to spherical vectors for accurate distances 

342 if config.coords_spherical: 

343 vec_ref = _radec_to_xyz(coord1_ref, coord2_ref) 

344 vec_target = _radec_to_xyz(coord1_target, coord2_target) 

345 else: 

346 vec_ref = np.vstack((coord1_ref, coord2_ref)) 

347 vec_target = np.vstack((coord1_target, coord2_target)) 

348 

349 columns_ref_meas = config.columns_ref_meas 

350 if config.coords_ref_to_convert: 

351 columns_ref_meas = [config.coords_ref_to_convert.get(column, column) 

352 for column in config.columns_ref_meas] 

353 

354 # Generate K-d tree to compute distances 

355 logger.info('Generating cKDTree with match_n_max=%d', config.match_n_max) 

356 tree_obj = cKDTree(vec_target) 

357 

358 scores, idxs_target = tree_obj.query( 

359 vec_ref, 

360 distance_upper_bound=match_dist_max, 

361 k=config.match_n_max, 

362 ) 

363 n_matches = np.sum(idxs_target != n_target_match, axis=1) 

364 n_matched_max = np.sum(n_matches == config.match_n_max) 

365 if n_matched_max > 0: 

366 logger.warning( 

367 '%d/%d (%.2f%%) true objects have n_matches=n_match_max(%d)', 

368 n_matched_max, n_ref_match, 100.*n_matched_max/n_ref_match, config.match_n_max 

369 ) 

370 

371 # Pre-allocate outputs 

372 target_row_match = np.full(extras_target.n, np.nan, dtype=np.int64) 

373 ref_candidate_match = np.zeros(extras_ref.n, dtype=bool) 

374 ref_row_match = np.full(extras_ref.n, np.nan, dtype=int) 

375 ref_match_count = np.zeros(extras_ref.n, dtype=int) 

376 ref_match_meas_finite = np.zeros(extras_ref.n, dtype=int) 

377 ref_chisq = np.full(extras_ref.n, np.nan, dtype=float) 

378 

379 # If no order is specified, take nansum of all flux columns for a 'total flux' 

380 # Note: it won't actually be a total flux if bands overlap significantly 

381 # (or it might define a filter with >100% efficiency 

382 column_order = ( 

383 catalog_ref.loc[extras_ref.indices, config.column_order].values 

384 if config.column_order is not None else 

385 np.nansum(catalog_ref.loc[extras_ref.indices, config.columns_ref_flux].values, axis=1) 

386 ) 

387 order = np.argsort(column_order if config.order_ascending else -column_order) 

388 

389 indices = extras_ref.indices[order] 

390 idxs_target = idxs_target[order] 

391 n_indices = len(indices) 

392 

393 data_ref = catalog_ref.loc[indices, columns_ref_meas] 

394 data_target = catalog_target.loc[extras_target.select, config.columns_target_meas] 

395 errors_target = catalog_target.loc[extras_target.select, config.columns_target_err] 

396 

397 exceptions = {} 

398 matched_target = set() 

399 

400 t_begin = time.process_time() 

401 

402 logger.info('Matching n_indices=%d/%d', n_indices, len(catalog_ref)) 

403 for index_n, index_row in enumerate(indices): 

404 ref_candidate_match[index_row] = True 

405 found = idxs_target[index_n, :] 

406 # Select match candidates from nearby sources not already matched 

407 found = [x for x in found[found != n_target_match] if x not in matched_target] 

408 n_found = len(found) 

409 if n_found > 0: 

410 # This is an ndarray of n_found rows x len(data_ref/target) columns 

411 chi = ( 

412 (data_target.iloc[found].values - data_ref.iloc[index_n].values) 

413 / errors_target.iloc[found].values 

414 ) 

415 finite = np.isfinite(chi) 

416 n_finite = np.sum(finite, axis=1) 

417 # Require some number of finite chi_sq to match 

418 chisq_good = n_finite >= config.match_n_finite_min 

419 if np.any(chisq_good): 

420 try: 

421 chisq_sum = np.zeros(n_found, dtype=float) 

422 chisq_sum[chisq_good] = np.nansum(chi[chisq_good, :] ** 2, axis=1) 

423 idx_chisq_min = np.nanargmin(chisq_sum / n_finite) 

424 idx_match = found[idx_chisq_min] 

425 ref_match_meas_finite[index_row] = n_finite[idx_chisq_min] 

426 ref_match_count[index_row] = len(chisq_good) 

427 ref_chisq[index_row] = chisq_sum[idx_chisq_min] 

428 row_target = extras_target.indices[idx_match] 

429 ref_row_match[index_row] = row_target 

430 target_row_match[row_target] = index_row 

431 matched_target.add(idx_match) 

432 except Exception as error: 

433 # Can't foresee any exceptions, but they shouldn't prevent 

434 # matching subsequent sources 

435 exceptions[index_row] = error 

436 

437 if logging_n_rows and ((index_n + 1) % logging_n_rows == 0): 

438 t_elapsed = time.process_time() - t_begin 

439 logger.info( 

440 'Processed %d/%d in %.2fs at sort value=%.3f', 

441 index_n + 1, n_indices, t_elapsed, column_order[order[index_n]], 

442 ) 

443 

444 catalog_out_ref = pd.DataFrame({ 

445 'match_candidate': ref_candidate_match, 

446 'match_row': ref_row_match, 

447 'match_count': ref_match_count, 

448 'match_chisq': ref_chisq, 

449 'match_n_chisq_finite': ref_match_meas_finite, 

450 }) 

451 

452 catalog_out_target = pd.DataFrame({ 

453 'match_row': target_row_match, 

454 }) 

455 

456 return catalog_out_ref, catalog_out_target, exceptions