Coverage for python/lsst/pipe/tasks/diff_matched_tract_catalog.py: 26%

409 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-10 03:09 -0700

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = [ 

23 'DiffMatchedTractCatalogConfig', 'DiffMatchedTractCatalogTask', 'MatchedCatalogFluxesConfig', 

24 'MatchType', 'MeasurementType', 'SourceType', 

25 'Statistic', 'Median', 'SigmaIQR', 'SigmaMAD', 'Percentile', 

26] 

27 

28import lsst.afw.geom as afwGeom 

29from lsst.meas.astrom.matcher_probabilistic import ( 

30 ComparableCatalog, ConvertCatalogCoordinatesConfig, 

31) 

32from lsst.meas.astrom.match_probabilistic_task import radec_to_xy 

33import lsst.pex.config as pexConfig 

34import lsst.pipe.base as pipeBase 

35import lsst.pipe.base.connectionTypes as cT 

36from lsst.skymap import BaseSkyMap 

37 

38from abc import ABCMeta, abstractmethod 

39from astropy.stats import mad_std 

40import astropy.units as u 

41from dataclasses import dataclass 

42from decimal import Decimal 

43from enum import Enum 

44import numpy as np 

45import pandas as pd 

46from scipy.stats import iqr 

47from typing import Dict, Sequence 

48 

49 

50def is_sequence_set(x: Sequence): 

51 return len(x) == len(set(x)) 

52 

53 

54def is_percentile(x: str): 

55 return 0 <= Decimal(x) <= 100 

56 

57 

58DiffMatchedTractCatalogBaseTemplates = { 

59 "name_input_cat_ref": "truth_summary", 

60 "name_input_cat_target": "objectTable_tract", 

61 "name_skymap": BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

62} 

63 

64 

65class DiffMatchedTractCatalogConnections( 

66 pipeBase.PipelineTaskConnections, 

67 dimensions=("tract", "skymap"), 

68 defaultTemplates=DiffMatchedTractCatalogBaseTemplates, 

69): 

70 cat_ref = cT.Input( 

71 doc="Reference object catalog to match from", 

72 name="{name_input_cat_ref}", 

73 storageClass="DataFrame", 

74 dimensions=("tract", "skymap"), 

75 deferLoad=True, 

76 ) 

77 cat_target = cT.Input( 

78 doc="Target object catalog to match", 

79 name="{name_input_cat_target}", 

80 storageClass="DataFrame", 

81 dimensions=("tract", "skymap"), 

82 deferLoad=True, 

83 ) 

84 skymap = cT.Input( 

85 doc="Input definition of geometry/bbox and projection/wcs for coadded exposures", 

86 name="{name_skymap}", 

87 storageClass="SkyMap", 

88 dimensions=("skymap",), 

89 ) 

90 cat_match_ref = cT.Input( 

91 doc="Reference match catalog with indices of target matches", 

92 name="match_ref_{name_input_cat_ref}_{name_input_cat_target}", 

93 storageClass="DataFrame", 

94 dimensions=("tract", "skymap"), 

95 deferLoad=True, 

96 ) 

97 cat_match_target = cT.Input( 

98 doc="Target match catalog with indices of references matches", 

99 name="match_target_{name_input_cat_ref}_{name_input_cat_target}", 

100 storageClass="DataFrame", 

101 dimensions=("tract", "skymap"), 

102 deferLoad=True, 

103 ) 

104 columns_match_target = cT.Input( 

105 doc="Target match catalog columns", 

106 name="match_target_{name_input_cat_ref}_{name_input_cat_target}.columns", 

107 storageClass="DataFrameIndex", 

108 dimensions=("tract", "skymap"), 

109 ) 

110 cat_matched = cT.Output( 

111 doc="Catalog with reference and target columns for joined sources", 

112 name="matched_{name_input_cat_ref}_{name_input_cat_target}", 

113 storageClass="DataFrame", 

114 dimensions=("tract", "skymap"), 

115 ) 

116 diff_matched = cT.Output( 

117 doc="Table with aggregated counts, difference and chi statistics", 

118 name="diff_matched_{name_input_cat_ref}_{name_input_cat_target}", 

119 storageClass="DataFrame", 

120 dimensions=("tract", "skymap"), 

121 ) 

122 

123 

124class MatchedCatalogFluxesConfig(pexConfig.Config): 

125 column_ref_flux = pexConfig.Field( 

126 dtype=str, 

127 doc='Reference catalog flux column name', 

128 ) 

129 columns_target_flux = pexConfig.ListField( 

130 dtype=str, 

131 listCheck=is_sequence_set, 

132 doc="List of target catalog flux column names", 

133 ) 

134 columns_target_flux_err = pexConfig.ListField( 

135 dtype=str, 

136 listCheck=is_sequence_set, 

137 doc="List of target catalog flux error column names", 

138 ) 

139 

140 # this should be an orderedset 

141 @property 

142 def columns_in_ref(self) -> list[str]: 

143 return [self.column_ref_flux] 

144 

145 # this should also be an orderedset 

146 @property 

147 def columns_in_target(self) -> list[str]: 

148 columns = [col for col in self.columns_target_flux] 

149 columns.extend(col for col in self.columns_target_flux_err if col not in columns) 

150 return columns 

