Coverage for python/lsst/pipe/tasks/diff_matched_tract_catalog.py: 29%

368 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-06 12:26 +0000

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = [ 

23 'DiffMatchedTractCatalogConfig', 'DiffMatchedTractCatalogTask', 'MatchedCatalogFluxesConfig', 

24 'MatchType', 'MeasurementType', 'SourceType', 

25 'Statistic', 'Median', 'SigmaIQR', 'SigmaMAD', 'Percentile', 

26] 

27 

28import lsst.afw.geom as afwGeom 

29from lsst.meas.astrom.matcher_probabilistic import ( 

30 ComparableCatalog, ConvertCatalogCoordinatesConfig, 

31) 

32from lsst.meas.astrom.match_probabilistic_task import radec_to_xy 

33import lsst.pex.config as pexConfig 

34import lsst.pipe.base as pipeBase 

35import lsst.pipe.base.connectionTypes as cT 

36from lsst.skymap import BaseSkyMap 

37 

38from abc import ABCMeta, abstractmethod 

39from astropy.stats import mad_std 

40from dataclasses import dataclass 

41from decimal import Decimal 

42from enum import Enum 

43import numpy as np 

44import pandas as pd 

45from scipy.stats import iqr 

46from typing import Dict, Sequence 

47 

48 

49def is_sequence_set(x: Sequence): 

50 return len(x) == len(set(x)) 

51 

52 

53def is_percentile(x: str): 

54 return 0 <= Decimal(x) <= 100 

55 

56 

57DiffMatchedTractCatalogBaseTemplates = { 

58 "name_input_cat_ref": "truth_summary", 

59 "name_input_cat_target": "objectTable_tract", 

60 "name_skymap": BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

61} 

62 

63 

64class DiffMatchedTractCatalogConnections( 

65 pipeBase.PipelineTaskConnections, 

66 dimensions=("tract", "skymap"), 

67 defaultTemplates=DiffMatchedTractCatalogBaseTemplates, 

68): 

69 cat_ref = cT.Input( 

70 doc="Reference object catalog to match from", 

71 name="{name_input_cat_ref}", 

72 storageClass="DataFrame", 

73 dimensions=("tract", "skymap"), 

74 deferLoad=True, 

75 ) 

76 cat_target = cT.Input( 

77 doc="Target object catalog to match", 

78 name="{name_input_cat_target}", 

79 storageClass="DataFrame", 

80 dimensions=("tract", "skymap"), 

81 deferLoad=True, 

82 ) 

83 skymap = cT.Input( 

84 doc="Input definition of geometry/bbox and projection/wcs for coadded exposures", 

85 name="{name_skymap}", 

86 storageClass="SkyMap", 

87 dimensions=("skymap",), 

88 ) 

89 cat_match_ref = cT.Input( 

90 doc="Reference match catalog with indices of target matches", 

91 name="match_ref_{name_input_cat_ref}_{name_input_cat_target}", 

92 storageClass="DataFrame", 

93 dimensions=("tract", "skymap"), 

94 deferLoad=True, 

95 ) 

96 cat_match_target = cT.Input( 

97 doc="Target match catalog with indices of references matches", 

98 name="match_target_{name_input_cat_ref}_{name_input_cat_target}", 

99 storageClass="DataFrame", 

100 dimensions=("tract", "skymap"), 

101 deferLoad=True, 

102 ) 

103 columns_match_target = cT.Input( 

104 doc="Target match catalog columns", 

105 name="match_target_{name_input_cat_ref}_{name_input_cat_target}.columns", 

106 storageClass="DataFrameIndex", 

107 dimensions=("tract", "skymap"), 

108 ) 

109 cat_matched = cT.Output( 

110 doc="Catalog with reference and target columns for matched sources only", 

111 name="matched_{name_input_cat_ref}_{name_input_cat_target}", 

112 storageClass="DataFrame", 

113 dimensions=("tract", "skymap"), 

114 ) 

115 diff_matched = cT.Output( 

116 doc="Table with aggregated counts, difference and chi statistics", 

117 name="diff_matched_{name_input_cat_ref}_{name_input_cat_target}", 

118 storageClass="DataFrame", 

119 dimensions=("tract", "skymap"), 

120 ) 

121 

122 

123class MatchedCatalogFluxesConfig(pexConfig.Config): 

124 column_ref_flux = pexConfig.Field( 

125 dtype=str, 

126 doc='Reference catalog flux column name', 

127 ) 

128 columns_target_flux = pexConfig.ListField( 

129 dtype=str, 

130 listCheck=is_sequence_set, 

131 doc="List of target catalog flux column names", 

132 ) 

