Coverage for python/lsst/pipe/tasks/diff_matched_tract_catalog.py: 26%

412 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-16 11:06 +0000

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = [ 

23 'DiffMatchedTractCatalogConfig', 'DiffMatchedTractCatalogTask', 'MatchedCatalogFluxesConfig', 

24 'MatchType', 'MeasurementType', 'SourceType', 

25 'Statistic', 'Median', 'SigmaIQR', 'SigmaMAD', 'Percentile', 

26] 

27 

28import lsst.afw.geom as afwGeom 

29from lsst.meas.astrom.matcher_probabilistic import ( 

30 ComparableCatalog, ConvertCatalogCoordinatesConfig, 

31) 

32from lsst.meas.astrom.match_probabilistic_task import radec_to_xy 

33import lsst.pex.config as pexConfig 

34import lsst.pipe.base as pipeBase 

35import lsst.pipe.base.connectionTypes as cT 

36from lsst.skymap import BaseSkyMap 

37 

38from abc import ABCMeta, abstractmethod 

39from astropy.stats import mad_std 

40import astropy.units as u 

41from dataclasses import dataclass 

42from decimal import Decimal 

43from enum import Enum 

44import numpy as np 

45import pandas as pd 

46from scipy.stats import iqr 

47from smatch.matcher import sphdist 

48from typing import Sequence 

49 

50 

51def is_sequence_set(x: Sequence): 

52 return len(x) == len(set(x)) 

53 

54 

55def is_percentile(x: str): 

56 return 0 <= Decimal(x) <= 100 

57 

58 

59DiffMatchedTractCatalogBaseTemplates = { 

60 "name_input_cat_ref": "truth_summary", 

61 "name_input_cat_target": "objectTable_tract", 

62 "name_skymap": BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

63} 

64 

65 

66class DiffMatchedTractCatalogConnections( 

67 pipeBase.PipelineTaskConnections, 

68 dimensions=("tract", "skymap"), 

69 defaultTemplates=DiffMatchedTractCatalogBaseTemplates, 

70): 

71 cat_ref = cT.Input( 

72 doc="Reference object catalog to match from", 

73 name="{name_input_cat_ref}", 

74 storageClass="DataFrame", 

75 dimensions=("tract", "skymap"), 

76 deferLoad=True, 

77 ) 

78 cat_target = cT.Input( 

79 doc="Target object catalog to match", 

80 name="{name_input_cat_target}", 

81 storageClass="DataFrame", 

82 dimensions=("tract", "skymap"), 

83 deferLoad=True, 

84 ) 

85 skymap = cT.Input( 

86 doc="Input definition of geometry/bbox and projection/wcs for coadded exposures", 

87 name="{name_skymap}", 

88 storageClass="SkyMap", 

89 dimensions=("skymap",), 

90 ) 

91 cat_match_ref = cT.Input( 

92 doc="Reference match catalog with indices of target matches", 

93 name="match_ref_{name_input_cat_ref}_{name_input_cat_target}", 

94 storageClass="DataFrame", 

95 dimensions=("tract", "skymap"), 

96 deferLoad=True, 

97 ) 

98 cat_match_target = cT.Input( 

99 doc="Target match catalog with indices of references matches", 

100 name="match_target_{name_input_cat_ref}_{name_input_cat_target}", 

101 storageClass="DataFrame", 

102 dimensions=("tract", "skymap"), 

103 deferLoad=True, 

104 ) 

105 columns_match_target = cT.Input( 

106 doc="Target match catalog columns", 

107 name="match_target_{name_input_cat_ref}_{name_input_cat_target}.columns", 

108 storageClass="DataFrameIndex", 

109 dimensions=("tract", "skymap"), 

110 ) 

111 cat_matched = cT.Output( 

112 doc="Catalog with reference and target columns for joined sources", 

113 name="matched_{name_input_cat_ref}_{name_input_cat_target}", 

114 storageClass="DataFrame", 

115 dimensions=("tract", "skymap"), 

116 ) 

117 diff_matched = cT.Output( 

118 doc="Table with aggregated counts, difference and chi statistics", 

119 name="diff_matched_{name_input_cat_ref}_{name_input_cat_target}", 

120 storageClass="DataFrame", 

121 dimensions=("tract", "skymap"), 

122 ) 

123 

124 

125class MatchedCatalogFluxesConfig(pexConfig.Config): 

126 column_ref_flux = pexConfig.Field( 

127 dtype=str, 

128 doc='Reference catalog flux column name', 

129 ) 

130 columns_target_flux = pexConfig.ListField( 

131 dtype=str, 

132 listCheck=is_sequence_set, 

133 doc="List of target catalog flux column names", 

134 ) 

135 columns_target_flux_err = pexConfig.ListField( 

136 dtype=str, 

137 listCheck=is_sequence_set, 

138 doc="List of target catalog flux error column names", 

139 ) 

140 

141 # this should be an orderedset 

142 @property 

143 def columns_in_ref(self) -> list[str]: 

144 return [self.column_ref_flux] 

145 

146 # this should also be an orderedset 

147 @property 

148 def columns_in_target(self) -> list[str]: 

149 columns = [col for col in self.columns_target_flux] 

150 columns.extend(col for col in self.columns_target_flux_err if col not in columns) 

151 return columns 

152 

153 

