Coverage for python / lsst / pipe / tasks / diff_matched_tract_catalog.py: 25%

185 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-22 09:17 +0000

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = [ 

23 'DiffMatchedTractCatalogConfig', 'DiffMatchedTractCatalogTask', 'MatchedCatalogFluxesConfig', 

24] 

25 

26import lsst.afw.geom as afwGeom 

27from lsst.meas.astrom.matcher_probabilistic import ConvertCatalogCoordinatesConfig 

28from lsst.meas.astrom.match_probabilistic_task import radec_to_xy 

29import lsst.pex.config as pexConfig 

30import lsst.pipe.base as pipeBase 

31import lsst.pipe.base.connectionTypes as cT 

32from lsst.skymap import BaseSkyMap 

33from lsst.daf.butler import DatasetProvenance 

34 

35import astropy.table 

36import astropy.units as u 

37import numpy as np 

38from smatch.matcher import sphdist 

39from typing import Sequence 

40 

41 

42def is_sequence_set(x: Sequence): 

43 return len(x) == len(set(x)) 

44 

45 

46DiffMatchedTractCatalogBaseTemplates = { 

47 "name_input_cat_ref": "truth_summary", 

48 "name_input_cat_target": "objectTable_tract", 

49 "name_skymap": BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

50} 

51 

52 

53class DiffMatchedTractCatalogConnections( 

54 pipeBase.PipelineTaskConnections, 

55 dimensions=("tract", "skymap"), 

56 defaultTemplates=DiffMatchedTractCatalogBaseTemplates, 

57): 

58 cat_ref = cT.Input( 

59 doc="Reference object catalog to match from", 

60 name="{name_input_cat_ref}", 

61 storageClass="ArrowAstropy", 

62 dimensions=("tract", "skymap"), 

63 deferLoad=True, 

64 ) 

65 cat_target = cT.Input( 

66 doc="Target object catalog to match", 

67 name="{name_input_cat_target}", 

68 storageClass="ArrowAstropy", 

69 dimensions=("tract", "skymap"), 

70 deferLoad=True, 

71 ) 

72 skymap = cT.Input( 

73 doc="Input definition of geometry/bbox and projection/wcs for coadded exposures", 

74 name="{name_skymap}", 

75 storageClass="SkyMap", 

76 dimensions=("skymap",), 

77 ) 

78 cat_match_ref = cT.Input( 

79 doc="Reference match catalog with indices of target matches", 

80 name="match_ref_{name_input_cat_ref}_{name_input_cat_target}", 

81 storageClass="ArrowAstropy", 

82 dimensions=("tract", "skymap"), 

83 deferLoad=True, 

84 ) 

85 cat_match_target = cT.Input( 

86 doc="Target match catalog with indices of references matches", 

87 name="match_target_{name_input_cat_ref}_{name_input_cat_target}", 

88 storageClass="ArrowAstropy", 

89 dimensions=("tract", "skymap"), 

90 deferLoad=True, 

91 ) 

92 columns_match_target = cT.Input( 

93 doc="Target match catalog columns", 

94 name="match_target_{name_input_cat_ref}_{name_input_cat_target}.columns", 

95 storageClass="ArrowColumnList", 

96 dimensions=("tract", "skymap"), 

97 ) 

98 cat_matched = cT.Output( 

99 doc="Catalog with reference and target columns for joined sources", 

100 name="matched_{name_input_cat_ref}_{name_input_cat_target}", 

101 storageClass="ArrowAstropy", 

102 dimensions=("tract", "skymap"), 

103 ) 

104 

105 def __init__(self, *, config=None): 

106 if config.refcat_sharding_type != "tract": 

107 if config.refcat_sharding_type == "none": 

108 old = self.cat_ref 

109 self.cat_ref = cT.Input( 

110 doc=old.doc, 

111 name=old.name, 

112 storageClass=old.storageClass, 

113 dimensions=(), 

114 deferLoad=old.deferLoad, 

115 ) 

