Coverage for python/lsst/pipe/tasks/diff_matched_tract_catalog.py: 26%

366 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-02-24 09:52 +0000

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = [ 

23 'DiffMatchedTractCatalogConfig', 'DiffMatchedTractCatalogTask', 'MatchedCatalogFluxesConfig', 

24 'MatchType', 'MeasurementType', 'SourceType', 

25 'Statistic', 'Median', 'SigmaIQR', 'SigmaMAD', 'Percentile', 

26] 

27 

28import lsst.afw.geom as afwGeom 

29from lsst.meas.astrom.matcher_probabilistic import ( 

30 ComparableCatalog, ConvertCatalogCoordinatesConfig, 

31) 

32from lsst.meas.astrom.match_probabilistic_task import radec_to_xy 

33import lsst.pex.config as pexConfig 

34import lsst.pipe.base as pipeBase 

35import lsst.pipe.base.connectionTypes as cT 

36from lsst.skymap import BaseSkyMap 

37 

38from abc import ABCMeta, abstractmethod 

39from astropy.stats import mad_std 

40from dataclasses import dataclass 

41from decimal import Decimal 

42from enum import Enum 

43import numpy as np 

44import pandas as pd 

45from scipy.stats import iqr 

46from typing import Dict, Sequence, Set 

47 

48 

49def is_sequence_set(x: Sequence): 

50 return len(x) == len(set(x)) 

51 

52 

53def is_percentile(x: str): 

54 return 0 <= Decimal(x) <= 100 

55 

56 

57DiffMatchedTractCatalogBaseTemplates = { 

58 "name_input_cat_ref": "truth_summary", 

59 "name_input_cat_target": "objectTable_tract", 

60 "name_skymap": BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

61} 

62 

63 

64class DiffMatchedTractCatalogConnections( 

65 pipeBase.PipelineTaskConnections, 

66 dimensions=("tract", "skymap"), 

67 defaultTemplates=DiffMatchedTractCatalogBaseTemplates, 

68): 

69 cat_ref = cT.Input( 

70 doc="Reference object catalog to match from", 

71 name="{name_input_cat_ref}", 

72 storageClass="DataFrame", 

73 dimensions=("tract", "skymap"), 

74 deferLoad=True, 

75 ) 

76 cat_target = cT.Input( 

77 doc="Target object catalog to match", 

78 name="{name_input_cat_target}", 

79 storageClass="DataFrame", 

80 dimensions=("tract", "skymap"), 

81 deferLoad=True, 

82 ) 

83 skymap = cT.Input( 

84 doc="Input definition of geometry/bbox and projection/wcs for coadded exposures", 

85 name="{name_skymap}", 

86 storageClass="SkyMap", 

87 dimensions=("skymap",), 

88 ) 

89 cat_match_ref = cT.Input( 

90 doc="Reference match catalog with indices of target matches", 

91 name="match_ref_{name_input_cat_ref}_{name_input_cat_target}", 

92 storageClass="DataFrame", 

93 dimensions=("tract", "skymap"), 

94 deferLoad=True, 

95 ) 

96 cat_match_target = cT.Input( 

97 doc="Target match catalog with indices of references matches", 

98 name="match_target_{name_input_cat_ref}_{name_input_cat_target}", 

99 storageClass="DataFrame", 

100 dimensions=("tract", "skymap"), 

101 deferLoad=True, 

102 ) 

103 columns_match_target = cT.Input( 

104 doc="Target match catalog columns", 

105 name="match_target_{name_input_cat_ref}_{name_input_cat_target}.columns", 

106 storageClass="DataFrameIndex", 

107 dimensions=("tract", "skymap"), 

108 ) 

109 cat_matched = cT.Output( 

110 doc="Catalog with reference and target columns for matched sources only", 

111 name="matched_{name_input_cat_ref}_{name_input_cat_target}", 

112 storageClass="DataFrame", 

113 dimensions=("tract", "skymap"), 

114 ) 

115 diff_matched = cT.Output( 

116 doc="Table with aggregated counts, difference and chi statistics", 

117 name="diff_matched_{name_input_cat_ref}_{name_input_cat_target}", 

118 storageClass="DataFrame", 

119 dimensions=("tract", "skymap"), 

120 ) 

121 

122 

123class MatchedCatalogFluxesConfig(pexConfig.Config): 

124 column_ref_flux = pexConfig.Field( 

125 dtype=str, 

126 doc='Reference catalog flux column name', 

127 ) 

128 columns_target_flux = pexConfig.ListField( 

129 dtype=str, 

130 listCheck=is_sequence_set, 

131 doc="List of target catalog flux column names", 

132 ) 