154class DiffMatchedTractCatalogConfig( 

155 pipeBase.PipelineTaskConfig, 

156 pipelineConnections=DiffMatchedTractCatalogConnections, 

157): 

158 column_matched_prefix_ref = pexConfig.Field[str]( 

159 default='refcat_', 

160 doc='The prefix for matched columns copied from the reference catalog', 

161 ) 

162 column_ref_extended = pexConfig.Field[str]( 

163 default='is_pointsource', 

164 doc='The boolean reference table column specifying if the target is extended', 

165 ) 

166 column_ref_extended_inverted = pexConfig.Field[bool]( 

167 default=True, 

168 doc='Whether column_ref_extended specifies if the object is compact, not extended', 

169 ) 

170 column_target_extended = pexConfig.Field[str]( 

171 default='refExtendedness', 

172 doc='The target table column estimating the extendedness of the object (0 <= x <= 1)', 

173 ) 

174 include_unmatched = pexConfig.Field[bool]( 

175 default=False, 

176 doc="Whether to include unmatched rows in the matched table", 

177 ) 

178 

179 @property 

180 def columns_in_ref(self) -> list[str]: 

181 columns_all = [self.coord_format.column_ref_coord1, self.coord_format.column_ref_coord2, 

182 self.column_ref_extended] 

183 for column_lists in ( 

184 ( 

185 self.columns_ref_copy, 

186 ), 

187 (x.columns_in_ref for x in self.columns_flux.values()), 

188 ): 

189 for column_list in column_lists: 

190 columns_all.extend(column_list) 

191 

192 return list({column: None for column in columns_all}.keys()) 

193 

194 @property 

195 def columns_in_target(self) -> list[str]: 

196 columns_all = [self.coord_format.column_target_coord1, self.coord_format.column_target_coord2, 

197 self.column_target_extended] 

198 if self.coord_format.coords_ref_to_convert is not None: 

199 columns_all.extend(col for col in self.coord_format.coords_ref_to_convert.values() 

200 if col not in columns_all) 

201 for column_lists in ( 

202 ( 

203 self.columns_target_coord_err, 

204 self.columns_target_select_false, 

205 self.columns_target_select_true, 

206 self.columns_target_copy, 

207 ), 

208 (x.columns_in_target for x in self.columns_flux.values()), 

209 ): 

210 for column_list in column_lists: 

211 columns_all.extend(col for col in column_list if col not in columns_all) 

212 return columns_all 

213 

214 columns_flux = pexConfig.ConfigDictField( 

215 doc="Configs for flux columns for each band", 

216 keytype=str, 

217 itemtype=MatchedCatalogFluxesConfig, 

218 default={}, 

219 ) 

220 columns_ref_mag_to_nJy = pexConfig.DictField[str, str]( 

221 doc='Reference table AB mag columns to convert to nJy flux columns with new names', 

222 default={}, 

223 ) 

224 columns_ref_copy = pexConfig.ListField[str]( 

225 doc='Reference table columns to copy into cat_matched', 

226 default=[], 

227 listCheck=is_sequence_set, 

228 ) 

229 columns_target_coord_err = pexConfig.ListField[str]( 229 ↛ exitline 229 didn't jump to the function exit

230 doc='Target table coordinate columns with standard errors (sigma)', 

231 listCheck=lambda x: (len(x) == 2) and (x[0] != x[1]), 

232 ) 

233 columns_target_copy = pexConfig.ListField[str]( 

234 doc='Target table columns to copy into cat_matched', 

235 default=('patch',), 

236 listCheck=is_sequence_set, 

237 ) 

238 columns_target_mag_to_nJy = pexConfig.DictField[str, str]( 

239 doc='Target table AB mag columns to convert to nJy flux columns with new names', 

240 default={}, 

241 ) 

242 columns_target_select_true = pexConfig.ListField[str]( 

243 doc='Target table columns to require to be True for selecting sources', 

244 default=('detect_isPrimary',), 

245 listCheck=is_sequence_set, 

246 ) 

247 columns_target_select_false = pexConfig.ListField[str]( 

248 doc='Target table columns to require to be False for selecting sources', 

249 default=('merge_peak_sky',), 

250 listCheck=is_sequence_set, 

251 ) 

252 coord_format = pexConfig.ConfigField[ConvertCatalogCoordinatesConfig]( 

253 doc="Configuration for coordinate conversion", 

254 ) 

255 extendedness_cut = pexConfig.Field[float]( 

256 dtype=float, 

257 default=0.5, 

258 doc='Minimum extendedness for a measured source to be considered extended', 

259 ) 

260 mag_num_bins = pexConfig.Field[int]( 

261 doc='Number of magnitude bins', 

262 default=15, 

263 ) 

264 mag_brightest_ref = pexConfig.Field[float]( 

265 doc='Brightest magnitude cutoff for binning', 

266 default=15, 

267 ) 

268 mag_ceiling_target = pexConfig.Field[float]( 

269 doc='Ceiling (maximum/faint) magnitude for target sources', 

270 default=None, 

271 optional=True, 

272 ) 

273 mag_faintest_ref = pexConfig.Field[float]( 

274 doc='Faintest magnitude cutoff for binning', 

275 default=30, 

276 ) 

277 mag_zeropoint_ref = pexConfig.Field[float]( 

278 doc='Magnitude zeropoint for reference sources', 

279 default=31.4, 

280 ) 