116 else: 

117 raise NotImplementedError(f"{config.refcat_sharding_type=} not implemented") 

118 if config.target_sharding_type != "tract": 

119 if config.target_sharding_type == "none": 

120 old = self.cat_target 

121 self.cat_target = cT.Input( 

122 doc=old.doc, 

123 name=old.name, 

124 storageClass=old.storageClass, 

125 dimensions=(), 

126 deferLoad=old.deferLoad, 

127 ) 

128 else: 

129 raise NotImplementedError(f"{config.target_sharding_type=} not implemented") 

130 

131 

132class MatchedCatalogFluxesConfig(pexConfig.Config): 

133 column_ref_flux = pexConfig.Field( 

134 dtype=str, 

135 doc='Reference catalog flux column name', 

136 ) 

137 columns_target_flux = pexConfig.ListField( 

138 dtype=str, 

139 listCheck=is_sequence_set, 

140 doc="List of target catalog flux column names", 

141 ) 

142 columns_target_flux_err = pexConfig.ListField( 

143 dtype=str, 

144 listCheck=is_sequence_set, 

145 doc="List of target catalog flux error column names", 

146 ) 

147 

148 # this should be an orderedset 

149 @property 

150 def columns_in_ref(self) -> list[str]: 

151 return [self.column_ref_flux] 

152 

153 # this should also be an orderedset 

154 @property 

155 def columns_in_target(self) -> list[str]: 

156 columns = [col for col in self.columns_target_flux] 

157 columns.extend(col for col in self.columns_target_flux_err if col not in columns) 

158 return columns 

159 

160 

161class DiffMatchedTractCatalogConfig( 

162 pipeBase.PipelineTaskConfig, 

163 pipelineConnections=DiffMatchedTractCatalogConnections, 

164): 

165 column_match_candidate_ref = pexConfig.Field[str]( 

166 default='match_candidate', 

167 doc='The column name for the boolean field identifying reference objects' 

168 ' that were used for matching', 

169 optional=True, 

170 ) 

171 column_match_candidate_target = pexConfig.Field[str]( 

172 default='match_candidate', 

173 doc='The column name for the boolean field identifying target objects' 

174 ' that were used for matching', 

175 optional=True, 

176 ) 

177 column_matched_prefix_ref = pexConfig.Field[str]( 

178 default='refcat_', 

179 doc='The prefix for matched columns copied from the reference catalog', 

180 ) 

181 column_matched_prefix_target = pexConfig.Field[str]( 

182 default='', 

183 doc='The prefix for matched columns copied from the target catalog', 

184 ) 

185 include_unmatched = pexConfig.Field[bool]( 

186 default=False, 

187 doc='Whether to include unmatched rows in the matched table', 

188 ) 

189 filter_on_match_candidate = pexConfig.Field[bool]( 

190 default=False, 

191 doc='Whether to use provided column_match_candidate_[ref/target] to' 

192 ' exclude rows from the output table. If False, any provided' 

193 ' columns will be copied instead.' 

194 ) 

195 prefix_best_coord = pexConfig.Field[str]( 

196 default=None, 

197 doc="A string prefix for ra/dec coordinate columns generated from the reference coordinate if " 

198 "available, and target otherwise. Ignored if None or include_unmatched is False.", 

199 optional=True, 

200 ) 

201 

202 @property 

203 def columns_in_ref(self) -> list[str]: 

204 columns_all = [self.coord_format.column_ref_coord1, self.coord_format.column_ref_coord2] 

205 for column_lists in ( 

206 ( 

207 self.columns_ref_copy, 

208 ), 

209 (x.columns_in_ref for x in self.columns_flux.values()), 

210 ): 

211 for column_list in column_lists: 

212 columns_all.extend(column_list) 

213 

214 return list({column: None for column in columns_all}.keys()) 