133 columns_target_flux_err = pexConfig.ListField( 

134 dtype=str, 

135 listCheck=is_sequence_set, 

136 doc="List of target catalog flux error column names", 

137 ) 

138 

139 @property 

140 def columns_in_ref(self) -> Set[str]: 

141 return {self.column_ref_flux} 

142 

143 @property 

144 def columns_in_target(self) -> Set[str]: 

145 return set(self.columns_target_flux).union(set(self.columns_target_flux_err)) 

146 

147 

148class DiffMatchedTractCatalogConfig( 

149 pipeBase.PipelineTaskConfig, 

150 pipelineConnections=DiffMatchedTractCatalogConnections, 

151): 

152 column_matched_prefix_ref = pexConfig.Field( 

153 dtype=str, 

154 default='refcat_', 

155 doc='The prefix for matched columns copied from the reference catalog', 

156 ) 

157 column_ref_extended = pexConfig.Field( 

158 dtype=str, 

159 default='is_pointsource', 

160 doc='The boolean reference table column specifying if the target is extended', 

161 ) 

162 column_ref_extended_inverted = pexConfig.Field( 

163 dtype=bool, 

164 default=True, 

165 doc='Whether column_ref_extended specifies if the object is compact, not extended', 

166 ) 

167 column_target_extended = pexConfig.Field( 

168 dtype=str, 

169 default='refExtendedness', 

170 doc='The target table column estimating the extendedness of the object (0 <= x <= 1)', 

171 ) 

172 

173 @property 

174 def columns_in_ref(self) -> Set[str]: 

175 columns_all = [self.coord_format.column_ref_coord1, self.coord_format.column_ref_coord2, 

176 self.column_ref_extended] 

177 for column_lists in ( 

178 ( 

179 self.columns_ref_copy, 

180 ), 

181 (x.columns_in_ref for x in self.columns_flux.values()), 

182 ): 

183 for column_list in column_lists: 

184 columns_all.extend(column_list) 

185 

186 return set(columns_all) 

187 

188 @property 

189 def columns_in_target(self) -> Set[str]: 

190 columns_all = [self.coord_format.column_target_coord1, self.coord_format.column_target_coord2, 

191 self.column_target_extended] 

192 if self.coord_format.coords_ref_to_convert is not None: 

193 columns_all.extend(self.coord_format.coords_ref_to_convert.values()) 

194 for column_lists in ( 

195 ( 

196 self.columns_target_coord_err, 

197 self.columns_target_select_false, 

198 self.columns_target_select_true, 

199 self.columns_target_copy, 

200 ), 

201 (x.columns_in_target for x in self.columns_flux.values()), 

202 ): 

203 for column_list in column_lists: 

204 columns_all.extend(column_list) 

205 return set(columns_all) 

206 

207 columns_flux = pexConfig.ConfigDictField( 

208 keytype=str, 

209 itemtype=MatchedCatalogFluxesConfig, 

210 doc="Configs for flux columns for each band", 

211 ) 

212 columns_ref_copy = pexConfig.ListField( 

213 dtype=str, 

214 default=set(), 

215 doc='Reference table columns to copy to copy into cat_matched', 

216 ) 

217 columns_target_coord_err = pexConfig.ListField( 217 ↛ exitline 217 didn't jump to the function exit

218 dtype=str, 

219 listCheck=lambda x: (len(x) == 2) and (x[0] != x[1]), 

220 doc='Target table coordinate columns with standard errors (sigma)', 

221 ) 

222 columns_target_copy = pexConfig.ListField( 

223 dtype=str, 

224 default=('patch',), 

225 doc='Target table columns to copy to copy into cat_matched', 

226 ) 

227 columns_target_select_true = pexConfig.ListField( 

228 dtype=str, 

229 default=('detect_isPrimary',), 

230 doc='Target table columns to require to be True for selecting sources', 

231 ) 

232 columns_target_select_false = pexConfig.ListField( 

233 dtype=str, 

234 default=('merge_peak_sky',), 

235 doc='Target table columns to require to be False for selecting sources', 

236 ) 

237 coord_format = pexConfig.ConfigField( 

238 dtype=ConvertCatalogCoordinatesConfig, 

239 doc="Configuration for coordinate conversion", 

240 ) 

241 extendedness_cut = pexConfig.Field( 

242 dtype=float, 

243 default=0.5, 

244 doc='Minimum extendedness for a measured source to be considered extended', 

245 ) 

246 mag_num_bins = pexConfig.Field( 

247 doc='Number of magnitude bins', 

248 default=15, 

249 dtype=int, 

250 ) 

251 mag_brightest_ref = pexConfig.Field( 

252 dtype=float, 

253 default=15, 

254 doc='Brightest magnitude cutoff for binning', 

255 ) 