281 mag_zeropoint_target = pexConfig.Field[float]( 

282 doc='Magnitude zeropoint for target sources', 

283 default=31.4, 

284 ) 

285 percentiles = pexConfig.ListField[str]( 

286 doc='Percentiles to compute for diff/chi values', 

287 # -2, -1, +1, +2 sigma percentiles for normal distribution 

288 default=('2.275', '15.866', '84.134', '97.725'), 

289 itemCheck=is_percentile, 

290 listCheck=is_sequence_set, 

291 ) 

292 

293 def validate(self): 

294 super().validate() 

295 

296 errors = [] 

297 

298 for columns_mag, columns_in, name_columns_copy in ( 

299 (self.columns_ref_mag_to_nJy, self.columns_in_ref, "columns_ref_copy"), 

300 (self.columns_target_mag_to_nJy, self.columns_in_target, "columns_target_copy"), 

301 ): 

302 columns_copy = getattr(self, name_columns_copy) 

303 for column_old, column_new in columns_mag.items(): 

304 if column_old not in columns_in: 

305 errors.append( 

306 f"{column_old=} key in self.columns_mag_to_nJy not found in {columns_in=}; did you" 

307 f" forget to add it to self.{name_columns_copy}={columns_copy}?" 

308 ) 

309 if column_new in columns_copy: 

310 errors.append( 

311 f"{column_new=} value found in self.{name_columns_copy}={columns_copy}" 

312 f" this will cause a collision. Please choose a different name." 

313 ) 

314 if errors: 

315 raise ValueError("\n".join(errors)) 

316 

317 

318@dataclass(frozen=True) 

319class MeasurementTypeInfo: 

320 doc: str 

321 name: str 

322 

323 

324class MeasurementType(Enum): 

325 DIFF = MeasurementTypeInfo( 

326 doc="difference (measured - reference)", 

327 name="diff", 

328 ) 

329 CHI = MeasurementTypeInfo( 

330 doc="scaled difference (measured - reference)/error", 

331 name="chi", 

332 ) 

333 

334 

335class Statistic(metaclass=ABCMeta): 

336 """A statistic that can be applied to a set of values. 

337 """ 

338 @abstractmethod 

339 def doc(self) -> str: 

340 """A description of the statistic""" 

341 raise NotImplementedError('Subclasses must implement this method') 

342 

343 @abstractmethod 

344 def name_short(self) -> str: 

345 """A short name for the statistic, e.g. for a table column name""" 

346 raise NotImplementedError('Subclasses must implement this method') 

347 

348 @abstractmethod 

349 def value(self, values): 

350 """The value of the statistic for a set of input values. 

351 

352 Parameters 

353 ---------- 

354 values : `Collection` [`float`] 

355 A set of values to compute the statistic for. 

356 

357 Returns 

358 ------- 

359 statistic : `float` 

360 The value of the statistic. 

361 """ 

362 raise NotImplementedError('Subclasses must implement this method') 

363 

364 

365class Median(Statistic): 

366 """The median of a set of values.""" 

367 @classmethod 

368 def doc(cls) -> str: 

369 return "Median" 

370 

371 @classmethod 

372 def name_short(cls) -> str: 

373 return "median" 

374 

375 def value(self, values): 

376 return np.median(values) 

377 

378 

379class SigmaIQR(Statistic): 

380 """The re-scaled interquartile range (sigma equivalent).""" 

381 @classmethod 

382 def doc(cls) -> str: 

383 return "Interquartile range divided by ~1.349 (sigma-equivalent)" 

384 

385 @classmethod 

386 def name_short(cls) -> str: 

387 return "sig_iqr" 

388 

389 def value(self, values): 

390 return iqr(values, scale='normal') 

391 

392 

393class SigmaMAD(Statistic): 

394 """The re-scaled median absolute deviation (sigma equivalent).""" 

395 @classmethod 

396 def doc(cls) -> str: 

397 return "Median absolute deviation multiplied by ~1.4826 (sigma-equivalent)" 

398 

399 @classmethod 

400 def name_short(cls) -> str: 

401 return "sig_mad" 

402 

403 def value(self, values): 

404 return mad_std(values) 

405 

406 

407@dataclass(frozen=True) 

408class Percentile(Statistic): 

409 """An arbitrary percentile. 

410 

411 Parameters 

412 ---------- 

413 percentile : `float` 

414 A valid percentile (0 <= p <= 100). 

415 """ 

416 percentile: float 

417 

418 def doc(self) -> str: 

419 return "Median absolute deviation multiplied by ~1.4826 (sigma-equivalent)" 

420 

421 def name_short(self) -> str: 

422 return f"pctl{f'{self.percentile/100:.5f}'[2:]}" 

423 

424 def value(self, values): 

425 return np.percentile(values, self.percentile) 

426 

427 def __post_init__(self): 

428 if not ((self.percentile >= 0) and (self.percentile <= 100)): 

429 raise ValueError(f'percentile={self.percentile} not >=0 and <= 100') 

430 

431 

432def _get_stat_name(*args): 

433 return '_'.join(args) 

434 

435 

436def _get_column_name(band, *args): 

437 return f"{band}_{_get_stat_name(*args)}" 

438 

439 

440def compute_stats(values_ref, values_target, errors_target, row, stats, suffixes, prefix, skip_diff=False): 