151 

152 

153class DiffMatchedTractCatalogConfig( 

154 pipeBase.PipelineTaskConfig, 

155 pipelineConnections=DiffMatchedTractCatalogConnections, 

156): 

157 column_matched_prefix_ref = pexConfig.Field[str]( 

158 default='refcat_', 

159 doc='The prefix for matched columns copied from the reference catalog', 

160 ) 

161 column_ref_extended = pexConfig.Field[str]( 

162 default='is_pointsource', 

163 doc='The boolean reference table column specifying if the target is extended', 

164 ) 

165 column_ref_extended_inverted = pexConfig.Field[bool]( 

166 default=True, 

167 doc='Whether column_ref_extended specifies if the object is compact, not extended', 

168 ) 

169 column_target_extended = pexConfig.Field[str]( 

170 default='refExtendedness', 

171 doc='The target table column estimating the extendedness of the object (0 <= x <= 1)', 

172 ) 

173 include_unmatched = pexConfig.Field[bool]( 

174 default=False, 

175 doc="Whether to include unmatched rows in the matched table", 

176 ) 

177 

178 @property 

179 def columns_in_ref(self) -> list[str]: 

180 columns_all = [self.coord_format.column_ref_coord1, self.coord_format.column_ref_coord2, 

181 self.column_ref_extended] 

182 for column_lists in ( 

183 ( 

184 self.columns_ref_copy, 

185 ), 

186 (x.columns_in_ref for x in self.columns_flux.values()), 

187 ): 

188 for column_list in column_lists: 

189 columns_all.extend(column_list) 

190 

191 return list({column: None for column in columns_all}.keys()) 

192 

193 @property 

194 def columns_in_target(self) -> list[str]: 

195 columns_all = [self.coord_format.column_target_coord1, self.coord_format.column_target_coord2, 

196 self.column_target_extended] 

197 if self.coord_format.coords_ref_to_convert is not None: 

198 columns_all.extend(col for col in self.coord_format.coords_ref_to_convert.values() 

199 if col not in columns_all) 

200 for column_lists in ( 

201 ( 

202 self.columns_target_coord_err, 

203 self.columns_target_select_false, 

204 self.columns_target_select_true, 

205 self.columns_target_copy, 

206 ), 

207 (x.columns_in_target for x in self.columns_flux.values()), 

208 ): 

209 for column_list in column_lists: 

210 columns_all.extend(col for col in column_list if col not in columns_all) 

211 return columns_all 

212 

213 columns_flux = pexConfig.ConfigDictField( 

214 doc="Configs for flux columns for each band", 

215 keytype=str, 

216 itemtype=MatchedCatalogFluxesConfig, 

217 default={}, 

218 ) 

219 columns_ref_mag_to_nJy = pexConfig.DictField[str, str]( 

220 doc='Reference table AB mag columns to convert to nJy flux columns with new names', 

221 default={}, 

222 ) 

223 columns_ref_copy = pexConfig.ListField[str]( 

224 doc='Reference table columns to copy into cat_matched', 

225 default=[], 

226 listCheck=is_sequence_set, 

227 ) 

228 columns_target_coord_err = pexConfig.ListField[str]( 228 ↛ exitline 228 didn't jump to the function exit

229 doc='Target table coordinate columns with standard errors (sigma)', 

230 listCheck=lambda x: (len(x) == 2) and (x[0] != x[1]), 

231 ) 

232 columns_target_copy = pexConfig.ListField[str]( 

233 doc='Target table columns to copy into cat_matched', 

234 default=('patch',), 

235 listCheck=is_sequence_set, 

236 ) 

237 columns_target_mag_to_nJy = pexConfig.DictField[str, str]( 

238 doc='Target table AB mag columns to convert to nJy flux columns with new names', 

239 default={}, 

240 ) 

241 columns_target_select_true = pexConfig.ListField[str]( 

242 doc='Target table columns to require to be True for selecting sources', 

243 default=('detect_isPrimary',), 

244 listCheck=is_sequence_set, 

245 ) 

246 columns_target_select_false = pexConfig.ListField[str]( 

247 doc='Target table columns to require to be False for selecting sources', 

248 default=('merge_peak_sky',), 

249 listCheck=is_sequence_set, 

250 ) 

251 coord_format = pexConfig.ConfigField[ConvertCatalogCoordinatesConfig]( 

252 doc="Configuration for coordinate conversion", 

253 ) 

254 extendedness_cut = pexConfig.Field[float]( 

255 dtype=float, 

256 default=0.5, 

257 doc='Minimum extendedness for a measured source to be considered extended', 

258 ) 

259 mag_num_bins = pexConfig.Field[int]( 

260 doc='Number of magnitude bins', 

261 default=15, 

262 ) 

263 mag_brightest_ref = pexConfig.Field[float]( 

264 doc='Brightest magnitude cutoff for binning', 

265 default=15, 

266 ) 

267 mag_ceiling_target = pexConfig.Field[float]( 

268 doc='Ceiling (maximum/faint) magnitude for target sources', 

269 default=None, 

270 optional=True, 

271 ) 

272 mag_faintest_ref = pexConfig.Field[float]( 

273 doc='Faintest magnitude cutoff for binning', 

274 default=30, 

275 ) 