215 

216 @property 

217 def columns_in_target(self) -> list[str]: 

218 columns_all = [self.coord_format.column_target_coord1, self.coord_format.column_target_coord2] 

219 if self.coord_format.coords_ref_to_convert is not None: 

220 columns_all.extend(col for col in self.coord_format.coords_ref_to_convert.values() 

221 if col not in columns_all) 

222 for column_lists in ( 

223 ( 

224 self.columns_target_coord_err, 

225 self.columns_target_select_false, 

226 self.columns_target_select_true, 

227 self.columns_target_copy, 

228 ), 

229 (x.columns_in_target for x in self.columns_flux.values()), 

230 ): 

231 for column_list in column_lists: 

232 columns_all.extend(col for col in column_list if col not in columns_all) 

233 return columns_all 

234 

235 columns_flux = pexConfig.ConfigDictField( 

236 doc="Configs for flux columns for each band", 

237 keytype=str, 

238 itemtype=MatchedCatalogFluxesConfig, 

239 default={}, 

240 ) 

241 columns_ref_mag_to_nJy = pexConfig.DictField[str, str]( 

242 doc='Reference table AB mag columns to convert to nJy flux columns with new names', 

243 default={}, 

244 ) 

245 columns_ref_copy = pexConfig.ListField[str]( 

246 doc='Reference table columns to copy into cat_matched', 

247 default=[], 

248 listCheck=is_sequence_set, 

249 ) 

250 columns_target_coord_err = pexConfig.ListField[str]( 

251 doc='Target table coordinate columns with standard errors (sigma)', 

252 listCheck=lambda x: (len(x) == 2) and (x[0] != x[1]), 

253 ) 

254 columns_target_copy = pexConfig.ListField[str]( 

255 doc='Target table columns to copy into cat_matched', 

256 default=('patch',), 

257 listCheck=is_sequence_set, 

258 ) 

259 columns_target_mag_to_nJy = pexConfig.DictField[str, str]( 

260 doc='Target table AB mag columns to convert to nJy flux columns with new names', 

261 default={}, 

262 ) 

263 columns_target_select_true = pexConfig.ListField[str]( 

264 doc='Target table columns to require to be True for selecting sources', 

265 default=('detect_isPrimary',), 

266 listCheck=is_sequence_set, 

267 ) 

268 columns_target_select_false = pexConfig.ListField[str]( 

269 doc='Target table columns to require to be False for selecting sources', 

270 default=('merge_peak_sky',), 

271 listCheck=is_sequence_set, 

272 ) 

273 coord_format = pexConfig.ConfigField[ConvertCatalogCoordinatesConfig]( 

274 doc="Configuration for coordinate conversion", 

275 ) 

276 refcat_sharding_type = pexConfig.ChoiceField[str]( 

277 doc="The type of sharding (spatial splitting) for the reference catalog", 

278 allowed={"tract": "Tract-based shards", "none": "No sharding at all"}, 

279 default="tract", 

280 ) 

281 target_sharding_type = pexConfig.ChoiceField[str]( 

282 doc="The type of sharding (spatial splitting) for the target catalog", 

283 allowed={"tract": "Tract-based shards", "none": "No sharding at all"}, 

284 default="tract", 

285 ) 

286 

287 def validate(self): 

288 super().validate() 

289 

290 errors = [] 

291 

292 for columns_mag, columns_in, name_columns_copy in ( 

293 (self.columns_ref_mag_to_nJy, self.columns_in_ref, "columns_ref_copy"), 

294 (self.columns_target_mag_to_nJy, self.columns_in_target, "columns_target_copy"), 

295 ): 

296 columns_copy = getattr(self, name_columns_copy) 

297 for column_old, column_new in columns_mag.items(): 

298 if column_old not in columns_in: 