133 columns_target_flux_err = pexConfig.ListField( 

134 dtype=str, 

135 listCheck=is_sequence_set, 

136 doc="List of target catalog flux error column names", 

137 ) 

138 

139 # this should be an orderedset 

140 @property 

141 def columns_in_ref(self) -> list[str]: 

142 return [self.column_ref_flux] 

143 

144 # this should also be an orderedset 

145 @property 

146 def columns_in_target(self) -> list[str]: 

147 columns = [col for col in self.columns_target_flux] 

148 columns.extend(col for col in self.columns_target_flux_err if col not in columns) 

149 return columns 

150 

151 

152class DiffMatchedTractCatalogConfig( 

153 pipeBase.PipelineTaskConfig, 

154 pipelineConnections=DiffMatchedTractCatalogConnections, 

155): 

156 column_matched_prefix_ref = pexConfig.Field( 

157 dtype=str, 

158 default='refcat_', 

159 doc='The prefix for matched columns copied from the reference catalog', 

160 ) 

161 column_ref_extended = pexConfig.Field( 

162 dtype=str, 

163 default='is_pointsource', 

164 doc='The boolean reference table column specifying if the target is extended', 

165 ) 

166 column_ref_extended_inverted = pexConfig.Field( 

167 dtype=bool, 

168 default=True, 

169 doc='Whether column_ref_extended specifies if the object is compact, not extended', 

170 ) 

171 column_target_extended = pexConfig.Field( 

172 dtype=str, 

173 default='refExtendedness', 

174 doc='The target table column estimating the extendedness of the object (0 <= x <= 1)', 

175 ) 

176 

177 @property 

178 def columns_in_ref(self) -> list[str]: 

179 columns_all = [self.coord_format.column_ref_coord1, self.coord_format.column_ref_coord2, 

180 self.column_ref_extended] 

181 for column_lists in ( 

182 ( 

183 self.columns_ref_copy, 

184 ), 

185 (x.columns_in_ref for x in self.columns_flux.values()), 

186 ): 

187 for column_list in column_lists: 

188 columns_all.extend(column_list) 

189 

190 return set(columns_all) 

191 

192 @property 

193 def columns_in_target(self) -> list[str]: 

194 columns_all = [self.coord_format.column_target_coord1, self.coord_format.column_target_coord2, 

195 self.column_target_extended] 

196 if self.coord_format.coords_ref_to_convert is not None: 

197 columns_all.extend(col for col in self.coord_format.coords_ref_to_convert.values() 

198 if col not in columns_all) 

199 for column_lists in ( 

200 ( 

201 self.columns_target_coord_err, 

202 self.columns_target_select_false, 

203 self.columns_target_select_true, 

204 self.columns_target_copy, 

205 ), 

206 (x.columns_in_target for x in self.columns_flux.values()), 

207 ): 

208 for column_list in column_lists: 

209 columns_all.extend(col for col in column_list if col not in columns_all) 

210 return columns_all 

211 

212 columns_flux = pexConfig.ConfigDictField( 

213 doc="Configs for flux columns for each band", 

214 keytype=str, 

215 itemtype=MatchedCatalogFluxesConfig, 

216 ) 

217 columns_ref_copy = pexConfig.ListField[str]( 

218 doc='Reference table columns to copy into cat_matched', 

219 default=[], 

220 listCheck=is_sequence_set, 

221 ) 

222 columns_target_coord_err = pexConfig.ListField[str]( 222 ↛ exitline 222 didn't jump to the function exit

223 doc='Target table coordinate columns with standard errors (sigma)', 

224 listCheck=lambda x: (len(x) == 2) and (x[0] != x[1]), 

225 ) 

226 columns_target_copy = pexConfig.ListField[str]( 

227 doc='Target table columns to copy into cat_matched', 

228 default=('patch',), 

229 listCheck=is_sequence_set, 

230 ) 

231 columns_target_select_true = pexConfig.ListField[str]( 

232 doc='Target table columns to require to be True for selecting sources', 

233 default=('detect_isPrimary',), 

234 listCheck=is_sequence_set, 

235 ) 

236 columns_target_select_false = pexConfig.ListField[str]( 

237 doc='Target table columns to require to be False for selecting sources', 

238 default=('merge_peak_sky',), 

239 listCheck=is_sequence_set, 

240 ) 

241 coord_format = pexConfig.ConfigField[ConvertCatalogCoordinatesConfig]( 

242 doc="Configuration for coordinate conversion", 

243 ) 

244 extendedness_cut = pexConfig.Field[float]( 

245 dtype=float, 

246 default=0.5, 

247 doc='Minimum extendedness for a measured source to be considered extended', 

248 ) 

249 mag_num_bins = pexConfig.Field[int]( 

250 doc='Number of magnitude bins', 

251 default=15, 

252 ) 