276 mag_zeropoint_ref = pexConfig.Field[float]( 

277 doc='Magnitude zeropoint for reference sources', 

278 default=31.4, 

279 ) 

280 mag_zeropoint_target = pexConfig.Field[float]( 

281 doc='Magnitude zeropoint for target sources', 

282 default=31.4, 

283 ) 

284 percentiles = pexConfig.ListField[str]( 

285 doc='Percentiles to compute for diff/chi values', 

286 # -2, -1, +1, +2 sigma percentiles for normal distribution 

287 default=('2.275', '15.866', '84.134', '97.725'), 

288 itemCheck=is_percentile, 

289 listCheck=is_sequence_set, 

290 ) 

291 

292 def validate(self): 

293 super().validate() 

294 

295 errors = [] 

296 

297 for columns_mag, columns_in, name_columns_copy in ( 

298 (self.columns_ref_mag_to_nJy, self.columns_in_ref, "columns_ref_copy"), 

299 (self.columns_target_mag_to_nJy, self.columns_in_target, "columns_target_copy"), 

300 ): 

301 columns_copy = getattr(self, name_columns_copy) 

302 for column_old, column_new in columns_mag.items(): 

303 if column_old not in columns_in: 

304 errors.append( 

305 f"{column_old=} key in self.columns_mag_to_nJy not found in {columns_in=}; did you" 

306 f" forget to add it to self.{name_columns_copy}={columns_copy}?" 

307 ) 

308 if column_new in columns_copy: 

309 errors.append( 

310 f"{column_new=} value found in self.{name_columns_copy}={columns_copy}" 

311 f" this will cause a collision. Please choose a different name." 

312 ) 

313 if errors: 

314 raise ValueError("\n".join(errors)) 

315 

316 

317@dataclass(frozen=True) 

318class MeasurementTypeInfo: 

319 doc: str 

320 name: str 

321 

322 

323class MeasurementType(Enum): 

324 DIFF = MeasurementTypeInfo( 

325 doc="difference (measured - reference)", 

326 name="diff", 

327 ) 

328 CHI = MeasurementTypeInfo( 

329 doc="scaled difference (measured - reference)/error", 

330 name="chi", 

331 ) 

332 

333 

334class Statistic(metaclass=ABCMeta): 

335 """A statistic that can be applied to a set of values. 

336 """ 

337 @abstractmethod 

338 def doc(self) -> str: 

339 """A description of the statistic""" 

340 raise NotImplementedError('Subclasses must implement this method') 

341 

342 @abstractmethod 

343 def name_short(self) -> str: 

344 """A short name for the statistic, e.g. for a table column name""" 

345 raise NotImplementedError('Subclasses must implement this method') 

346 

347 @abstractmethod 

348 def value(self, values): 

349 """The value of the statistic for a set of input values. 

350 

351 Parameters 

352 ---------- 

353 values : `Collection` [`float`] 

354 A set of values to compute the statistic for. 

355 

356 Returns 

357 ------- 

358 statistic : `float` 

359 The value of the statistic. 

360 """ 

361 raise NotImplementedError('Subclasses must implement this method') 

362 

363 

364class Median(Statistic): 

365 """The median of a set of values.""" 

366 @classmethod 

367 def doc(cls) -> str: 

368 return "Median" 

369 

370 @classmethod 

371 def name_short(cls) -> str: 

372 return "median" 

373 

374 def value(self, values): 

375 return np.median(values) 

376 

377 

378class SigmaIQR(Statistic): 

379 """The re-scaled interquartile range (sigma equivalent).""" 

380 @classmethod 

381 def doc(cls) -> str: 

382 return "Interquartile range divided by ~1.349 (sigma-equivalent)" 

383 

384 @classmethod 

385 def name_short(cls) -> str: 

386 return "sig_iqr" 

387 

388 def value(self, values): 

389 return iqr(values, scale='normal') 

390 

391 

392class SigmaMAD(Statistic): 

393 """The re-scaled median absolute deviation (sigma equivalent).""" 

394 @classmethod 

395 def doc(cls) -> str: 

396 return "Median absolute deviation multiplied by ~1.4826 (sigma-equivalent)" 

397 

398 @classmethod 

399 def name_short(cls) -> str: 

400 return "sig_mad" 

401 

402 def value(self, values): 

403 return mad_std(values) 

404 

405 

406@dataclass(frozen=True) 

407class Percentile(Statistic): 

408 """An arbitrary percentile. 

409 

410 Parameters 

411 ---------- 

412 percentile : `float` 

413 A valid percentile (0 <= p <= 100). 

414 """ 

415 percentile: float 

416 

417 def doc(self) -> str: 

418 return "Median absolute deviation multiplied by ~1.4826 (sigma-equivalent)" 

419 

420 def name_short(self) -> str: 

421 return f"pctl{f'{self.percentile/100:.5f}'[2:]}" 

422 

423 def value(self, values): 

424 return np.percentile(values, self.percentile) 

425 

426 def __post_init__(self): 

427 if not ((self.percentile >= 0) and (self.percentile <= 100)): 

428 raise ValueError(f'percentile={self.percentile} not >=0 and <= 100') 

429 

430 

431def _get_stat_name(*args): 

432 return '_'.join(args) 

433 