299 errors.append( 

300 f"{column_old=} key in self.columns_mag_to_nJy not found in {columns_in=}; did you" 

301 f" forget to add it to self.{name_columns_copy}={columns_copy}?" 

302 ) 

303 if column_new in columns_copy: 

304 errors.append( 

305 f"{column_new=} value found in self.{name_columns_copy}={columns_copy}" 

306 f" this will cause a collision. Please choose a different name." 

307 ) 

308 if errors: 

309 raise ValueError("\n".join(errors)) 

310 

311 

312class DiffMatchedTractCatalogTask(pipeBase.PipelineTask): 

313 """Load subsets of matched catalogs and output a merged catalog of matched sources. 

314 """ 

315 ConfigClass = DiffMatchedTractCatalogConfig 

316 _DefaultName = "DiffMatchedTractCatalog" 

317 

318 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

319 inputs = butlerQC.get(inputRefs) 

320 skymap = inputs.pop("skymap") 

321 

322 columns_match_ref = ['match_row'] 

323 if (column := self.config.column_match_candidate_ref) is not None: 

324 columns_match_ref.append(column) 

325 

326 columns_match_target = ['match_row'] 

327 if (column := self.config.column_match_candidate_target) is not None and ( 

328 column in inputs['columns_match_target'] 

329 ): 

330 columns_match_target.append(column) 

331 

332 outputs = self.run( 

333 catalog_ref=inputs['cat_ref'].get(parameters={'columns': self.config.columns_in_ref}), 

334 catalog_target=inputs['cat_target'].get(parameters={'columns': self.config.columns_in_target}), 

335 catalog_match_ref=inputs['cat_match_ref'].get(parameters={'columns': columns_match_ref}), 

336 catalog_match_target=inputs['cat_match_target'].get(parameters={'columns': columns_match_target}), 

337 wcs=skymap[butlerQC.quantum.dataId["tract"]].wcs, 

338 ) 

339 butlerQC.put(outputs, outputRefs) 

340 

341 def run( 

342 self, 

343 catalog_ref: astropy.table.Table, 

344 catalog_target: astropy.table.Table, 

345 catalog_match_ref: astropy.table.Table, 

346 catalog_match_target: astropy.table.Table, 

347 wcs: afwGeom.SkyWcs = None, 

348 ) -> pipeBase.Struct: 

349 """Load matched reference and target (measured) catalogs, measure summary statistics, and output 

350 a combined matched catalog with columns from both inputs. 

351 

352 Parameters 

353 ---------- 

354 catalog_ref : `astropy.table.Table` 

355 A reference catalog to diff objects/sources from. 

356 catalog_target : `astropy.table.Table` 

357 A target catalog to diff reference objects/sources to. 

358 catalog_match_ref : `astropy.table.Table` 

359 A catalog with match indices of target sources and selection flags 

360 for each reference source. 

361 catalog_match_target : `astropy.table.Table` 

362 A catalog with selection flags for each target source. 

363 wcs : `lsst.afw.image.SkyWcs` 

364 A coordinate system to convert catalog positions to sky coordinates, 

365 if necessary. 

366 

367 Returns 

368 ------- 

369 retStruct : `lsst.pipe.base.Struct` 

370 A struct with output_ref and output_target attribute containing the 

371 output matched catalogs. 

372 """ 

373 # Would be nice if this could refer directly to ConfigClass 

374 config: DiffMatchedTractCatalogConfig = self.config 

375 

376 # Strip any provenance from tables before merging to prevent 

377 # warnings from conflicts being issued by astropy.utils.merge during 

378 # vstack or hstack calls. 

379 DatasetProvenance.strip_provenance_from_flat_dict(catalog_ref.meta) 

380 DatasetProvenance.strip_provenance_from_flat_dict(catalog_target.meta) 

381 DatasetProvenance.strip_provenance_from_flat_dict(catalog_match_ref.meta) 

382 DatasetProvenance.strip_provenance_from_flat_dict(catalog_match_target.meta) 