253 mag_brightest_ref = pexConfig.Field[float]( 

254 doc='Brightest magnitude cutoff for binning', 

255 default=15, 

256 ) 

257 mag_ceiling_target = pexConfig.Field[float]( 

258 doc='Ceiling (maximum/faint) magnitude for target sources', 

259 default=None, 

260 optional=True, 

261 ) 

262 mag_faintest_ref = pexConfig.Field[float]( 

263 doc='Faintest magnitude cutoff for binning', 

264 default=30, 

265 ) 

266 mag_zeropoint_ref = pexConfig.Field[float]( 

267 doc='Magnitude zeropoint for reference sources', 

268 default=31.4, 

269 ) 

270 mag_zeropoint_target = pexConfig.Field[float]( 

271 doc='Magnitude zeropoint for target sources', 

272 default=31.4, 

273 ) 

274 percentiles = pexConfig.ListField[str]( 

275 doc='Percentiles to compute for diff/chi values', 

276 # -2, -1, +1, +2 sigma percentiles for normal distribution 

277 default=('2.275', '15.866', '84.134', '97.725'), 

278 itemCheck=is_percentile, 

279 listCheck=is_sequence_set, 

280 ) 

281 

282 

283@dataclass(frozen=True) 

284class MeasurementTypeInfo: 

285 doc: str 

286 name: str 

287 

288 

289class MeasurementType(Enum): 

290 DIFF = MeasurementTypeInfo( 

291 doc="difference (measured - reference)", 

292 name="diff", 

293 ) 

294 CHI = MeasurementTypeInfo( 

295 doc="scaled difference (measured - reference)/error", 

296 name="chi", 

297 ) 

298 

299 

300class Statistic(metaclass=ABCMeta): 

301 """A statistic that can be applied to a set of values. 

302 """ 

303 @abstractmethod 

304 def doc(self) -> str: 

305 """A description of the statistic""" 

306 raise NotImplementedError('Subclasses must implement this method') 

307 

308 @abstractmethod 

309 def name_short(self) -> str: 

310 """A short name for the statistic, e.g. for a table column name""" 

311 raise NotImplementedError('Subclasses must implement this method') 

312 

313 @abstractmethod 

314 def value(self, values): 

315 """The value of the statistic for a set of input values. 

316 

317 Parameters 

318 ---------- 

319 values : `Collection` [`float`] 

320 A set of values to compute the statistic for. 

321 

322 Returns 

323 ------- 

324 statistic : `float` 

325 The value of the statistic. 

326 """ 

327 raise NotImplementedError('Subclasses must implement this method') 

328 

329 

330class Median(Statistic): 

331 """The median of a set of values.""" 

332 @classmethod 

333 def doc(cls) -> str: 

334 return "Median" 

335 

336 @classmethod 

337 def name_short(cls) -> str: 

338 return "median" 

339 

340 def value(self, values): 

341 return np.median(values) 

342 

343 

344class SigmaIQR(Statistic): 

345 """The re-scaled interquartile range (sigma equivalent).""" 

346 @classmethod 

347 def doc(cls) -> str: 

348 return "Interquartile range divided by ~1.349 (sigma-equivalent)" 

349 

350 @classmethod 

351 def name_short(cls) -> str: 

352 return "sig_iqr" 

353 

354 def value(self, values): 

355 return iqr(values, scale='normal') 

356 

357 

358class SigmaMAD(Statistic): 

359 """The re-scaled median absolute deviation (sigma equivalent).""" 

360 @classmethod 

361 def doc(cls) -> str: 

362 return "Median absolute deviation multiplied by ~1.4826 (sigma-equivalent)" 

363 

364 @classmethod 

365 def name_short(cls) -> str: 

366 return "sig_mad" 

367 

368 def value(self, values): 

369 return mad_std(values) 

370 

371 

372@dataclass(frozen=True) 

373class Percentile(Statistic): 

374 """An arbitrary percentile. 

375 

376 Parameters 

377 ---------- 

378 percentile : `float` 

379 A valid percentile (0 <= p <= 100). 

380 """ 

381 percentile: float 

382 

383 def doc(self) -> str: 

384 return "Median absolute deviation multiplied by ~1.4826 (sigma-equivalent)" 

385 

386 def name_short(self) -> str: 

387 return f"pctl{f'{self.percentile/100:.5f}'[2:]}" 

388 

389 def value(self, values): 

390 return np.percentile(values, self.percentile) 

391 

392 def __post_init__(self): 

393 if not ((self.percentile >= 0) and (self.percentile <= 100)): 

394 raise ValueError(f'percentile={self.percentile} not >=0 and <= 100') 

395 

396 

397def _get_stat_name(*args): 

398 return '_'.join(args) 