434 

435def _get_column_name(band, *args): 

436 return f"{band}_{_get_stat_name(*args)}" 

437 

438 

439def compute_stats(values_ref, values_target, errors_target, row, stats, suffixes, prefix, skip_diff=False): 

440 """Compute statistics on differences and store results in a row. 

441 

442 Parameters 

443 ---------- 

444 values_ref : `numpy.ndarray`, (N,) 

445 Reference values. 

446 values_target : `numpy.ndarray`, (N,) 

447 Measured values. 

448 errors_target : `numpy.ndarray`, (N,) 

449 Errors (standard deviations) on `values_target`. 

450 row : `numpy.ndarray`, (1, C) 

451 A numpy array with pre-assigned column names. 

452 stats : `Dict` [`str`, `Statistic`] 

453 A dict of `Statistic` values to measure, keyed by their column suffix. 

454 suffixes : `Dict` [`MeasurementType`, `str`] 

455 A dict of measurement type column suffixes, keyed by the measurement type. 

456 prefix : `str` 

457 A prefix for all column names (e.g. band). 

458 skip_diff : `bool` 

459 Whether to skip computing statistics on differences. Note that 

460 differences will still be computed for chi statistics. 

461 

462 Returns 

463 ------- 

464 row_with_stats : `numpy.ndarray`, (1, C) 

465 The original `row` with statistic values assigned. 

466 """ 

467 n_ref = len(values_ref) 

468 if n_ref > 0: 

469 n_target = len(values_target) 

470 n_target_err = len(errors_target) if errors_target is not None else n_ref 

471 if (n_target != n_ref) or (n_target_err != n_ref): 

472 raise ValueError(f'lengths of values_ref={n_ref}, values_target={n_target}' 

473 f', error_target={n_target_err} must match') 

474 

475 do_chi = errors_target is not None 

476 diff = values_target - values_ref 

477 chi = diff/errors_target if do_chi else diff 

478 # Could make this configurable, but non-finite values/errors are not really usable 

479 valid = np.isfinite(chi) 

480 values_type = {} if skip_diff else {MeasurementType.DIFF: diff[valid]} 

481 if do_chi: 

482 values_type[MeasurementType.CHI] = chi[valid] 

483 

484 for suffix_type, suffix in suffixes.items(): 

485 values = values_type.get(suffix_type) 

486 if values is not None and len(values) > 0: 

487 for stat_name, stat in stats.items(): 

488 row[_get_stat_name(prefix, suffix, stat_name)] = stat.value(values) 

489 return row 

490 

491 

492@dataclass(frozen=True) 

493class SourceTypeInfo: 

494 is_extended: bool | None 

495 label: str 

496 

497 

498class SourceType(Enum): 

499 ALL = SourceTypeInfo(is_extended=None, label='all') 

500 RESOLVED = SourceTypeInfo(is_extended=True, label='resolved') 

501 UNRESOLVED = SourceTypeInfo(is_extended=False, label='unresolved') 

502 

503 

504class MatchType(Enum): 

505 ALL = 'all' 

506 MATCH_RIGHT = 'match_right' 

507 MATCH_WRONG = 'match_wrong' 

508 

509 

510def _get_columns(bands_columns: Dict, suffixes: Dict, suffixes_flux: Dict, suffixes_mag: Dict, 

511 stats: Dict, target: ComparableCatalog, column_dist: str): 

512 """Get column names for a table of difference statistics. 

513 

514 Parameters 

515 ---------- 

516 bands_columns : `Dict` [`str`,`MatchedCatalogFluxesConfig`] 

517 Dict keyed by band of flux column configuration. 

518 suffixes, suffixes_flux, suffixes_mag : `Dict` [`MeasurementType`, `str`] 

519 Dict of suffixes for each `MeasurementType` type, for general columns (e.g. 

520 coordinates), fluxes and magnitudes, respectively. 

521 stats : `Dict` [`str`, `Statistic`] 

522 Dict of suffixes for each `Statistic` type. 

523 target : `ComparableCatalog` 

524 A target catalog with coordinate column names. 

525 column_dist : `str` 

526 The name of the distance column. 

527 

528 Returns 

529 ------- 

530 columns : `Dict` [`str`, `type`] 

531 Dictionary of column types keyed by name. 

532 n_models : `int` 

533 The number of models measurements will be made for. 

534 

535 Notes 

536 ----- 

537 Presently, models must be identical for each band. 

538 """ 

539 # Initial columns 

540 columns = { 

541 "bin": int, 

542 "mag_min": float, 

543 "mag_max": float, 

544 } 

545 

546 # pre-assign all of the columns with appropriate types 

547 n_models = 0 

548 

549 bands = list(bands_columns.keys()) 

550 n_bands = len(bands) 

551 

552 for idx, (band, config_flux) in enumerate(bands_columns.items()): 

553 columns_suffix = [ 

554 ('flux', suffixes_flux), 

555 ('mag', suffixes_mag), 

556 ] 

557 if idx == 0: 

558 n_models = len(config_flux.columns_target_flux) 

559 if (idx > 0) or (n_bands > 2): 

560 columns_suffix.append((f'color_{bands[idx - 1]}_m_{band}', suffixes)) 

561 n_models_flux = len(config_flux.columns_target_flux) 