441 """Compute statistics on differences and store results in a row. 

442 

443 Parameters 

444 ---------- 

445 values_ref : `numpy.ndarray`, (N,) 

446 Reference values. 

447 values_target : `numpy.ndarray`, (N,) 

448 Measured values. 

449 errors_target : `numpy.ndarray`, (N,) 

450 Errors (standard deviations) on `values_target`. 

451 row : `numpy.ndarray`, (1, C) 

452 A numpy array with pre-assigned column names. 

453 stats : `Dict` [`str`, `Statistic`] 

454 A dict of `Statistic` values to measure, keyed by their column suffix. 

455 suffixes : `Dict` [`MeasurementType`, `str`] 

456 A dict of measurement type column suffixes, keyed by the measurement type. 

457 prefix : `str` 

458 A prefix for all column names (e.g. band). 

459 skip_diff : `bool` 

460 Whether to skip computing statistics on differences. Note that 

461 differences will still be computed for chi statistics. 

462 

463 Returns 

464 ------- 

465 row_with_stats : `numpy.ndarray`, (1, C) 

466 The original `row` with statistic values assigned. 

467 """ 

468 n_ref = len(values_ref) 

469 if n_ref > 0: 

470 n_target = len(values_target) 

471 n_target_err = len(errors_target) if errors_target is not None else n_ref 

472 if (n_target != n_ref) or (n_target_err != n_ref): 

473 raise ValueError(f'lengths of values_ref={n_ref}, values_target={n_target}' 

474 f', error_target={n_target_err} must match') 

475 

476 do_chi = errors_target is not None 

477 diff = values_target - values_ref 

478 chi = diff/errors_target if do_chi else diff 

479 # Could make this configurable, but non-finite values/errors are not really usable 

480 valid = np.isfinite(chi) 

481 values_type = {} if skip_diff else {MeasurementType.DIFF: diff[valid]} 

482 if do_chi: 

483 values_type[MeasurementType.CHI] = chi[valid] 

484 

485 for suffix_type, suffix in suffixes.items(): 

486 values = values_type.get(suffix_type) 

487 if values is not None and len(values) > 0: 

488 for stat_name, stat in stats.items(): 

489 row[_get_stat_name(prefix, suffix, stat_name)] = stat.value(values) 

490 return row 

491 

492 

493@dataclass(frozen=True) 

494class SourceTypeInfo: 

495 is_extended: bool | None 

496 label: str 

497 

498 

499class SourceType(Enum): 

500 ALL = SourceTypeInfo(is_extended=None, label='all') 

501 RESOLVED = SourceTypeInfo(is_extended=True, label='resolved') 

502 UNRESOLVED = SourceTypeInfo(is_extended=False, label='unresolved') 

503 

504 

505class MatchType(Enum): 

506 ALL = 'all' 

507 MATCH_RIGHT = 'match_right' 

508 MATCH_WRONG = 'match_wrong' 

509 

510 

511def _get_columns(bands_columns: dict, suffixes: dict, suffixes_flux: dict, suffixes_mag: dict, 

512 stats: dict, target: ComparableCatalog, column_dist: str): 

513 """Get column names for a table of difference statistics. 

514 

515 Parameters 

516 ---------- 

517 bands_columns : `Dict` [`str`,`MatchedCatalogFluxesConfig`] 

518 Dict keyed by band of flux column configuration. 

519 suffixes, suffixes_flux, suffixes_mag : `Dict` [`MeasurementType`, `str`] 

520 Dict of suffixes for each `MeasurementType` type, for general columns (e.g. 

521 coordinates), fluxes and magnitudes, respectively. 

522 stats : `Dict` [`str`, `Statistic`] 

523 Dict of suffixes for each `Statistic` type. 

524 target : `ComparableCatalog` 

525 A target catalog with coordinate column names. 

526 column_dist : `str` 

527 The name of the distance column. 

528 

529 Returns 

530 ------- 

531 columns : `Dict` [`str`, `type`] 

532 Dictionary of column types keyed by name. 

533 n_models : `int` 

534 The number of models measurements will be made for. 

535 

536 Notes 

537 ----- 

538 Presently, models must be identical for each band. 

539 """ 

540 # Initial columns 

541 columns = { 

542 "bin": int, 

543 "mag_min": float, 

544 "mag_max": float, 

545 } 

546 

547 # pre-assign all of the columns with appropriate types 

548 n_models = 0 

549 

550 bands = list(bands_columns.keys()) 

551 n_bands = len(bands) 

552 

553 for idx, (band, config_flux) in enumerate(bands_columns.items()): 

554 columns_suffix = [ 

555 ('flux', suffixes_flux), 

556 ('mag', suffixes_mag), 

557 ] 

558 if idx == 0: 

559 n_models = len(config_flux.columns_target_flux) 

560 if (idx > 0) or (n_bands > 2): 

561 columns_suffix.append((f'color_{bands[idx - 1]}_m_{band}', suffixes)) 

562 n_models_flux = len(config_flux.columns_target_flux) 

563 n_models_err = len(config_flux.columns_target_flux_err) 

564 

565 # TODO: Do equivalent validation earlier, in the config 

566 if (n_models_flux != n_models) or (n_models_err != n_models): 