399 

400 

401def _get_column_name(band, *args): 

402 return f"{band}_{_get_stat_name(*args)}" 

403 

404 

405def compute_stats(values_ref, values_target, errors_target, row, stats, suffixes, prefix, skip_diff=False): 

406 """Compute statistics on differences and store results in a row. 

407 

408 Parameters 

409 ---------- 

410 values_ref : `numpy.ndarray`, (N,) 

411 Reference values. 

412 values_target : `numpy.ndarray`, (N,) 

413 Measured values. 

414 errors_target : `numpy.ndarray`, (N,) 

415 Errors (standard deviations) on `values_target`. 

416 row : `numpy.ndarray`, (1, C) 

417 A numpy array with pre-assigned column names. 

418 stats : `Dict` [`str`, `Statistic`] 

419 A dict of `Statistic` values to measure, keyed by their column suffix. 

420 suffixes : `Dict` [`MeasurementType`, `str`] 

421 A dict of measurement type column suffixes, keyed by the measurement type. 

422 prefix : `str` 

423 A prefix for all column names (e.g. band). 

424 skip_diff : `bool` 

425 Whether to skip computing statistics on differences. Note that 

426 differences will still be computed for chi statistics. 

427 

428 Returns 

429 ------- 

430 row_with_stats : `numpy.ndarray`, (1, C) 

431 The original `row` with statistic values assigned. 

432 """ 

433 n_ref = len(values_ref) 

434 if n_ref > 0: 

435 n_target = len(values_target) 

436 n_target_err = len(errors_target) if errors_target is not None else n_ref 

437 if (n_target != n_ref) or (n_target_err != n_ref): 

438 raise ValueError(f'lengths of values_ref={n_ref}, values_target={n_target}' 

439 f', error_target={n_target_err} must match') 

440 

441 do_chi = errors_target is not None 

442 diff = values_target - values_ref 

443 chi = diff/errors_target if do_chi else diff 

444 # Could make this configurable, but non-finite values/errors are not really usable 

445 valid = np.isfinite(chi) 

446 values_type = {} if skip_diff else {MeasurementType.DIFF: diff[valid]} 

447 if do_chi: 

448 values_type[MeasurementType.CHI] = chi[valid] 

449 

450 for suffix_type, suffix in suffixes.items(): 

451 values = values_type.get(suffix_type) 

452 if values is not None and len(values) > 0: 

453 for stat_name, stat in stats.items(): 

454 row[_get_stat_name(prefix, suffix, stat_name)] = stat.value(values) 

455 return row 

456 

457 

458@dataclass(frozen=True) 

459class SourceTypeInfo: 

460 is_extended: bool | None 

461 label: str 

462 

463 

464class SourceType(Enum): 

465 ALL = SourceTypeInfo(is_extended=None, label='all') 

466 RESOLVED = SourceTypeInfo(is_extended=True, label='resolved') 

467 UNRESOLVED = SourceTypeInfo(is_extended=False, label='unresolved') 

468 

469 

470class MatchType(Enum): 

471 ALL = 'all' 

472 MATCH_RIGHT = 'match_right' 

473 MATCH_WRONG = 'match_wrong' 

474 

475 

476def _get_columns(bands_columns: Dict, suffixes: Dict, suffixes_flux: Dict, suffixes_mag: Dict, 

477 stats: Dict, target: ComparableCatalog, column_dist: str): 

478 """Get column names for a table of difference statistics. 

479 

480 Parameters 

481 ---------- 

482 bands_columns : `Dict` [`str`,`MatchedCatalogFluxesConfig`] 

483 Dict keyed by band of flux column configuration. 

484 suffixes, suffixes_flux, suffixes_mag : `Dict` [`MeasurementType`, `str`] 

485 Dict of suffixes for each `MeasurementType` type, for general columns (e.g. 

486 coordinates), fluxes and magnitudes, respectively. 

487 stats : `Dict` [`str`, `Statistic`] 

488 Dict of suffixes for each `Statistic` type. 

489 target : `ComparableCatalog` 

490 A target catalog with coordinate column names. 

491 column_dist : `str` 

492 The name of the distance column. 

493 

494 Returns 

495 ------- 

496 columns : `Dict` [`str`, `type`] 

497 Dictionary of column types keyed by name. 

498 n_models : `int` 

499 The number of models measurements will be made for. 

500 

501 Notes 

502 ----- 

503 Presently, models must be identical for each band. 

504 """ 

505 # Initial columns 

506 columns = { 

507 "bin": int, 

508 "mag_min": float, 

509 "mag_max": float, 

510 } 

511 

512 # pre-assign all of the columns with appropriate types 

513 n_models = 0 

514 

515 bands = list(bands_columns.keys()) 