256 mag_ceiling_target = pexConfig.Field( 

257 dtype=float, 

258 default=None, 

259 optional=True, 

260 doc='Ceiling (maximum/faint) magnitude for target sources', 

261 ) 

262 mag_faintest_ref = pexConfig.Field( 

263 dtype=float, 

264 default=30, 

265 doc='Faintest magnitude cutoff for binning', 

266 ) 

267 mag_zeropoint_ref = pexConfig.Field( 

268 dtype=float, 

269 default=31.4, 

270 doc='Magnitude zeropoint for reference sources', 

271 ) 

272 mag_zeropoint_target = pexConfig.Field( 

273 dtype=float, 

274 default=31.4, 

275 doc='Magnitude zeropoint for target sources', 

276 ) 

277 percentiles = pexConfig.ListField( 

278 dtype=str, 

279 # -2, -1, +1, +2 sigma percentiles for normal distribution 

280 default=('2.275', '15.866', '84.134', '97.725'), 

281 doc='Percentiles to compute for diff/chi values', 

282 itemCheck=is_percentile, 

283 listCheck=is_sequence_set, 

284 ) 

285 

286 

287@dataclass(frozen=True) 

288class MeasurementTypeInfo: 

289 doc: str 

290 name: str 

291 

292 

293class MeasurementType(Enum): 

294 DIFF = MeasurementTypeInfo( 

295 doc="difference (measured - reference)", 

296 name="diff", 

297 ) 

298 CHI = MeasurementTypeInfo( 

299 doc="scaled difference (measured - reference)/error", 

300 name="chi", 

301 ) 

302 

303 

304class Statistic(metaclass=ABCMeta): 

305 """A statistic that can be applied to a set of values. 

306 """ 

307 @abstractmethod 

308 def doc(self) -> str: 

309 """A description of the statistic""" 

310 raise NotImplementedError('Subclasses must implement this method') 

311 

312 @abstractmethod 

313 def name_short(self) -> str: 

314 """A short name for the statistic, e.g. for a table column name""" 

315 raise NotImplementedError('Subclasses must implement this method') 

316 

317 @abstractmethod 

318 def value(self, values): 

319 """The value of the statistic for a set of input values. 

320 

321 Parameters 

322 ---------- 

323 values : `Collection` [`float`] 

324 A set of values to compute the statistic for. 

325 

326 Returns 

327 ------- 

328 statistic : `float` 

329 The value of the statistic. 

330 """ 

331 raise NotImplementedError('Subclasses must implement this method') 

332 

333 

334class Median(Statistic): 

335 """The median of a set of values.""" 

336 @classmethod 

337 def doc(cls) -> str: 

338 return "Median" 

339 

340 @classmethod 

341 def name_short(cls) -> str: 

342 return "median" 

343 

344 def value(self, values): 

345 return np.median(values) 

346 

347 

348class SigmaIQR(Statistic): 

349 """The re-scaled interquartile range (sigma equivalent).""" 

350 @classmethod 

351 def doc(cls) -> str: 

352 return "Interquartile range divided by ~1.349 (sigma-equivalent)" 

353 

354 @classmethod 

355 def name_short(cls) -> str: 

356 return "sig_iqr" 

357 

358 def value(self, values): 

359 return iqr(values, scale='normal') 

360 

361 

362class SigmaMAD(Statistic): 

363 """The re-scaled median absolute deviation (sigma equivalent).""" 

364 @classmethod 

365 def doc(cls) -> str: 

366 return "Median absolute deviation multiplied by ~1.4826 (sigma-equivalent)" 

367 

368 @classmethod 

369 def name_short(cls) -> str: 

370 return "sig_mad" 

371 

372 def value(self, values): 

373 return mad_std(values) 

374 

375 

376@dataclass(frozen=True) 

377class Percentile(Statistic): 

378 """An arbitrary percentile. 

379 

380 Parameters 

381 ---------- 

382 percentile : `float` 

383 A valid percentile (0 <= p <= 100). 

384 """ 

385 percentile: float 

386 

387 def doc(self) -> str: 

388 return "Median absolute deviation multiplied by ~1.4826 (sigma-equivalent)" 

389 

390 def name_short(self) -> str: 

391 return f"pctl{f'{self.percentile/100:.5f}'[2:]}" 

392 

393 def value(self, values): 

394 return np.percentile(values, self.percentile) 

395 

396 def __post_init__(self): 

397 if not ((self.percentile >= 0) and (self.percentile <= 100)): 

398 raise ValueError(f'percentile={self.percentile} not >=0 and <= 100') 

399 

400 

401def _get_stat_name(*args): 

402 return '_'.join(args) 

403 

404 

405def _get_column_name(band, *args): 