567 raise RuntimeError(f'{config_flux} len(columns_target_flux)={n_models_flux} and' 

568 f' len(columns_target_flux_err)={n_models_err} must equal {n_models}') 

569 

570 for sourcetype in SourceType: 

571 label = sourcetype.value.label 

572 # Totals would be redundant 

573 if sourcetype != SourceType.ALL: 

574 for item in (f'n_{itype}_{mtype.value}' for itype in ('ref', 'target') 

575 for mtype in MatchType): 

576 columns[_get_column_name(band, label, item)] = int 

577 

578 for item in (target.column_coord1, target.column_coord2, column_dist): 

579 for suffix in suffixes.values(): 

580 for stat in stats.keys(): 

581 columns[_get_column_name(band, label, item, suffix, stat)] = float 

582 

583 for item in config_flux.columns_target_flux: 

584 for prefix_item, suffixes_col in columns_suffix: 

585 for suffix in suffixes_col.values(): 

586 for stat in stats.keys(): 

587 columns[_get_column_name(band, label, prefix_item, item, suffix, stat)] = float 

588 

589 return columns, n_models 

590 

591 

592class DiffMatchedTractCatalogTask(pipeBase.PipelineTask): 

593 """Load subsets of matched catalogs and output a merged catalog of matched sources. 

594 """ 

595 ConfigClass = DiffMatchedTractCatalogConfig 

596 _DefaultName = "DiffMatchedTractCatalog" 

597 

598 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

599 inputs = butlerQC.get(inputRefs) 

600 skymap = inputs.pop("skymap") 

601 

602 columns_match_target = ['match_row'] 

603 if 'match_candidate' in inputs['columns_match_target']: 

604 columns_match_target.append('match_candidate') 

605 

606 outputs = self.run( 

607 catalog_ref=inputs['cat_ref'].get(parameters={'columns': self.config.columns_in_ref}), 

608 catalog_target=inputs['cat_target'].get(parameters={'columns': self.config.columns_in_target}), 

609 catalog_match_ref=inputs['cat_match_ref'].get( 

610 parameters={'columns': ['match_candidate', 'match_row']}, 

611 ), 

612 catalog_match_target=inputs['cat_match_target'].get( 

613 parameters={'columns': columns_match_target}, 

614 ), 

615 wcs=skymap[butlerQC.quantum.dataId["tract"]].wcs, 

616 ) 

617 butlerQC.put(outputs, outputRefs) 

618 

619 def run( 

620 self, 

621 catalog_ref: pd.DataFrame, 

622 catalog_target: pd.DataFrame, 

623 catalog_match_ref: pd.DataFrame, 

624 catalog_match_target: pd.DataFrame, 

625 wcs: afwGeom.SkyWcs = None, 

626 ) -> pipeBase.Struct: 

627 """Load matched reference and target (measured) catalogs, measure summary statistics, and output 

628 a combined matched catalog with columns from both inputs. 

629 

630 Parameters 

631 ---------- 

632 catalog_ref : `pandas.DataFrame` 

633 A reference catalog to diff objects/sources from. 

634 catalog_target : `pandas.DataFrame` 

635 A target catalog to diff reference objects/sources to. 

636 catalog_match_ref : `pandas.DataFrame` 

637 A catalog with match indices of target sources and selection flags 

638 for each reference source. 

639 catalog_match_target : `pandas.DataFrame` 

640 A catalog with selection flags for each target source. 

641 wcs : `lsst.afw.image.SkyWcs` 

642 A coordinate system to convert catalog positions to sky coordinates, 

643 if necessary. 

644 

645 Returns 

646 ------- 

647 retStruct : `lsst.pipe.base.Struct` 

648 A struct with output_ref and output_target attribute containing the 

649 output matched catalogs. 

650 """ 

651 # Would be nice if this could refer directly to ConfigClass 

652 config: DiffMatchedTractCatalogConfig = self.config 

653 

654 select_ref = catalog_match_ref['match_candidate'].values 

655 # Add additional selection criteria for target sources beyond those for matching 

656 # (not recommended, but can be done anyway) 

657 select_target = (catalog_match_target['match_candidate'].values 

658 if 'match_candidate' in catalog_match_target.columns 

659 else np.ones(len(catalog_match_target), dtype=bool)) 

660 for column in config.columns_target_select_true: 

661 select_target &= catalog_target[column].values 

662 for column in config.columns_target_select_false: 

663 select_target &= ~catalog_target[column].values 

664 

665 ref, target = config.coord_format.format_catalogs( 

666 catalog_ref=catalog_ref, catalog_target=catalog_target, 

667 select_ref=None, select_target=select_target, wcs=wcs, radec_to_xy_func=radec_to_xy, 

668 ) 

669 cat_ref = ref.catalog 

670 cat_target = target.catalog 

671 n_target = len(cat_target) 

672 

673 if config.include_unmatched: 

674 for cat_add, cat_match in ((cat_ref, catalog_match_ref), (cat_target, catalog_match_target)): 

675 cat_add['match_candidate'] = cat_match['match_candidate'].values 

676 

677 match_row = catalog_match_ref['match_row'].values 

678 matched_ref = match_row >= 0 

679 matched_row = match_row[matched_ref] 

680 matched_target = np.zeros(n_target, dtype=bool) 

681 matched_target[matched_row] = True 

682 

683 # Add/compute distance columns 

684 coord1_target_err, coord2_target_err = config.columns_target_coord_err 