516 n_bands = len(bands) 

517 

518 for idx, (band, config_flux) in enumerate(bands_columns.items()): 

519 columns_suffix = [ 

520 ('flux', suffixes_flux), 

521 ('mag', suffixes_mag), 

522 ] 

523 if idx == 0: 

524 n_models = len(config_flux.columns_target_flux) 

525 if (idx > 0) or (n_bands > 2): 

526 columns_suffix.append((f'color_{bands[idx - 1]}_m_{band}', suffixes)) 

527 n_models_flux = len(config_flux.columns_target_flux) 

528 n_models_err = len(config_flux.columns_target_flux_err) 

529 

530 # TODO: Do equivalent validation earlier, in the config 

531 if (n_models_flux != n_models) or (n_models_err != n_models): 

532 raise RuntimeError(f'{config_flux} len(columns_target_flux)={n_models_flux} and' 

533 f' len(columns_target_flux_err)={n_models_err} must equal {n_models}') 

534 

535 for sourcetype in SourceType: 

536 label = sourcetype.value.label 

537 # Totals would be redundant 

538 if sourcetype != SourceType.ALL: 

539 for item in (f'n_{itype}_{mtype.value}' for itype in ('ref', 'target') 

540 for mtype in MatchType): 

541 columns[_get_column_name(band, label, item)] = int 

542 

543 for item in (target.column_coord1, target.column_coord2, column_dist): 

544 for suffix in suffixes.values(): 

545 for stat in stats.keys(): 

546 columns[_get_column_name(band, label, item, suffix, stat)] = float 

547 

548 for item in config_flux.columns_target_flux: 

549 for prefix_item, suffixes_col in columns_suffix: 

550 for suffix in suffixes_col.values(): 

551 for stat in stats.keys(): 

552 columns[_get_column_name(band, label, prefix_item, item, suffix, stat)] = float 

553 

554 return columns, n_models 

555 

556 

557class DiffMatchedTractCatalogTask(pipeBase.PipelineTask): 

558 """Load subsets of matched catalogs and output a merged catalog of matched sources. 

559 """ 

560 ConfigClass = DiffMatchedTractCatalogConfig 

561 _DefaultName = "DiffMatchedTractCatalog" 

562 

563 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

564 inputs = butlerQC.get(inputRefs) 

565 skymap = inputs.pop("skymap") 

566 

567 columns_match_target = ['match_row'] 

568 if 'match_candidate' in inputs['columns_match_target']: 

569 columns_match_target.append('match_candidate') 

570 

571 outputs = self.run( 

572 catalog_ref=inputs['cat_ref'].get(parameters={'columns': self.config.columns_in_ref}), 

573 catalog_target=inputs['cat_target'].get(parameters={'columns': self.config.columns_in_target}), 

574 catalog_match_ref=inputs['cat_match_ref'].get( 

575 parameters={'columns': ['match_candidate', 'match_row']}, 

576 ), 

577 catalog_match_target=inputs['cat_match_target'].get( 

578 parameters={'columns': columns_match_target}, 

579 ), 

580 wcs=skymap[butlerQC.quantum.dataId["tract"]].wcs, 

581 ) 

582 butlerQC.put(outputs, outputRefs) 

583 

584 def run( 

585 self, 

586 catalog_ref: pd.DataFrame, 

587 catalog_target: pd.DataFrame, 

588 catalog_match_ref: pd.DataFrame, 

589 catalog_match_target: pd.DataFrame, 

590 wcs: afwGeom.SkyWcs = None, 

591 ) -> pipeBase.Struct: 

592 """Load matched reference and target (measured) catalogs, measure summary statistics, and output 

593 a combined matched catalog with columns from both inputs. 

594 

595 Parameters 

596 ---------- 

597 catalog_ref : `pandas.DataFrame` 

598 A reference catalog to diff objects/sources from. 

599 catalog_target : `pandas.DataFrame` 

600 A target catalog to diff reference objects/sources to. 

601 catalog_match_ref : `pandas.DataFrame` 

602 A catalog with match indices of target sources and selection flags 

603 for each reference source. 

604 catalog_match_target : `pandas.DataFrame` 

605 A catalog with selection flags for each target source. 

606 wcs : `lsst.afw.image.SkyWcs` 

607 A coordinate system to convert catalog positions to sky coordinates, 

608 if necessary. 

609 

610 Returns 

611 ------- 

612 retStruct : `lsst.pipe.base.Struct` 

613 A struct with output_ref and output_target attribute containing the 

614 output matched catalogs. 

615 """ 

616 config = self.config 

617 

618 select_ref = catalog_match_ref['match_candidate'].values 

619 # Add additional selection criteria for target sources beyond those for matching 