406 return f"{band}_{_get_stat_name(*args)}" 

407 

408 

409def compute_stats(values_ref, values_target, errors_target, row, stats, suffixes, prefix, skip_diff=False): 

410 """Compute statistics on differences and store results in a row. 

411 

412 Parameters 

413 ---------- 

414 values_ref : `numpy.ndarray`, (N,) 

415 Reference values. 

416 values_target : `numpy.ndarray`, (N,) 

417 Measured values. 

418 errors_target : `numpy.ndarray`, (N,) 

419 Errors (standard deviations) on `values_target`. 

420 row : `numpy.ndarray`, (1, C) 

421 A numpy array with pre-assigned column names. 

422 stats : `Dict` [`str`, `Statistic`] 

423 A dict of `Statistic` values to measure, keyed by their column suffix. 

424 suffixes : `Dict` [`MeasurementType`, `str`] 

425 A dict of measurement type column suffixes, keyed by the measurement type. 

426 prefix : `str` 

427 A prefix for all column names (e.g. band). 

428 skip_diff : `bool` 

429 Whether to skip computing statistics on differences. Note that 

430 differences will still be computed for chi statistics. 

431 

432 Returns 

433 ------- 

434 row_with_stats : `numpy.ndarray`, (1, C) 

435 The original `row` with statistic values assigned. 

436 """ 

437 n_ref = len(values_ref) 

438 if n_ref > 0: 

439 n_target = len(values_target) 

440 n_target_err = len(errors_target) if errors_target is not None else n_ref 

441 if (n_target != n_ref) or (n_target_err != n_ref): 

442 raise ValueError(f'lengths of values_ref={n_ref}, values_target={n_target}' 

443 f', error_target={n_target_err} must match') 

444 

445 do_chi = errors_target is not None 

446 diff = values_target - values_ref 

447 chi = diff/errors_target if do_chi else diff 

448 # Could make this configurable, but non-finite values/errors are not really usable 

449 valid = np.isfinite(chi) 

450 values_type = {} if skip_diff else {MeasurementType.DIFF: diff[valid]} 

451 if do_chi: 

452 values_type[MeasurementType.CHI] = chi[valid] 

453 

454 for suffix_type, suffix in suffixes.items(): 

455 values = values_type.get(suffix_type) 

456 if values is not None and len(values) > 0: 

457 for stat_name, stat in stats.items(): 

458 row[_get_stat_name(prefix, suffix, stat_name)] = stat.value(values) 

459 return row 

460 

461 

462@dataclass(frozen=True) 

463class SourceTypeInfo: 

464 is_extended: bool | None 

465 label: str 

466 

467 

468class SourceType(Enum): 

469 ALL = SourceTypeInfo(is_extended=None, label='all') 

470 RESOLVED = SourceTypeInfo(is_extended=True, label='resolved') 

471 UNRESOLVED = SourceTypeInfo(is_extended=False, label='unresolved') 

472 

473 

474class MatchType(Enum): 

475 ALL = 'all' 

476 MATCH_RIGHT = 'match_right' 

477 MATCH_WRONG = 'match_wrong' 

478 

479 

480def _get_columns(bands_columns: Dict, suffixes: Dict, suffixes_flux: Dict, suffixes_mag: Dict, 

481 stats: Dict, target: ComparableCatalog, column_dist: str): 

482 """Get column names for a table of difference statistics. 

483 

484 Parameters 

485 ---------- 

486 bands_columns : `Dict` [`str`,`MatchedCatalogFluxesConfig`] 

487 Dict keyed by band of flux column configuration. 

488 suffixes, suffixes_flux, suffixes_mag : `Dict` [`MeasurementType`, `str`] 

489 Dict of suffixes for each `MeasurementType` type, for general columns (e.g. 

490 coordinates), fluxes and magnitudes, respectively. 

491 stats : `Dict` [`str`, `Statistic`] 

492 Dict of suffixes for each `Statistic` type. 

493 target : `ComparableCatalog` 

494 A target catalog with coordinate column names. 

495 column_dist : `str` 

496 The name of the distance column. 

497 

498 Returns 

499 ------- 

500 columns : `Dict` [`str`, `type`] 

501 Dictionary of column types keyed by name. 

502 n_models : `int` 

503 The number of models measurements will be made for. 

504 

505 Notes 

506 ----- 

507 Presently, models must be identical for each band. 

508 """ 

509 # Initial columns 

510 columns = { 

511 "bin": int, 

512 "mag_min": float, 

513 "mag_max": float, 

514 } 

515 

516 # pre-assign all of the columns with appropriate types 

517 n_models = 0 

518 

519 bands = list(bands_columns.keys()) 

520 n_bands = len(bands) 

521 