685 column_dist, column_dist_err = 'match_distance', 'match_distanceErr' 

686 dist = np.full(n_target, np.nan) 

687 

688 target_match_c1, target_match_c2 = (coord[matched_row] for coord in (target.coord1, target.coord2)) 

689 target_ref_c1, target_ref_c2 = (coord[matched_ref] for coord in (ref.coord1, ref.coord2)) 

690 

691 dist_err = np.full(n_target, np.nan) 

692 dist[matched_row] = sphdist( 

693 target_match_c1, target_match_c2, target_ref_c1, target_ref_c2 

694 ) if config.coord_format.coords_spherical else np.hypot( 

695 target_match_c1 - target_ref_c1, target_match_c2 - target_ref_c2, 

696 ) 

697 # Should probably explicitly add cosine terms if ref has errors too 

698 dist_err[matched_row] = sphdist( 

699 target_match_c1, target_match_c2, 

700 target_match_c1 + cat_target.iloc[matched_row][coord1_target_err].values, 

701 target_match_c2 + cat_target.iloc[matched_row][coord2_target_err].values, 

702 ) if config.coord_format.coords_spherical else np.hypot( 

703 cat_target.iloc[matched_row][coord1_target_err].values, 

704 cat_target.iloc[matched_row][coord2_target_err].values 

705 ) 

706 cat_target[column_dist], cat_target[column_dist_err] = dist, dist_err 

707 

708 # Create a matched table, preserving the target catalog's named index (if it has one) 

709 cat_left = cat_target.iloc[matched_row] 

710 has_index_left = cat_left.index.name is not None 

711 cat_right = cat_ref[matched_ref].reset_index() 

712 cat_right.columns = [f'{config.column_matched_prefix_ref}{col}' for col in cat_right.columns] 

713 cat_matched = pd.concat(objs=(cat_left.reset_index(drop=not has_index_left), cat_right), axis=1) 

714 

715 if config.include_unmatched: 

716 # Create an unmatched table with the same schema as the matched one 

717 # ... but only for objects with no matches (for completeness/purity) 

718 # and that were selected for matching (or inclusion via config) 

719 cat_right = cat_ref[~matched_ref & select_ref].reset_index(drop=False) 

720 cat_right.columns = (f'{config.column_matched_prefix_ref}{col}' for col in cat_right.columns) 

721 match_row_target = catalog_match_target['match_row'].values 

722 cat_left = cat_target[~(match_row_target >= 0) & select_target].reset_index( 

723 drop=not has_index_left) 

724 # See https://github.com/pandas-dev/pandas/issues/46662 

725 # astropy masked columns would handle this much more gracefully 

726 # Unfortunately, that would require storageClass migration 

727 # So we use pandas "extended" nullable types for now 

728 for cat_i in (cat_left, cat_right): 

729 for colname in cat_i.columns: 

730 column = cat_i[colname] 

731 dtype = str(column.dtype) 

732 if dtype == "bool": 

733 cat_i[colname] = column.astype("boolean") 

734 elif dtype.startswith("int"): 

735 cat_i[colname] = column.astype(f"Int{dtype[3:]}") 

736 elif dtype.startswith("uint"): 

737 cat_i[colname] = column.astype(f"UInt{dtype[3:]}") 

738 cat_unmatched = pd.concat(objs=(cat_left, cat_right)) 

739 

740 for columns_convert_base, prefix in ( 

741 (config.columns_ref_mag_to_nJy, config.column_matched_prefix_ref), 

742 (config.columns_target_mag_to_nJy, ""), 

743 ): 

744 if columns_convert_base: 

745 columns_convert = { 

746 f"{prefix}{k}": f"{prefix}{v}" for k, v in columns_convert_base.items() 

747 } if prefix else columns_convert_base 

748 for cat_convert in (cat_matched, cat_unmatched): 

749 cat_convert.rename(columns=columns_convert, inplace=True) 

750 for column_flux in columns_convert.values(): 

751 cat_convert[column_flux] = u.ABmag.to(u.nJy, cat_convert[column_flux]) 

752 

753 # TODO: Deprecate all matched difference output in DM-43831 (per RFC-1008) 

754 

755 # Slightly smelly hack for when a column (like distance) is already relative to truth 

756 column_dummy = 'dummy' 

757 cat_ref[column_dummy] = np.zeros_like(ref.coord1) 

758 

759 # Add a boolean column for whether a match is classified correctly 

760 # TODO: remove the assumption of a boolean column 

761 extended_ref = cat_ref[config.column_ref_extended] == (not config.column_ref_extended_inverted) 

762 

763 extended_target = cat_target[config.column_target_extended].values >= config.extendedness_cut 

764 

765 # Define difference/chi columns and statistics thereof 

766 suffixes = {MeasurementType.DIFF: 'diff', MeasurementType.CHI: 'chi'} 

767 # Skip diff for fluxes - covered by mags 

768 suffixes_flux = {MeasurementType.CHI: suffixes[MeasurementType.CHI]} 

769 # Skip chi for magnitudes, which have strange errors 

770 suffixes_mag = {MeasurementType.DIFF: suffixes[MeasurementType.DIFF]} 

771 stats = {stat.name_short(): stat() for stat in (Median, SigmaIQR, SigmaMAD)} 

772 

773 for percentile in self.config.percentiles: 