383 

384 # It would be nice to make this a Selector but those are 

385 # only available in analysis_tools for now 

386 select_ref, select_target = ( 

387 (catalog[column] if column else np.ones(len(catalog), dtype=bool)) 

388 for catalog, column in ( 

389 (catalog_match_ref, self.config.column_match_candidate_ref), 

390 (catalog_match_target, self.config.column_match_candidate_target), 

391 ) 

392 ) 

393 # Add additional selection criteria for target sources beyond those for matching 

394 # (not recommended, but can be done anyway) 

395 for column in config.columns_target_select_true: 

396 select_target &= catalog_target[column] 

397 for column in config.columns_target_select_false: 

398 select_target &= ~catalog_target[column] 

399 

400 ref, target = config.coord_format.format_catalogs( 

401 catalog_ref=catalog_ref, catalog_target=catalog_target, 

402 select_ref=None, select_target=select_target, wcs=wcs, radec_to_xy_func=radec_to_xy, 

403 ) 

404 cat_ref = ref.catalog 

405 cat_target = target.catalog 

406 n_target = len(cat_target) 

407 

408 if not config.filter_on_match_candidate: 

409 for cat_add, cat_match, column in ( 

410 (cat_ref, catalog_match_ref, config.column_match_candidate_ref), 

411 (cat_target, catalog_match_target, config.column_match_candidate_target), 

412 ): 

413 if column is not None: 

414 cat_add[column] = cat_match[column] 

415 

416 match_row = catalog_match_ref['match_row'] 

417 matched_ref = match_row >= 0 

418 matched_row = match_row[matched_ref] 

419 matched_target = np.zeros(n_target, dtype=bool) 

420 matched_target[matched_row] = True 

421 

422 # Add/compute distance columns 

423 coord1_target_err, coord2_target_err = config.columns_target_coord_err 

424 column_dist, column_dist_err = 'match_distance', 'match_distanceErr' 

425 dist = np.full(n_target, np.nan) 

426 

427 target_match_c1, target_match_c2 = (coord[matched_row] for coord in (target.coord1, target.coord2)) 

428 target_ref_c1, target_ref_c2 = (coord[matched_ref] for coord in (ref.coord1, ref.coord2)) 

429 

430 dist_err = np.full(n_target, np.nan) 

431 dist[matched_row] = sphdist( 

432 target_match_c1, target_match_c2, target_ref_c1, target_ref_c2 

433 ) if config.coord_format.coords_spherical else np.hypot( 

434 target_match_c1 - target_ref_c1, target_match_c2 - target_ref_c2, 

435 ) 

436 cat_target_matched = cat_target[matched_row] 

437 # This will convert a masked array to an array filled with nans 

438 # wherever there are bad values (otherwise sphdist can raise) 

439 c1_err, c2_err = ( 

440 np.ma.getdata(cat_target_matched[c_err]) for c_err in (coord1_target_err, coord2_target_err) 

441 ) 

442 # Should probably explicitly add cosine terms if ref has errors too 

443 dist_err[matched_row] = sphdist( 

444 target_match_c1, target_match_c2, target_match_c1 + c1_err, target_match_c2 + c2_err 

445 ) if config.coord_format.coords_spherical else np.hypot(c1_err, c2_err) 

446 cat_target[column_dist], cat_target[column_dist_err] = dist, dist_err 

447 

448 # Create a matched table, preserving the target catalog's named index (if it has one) 

449 cat_left = cat_target[matched_row] 

450 cat_right = cat_ref[matched_ref] 

451 if config.column_matched_prefix_target: 

452 cat_left.rename_columns( 

453 list(cat_left.columns), 

454 new_names=[f'{config.column_matched_prefix_target}{col}' for col in cat_left.columns], 

455 ) 

456 if config.column_matched_prefix_ref: 