522 for idx, (band, config_flux) in enumerate(bands_columns.items()): 

523 columns_suffix = [ 

524 ('flux', suffixes_flux), 

525 ('mag', suffixes_mag), 

526 ] 

527 if idx == 0: 

528 n_models = len(config_flux.columns_target_flux) 

529 if (idx > 0) or (n_bands > 2): 

530 columns_suffix.append((f'color_{bands[idx - 1]}_m_{band}', suffixes)) 

531 n_models_flux = len(config_flux.columns_target_flux) 

532 n_models_err = len(config_flux.columns_target_flux_err) 

533 

534 # TODO: Do equivalent validation earlier, in the config 

535 if (n_models_flux != n_models) or (n_models_err != n_models): 

536 raise RuntimeError(f'{config_flux} len(columns_target_flux)={n_models_flux} and' 

537 f' len(columns_target_flux_err)={n_models_err} must equal {n_models}') 

538 

539 for sourcetype in SourceType: 

540 label = sourcetype.value.label 

541 # Totals would be redundant 

542 if sourcetype != SourceType.ALL: 

543 for item in (f'n_{itype}_{mtype.value}' for itype in ('ref', 'target') 

544 for mtype in MatchType): 

545 columns[_get_column_name(band, label, item)] = int 

546 

547 for item in (target.column_coord1, target.column_coord2, column_dist): 

548 for suffix in suffixes.values(): 

549 for stat in stats.keys(): 

550 columns[_get_column_name(band, label, item, suffix, stat)] = float 

551 

552 for item in config_flux.columns_target_flux: 

553 for prefix_item, suffixes_col in columns_suffix: 

554 for suffix in suffixes_col.values(): 

555 for stat in stats.keys(): 

556 columns[_get_column_name(band, label, prefix_item, item, suffix, stat)] = float 

557 

558 return columns, n_models 

559 

560 

561class DiffMatchedTractCatalogTask(pipeBase.PipelineTask): 

562 """Load subsets of matched catalogs and output a merged catalog of matched sources. 

563 """ 

564 ConfigClass = DiffMatchedTractCatalogConfig 

565 _DefaultName = "DiffMatchedTractCatalog" 

566 

567 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

568 inputs = butlerQC.get(inputRefs) 

569 skymap = inputs.pop("skymap") 

570 

571 columns_match_target = ['match_row'] 

572 if 'match_candidate' in inputs['columns_match_target']: 

573 columns_match_target.append('match_candidate') 

574 

575 outputs = self.run( 

576 catalog_ref=inputs['cat_ref'].get(parameters={'columns': self.config.columns_in_ref}), 

577 catalog_target=inputs['cat_target'].get(parameters={'columns': self.config.columns_in_target}), 

578 catalog_match_ref=inputs['cat_match_ref'].get( 

579 parameters={'columns': ['match_candidate', 'match_row']}, 

580 ), 

581 catalog_match_target=inputs['cat_match_target'].get( 

582 parameters={'columns': columns_match_target}, 

583 ), 

584 wcs=skymap[butlerQC.quantum.dataId["tract"]].wcs, 

585 ) 

586 butlerQC.put(outputs, outputRefs) 

587 

588 def run( 

589 self, 

590 catalog_ref: pd.DataFrame, 

591 catalog_target: pd.DataFrame, 

592 catalog_match_ref: pd.DataFrame, 

593 catalog_match_target: pd.DataFrame, 

594 wcs: afwGeom.SkyWcs = None, 

595 ) -> pipeBase.Struct: 

596 """Load matched reference and target (measured) catalogs, measure summary statistics, and output 

597 a combined matched catalog with columns from both inputs. 

598 

599 Parameters 

600 ---------- 

601 catalog_ref : `pandas.DataFrame` 

602 A reference catalog to diff objects/sources from. 

603 catalog_target : `pandas.DataFrame` 

604 A target catalog to diff reference objects/sources to. 

605 catalog_match_ref : `pandas.DataFrame` 

606 A catalog with match indices of target sources and selection flags 

607 for each reference source. 

608 catalog_match_target : `pandas.DataFrame` 

609 A catalog with selection flags for each target source. 

610 wcs : `lsst.afw.image.SkyWcs` 

611 A coordinate system to convert catalog positions to sky coordinates, 

612 if necessary. 

613 

614 Returns 

615 ------- 

616 retStruct : `lsst.pipe.base.Struct` 

617 A struct with output_ref and output_target attribute containing the 

618 output matched catalogs. 

619 """ 

620 config = self.config 

621 

622 select_ref = catalog_match_ref['match_candidate'].values 

623 # Add additional selection criteria for target sources beyond those for matching 

624 # (not recommended, but can be done anyway) 