774 stat = Percentile(percentile=float(Decimal(percentile))) 

775 stats[stat.name_short()] = stat 

776 

777 # Get dict of column names 

778 columns, n_models = _get_columns( 

779 bands_columns=config.columns_flux, 

780 suffixes=suffixes, 

781 suffixes_flux=suffixes_flux, 

782 suffixes_mag=suffixes_mag, 

783 stats=stats, 

784 target=target, 

785 column_dist=column_dist, 

786 ) 

787 

788 # Setup numpy table 

789 n_bins = config.mag_num_bins 

790 data = np.zeros((n_bins,), dtype=[(key, value) for key, value in columns.items()]) 

791 data['bin'] = np.arange(n_bins) 

792 

793 # Setup bins 

794 bins_mag = np.linspace(start=config.mag_brightest_ref, stop=config.mag_faintest_ref, 

795 num=n_bins + 1) 

796 data['mag_min'] = bins_mag[:-1] 

797 data['mag_max'] = bins_mag[1:] 

798 bins_mag = tuple((bins_mag[idx], bins_mag[idx + 1]) for idx in range(n_bins)) 

799 

800 # Define temporary columns for intermediate storage 

801 column_mag_temp = 'mag_temp' 

802 column_color_temp = 'color_temp' 

803 column_color_err_temp = 'colorErr_temp' 

804 flux_err_frac_prev = [None]*n_models 

805 mag_prev = [None]*n_models 

806 

807 columns_target = { 

808 target.column_coord1: ( 

809 ref.column_coord1, target.column_coord1, coord1_target_err, False, 

810 ), 

811 target.column_coord2: ( 

812 ref.column_coord2, target.column_coord2, coord2_target_err, False, 

813 ), 

814 column_dist: (column_dummy, column_dist, column_dist_err, False), 

815 } 

816 

817 # Cheat a little and do the first band last so that the color is 

818 # based on the last band 

819 band_fluxes = [(band, config_flux) for (band, config_flux) in config.columns_flux.items()] 

820 n_bands = len(band_fluxes) 

821 if n_bands > 0: 

822 band_fluxes.append(band_fluxes[0]) 

823 flux_err_frac_first = None 

824 mag_first = None 

825 mag_ref_first = None 

826 

827 band_prev = None 

828 for idx_band, (band, config_flux) in enumerate(band_fluxes): 

829 if idx_band == n_bands: 

830 # These were already computed earlier 

831 mag_ref = mag_ref_first 

832 flux_err_frac = flux_err_frac_first 

833 mag_model = mag_first 

834 else: 

835 mag_ref = -2.5*np.log10(cat_ref[config_flux.column_ref_flux]) + config.mag_zeropoint_ref 

836 flux_err_frac = [None]*n_models 

837 mag_model = [None]*n_models 

838 

839 if idx_band > 0: 

840 cat_ref[column_color_temp] = cat_ref[column_mag_temp] - mag_ref 

841 

842 cat_ref[column_mag_temp] = mag_ref 

843 

844 select_ref_bins = [select_ref & (mag_ref > mag_lo) & (mag_ref < mag_hi) 

845 for idx_bin, (mag_lo, mag_hi) in enumerate(bins_mag)] 

846 

847 # Iterate over multiple models, compute their mags and colours (if there's a previous band) 

848 for idx_model in range(n_models): 

849 column_target_flux = config_flux.columns_target_flux[idx_model] 

850 column_target_flux_err = config_flux.columns_target_flux_err[idx_model] 

851 

852 flux_target = cat_target[column_target_flux] 

853 mag_target = -2.5*np.log10(flux_target) + config.mag_zeropoint_target 

854 if config.mag_ceiling_target is not None: 

855 mag_target[mag_target > config.mag_ceiling_target] = config.mag_ceiling_target 

856 mag_model[idx_model] = mag_target 

857 

858 # These are needed for computing magnitude/color "errors" (which are a sketchy concept) 

859 flux_err_frac[idx_model] = cat_target[column_target_flux_err]/flux_target 

860 

861 # Stop if idx == 0: The rest will be picked up at idx == n_bins 

862 if idx_band > 0: 

863 # Keep these mags tabulated for convenience 

864 column_mag_temp_model = f'{column_mag_temp}{idx_model}' 

865 cat_target[column_mag_temp_model] = mag_target 

866 

867 columns_target[f'flux_{column_target_flux}'] = ( 

868 config_flux.column_ref_flux, 

869 column_target_flux, 

870 column_target_flux_err, 

871 True, 

872 ) 

873 # Note: magnitude errors are generally problematic and not worth aggregating 

874 columns_target[f'mag_{column_target_flux}'] = ( 

875 column_mag_temp, column_mag_temp_model, None, False, 

876 ) 

877 

878 # No need for colors if this is the last band and there are only two bands 

879 # (because it would just be the negative of the first color) 

880 skip_color = (idx_band == n_bands) and (n_bands <= 2) 

881 if not skip_color: 

882 column_color_temp_model = f'{column_color_temp}{idx_model}' 

883 column_color_err_temp_model = f'{column_color_err_temp}{idx_model}' 

884 

885 # e.g. if order is ugrizy, first color will be u - g 

886 cat_target[column_color_temp_model] = mag_prev[idx_model] - mag_model[idx_model] 

887 