620 # (not recommended, but can be done anyway) 

621 select_target = (catalog_match_target['match_candidate'].values 

622 if 'match_candidate' in catalog_match_target.columns 

623 else np.ones(len(catalog_match_target), dtype=bool)) 

624 for column in config.columns_target_select_true: 

625 select_target &= catalog_target[column].values 

626 for column in config.columns_target_select_false: 

627 select_target &= ~catalog_target[column].values 

628 

629 ref, target = config.coord_format.format_catalogs( 

630 catalog_ref=catalog_ref, catalog_target=catalog_target, 

631 select_ref=None, select_target=select_target, wcs=wcs, radec_to_xy_func=radec_to_xy, 

632 return_converted_columns=config.coord_format.coords_ref_to_convert is not None, 

633 ) 

634 cat_ref = ref.catalog 

635 cat_target = target.catalog 

636 n_target = len(cat_target) 

637 

638 match_row = catalog_match_ref['match_row'].values 

639 matched_ref = match_row >= 0 

640 matched_row = match_row[matched_ref] 

641 matched_target = np.zeros(n_target, dtype=bool) 

642 matched_target[matched_row] = True 

643 

644 # Create a matched table, preserving the target catalog's named index (if it has one) 

645 cat_left = cat_target.iloc[matched_row] 

646 has_index_left = cat_left.index.name is not None 

647 cat_right = cat_ref[matched_ref].reset_index() 

648 cat_matched = pd.concat(objs=(cat_left.reset_index(drop=True), cat_right), axis=1, sort=False) 

649 if has_index_left: 

650 cat_matched.index = cat_left.index 

651 cat_matched.columns.values[len(cat_target.columns):] = [f'refcat_{col}' for col in cat_right.columns] 

652 

653 # Add/compute distance columns 

654 coord1_target_err, coord2_target_err = config.columns_target_coord_err 

655 column_dist, column_dist_err = 'distance', 'distanceErr' 

656 dist = np.full(n_target, np.Inf) 

657 

658 dist[matched_row] = np.hypot( 

659 target.coord1[matched_row] - ref.coord1[matched_ref], 

660 target.coord2[matched_row] - ref.coord2[matched_ref], 

661 ) 

662 dist_err = np.full(n_target, np.Inf) 

663 dist_err[matched_row] = np.hypot(cat_target.iloc[matched_row][coord1_target_err].values, 

664 cat_target.iloc[matched_row][coord2_target_err].values) 

665 cat_target[column_dist], cat_target[column_dist_err] = dist, dist_err 

666 

667 # Slightly smelly hack for when a column (like distance) is already relative to truth 

668 column_dummy = 'dummy' 

669 cat_ref[column_dummy] = np.zeros_like(ref.coord1) 

670 

671 # Add a boolean column for whether a match is classified correctly 

672 extended_ref = cat_ref[config.column_ref_extended] 

673 if config.column_ref_extended_inverted: 

674 extended_ref = 1 - extended_ref 

675 

676 extended_target = cat_target[config.column_target_extended].values >= config.extendedness_cut 

677 

678 # Define difference/chi columns and statistics thereof 

679 suffixes = {MeasurementType.DIFF: 'diff', MeasurementType.CHI: 'chi'} 

680 # Skip diff for fluxes - covered by mags 

681 suffixes_flux = {MeasurementType.CHI: suffixes[MeasurementType.CHI]} 

682 # Skip chi for magnitudes, which have strange errors 

683 suffixes_mag = {MeasurementType.DIFF: suffixes[MeasurementType.DIFF]} 

684 stats = {stat.name_short(): stat() for stat in (Median, SigmaIQR, SigmaMAD)} 

685 

686 for percentile in self.config.percentiles: 

687 stat = Percentile(percentile=float(Decimal(percentile))) 

688 stats[stat.name_short()] = stat 

689 

690 # Get dict of column names 

691 columns, n_models = _get_columns( 

692 bands_columns=config.columns_flux, 

693 suffixes=suffixes, 

694 suffixes_flux=suffixes_flux, 

695 suffixes_mag=suffixes_mag, 

696 stats=stats, 

697 target=target, 

698 column_dist=column_dist, 

699 ) 

700 

701 # Setup numpy table 

702 n_bins = config.mag_num_bins 

703 data = np.zeros((n_bins,), dtype=[(key, value) for key, value in columns.items()]) 

704 data['bin'] = np.arange(n_bins) 

705 

706 # Setup bins 

707 bins_mag = np.linspace(start=config.mag_brightest_ref, stop=config.mag_faintest_ref, 

708 num=n_bins + 1) 

709 data['mag_min'] = bins_mag[:-1] 

710 data['mag_max'] = bins_mag[1:] 