625 select_target = (catalog_match_target['match_candidate'].values 

626 if 'match_candidate' in catalog_match_target.columns 

627 else np.ones(len(catalog_match_target), dtype=bool)) 

628 for column in config.columns_target_select_true: 

629 select_target &= catalog_target[column].values 

630 for column in config.columns_target_select_false: 

631 select_target &= ~catalog_target[column].values 

632 

633 ref, target = config.coord_format.format_catalogs( 

634 catalog_ref=catalog_ref, catalog_target=catalog_target, 

635 select_ref=None, select_target=select_target, wcs=wcs, radec_to_xy_func=radec_to_xy, 

636 return_converted_columns=config.coord_format.coords_ref_to_convert is not None, 

637 ) 

638 cat_ref = ref.catalog 

639 cat_target = target.catalog 

640 n_target = len(cat_target) 

641 

642 match_row = catalog_match_ref['match_row'].values 

643 matched_ref = match_row >= 0 

644 matched_row = match_row[matched_ref] 

645 matched_target = np.zeros(n_target, dtype=bool) 

646 matched_target[matched_row] = True 

647 

648 # Create a matched table, preserving the target catalog's named index (if it has one) 

649 cat_left = cat_target.iloc[matched_row] 

650 has_index_left = cat_left.index.name is not None 

651 cat_right = cat_ref[matched_ref].reset_index() 

652 cat_matched = pd.concat(objs=(cat_left.reset_index(drop=True), cat_right), axis=1) 

653 if has_index_left: 

654 cat_matched.index = cat_left.index 

655 cat_matched.columns.values[len(cat_target.columns):] = [f'refcat_{col}' for col in cat_right.columns] 

656 

657 # Add/compute distance columns 

658 coord1_target_err, coord2_target_err = config.columns_target_coord_err 

659 column_dist, column_dist_err = 'distance', 'distanceErr' 

660 dist = np.full(n_target, np.Inf) 

661 

662 dist[matched_row] = np.hypot( 

663 target.coord1[matched_row] - ref.coord1[matched_ref], 

664 target.coord2[matched_row] - ref.coord2[matched_ref], 

665 ) 

666 dist_err = np.full(n_target, np.Inf) 

667 dist_err[matched_row] = np.hypot(cat_target.iloc[matched_row][coord1_target_err].values, 

668 cat_target.iloc[matched_row][coord2_target_err].values) 

669 cat_target[column_dist], cat_target[column_dist_err] = dist, dist_err 

670 

671 # Slightly smelly hack for when a column (like distance) is already relative to truth 

672 column_dummy = 'dummy' 

673 cat_ref[column_dummy] = np.zeros_like(ref.coord1) 

674 

675 # Add a boolean column for whether a match is classified correctly 

676 extended_ref = cat_ref[config.column_ref_extended] 

677 if config.column_ref_extended_inverted: 

678 extended_ref = 1 - extended_ref 

679 

680 extended_target = cat_target[config.column_target_extended].values >= config.extendedness_cut 

681 

682 # Define difference/chi columns and statistics thereof 

683 suffixes = {MeasurementType.DIFF: 'diff', MeasurementType.CHI: 'chi'} 

684 # Skip diff for fluxes - covered by mags 

685 suffixes_flux = {MeasurementType.CHI: suffixes[MeasurementType.CHI]} 

686 # Skip chi for magnitudes, which have strange errors 

687 suffixes_mag = {MeasurementType.DIFF: suffixes[MeasurementType.DIFF]} 

688 stats = {stat.name_short(): stat() for stat in (Median, SigmaIQR, SigmaMAD)} 

689 

690 for percentile in self.config.percentiles: 

691 stat = Percentile(percentile=float(Decimal(percentile))) 

692 stats[stat.name_short()] = stat 

693 

694 # Get dict of column names 

695 columns, n_models = _get_columns( 

696 bands_columns=config.columns_flux, 

697 suffixes=suffixes, 

698 suffixes_flux=suffixes_flux, 

699 suffixes_mag=suffixes_mag, 

700 stats=stats, 

701 target=target, 

702 column_dist=column_dist, 

703 ) 

704 

705 # Setup numpy table 

706 n_bins = config.mag_num_bins 

707 data = np.zeros((n_bins,), dtype=[(key, value) for key, value in columns.items()]) 

708 data['bin'] = np.arange(n_bins) 

709 

710 # Setup bins 

711 bins_mag = np.linspace(start=config.mag_brightest_ref, stop=config.mag_faintest_ref, 

712 num=n_bins + 1) 

713 data['mag_min'] = bins_mag[:-1] 

714 data['mag_max'] = bins_mag[1:] 