888 # Sum (in quadrature, and admittedly sketchy for faint fluxes) magnitude errors 

889 cat_target[column_color_err_temp_model] = 2.5/np.log(10)*np.hypot( 

890 flux_err_frac[idx_model], flux_err_frac_prev[idx_model]) 

891 columns_target[f'color_{band_prev}_m_{band}_{column_target_flux}'] = ( 

892 column_color_temp, 

893 column_color_temp_model, 

894 column_color_err_temp_model, 

895 False, 

896 ) 

897 

898 for idx_bin, (mag_lo, mag_hi) in enumerate(bins_mag): 

899 row = data[idx_bin] 

900 # Reference sources only need to be counted once 

901 if idx_model == 0: 

902 select_ref_bin = select_ref_bins[idx_bin] 

903 select_target_bin = select_target & (mag_target > mag_lo) & (mag_target < mag_hi) 

904 

905 for sourcetype in SourceType: 

906 sourcetype_info = sourcetype.value 

907 is_extended = sourcetype_info.is_extended 

908 # Counts filtered by match selection and magnitude bin 

909 select_ref_sub = select_ref_bin.copy() 

910 select_target_sub = select_target_bin.copy() 

911 if is_extended is not None: 

912 is_extended_ref = (extended_ref == is_extended) 

913 select_ref_sub &= is_extended_ref 

914 if idx_model == 0: 

915 n_ref_sub = np.count_nonzero(select_ref_sub) 

916 row[_get_column_name(band, sourcetype_info.label, 'n_ref', 

917 MatchType.ALL.value)] = n_ref_sub 

918 select_target_sub &= (extended_target == is_extended) 

919 n_target_sub = np.count_nonzero(select_target_sub) 

920 row[_get_column_name(band, sourcetype_info.label, 'n_target', 

921 MatchType.ALL.value)] = n_target_sub 

922 

923 # Filter matches by magnitude bin and true class 

924 match_row_bin = match_row.copy() 

925 match_row_bin[~select_ref_sub] = -1 

926 match_good = match_row_bin >= 0 

927 

928 n_match = np.count_nonzero(match_good) 

929 

930 # Same for counts of matched target sources (for e.g. purity) 

931 

932 if n_match > 0: 

933 rows_matched = match_row_bin[match_good] 

934 subset_target = cat_target.iloc[rows_matched] 

935 if (is_extended is not None) and (idx_model == 0): 

936 right_type = extended_target[rows_matched] == is_extended 

937 n_total = len(right_type) 

938 n_right = np.count_nonzero(right_type) 

939 row[_get_column_name(band, sourcetype_info.label, 'n_ref', 

940 MatchType.MATCH_RIGHT.value)] = n_right 

941 row[_get_column_name( 

942 band, 

943 sourcetype_info.label, 

944 'n_ref', 

945 MatchType.MATCH_WRONG.value, 

946 )] = n_total - n_right 

947 

948 # compute stats for this bin, for all columns 

949 for column, (column_ref, column_target, column_err_target, skip_diff) \ 

950 in columns_target.items(): 

951 values_ref = cat_ref[column_ref][match_good].values 

952 errors_target = ( 

953 subset_target[column_err_target].values 

954 if column_err_target is not None 

955 else None 

956 ) 

957 compute_stats( 

958 values_ref, 

959 subset_target[column_target].values, 

960 errors_target, 

961 row, 

962 stats, 

963 suffixes, 

964 prefix=f'{band}_{sourcetype_info.label}_{column}', 

965 skip_diff=skip_diff, 

966 ) 

967 

968 # Count matched target sources with *measured* mags within bin 

969 # Used for e.g. purity calculation 

970 # Should be merged with above code if there's ever a need for 

971 # measuring stats on this source selection 

972 select_target_sub &= matched_target 

973 

974 if is_extended is not None and (np.count_nonzero(select_target_sub) > 0): 

975 n_total = np.count_nonzero(select_target_sub) 

976 right_type = np.zeros(n_target, dtype=bool) 

977 right_type[match_row[matched_ref & is_extended_ref]] = True 

978 right_type &= select_target_sub 

979 n_right = np.count_nonzero(right_type) 

980 row[_get_column_name(band, sourcetype_info.label, 'n_target', 

981 MatchType.MATCH_RIGHT.value)] = n_right 

982 row[_get_column_name(band, sourcetype_info.label, 'n_target', 

983 MatchType.MATCH_WRONG.value)] = n_total - n_right 

984 

985 # delete the flux/color columns since they change with each band 

986 for prefix in ('flux', 'mag'): 

987 del columns_target[f'{prefix}_{column_target_flux}'] 

988 if not skip_color: 

989 del columns_target[f'color_{band_prev}_m_{band}_{column_target_flux}'] 

990 

991 # keep values needed for colors 

992 flux_err_frac_prev = flux_err_frac 

993 mag_prev = mag_model 

994 band_prev = band 

995 if idx_band == 0: 

996 flux_err_frac_first = flux_err_frac 

997 mag_first = mag_model 

998 mag_ref_first = mag_ref 

999 

1000 if config.include_unmatched: 

1001 cat_matched = pd.concat((cat_matched, cat_unmatched)) 

1002 

1003 retStruct = pipeBase.Struct(cat_matched=cat_matched, diff_matched=pd.DataFrame(data)) 

1004 return retStruct