711 bins_mag = tuple((bins_mag[idx], bins_mag[idx + 1]) for idx in range(n_bins)) 

712 

713 # Define temporary columns for intermediate storage 

714 column_mag_temp = 'mag_temp' 

715 column_color_temp = 'color_temp' 

716 column_color_err_temp = 'colorErr_temp' 

717 flux_err_frac_prev = [None]*n_models 

718 mag_prev = [None]*n_models 

719 

720 columns_target = { 

721 target.column_coord1: ( 

722 ref.column_coord1, target.column_coord1, coord1_target_err, False, 

723 ), 

724 target.column_coord2: ( 

725 ref.column_coord2, target.column_coord2, coord2_target_err, False, 

726 ), 

727 column_dist: (column_dummy, column_dist, column_dist_err, False), 

728 } 

729 

730 # Cheat a little and do the first band last so that the color is 

731 # based on the last band 

732 band_fluxes = [(band, config_flux) for (band, config_flux) in config.columns_flux.items()] 

733 n_bands = len(band_fluxes) 

734 band_fluxes.append(band_fluxes[0]) 

735 flux_err_frac_first = None 

736 mag_first = None 

737 mag_ref_first = None 

738 

739 band_prev = None 

740 for idx_band, (band, config_flux) in enumerate(band_fluxes): 

741 if idx_band == n_bands: 

742 # These were already computed earlier 

743 mag_ref = mag_ref_first 

744 flux_err_frac = flux_err_frac_first 

745 mag_model = mag_first 

746 else: 

747 mag_ref = -2.5*np.log10(cat_ref[config_flux.column_ref_flux]) + config.mag_zeropoint_ref 

748 flux_err_frac = [None]*n_models 

749 mag_model = [None]*n_models 

750 

751 if idx_band > 0: 

752 cat_ref[column_color_temp] = cat_ref[column_mag_temp] - mag_ref 

753 

754 cat_ref[column_mag_temp] = mag_ref 

755 

756 select_ref_bins = [select_ref & (mag_ref > mag_lo) & (mag_ref < mag_hi) 

757 for idx_bin, (mag_lo, mag_hi) in enumerate(bins_mag)] 

758 

759 # Iterate over multiple models, compute their mags and colours (if there's a previous band) 

760 for idx_model in range(n_models): 

761 column_target_flux = config_flux.columns_target_flux[idx_model] 

762 column_target_flux_err = config_flux.columns_target_flux_err[idx_model] 

763 

764 flux_target = cat_target[column_target_flux] 

765 mag_target = -2.5*np.log10(flux_target) + config.mag_zeropoint_target 

766 if config.mag_ceiling_target is not None: 

767 mag_target[mag_target > config.mag_ceiling_target] = config.mag_ceiling_target 

768 mag_model[idx_model] = mag_target 

769 

770 # These are needed for computing magnitude/color "errors" (which are a sketchy concept) 

771 flux_err_frac[idx_model] = cat_target[column_target_flux_err]/flux_target 

772 

773 # Stop if idx == 0: The rest will be picked up at idx == n_bins 

774 if idx_band > 0: 

775 # Keep these mags tabulated for convenience 

776 column_mag_temp_model = f'{column_mag_temp}{idx_model}' 

777 cat_target[column_mag_temp_model] = mag_target 

778 

779 columns_target[f'flux_{column_target_flux}'] = ( 

780 config_flux.column_ref_flux, 

781 column_target_flux, 

782 column_target_flux_err, 

783 True, 

784 ) 

785 # Note: magnitude errors are generally problematic and not worth aggregating 

786 columns_target[f'mag_{column_target_flux}'] = ( 

787 column_mag_temp, column_mag_temp_model, None, False, 

788 ) 

789 

790 # No need for colors if this is the last band and there are only two bands 

791 # (because it would just be the negative of the first color) 

792 skip_color = (idx_band == n_bands) and (n_bands <= 2) 

793 if not skip_color: 

794 column_color_temp_model = f'{column_color_temp}{idx_model}' 

795 column_color_err_temp_model = f'{column_color_err_temp}{idx_model}' 

796 

797 # e.g. if order is ugrizy, first color will be u - g 

798 cat_target[column_color_temp_model] = mag_prev[idx_model] - mag_model[idx_model] 

799 

800 # Sum (in quadrature, and admittedly sketchy for faint fluxes) magnitude errors 

801 cat_target[column_color_err_temp_model] = 2.5/np.log(10)*np.hypot( 

802 flux_err_frac[idx_model], flux_err_frac_prev[idx_model]) 

803 columns_target[f'color_{band_prev}_m_{band}_{column_target_flux}'] = ( 

804 column_color_temp, 

805 column_color_temp_model, 

806 column_color_err_temp_model, 

807 False, 

808 ) 