715 bins_mag = tuple((bins_mag[idx], bins_mag[idx + 1]) for idx in range(n_bins)) 

716 

717 # Define temporary columns for intermediate storage 

718 column_mag_temp = 'mag_temp' 

719 column_color_temp = 'color_temp' 

720 column_color_err_temp = 'colorErr_temp' 

721 flux_err_frac_prev = [None]*n_models 

722 mag_prev = [None]*n_models 

723 

724 columns_target = { 

725 target.column_coord1: ( 

726 ref.column_coord1, target.column_coord1, coord1_target_err, False, 

727 ), 

728 target.column_coord2: ( 

729 ref.column_coord2, target.column_coord2, coord2_target_err, False, 

730 ), 

731 column_dist: (column_dummy, column_dist, column_dist_err, False), 

732 } 

733 

734 # Cheat a little and do the first band last so that the color is 

735 # based on the last band 

736 band_fluxes = [(band, config_flux) for (band, config_flux) in config.columns_flux.items()] 

737 n_bands = len(band_fluxes) 

738 band_fluxes.append(band_fluxes[0]) 

739 flux_err_frac_first = None 

740 mag_first = None 

741 mag_ref_first = None 

742 

743 band_prev = None 

744 for idx_band, (band, config_flux) in enumerate(band_fluxes): 

745 if idx_band == n_bands: 

746 # These were already computed earlier 

747 mag_ref = mag_ref_first 

748 flux_err_frac = flux_err_frac_first 

749 mag_model = mag_first 

750 else: 

751 mag_ref = -2.5*np.log10(cat_ref[config_flux.column_ref_flux]) + config.mag_zeropoint_ref 

752 flux_err_frac = [None]*n_models 

753 mag_model = [None]*n_models 

754 

755 if idx_band > 0: 

756 cat_ref[column_color_temp] = cat_ref[column_mag_temp] - mag_ref 

757 

758 cat_ref[column_mag_temp] = mag_ref 

759 

760 select_ref_bins = [select_ref & (mag_ref > mag_lo) & (mag_ref < mag_hi) 

761 for idx_bin, (mag_lo, mag_hi) in enumerate(bins_mag)] 

762 

763 # Iterate over multiple models, compute their mags and colours (if there's a previous band) 

764 for idx_model in range(n_models): 

765 column_target_flux = config_flux.columns_target_flux[idx_model] 

766 column_target_flux_err = config_flux.columns_target_flux_err[idx_model] 

767 

768 flux_target = cat_target[column_target_flux] 

769 mag_target = -2.5*np.log10(flux_target) + config.mag_zeropoint_target 

770 if config.mag_ceiling_target is not None: 

771 mag_target[mag_target > config.mag_ceiling_target] = config.mag_ceiling_target 

772 mag_model[idx_model] = mag_target 

773 

774 # These are needed for computing magnitude/color "errors" (which are a sketchy concept) 

775 flux_err_frac[idx_model] = cat_target[column_target_flux_err]/flux_target 

776 

777 # Stop if idx == 0: The rest will be picked up at idx == n_bins 

778 if idx_band > 0: 

779 # Keep these mags tabulated for convenience 

780 column_mag_temp_model = f'{column_mag_temp}{idx_model}' 

781 cat_target[column_mag_temp_model] = mag_target 

782 

783 columns_target[f'flux_{column_target_flux}'] = ( 

784 config_flux.column_ref_flux, 

785 column_target_flux, 

786 column_target_flux_err, 

787 True, 

788 ) 

789 # Note: magnitude errors are generally problematic and not worth aggregating 

790 columns_target[f'mag_{column_target_flux}'] = ( 

791 column_mag_temp, column_mag_temp_model, None, False, 

792 ) 

793 

794 # No need for colors if this is the last band and there are only two bands 

795 # (because it would just be the negative of the first color) 

796 skip_color = (idx_band == n_bands) and (n_bands <= 2) 

797 if not skip_color: 

798 column_color_temp_model = f'{column_color_temp}{idx_model}' 

799 column_color_err_temp_model = f'{column_color_err_temp}{idx_model}' 

800 

801 # e.g. if order is ugrizy, first color will be u - g 

802 cat_target[column_color_temp_model] = mag_prev[idx_model] - mag_model[idx_model] 

803 

804 # Sum (in quadrature, and admittedly sketchy for faint fluxes) magnitude errors 

805 cat_target[column_color_err_temp_model] = 2.5/np.log(10)*np.hypot( 

806 flux_err_frac[idx_model], flux_err_frac_prev[idx_model]) 

807 columns_target[f'color_{band_prev}_m_{band}_{column_target_flux}'] = ( 

808 column_color_temp, 

809 column_color_temp_model, 

810 column_color_err_temp_model, 

811 False, 

812 ) 