562 n_models_err = len(config_flux.columns_target_flux_err) 

563 

564 # TODO: Do equivalent validation earlier, in the config 

565 if (n_models_flux != n_models) or (n_models_err != n_models): 

566 raise RuntimeError(f'{config_flux} len(columns_target_flux)={n_models_flux} and' 

567 f' len(columns_target_flux_err)={n_models_err} must equal {n_models}') 

568 

569 for sourcetype in SourceType: 

570 label = sourcetype.value.label 

571 # Totals would be redundant 

572 if sourcetype != SourceType.ALL: 

573 for item in (f'n_{itype}_{mtype.value}' for itype in ('ref', 'target') 

574 for mtype in MatchType): 

575 columns[_get_column_name(band, label, item)] = int 

576 

577 for item in (target.column_coord1, target.column_coord2, column_dist): 

578 for suffix in suffixes.values(): 

579 for stat in stats.keys(): 

580 columns[_get_column_name(band, label, item, suffix, stat)] = float 

581 

582 for item in config_flux.columns_target_flux: 

583 for prefix_item, suffixes_col in columns_suffix: 

584 for suffix in suffixes_col.values(): 

585 for stat in stats.keys(): 

586 columns[_get_column_name(band, label, prefix_item, item, suffix, stat)] = float 

587 

588 return columns, n_models 

589 

590 

591class DiffMatchedTractCatalogTask(pipeBase.PipelineTask): 

592 """Load subsets of matched catalogs and output a merged catalog of matched sources. 

593 """ 

594 ConfigClass = DiffMatchedTractCatalogConfig 

595 _DefaultName = "DiffMatchedTractCatalog" 

596 

597 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

598 inputs = butlerQC.get(inputRefs) 

599 skymap = inputs.pop("skymap") 

600 

601 columns_match_target = ['match_row'] 

602 if 'match_candidate' in inputs['columns_match_target']: 

603 columns_match_target.append('match_candidate') 

604 

605 outputs = self.run( 

606 catalog_ref=inputs['cat_ref'].get(parameters={'columns': self.config.columns_in_ref}), 

607 catalog_target=inputs['cat_target'].get(parameters={'columns': self.config.columns_in_target}), 

608 catalog_match_ref=inputs['cat_match_ref'].get( 

609 parameters={'columns': ['match_candidate', 'match_row']}, 

610 ), 

611 catalog_match_target=inputs['cat_match_target'].get( 

612 parameters={'columns': columns_match_target}, 

613 ), 

614 wcs=skymap[butlerQC.quantum.dataId["tract"]].wcs, 

615 ) 

616 butlerQC.put(outputs, outputRefs) 

617 

618 def run( 

619 self, 

620 catalog_ref: pd.DataFrame, 

621 catalog_target: pd.DataFrame, 

622 catalog_match_ref: pd.DataFrame, 

623 catalog_match_target: pd.DataFrame, 

624 wcs: afwGeom.SkyWcs = None, 

625 ) -> pipeBase.Struct: 

626 """Load matched reference and target (measured) catalogs, measure summary statistics, and output 

627 a combined matched catalog with columns from both inputs. 

628 

629 Parameters 

630 ---------- 

631 catalog_ref : `pandas.DataFrame` 

632 A reference catalog to diff objects/sources from. 

633 catalog_target : `pandas.DataFrame` 

634 A target catalog to diff reference objects/sources to. 

635 catalog_match_ref : `pandas.DataFrame` 

636 A catalog with match indices of target sources and selection flags 

637 for each reference source. 

638 catalog_match_target : `pandas.DataFrame` 

639 A catalog with selection flags for each target source. 

640 wcs : `lsst.afw.image.SkyWcs` 

641 A coordinate system to convert catalog positions to sky coordinates, 

642 if necessary. 

643 

644 Returns 

645 ------- 

646 retStruct : `lsst.pipe.base.Struct` 

647 A struct with output_ref and output_target attribute containing the 

648 output matched catalogs. 

649 """ 

650 # Would be nice if this could refer directly to ConfigClass 

651 config: DiffMatchedTractCatalogConfig = self.config 

652 

653 select_ref = catalog_match_ref['match_candidate'].values 

654 # Add additional selection criteria for target sources beyond those for matching 

655 # (not recommended, but can be done anyway) 

656 select_target = (catalog_match_target['match_candidate'].values 

657 if 'match_candidate' in catalog_match_target.columns 

658 else np.ones(len(catalog_match_target), dtype=bool)) 

659 for column in config.columns_target_select_true: 

660 select_target &= catalog_target[column].values 

661 for column in config.columns_target_select_false: 

662 select_target &= ~catalog_target[column].values 

663 

664 ref, target = config.coord_format.format_catalogs( 

665 catalog_ref=catalog_ref, catalog_target=catalog_target, 

666 select_ref=None, select_target=select_target, wcs=wcs, radec_to_xy_func=radec_to_xy, 

667 return_converted_columns=config.coord_format.coords_ref_to_convert is not None, 

668 ) 

669 cat_ref = ref.catalog 

670 cat_target = target.catalog 

671 n_target = len(cat_target) 

672 

673 if config.include_unmatched: 

674 for cat_add, cat_match in ((cat_ref, catalog_match_ref), (cat_target, catalog_match_target)): 