457 cat_right.rename_columns( 

458 list(cat_right.columns), 

459 new_names=[f'{config.column_matched_prefix_ref}{col}' for col in cat_right.columns], 

460 ) 

461 cat_matched = astropy.table.hstack((cat_left, cat_right)) 

462 

463 if config.include_unmatched: 

464 # Create an unmatched table with the same schema as the matched one 

465 # ... but only for objects with no matches (for completeness/purity) 

466 # and that were selected for matching (or inclusion via config) 

467 cat_right = astropy.table.Table( 

468 cat_ref[~matched_ref & select_ref] 

469 ) 

470 cat_right.rename_columns( 

471 cat_right.colnames, 

472 [f"{config.column_matched_prefix_ref}{col}" for col in cat_right.colnames], 

473 ) 

474 match_row_target = catalog_match_target['match_row'] 

475 cat_left = cat_target[~(match_row_target >= 0) & select_target] 

476 cat_left.rename_columns( 

477 cat_left.colnames, 

478 [f"{config.column_matched_prefix_target}{col}" for col in cat_left.colnames], 

479 ) 

480 # This may be slower than pandas but will, for example, create 

481 # masked columns for booleans, which pandas does not support. 

482 # See https://github.com/pandas-dev/pandas/issues/46662 

483 cat_unmatched = astropy.table.vstack([cat_left, cat_right]) 

484 

485 for columns_convert_base, prefix in ( 

486 (config.columns_ref_mag_to_nJy, config.column_matched_prefix_ref), 

487 (config.columns_target_mag_to_nJy, ""), 

488 ): 

489 if columns_convert_base: 

490 columns_convert = { 

491 f"{prefix}{k}": f"{prefix}{v}" for k, v in columns_convert_base.items() 

492 } if prefix else columns_convert_base 

493 to_convert = [cat_matched] 

494 if config.include_unmatched: 

495 to_convert.append(cat_unmatched) 

496 for cat_convert in to_convert: 

497 cat_convert.rename_columns( 

498 tuple(columns_convert.keys()), 

499 tuple(columns_convert.values()), 

500 ) 

501 for column_flux in columns_convert.values(): 

502 cat_convert[column_flux] = u.ABmag.to(u.nJy, cat_convert[column_flux]) 

503 

504 if config.include_unmatched: 

505 # This is probably less efficient than just doing an outer join originally; worth checking 

506 cat_matched = astropy.table.vstack([cat_matched, cat_unmatched]) 

507 if (prefix_coord := config.prefix_best_coord) is not None: 

508 columns_coord_best = ( 

509 f"{prefix_coord}{col_coord}" for col_coord in ( 

510 ("ra", "dec") if config.coord_format.coords_spherical else ("coord1", "coord2") 

511 ) 

512 ) 

513 for column_coord_best, column_coord_ref, column_coord_target in zip( 

514 columns_coord_best, 

515 (config.coord_format.column_ref_coord1, config.coord_format.column_ref_coord2), 

516 (config.coord_format.column_target_coord1, config.coord_format.column_target_coord2), 

517 ): 

518 column_full_ref = f'{config.column_matched_prefix_ref}{column_coord_ref}' 

519 column_full_target = f'{config.column_matched_prefix_target}{column_coord_target}' 

520 values = cat_matched[column_full_ref] 

521 unit = values.unit 

522 values_bad = np.ma.masked_invalid(values).mask 

523 # Cast to an unmasked array - there will be no bad values 

524 values = np.array(values) 

525 values[values_bad] = cat_matched[column_full_target][values_bad] 

526 cat_matched[column_coord_best] = values 

527 cat_matched[column_coord_best].unit = unit 

528 cat_matched[column_coord_best].description = ( 

529 f"Best {columns_coord_best} value from {column_full_ref} if available" 

530 f" else {column_full_target}" 

531 ) 

532 

533 retStruct = pipeBase.Struct(cat_matched=cat_matched) 

534 return retStruct