813 

814 for idx_bin, (mag_lo, mag_hi) in enumerate(bins_mag): 

815 row = data[idx_bin] 

816 # Reference sources only need to be counted once 

817 if idx_model == 0: 

818 select_ref_bin = select_ref_bins[idx_bin] 

819 select_target_bin = select_target & (mag_target > mag_lo) & (mag_target < mag_hi) 

820 

821 for sourcetype in SourceType: 

822 sourcetype_info = sourcetype.value 

823 is_extended = sourcetype_info.is_extended 

824 # Counts filtered by match selection and magnitude bin 

825 select_ref_sub = select_ref_bin.copy() 

826 select_target_sub = select_target_bin.copy() 

827 if is_extended is not None: 

828 is_extended_ref = (extended_ref == is_extended) 

829 select_ref_sub &= is_extended_ref 

830 if idx_model == 0: 

831 n_ref_sub = np.count_nonzero(select_ref_sub) 

832 row[_get_column_name(band, sourcetype_info.label, 'n_ref', 

833 MatchType.ALL.value)] = n_ref_sub 

834 select_target_sub &= (extended_target == is_extended) 

835 n_target_sub = np.count_nonzero(select_target_sub) 

836 row[_get_column_name(band, sourcetype_info.label, 'n_target', 

837 MatchType.ALL.value)] = n_target_sub 

838 

839 # Filter matches by magnitude bin and true class 

840 match_row_bin = match_row.copy() 

841 match_row_bin[~select_ref_sub] = -1 

842 match_good = match_row_bin >= 0 

843 

844 n_match = np.count_nonzero(match_good) 

845 

846 # Same for counts of matched target sources (for e.g. purity) 

847 

848 if n_match > 0: 

849 rows_matched = match_row_bin[match_good] 

850 subset_target = cat_target.iloc[rows_matched] 

851 if (is_extended is not None) and (idx_model == 0): 

852 right_type = extended_target[rows_matched] == is_extended 

853 n_total = len(right_type) 

854 n_right = np.count_nonzero(right_type) 

855 row[_get_column_name(band, sourcetype_info.label, 'n_ref', 

856 MatchType.MATCH_RIGHT.value)] = n_right 

857 row[_get_column_name( 

858 band, sourcetype_info.label, 'n_ref', MatchType.MATCH_WRONG.value, 

859 )] = n_total - n_right 

860 

861 # compute stats for this bin, for all columns 

862 for column, (column_ref, column_target, column_err_target, skip_diff) \ 

863 in columns_target.items(): 

864 values_ref = cat_ref[column_ref][match_good].values 

865 errors_target = ( 

866 subset_target[column_err_target].values 

867 if column_err_target is not None 

868 else None 

869 ) 

870 compute_stats( 

871 values_ref, 

872 subset_target[column_target].values, 

873 errors_target, 

874 row, 

875 stats, 

876 suffixes, 

877 prefix=f'{band}_{sourcetype_info.label}_{column}', 

878 skip_diff=skip_diff, 

879 ) 

880 

881 # Count matched target sources with *measured* mags within bin 

882 # Used for e.g. purity calculation 

883 # Should be merged with above code if there's ever a need for 

884 # measuring stats on this source selection 

885 select_target_sub &= matched_target 

886 

887 if is_extended is not None and (np.count_nonzero(select_target_sub) > 0): 

888 n_total = np.count_nonzero(select_target_sub) 

889 right_type = np.zeros(n_target, dtype=bool) 

890 right_type[match_row[matched_ref & is_extended_ref]] = True 

891 right_type &= select_target_sub 

892 n_right = np.count_nonzero(right_type) 

893 row[_get_column_name(band, sourcetype_info.label, 'n_target', 

894 MatchType.MATCH_RIGHT.value)] = n_right 

895 row[_get_column_name(band, sourcetype_info.label, 'n_target', 

896 MatchType.MATCH_WRONG.value)] = n_total - n_right 

897 

898 # delete the flux/color columns since they change with each band 

899 for prefix in ('flux', 'mag'): 

900 del columns_target[f'{prefix}_{column_target_flux}'] 

901 if not skip_color: 

902 del columns_target[f'color_{band_prev}_m_{band}_{column_target_flux}'] 

903 

904 # keep values needed for colors 

905 flux_err_frac_prev = flux_err_frac 

906 mag_prev = mag_model 

907 band_prev = band 

908 if idx_band == 0: 

909 flux_err_frac_first = flux_err_frac 

910 mag_first = mag_model 

911 mag_ref_first = mag_ref 

912 

913 retStruct = pipeBase.Struct(cat_matched=cat_matched, diff_matched=pd.DataFrame(data)) 

914 return retStruct