675 cat_add['match_candidate'] = cat_match['match_candidate'].values 

676 

677 match_row = catalog_match_ref['match_row'].values 

678 matched_ref = match_row >= 0 

679 matched_row = match_row[matched_ref] 

680 matched_target = np.zeros(n_target, dtype=bool) 

681 matched_target[matched_row] = True 

682 

683 # Add/compute distance columns 

684 coord1_target_err, coord2_target_err = config.columns_target_coord_err 

685 column_dist, column_dist_err = 'match_distance', 'match_distanceErr' 

686 dist = np.full(n_target, np.nan) 

687 

688 dist[matched_row] = np.hypot( 

689 target.coord1[matched_row] - ref.coord1[matched_ref], 

690 target.coord2[matched_row] - ref.coord2[matched_ref], 

691 ) 

692 dist_err = np.full(n_target, np.nan) 

693 dist_err[matched_row] = np.hypot(cat_target.iloc[matched_row][coord1_target_err].values, 

694 cat_target.iloc[matched_row][coord2_target_err].values) 

695 cat_target[column_dist], cat_target[column_dist_err] = dist, dist_err 

696 

697 # Create a matched table, preserving the target catalog's named index (if it has one) 

698 cat_left = cat_target.iloc[matched_row] 

699 has_index_left = cat_left.index.name is not None 

700 cat_right = cat_ref[matched_ref].reset_index() 

701 cat_right.columns = [f'{config.column_matched_prefix_ref}{col}' for col in cat_right.columns] 

702 cat_matched = pd.concat(objs=(cat_left.reset_index(drop=not has_index_left), cat_right), axis=1) 

703 

704 if config.include_unmatched: 

705 # Create an unmatched table with the same schema as the matched one 

706 # ... but only for objects with no matches (for completeness/purity) 

707 # and that were selected for matching (or inclusion via config) 

708 cat_right = cat_ref[~matched_ref & select_ref].reset_index(drop=False) 

709 cat_right.columns = (f'{config.column_matched_prefix_ref}{col}' for col in cat_right.columns) 

710 match_row_target = catalog_match_target['match_row'].values 

711 cat_left = cat_target[~(match_row_target >= 0) & select_target].reset_index( 

712 drop=not has_index_left) 

713 # See https://github.com/pandas-dev/pandas/issues/46662 

714 # astropy masked columns would handle this much more gracefully 

715 # Unfortunately, that would require storageClass migration 

716 # So we use pandas "extended" nullable types for now 

717 for cat_i in (cat_left, cat_right): 

718 for colname in cat_i.columns: 

719 column = cat_i[colname] 

720 dtype = str(column.dtype) 

721 if dtype == "bool": 

722 cat_i[colname] = column.astype("boolean") 

723 elif dtype.startswith("int"): 

724 cat_i[colname] = column.astype(f"Int{dtype[3:]}") 

725 elif dtype.startswith("uint"): 

726 cat_i[colname] = column.astype(f"UInt{dtype[3:]}") 

727 cat_unmatched = pd.concat(objs=(cat_left, cat_right)) 

728 

729 for columns_convert_base, prefix in ( 

730 (config.columns_ref_mag_to_nJy, config.column_matched_prefix_ref), 

731 (config.columns_target_mag_to_nJy, ""), 

732 ): 

733 if columns_convert_base: 

734 columns_convert = { 

735 f"{prefix}{k}": f"{prefix}{v}" for k, v in columns_convert_base.items() 

736 } if prefix else columns_convert_base 

737 for cat_convert in (cat_matched, cat_unmatched): 

738 cat_convert.rename(columns=columns_convert, inplace=True) 

739 for column_flux in columns_convert.values(): 

740 cat_convert[column_flux] = u.ABmag.to(u.nJy, cat_convert[column_flux]) 

741 

742 # TODO: Deprecate all matched difference output in DM-43831 (per RFC-1008) 

743 

744 # Slightly smelly hack for when a column (like distance) is already relative to truth 

745 column_dummy = 'dummy' 

746 cat_ref[column_dummy] = np.zeros_like(ref.coord1) 

747 

748 # Add a boolean column for whether a match is classified correctly 

749 # TODO: remove the assumption of a boolean column 

750 extended_ref = cat_ref[config.column_ref_extended] == (not config.column_ref_extended_inverted) 

751 

752 extended_target = cat_target[config.column_target_extended].values >= config.extendedness_cut 

753 

754 # Define difference/chi columns and statistics thereof 

755 suffixes = {MeasurementType.DIFF: 'diff', MeasurementType.CHI: 'chi'} 

756 # Skip diff for fluxes - covered by mags 

757 suffixes_flux = {MeasurementType.CHI: suffixes[MeasurementType.CHI]} 

758 # Skip chi for magnitudes, which have strange errors 

759 suffixes_mag = {MeasurementType.DIFF: suffixes[MeasurementType.DIFF]} 

760 stats = {stat.name_short(): stat() for stat in (Median, SigmaIQR, SigmaMAD)} 

761 

762 for percentile in self.config.percentiles: 

763 stat = Percentile(percentile=float(Decimal(percentile))) 

764 stats[stat.name_short()] = stat 

765 