809 

810 for idx_bin, (mag_lo, mag_hi) in enumerate(bins_mag): 

811 row = data[idx_bin] 

812 # Reference sources only need to be counted once 

813 if idx_model == 0: 

814 select_ref_bin = select_ref_bins[idx_bin] 

815 select_target_bin = select_target & (mag_target > mag_lo) & (mag_target < mag_hi) 

816 

817 for sourcetype in SourceType: 

818 sourcetype_info = sourcetype.value 

819 is_extended = sourcetype_info.is_extended 

820 # Counts filtered by match selection and magnitude bin 

821 select_ref_sub = select_ref_bin.copy() 

822 select_target_sub = select_target_bin.copy() 

823 if is_extended is not None: 

824 is_extended_ref = (extended_ref == is_extended) 

825 select_ref_sub &= is_extended_ref 

826 if idx_model == 0: 

827 n_ref_sub = np.count_nonzero(select_ref_sub) 

828 row[_get_column_name(band, sourcetype_info.label, 'n_ref', 

829 MatchType.ALL.value)] = n_ref_sub 

830 select_target_sub &= (extended_target == is_extended) 

831 n_target_sub = np.count_nonzero(select_target_sub) 

832 row[_get_column_name(band, sourcetype_info.label, 'n_target', 

833 MatchType.ALL.value)] = n_target_sub 

834 

835 # Filter matches by magnitude bin and true class 

836 match_row_bin = match_row.copy() 

837 match_row_bin[~select_ref_sub] = -1 

838 match_good = match_row_bin >= 0 

839 

840 n_match = np.count_nonzero(match_good) 

841 

842 # Same for counts of matched target sources (for e.g. purity) 

843 

844 if n_match > 0: 

845 rows_matched = match_row_bin[match_good] 

846 subset_target = cat_target.iloc[rows_matched] 

847 if (is_extended is not None) and (idx_model == 0): 

848 right_type = extended_target[rows_matched] == is_extended 

849 n_total = len(right_type) 

850 n_right = np.count_nonzero(right_type) 

851 row[_get_column_name(band, sourcetype_info.label, 'n_ref', 

852 MatchType.MATCH_RIGHT.value)] = n_right 

853 row[_get_column_name( 

854 band, sourcetype_info.label, 'n_ref', MatchType.MATCH_WRONG.value, 

855 )] = n_total - n_right 

856 

857 # compute stats for this bin, for all columns 

858 for column, (column_ref, column_target, column_err_target, skip_diff) \ 

859 in columns_target.items(): 

860 values_ref = cat_ref[column_ref][match_good].values 

861 errors_target = ( 

862 subset_target[column_err_target].values 

863 if column_err_target is not None 

864 else None 

865 ) 

866 compute_stats( 

867 values_ref, 

868 subset_target[column_target].values, 

869 errors_target, 

870 row, 

871 stats, 

872 suffixes, 

873 prefix=f'{band}_{sourcetype_info.label}_{column}', 

874 skip_diff=skip_diff, 

875 ) 

876 

877 # Count matched target sources with *measured* mags within bin 

878 # Used for e.g. purity calculation 

879 # Should be merged with above code if there's ever a need for 

880 # measuring stats on this source selection 

881 select_target_sub &= matched_target 

882 

883 if is_extended is not None and (np.count_nonzero(select_target_sub) > 0): 

884 n_total = np.count_nonzero(select_target_sub) 

885 right_type = np.zeros(n_target, dtype=bool) 

886 right_type[match_row[matched_ref & is_extended_ref]] = True 

887 right_type &= select_target_sub 

888 n_right = np.count_nonzero(right_type) 

889 row[_get_column_name(band, sourcetype_info.label, 'n_target', 

890 MatchType.MATCH_RIGHT.value)] = n_right 

891 row[_get_column_name(band, sourcetype_info.label, 'n_target', 

892 MatchType.MATCH_WRONG.value)] = n_total - n_right 

893 

894 # delete the flux/color columns since they change with each band 

895 for prefix in ('flux', 'mag'): 

896 del columns_target[f'{prefix}_{column_target_flux}'] 

897 if not skip_color: 

898 del columns_target[f'color_{band_prev}_m_{band}_{column_target_flux}'] 

899 

900 # keep values needed for colors 

901 flux_err_frac_prev = flux_err_frac 

902 mag_prev = mag_model 

903 band_prev = band 

904 if idx_band == 0: 

905 flux_err_frac_first = flux_err_frac 

906 mag_first = mag_model 

907 mag_ref_first = mag_ref 

908 

909 retStruct = pipeBase.Struct(cat_matched=cat_matched, diff_matched=pd.DataFrame(data)) 

910 return retStruct