766 # Get dict of column names 

767 columns, n_models = _get_columns( 

768 bands_columns=config.columns_flux, 

769 suffixes=suffixes, 

770 suffixes_flux=suffixes_flux, 

771 suffixes_mag=suffixes_mag, 

772 stats=stats, 

773 target=target, 

774 column_dist=column_dist, 

775 ) 

776 

777 # Setup numpy table 

778 n_bins = config.mag_num_bins 

779 data = np.zeros((n_bins,), dtype=[(key, value) for key, value in columns.items()]) 

780 data['bin'] = np.arange(n_bins) 

781 

782 # Setup bins 

783 bins_mag = np.linspace(start=config.mag_brightest_ref, stop=config.mag_faintest_ref, 

784 num=n_bins + 1) 

785 data['mag_min'] = bins_mag[:-1] 

786 data['mag_max'] = bins_mag[1:] 

787 bins_mag = tuple((bins_mag[idx], bins_mag[idx + 1]) for idx in range(n_bins)) 

788 

789 # Define temporary columns for intermediate storage 

790 column_mag_temp = 'mag_temp' 

791 column_color_temp = 'color_temp' 

792 column_color_err_temp = 'colorErr_temp' 

793 flux_err_frac_prev = [None]*n_models 

794 mag_prev = [None]*n_models 

795 

796 columns_target = { 

797 target.column_coord1: ( 

798 ref.column_coord1, target.column_coord1, coord1_target_err, False, 

799 ), 

800 target.column_coord2: ( 

801 ref.column_coord2, target.column_coord2, coord2_target_err, False, 

802 ), 

803 column_dist: (column_dummy, column_dist, column_dist_err, False), 

804 } 

805 

806 # Cheat a little and do the first band last so that the color is 

807 # based on the last band 

808 band_fluxes = [(band, config_flux) for (band, config_flux) in config.columns_flux.items()] 

809 n_bands = len(band_fluxes) 

810 if n_bands > 0: 

811 band_fluxes.append(band_fluxes[0]) 

812 flux_err_frac_first = None 

813 mag_first = None 

814 mag_ref_first = None 

815 

816 band_prev = None 

817 for idx_band, (band, config_flux) in enumerate(band_fluxes): 

818 if idx_band == n_bands: 

819 # These were already computed earlier 

820 mag_ref = mag_ref_first 

821 flux_err_frac = flux_err_frac_first 

822 mag_model = mag_first 

823 else: 

824 mag_ref = -2.5*np.log10(cat_ref[config_flux.column_ref_flux]) + config.mag_zeropoint_ref 

825 flux_err_frac = [None]*n_models 

826 mag_model = [None]*n_models 

827 

828 if idx_band > 0: 

829 cat_ref[column_color_temp] = cat_ref[column_mag_temp] - mag_ref 

830 

831 cat_ref[column_mag_temp] = mag_ref 

832 

833 select_ref_bins = [select_ref & (mag_ref > mag_lo) & (mag_ref < mag_hi) 

834 for idx_bin, (mag_lo, mag_hi) in enumerate(bins_mag)] 

835 

836 # Iterate over multiple models, compute their mags and colours (if there's a previous band) 

837 for idx_model in range(n_models): 

838 column_target_flux = config_flux.columns_target_flux[idx_model] 

839 column_target_flux_err = config_flux.columns_target_flux_err[idx_model] 

840 

841 flux_target = cat_target[column_target_flux] 

842 mag_target = -2.5*np.log10(flux_target) + config.mag_zeropoint_target 

843 if config.mag_ceiling_target is not None: 

844 mag_target[mag_target > config.mag_ceiling_target] = config.mag_ceiling_target 

845 mag_model[idx_model] = mag_target 

846 

847 # These are needed for computing magnitude/color "errors" (which are a sketchy concept) 

848 flux_err_frac[idx_model] = cat_target[column_target_flux_err]/flux_target 

849 

850 # Stop if idx == 0: The rest will be picked up at idx == n_bins 

851 if idx_band > 0: 

852 # Keep these mags tabulated for convenience 

853 column_mag_temp_model = f'{column_mag_temp}{idx_model}' 

854 cat_target[column_mag_temp_model] = mag_target 

855 

856 columns_target[f'flux_{column_target_flux}'] = ( 

857 config_flux.column_ref_flux, 

858 column_target_flux, 

859 column_target_flux_err, 

860 True, 

861 ) 

862 # Note: magnitude errors are generally problematic and not worth aggregating 

863 columns_target[f'mag_{column_target_flux}'] = ( 

864 column_mag_temp, column_mag_temp_model, None, False, 

865 ) 

866 

867 # No need for colors if this is the last band and there are only two bands 

868 # (because it would just be the negative of the first color) 

869 skip_color = (idx_band == n_bands) and (n_bands <= 2) 

870 if not skip_color: 

871 column_color_temp_model = f'{column_color_temp}{idx_model}' 

872 column_color_err_temp_model = f'{column_color_err_temp}{idx_model}' 

873 

874 # e.g. if order is ugrizy, first color will be u - g 

875 cat_target[column_color_temp_model] = mag_prev[idx_model] - mag_model[idx_model] 

876 

877 # Sum (in quadrature, and admittedly sketchy for faint fluxes) magnitude errors 

878 cat_target[column_color_err_temp_model] = 2.5/np.log(10)*np.hypot( 

879 flux_err_frac[idx_model], flux_err_frac_prev[idx_model]) 

880 columns_target[f'color_{band_prev}_m_{band}_{column_target_flux}'] = ( 

881 column_color_temp, 

882 column_color_temp_model, 

883 column_color_err_temp_model, 

884 False, 

885 ) 

886 

887 for idx_bin, (mag_lo, mag_hi) in enumerate(bins_mag): 

888 row = data[idx_bin] 

889 # Reference sources only need to be counted once 

890 if idx_model == 0: 

891 select_ref_bin = select_ref_bins[idx_bin] 

892 select_target_bin = select_target & (mag_target > mag_lo) & (mag_target < mag_hi) 

893 

894 for sourcetype in SourceType: 

895 sourcetype_info = sourcetype.value 

896 is_extended = sourcetype_info.is_extended 

897 # Counts filtered by match selection and magnitude bin 

898 select_ref_sub = select_ref_bin.copy() 

899 select_target_sub = select_target_bin.copy() 

900 if is_extended is not None: 

901 is_extended_ref = (extended_ref == is_extended) 

902 select_ref_sub &= is_extended_ref 

903 if idx_model == 0: 

904 n_ref_sub = np.count_nonzero(select_ref_sub) 

905 row[_get_column_name(band, sourcetype_info.label, 'n_ref', 

906 MatchType.ALL.value)] = n_ref_sub 

907 select_target_sub &= (extended_target == is_extended) 

908 n_target_sub = np.count_nonzero(select_target_sub) 

909 row[_get_column_name(band, sourcetype_info.label, 'n_target', 

910 MatchType.ALL.value)] = n_target_sub 

911 

912 # Filter matches by magnitude bin and true class 

913 match_row_bin = match_row.copy() 

914 match_row_bin[~select_ref_sub] = -1 

915 match_good = match_row_bin >= 0 

916 

917 n_match = np.count_nonzero(match_good) 

918 

919 # Same for counts of matched target sources (for e.g. purity) 

920 

921 if n_match > 0: 

922 rows_matched = match_row_bin[match_good] 

923 subset_target = cat_target.iloc[rows_matched] 

924 if (is_extended is not None) and (idx_model == 0): 

925 right_type = extended_target[rows_matched] == is_extended 

926 n_total = len(right_type) 

927 n_right = np.count_nonzero(right_type) 

928 row[_get_column_name(band, sourcetype_info.label, 'n_ref', 

929 MatchType.MATCH_RIGHT.value)] = n_right 

930 row[_get_column_name( 

931 band, 

932 sourcetype_info.label, 

933 'n_ref', 

934 MatchType.MATCH_WRONG.value, 

935 )] = n_total - n_right 

936 

937 # compute stats for this bin, for all columns 

938 for column, (column_ref, column_target, column_err_target, skip_diff) \ 

939 in columns_target.items(): 

940 values_ref = cat_ref[column_ref][match_good].values 

941 errors_target = ( 

942 subset_target[column_err_target].values 

943 if column_err_target is not None 

944 else None 

945 ) 

946 compute_stats( 

947 values_ref, 

948 subset_target[column_target].values, 

949 errors_target, 

950 row, 

951 stats, 

952 suffixes, 

953 prefix=f'{band}_{sourcetype_info.label}_{column}', 

954 skip_diff=skip_diff, 

955 ) 

956 

957 # Count matched target sources with *measured* mags within bin 

958 # Used for e.g. purity calculation 

959 # Should be merged with above code if there's ever a need for 

960 # measuring stats on this source selection 

961 select_target_sub &= matched_target 

962 

963 if is_extended is not None and (np.count_nonzero(select_target_sub) > 0): 

964 n_total = np.count_nonzero(select_target_sub) 

965 right_type = np.zeros(n_target, dtype=bool) 

966 right_type[match_row[matched_ref & is_extended_ref]] = True 

967 right_type &= select_target_sub 

968 n_right = np.count_nonzero(right_type) 

969 row[_get_column_name(band, sourcetype_info.label, 'n_target', 

970 MatchType.MATCH_RIGHT.value)] = n_right 

971 row[_get_column_name(band, sourcetype_info.label, 'n_target', 

972 MatchType.MATCH_WRONG.value)] = n_total - n_right 

973 

974 # delete the flux/color columns since they change with each band 

975 for prefix in ('flux', 'mag'): 

976 del columns_target[f'{prefix}_{column_target_flux}'] 

977 if not skip_color: 

978 del columns_target[f'color_{band_prev}_m_{band}_{column_target_flux}'] 

979 

980 # keep values needed for colors 

981 flux_err_frac_prev = flux_err_frac 

982 mag_prev = mag_model 

983 band_prev = band 

984 if idx_band == 0: 

985 flux_err_frac_first = flux_err_frac 

986 mag_first = mag_model 

987 mag_ref_first = mag_ref 

988 

989 if config.include_unmatched: 

990 cat_matched = pd.concat((cat_matched, cat_unmatched)) 

991 

992 retStruct = pipeBase.Struct(cat_matched=cat_matched, diff_matched=pd.DataFrame(data)) 

993 return retStruct