lsst.pipe.tasks g3de15ee5c7+f497bfeb17
Loading...
Searching...
No Matches
diff_matched_tract_catalog.py
Go to the documentation of this file.
1# This file is part of pipe_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21
22__all__ = [
23 'DiffMatchedTractCatalogConfig', 'DiffMatchedTractCatalogTask', 'MatchedCatalogFluxesConfig',
24 'MatchType', 'MeasurementType', 'SourceType',
25 'Statistic', 'Median', 'SigmaIQR', 'SigmaMAD', 'Percentile',
26]
27
28import lsst.afw.geom as afwGeom
30 ComparableCatalog, ConvertCatalogCoordinatesConfig,
31)
33import lsst.pex.config as pexConfig
34import lsst.pipe.base as pipeBase
35import lsst.pipe.base.connectionTypes as cT
36from lsst.skymap import BaseSkyMap
37
38from abc import ABCMeta, abstractmethod
39from astropy.stats import mad_std
40import astropy.units as u
41from dataclasses import dataclass
42from decimal import Decimal
43from enum import Enum
44import numpy as np
45import pandas as pd
46from scipy.stats import iqr
47from typing import Dict, Sequence
48
49
50def is_sequence_set(x: Sequence):
51 return len(x) == len(set(x))
52
53
54def is_percentile(x: str):
55 return 0 <= Decimal(x) <= 100
56
57
58DiffMatchedTractCatalogBaseTemplates = {
59 "name_input_cat_ref": "truth_summary",
60 "name_input_cat_target": "objectTable_tract",
61 "name_skymap": BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
62}
63
64
66 pipeBase.PipelineTaskConnections,
67 dimensions=("tract", "skymap"),
68 defaultTemplates=DiffMatchedTractCatalogBaseTemplates,
69):
70 cat_ref = cT.Input(
71 doc="Reference object catalog to match from",
72 name="{name_input_cat_ref}",
73 storageClass="DataFrame",
74 dimensions=("tract", "skymap"),
75 deferLoad=True,
76 )
77 cat_target = cT.Input(
78 doc="Target object catalog to match",
79 name="{name_input_cat_target}",
80 storageClass="DataFrame",
81 dimensions=("tract", "skymap"),
82 deferLoad=True,
83 )
84 skymap = cT.Input(
85 doc="Input definition of geometry/bbox and projection/wcs for coadded exposures",
86 name="{name_skymap}",
87 storageClass="SkyMap",
88 dimensions=("skymap",),
89 )
90 cat_match_ref = cT.Input(
91 doc="Reference match catalog with indices of target matches",
92 name="match_ref_{name_input_cat_ref}_{name_input_cat_target}",
93 storageClass="DataFrame",
94 dimensions=("tract", "skymap"),
95 deferLoad=True,
96 )
97 cat_match_target = cT.Input(
98 doc="Target match catalog with indices of references matches",
99 name="match_target_{name_input_cat_ref}_{name_input_cat_target}",
100 storageClass="DataFrame",
101 dimensions=("tract", "skymap"),
102 deferLoad=True,
103 )
104 columns_match_target = cT.Input(
105 doc="Target match catalog columns",
106 name="match_target_{name_input_cat_ref}_{name_input_cat_target}.columns",
107 storageClass="DataFrameIndex",
108 dimensions=("tract", "skymap"),
109 )
110 cat_matched = cT.Output(
111 doc="Catalog with reference and target columns for joined sources",
112 name="matched_{name_input_cat_ref}_{name_input_cat_target}",
113 storageClass="DataFrame",
114 dimensions=("tract", "skymap"),
115 )
116 diff_matched = cT.Output(
117 doc="Table with aggregated counts, difference and chi statistics",
118 name="diff_matched_{name_input_cat_ref}_{name_input_cat_target}",
119 storageClass="DataFrame",
120 dimensions=("tract", "skymap"),
121 )
122
123
124class MatchedCatalogFluxesConfig(pexConfig.Config):
125 column_ref_flux = pexConfig.Field(
126 dtype=str,
127 doc='Reference catalog flux column name',
128 )
129 columns_target_flux = pexConfig.ListField(
130 dtype=str,
131 listCheck=is_sequence_set,
132 doc="List of target catalog flux column names",
133 )
134 columns_target_flux_err = pexConfig.ListField(
135 dtype=str,
136 listCheck=is_sequence_set,
137 doc="List of target catalog flux error column names",
138 )
139
140 # this should be an orderedset
141 @property
142 def columns_in_ref(self) -> list[str]:
143 return [self.column_ref_flux]
144
145 # this should also be an orderedset
146 @property
147 def columns_in_target(self) -> list[str]:
148 columns = [col for col in self.columns_target_flux]
149 columns.extend(col for col in self.columns_target_flux_err if col not in columns)
150 return columns
151
152
154 pipeBase.PipelineTaskConfig,
155 pipelineConnections=DiffMatchedTractCatalogConnections,
156):
157 column_matched_prefix_ref = pexConfig.Field[str](
158 default='refcat_',
159 doc='The prefix for matched columns copied from the reference catalog',
160 )
161 column_ref_extended = pexConfig.Field[str](
162 default='is_pointsource',
163 doc='The boolean reference table column specifying if the target is extended',
164 )
165 column_ref_extended_inverted = pexConfig.Field[bool](
166 default=True,
167 doc='Whether column_ref_extended specifies if the object is compact, not extended',
168 )
169 column_target_extended = pexConfig.Field[str](
170 default='refExtendedness',
171 doc='The target table column estimating the extendedness of the object (0 <= x <= 1)',
172 )
173 include_unmatched = pexConfig.Field[bool](
174 default=False,
175 doc="Whether to include unmatched rows in the matched table",
176 )
177
178 @property
179 def columns_in_ref(self) -> list[str]:
180 columns_all = [self.coord_format.column_ref_coord1, self.coord_format.column_ref_coord2,
182 for column_lists in (
183 (
184 self.columns_ref_copy,
185 ),
186 (x.columns_in_ref for x in self.columns_flux.values()),
187 ):
188 for column_list in column_lists:
189 columns_all.extend(column_list)
190
191 return list({column: None for column in columns_all}.keys())
192
193 @property
194 def columns_in_target(self) -> list[str]:
195 columns_all = [self.coord_format.column_target_coord1, self.coord_format.column_target_coord2,
197 if self.coord_format.coords_ref_to_convert is not None:
198 columns_all.extend(col for col in self.coord_format.coords_ref_to_convert.values()
199 if col not in columns_all)
200 for column_lists in (
201 (
206 ),
207 (x.columns_in_target for x in self.columns_flux.values()),
208 ):
209 for column_list in column_lists:
210 columns_all.extend(col for col in column_list if col not in columns_all)
211 return columns_all
212
213 columns_flux = pexConfig.ConfigDictField(
214 doc="Configs for flux columns for each band",
215 keytype=str,
216 itemtype=MatchedCatalogFluxesConfig,
217 default={},
218 )
219 columns_ref_mag_to_nJy = pexConfig.DictField[str, str](
220 doc='Reference table AB mag columns to convert to nJy flux columns with new names',
221 default={},
222 )
223 columns_ref_copy = pexConfig.ListField[str](
224 doc='Reference table columns to copy into cat_matched',
225 default=[],
226 listCheck=is_sequence_set,
227 )
228 columns_target_coord_err = pexConfig.ListField[str](
229 doc='Target table coordinate columns with standard errors (sigma)',
230 listCheck=lambda x: (len(x) == 2) and (x[0] != x[1]),
231 )
232 columns_target_copy = pexConfig.ListField[str](
233 doc='Target table columns to copy into cat_matched',
234 default=('patch',),
235 listCheck=is_sequence_set,
236 )
237 columns_target_mag_to_nJy = pexConfig.DictField[str, str](
238 doc='Target table AB mag columns to convert to nJy flux columns with new names',
239 default={},
240 )
241 columns_target_select_true = pexConfig.ListField[str](
242 doc='Target table columns to require to be True for selecting sources',
243 default=('detect_isPrimary',),
244 listCheck=is_sequence_set,
245 )
246 columns_target_select_false = pexConfig.ListField[str](
247 doc='Target table columns to require to be False for selecting sources',
248 default=('merge_peak_sky',),
249 listCheck=is_sequence_set,
250 )
251 coord_format = pexConfig.ConfigField[ConvertCatalogCoordinatesConfig](
252 doc="Configuration for coordinate conversion",
253 )
254 extendedness_cut = pexConfig.Field[float](
255 dtype=float,
256 default=0.5,
257 doc='Minimum extendedness for a measured source to be considered extended',
258 )
259 mag_num_bins = pexConfig.Field[int](
260 doc='Number of magnitude bins',
261 default=15,
262 )
263 mag_brightest_ref = pexConfig.Field[float](
264 doc='Brightest magnitude cutoff for binning',
265 default=15,
266 )
267 mag_ceiling_target = pexConfig.Field[float](
268 doc='Ceiling (maximum/faint) magnitude for target sources',
269 default=None,
270 optional=True,
271 )
272 mag_faintest_ref = pexConfig.Field[float](
273 doc='Faintest magnitude cutoff for binning',
274 default=30,
275 )
276 mag_zeropoint_ref = pexConfig.Field[float](
277 doc='Magnitude zeropoint for reference sources',
278 default=31.4,
279 )
280 mag_zeropoint_target = pexConfig.Field[float](
281 doc='Magnitude zeropoint for target sources',
282 default=31.4,
283 )
284 percentiles = pexConfig.ListField[str](
285 doc='Percentiles to compute for diff/chi values',
286 # -2, -1, +1, +2 sigma percentiles for normal distribution
287 default=('2.275', '15.866', '84.134', '97.725'),
288 itemCheck=is_percentile,
289 listCheck=is_sequence_set,
290 )
291
292 def validate(self):
293 super().validate()
294
295 errors = []
296
297 for columns_mag, columns_in, name_columns_copy in (
298 (self.columns_ref_mag_to_nJy, self.columns_in_refcolumns_in_ref, "columns_ref_copy"),
300 ):
301 columns_copy = getattr(self, name_columns_copy)
302 for column_old, column_new in columns_mag.items():
303 if column_old not in columns_in:
304 errors.append(
305 f"{column_old=} key in self.columns_mag_to_nJy not found in {columns_in=}; did you"
306 f" forget to add it to self.{name_columns_copy}={columns_copy}?"
307 )
308 if column_new in columns_copy:
309 errors.append(
310 f"{column_new=} value found in self.{name_columns_copy}={columns_copy}"
311 f" this will cause a collision. Please choose a different name."
312 )
313 if errors:
314 raise ValueError("\n".join(errors))
315
316
317@dataclass(frozen=True)
319 doc: str
320 name: str
321
322
323class MeasurementType(Enum):
325 doc="difference (measured - reference)",
326 name="diff",
327 )
329 doc="scaled difference (measured - reference)/error",
330 name="chi",
331 )
332
333
334class Statistic(metaclass=ABCMeta):
335 """A statistic that can be applied to a set of values.
336 """
337 @abstractmethod
338 def doc(self) -> str:
339 """A description of the statistic"""
340 raise NotImplementedError('Subclasses must implement this method')
341
342 @abstractmethod
343 def name_short(self) -> str:
344 """A short name for the statistic, e.g. for a table column name"""
345 raise NotImplementedError('Subclasses must implement this method')
346
347 @abstractmethod
348 def value(self, values):
349 """The value of the statistic for a set of input values.
350
351 Parameters
352 ----------
353 values : `Collection` [`float`]
354 A set of values to compute the statistic for.
355
356 Returns
357 -------
358 statistic : `float`
359 The value of the statistic.
360 """
361 raise NotImplementedError('Subclasses must implement this method')
362
363
365 """The median of a set of values."""
366 @classmethod
367 def doc(cls) -> str:
368 return "Median"
369
370 @classmethod
371 def name_short(cls) -> str:
372 return "median"
373
374 def value(self, values):
375 return np.median(values)
376
377
379 """The re-scaled interquartile range (sigma equivalent)."""
380 @classmethod
381 def doc(cls) -> str:
382 return "Interquartile range divided by ~1.349 (sigma-equivalent)"
383
384 @classmethod
385 def name_short(cls) -> str:
386 return "sig_iqr"
387
388 def value(self, values):
389 return iqr(values, scale='normal')
390
391
393 """The re-scaled median absolute deviation (sigma equivalent)."""
394 @classmethod
395 def doc(cls) -> str:
396 return "Median absolute deviation multiplied by ~1.4826 (sigma-equivalent)"
397
398 @classmethod
399 def name_short(cls) -> str:
400 return "sig_mad"
401
402 def value(self, values):
403 return mad_std(values)
404
405
406@dataclass(frozen=True)
408 """An arbitrary percentile.
409
410 Parameters
411 ----------
412 percentile : `float`
413 A valid percentile (0 <= p <= 100).
414 """
415 percentile: float
416
417 def doc(self) -> str:
418 return "Median absolute deviation multiplied by ~1.4826 (sigma-equivalent)"
419
420 def name_short(self) -> str:
421 return f"pctl{f'{self.percentile/100:.5f}'[2:]}"
422
423 def value(self, values):
424 return np.percentile(values, self.percentilepercentile)
425
426 def __post_init__(self):
427 if not ((self.percentilepercentile >= 0) and (self.percentilepercentile <= 100)):
428 raise ValueError(f'percentile={self.percentile} not >=0 and <= 100')
429
430
431def _get_stat_name(*args):
432 return '_'.join(args)
433
434
435def _get_column_name(band, *args):
436 return f"{band}_{_get_stat_name(*args)}"
437
438
439def compute_stats(values_ref, values_target, errors_target, row, stats, suffixes, prefix, skip_diff=False):
440 """Compute statistics on differences and store results in a row.
441
442 Parameters
443 ----------
444 values_ref : `numpy.ndarray`, (N,)
445 Reference values.
446 values_target : `numpy.ndarray`, (N,)
447 Measured values.
448 errors_target : `numpy.ndarray`, (N,)
449 Errors (standard deviations) on `values_target`.
450 row : `numpy.ndarray`, (1, C)
451 A numpy array with pre-assigned column names.
452 stats : `Dict` [`str`, `Statistic`]
453 A dict of `Statistic` values to measure, keyed by their column suffix.
454 suffixes : `Dict` [`MeasurementType`, `str`]
455 A dict of measurement type column suffixes, keyed by the measurement type.
456 prefix : `str`
457 A prefix for all column names (e.g. band).
458 skip_diff : `bool`
459 Whether to skip computing statistics on differences. Note that
460 differences will still be computed for chi statistics.
461
462 Returns
463 -------
464 row_with_stats : `numpy.ndarray`, (1, C)
465 The original `row` with statistic values assigned.
466 """
467 n_ref = len(values_ref)
468 if n_ref > 0:
469 n_target = len(values_target)
470 n_target_err = len(errors_target) if errors_target is not None else n_ref
471 if (n_target != n_ref) or (n_target_err != n_ref):
472 raise ValueError(f'lengths of values_ref={n_ref}, values_target={n_target}'
473 f', error_target={n_target_err} must match')
474
475 do_chi = errors_target is not None
476 diff = values_target - values_ref
477 chi = diff/errors_target if do_chi else diff
478 # Could make this configurable, but non-finite values/errors are not really usable
479 valid = np.isfinite(chi)
480 values_type = {} if skip_diff else {MeasurementType.DIFF: diff[valid]}
481 if do_chi:
482 values_type[MeasurementType.CHI] = chi[valid]
483
484 for suffix_type, suffix in suffixes.items():
485 values = values_type.get(suffix_type)
486 if values is not None and len(values) > 0:
487 for stat_name, stat in stats.items():
488 row[_get_stat_name(prefix, suffix, stat_name)] = stat.value(values)
489 return row
490
491
492@dataclass(frozen=True)
494 is_extended: bool | None
495 label: str
496
497
498class SourceType(Enum):
499 ALL = SourceTypeInfo(is_extended=None, label='all')
500 RESOLVED = SourceTypeInfo(is_extended=True, label='resolved')
501 UNRESOLVED = SourceTypeInfo(is_extended=False, label='unresolved')
502
503
504class MatchType(Enum):
505 ALL = 'all'
506 MATCH_RIGHT = 'match_right'
507 MATCH_WRONG = 'match_wrong'
508
509
510def _get_columns(bands_columns: Dict, suffixes: Dict, suffixes_flux: Dict, suffixes_mag: Dict,
511 stats: Dict, target: ComparableCatalog, column_dist: str):
512 """Get column names for a table of difference statistics.
513
514 Parameters
515 ----------
516 bands_columns : `Dict` [`str`,`MatchedCatalogFluxesConfig`]
517 Dict keyed by band of flux column configuration.
518 suffixes, suffixes_flux, suffixes_mag : `Dict` [`MeasurementType`, `str`]
519 Dict of suffixes for each `MeasurementType` type, for general columns (e.g.
520 coordinates), fluxes and magnitudes, respectively.
521 stats : `Dict` [`str`, `Statistic`]
522 Dict of suffixes for each `Statistic` type.
523 target : `ComparableCatalog`
524 A target catalog with coordinate column names.
525 column_dist : `str`
526 The name of the distance column.
527
528 Returns
529 -------
530 columns : `Dict` [`str`, `type`]
531 Dictionary of column types keyed by name.
532 n_models : `int`
533 The number of models measurements will be made for.
534
535 Notes
536 -----
537 Presently, models must be identical for each band.
538 """
539 # Initial columns
540 columns = {
541 "bin": int,
542 "mag_min": float,
543 "mag_max": float,
544 }
545
546 # pre-assign all of the columns with appropriate types
547 n_models = 0
548
549 bands = list(bands_columns.keys())
550 n_bands = len(bands)
551
552 for idx, (band, config_flux) in enumerate(bands_columns.items()):
553 columns_suffix = [
554 ('flux', suffixes_flux),
555 ('mag', suffixes_mag),
556 ]
557 if idx == 0:
558 n_models = len(config_flux.columns_target_flux)
559 if (idx > 0) or (n_bands > 2):
560 columns_suffix.append((f'color_{bands[idx - 1]}_m_{band}', suffixes))
561 n_models_flux = len(config_flux.columns_target_flux)
562 n_models_err = len(config_flux.columns_target_flux_err)
563
564 # TODO: Do equivalent validation earlier, in the config
565 if (n_models_flux != n_models) or (n_models_err != n_models):
566 raise RuntimeError(f'{config_flux} len(columns_target_flux)={n_models_flux} and'
567 f' len(columns_target_flux_err)={n_models_err} must equal {n_models}')
568
569 for sourcetype in SourceType:
570 label = sourcetype.value.label
571 # Totals would be redundant
572 if sourcetype != SourceType.ALL:
573 for item in (f'n_{itype}_{mtype.value}' for itype in ('ref', 'target')
574 for mtype in MatchType):
575 columns[_get_column_name(band, label, item)] = int
576
577 for item in (target.column_coord1, target.column_coord2, column_dist):
578 for suffix in suffixes.values():
579 for stat in stats.keys():
580 columns[_get_column_name(band, label, item, suffix, stat)] = float
581
582 for item in config_flux.columns_target_flux:
583 for prefix_item, suffixes_col in columns_suffix:
584 for suffix in suffixes_col.values():
585 for stat in stats.keys():
586 columns[_get_column_name(band, label, prefix_item, item, suffix, stat)] = float
587
588 return columns, n_models
589
590
591class DiffMatchedTractCatalogTask(pipeBase.PipelineTask):
592 """Load subsets of matched catalogs and output a merged catalog of matched sources.
593 """
594 ConfigClass = DiffMatchedTractCatalogConfig
595 _DefaultName = "DiffMatchedTractCatalog"
596
597 def runQuantum(self, butlerQC, inputRefs, outputRefs):
598 inputs = butlerQC.get(inputRefs)
599 skymap = inputs.pop("skymap")
600
601 columns_match_target = ['match_row']
602 if 'match_candidate' in inputs['columns_match_target']:
603 columns_match_target.append('match_candidate')
604
605 outputs = self.run(
606 catalog_ref=inputs['cat_ref'].get(parameters={'columns': self.config.columns_in_ref}),
607 catalog_target=inputs['cat_target'].get(parameters={'columns': self.config.columns_in_target}),
608 catalog_match_ref=inputs['cat_match_ref'].get(
609 parameters={'columns': ['match_candidate', 'match_row']},
610 ),
611 catalog_match_target=inputs['cat_match_target'].get(
612 parameters={'columns': columns_match_target},
613 ),
614 wcs=skymap[butlerQC.quantum.dataId["tract"]].wcs,
615 )
616 butlerQC.put(outputs, outputRefs)
617
618 def run(
619 self,
620 catalog_ref: pd.DataFrame,
621 catalog_target: pd.DataFrame,
622 catalog_match_ref: pd.DataFrame,
623 catalog_match_target: pd.DataFrame,
624 wcs: afwGeom.SkyWcs = None,
625 ) -> pipeBase.Struct:
626 """Load matched reference and target (measured) catalogs, measure summary statistics, and output
627 a combined matched catalog with columns from both inputs.
628
629 Parameters
630 ----------
631 catalog_ref : `pandas.DataFrame`
632 A reference catalog to diff objects/sources from.
633 catalog_target : `pandas.DataFrame`
634 A target catalog to diff reference objects/sources to.
635 catalog_match_ref : `pandas.DataFrame`
636 A catalog with match indices of target sources and selection flags
637 for each reference source.
638 catalog_match_target : `pandas.DataFrame`
639 A catalog with selection flags for each target source.
640 wcs : `lsst.afw.image.SkyWcs`
641 A coordinate system to convert catalog positions to sky coordinates,
642 if necessary.
643
644 Returns
645 -------
646 retStruct : `lsst.pipe.base.Struct`
647 A struct with output_ref and output_target attribute containing the
648 output matched catalogs.
649 """
650 # Would be nice if this could refer directly to ConfigClass
651 config: DiffMatchedTractCatalogConfig = self.config
652
653 select_ref = catalog_match_ref['match_candidate'].values
654 # Add additional selection criteria for target sources beyond those for matching
655 # (not recommended, but can be done anyway)
656 select_target = (catalog_match_target['match_candidate'].values
657 if 'match_candidate' in catalog_match_target.columns
658 else np.ones(len(catalog_match_target), dtype=bool))
659 for column in config.columns_target_select_true:
660 select_target &= catalog_target[column].values
661 for column in config.columns_target_select_false:
662 select_target &= ~catalog_target[column].values
663
664 ref, target = config.coord_format.format_catalogs(
665 catalog_ref=catalog_ref, catalog_target=catalog_target,
666 select_ref=None, select_target=select_target, wcs=wcs, radec_to_xy_func=radec_to_xy,
667 return_converted_columns=config.coord_format.coords_ref_to_convert is not None,
668 )
669 cat_ref = ref.catalog
670 cat_target = target.catalog
671 n_target = len(cat_target)
672
673 if config.include_unmatched:
674 for cat_add, cat_match in ((cat_ref, catalog_match_ref), (cat_target, catalog_match_target)):
675 cat_add['match_candidate'] = cat_match['match_candidate'].values
676
677 match_row = catalog_match_ref['match_row'].values
678 matched_ref = match_row >= 0
679 matched_row = match_row[matched_ref]
680 matched_target = np.zeros(n_target, dtype=bool)
681 matched_target[matched_row] = True
682
683 # Add/compute distance columns
684 coord1_target_err, coord2_target_err = config.columns_target_coord_err
685 column_dist, column_dist_err = 'match_distance', 'match_distanceErr'
686 dist = np.full(n_target, np.nan)
687
688 dist[matched_row] = np.hypot(
689 target.coord1[matched_row] - ref.coord1[matched_ref],
690 target.coord2[matched_row] - ref.coord2[matched_ref],
691 )
692 dist_err = np.full(n_target, np.nan)
693 dist_err[matched_row] = np.hypot(cat_target.iloc[matched_row][coord1_target_err].values,
694 cat_target.iloc[matched_row][coord2_target_err].values)
695 cat_target[column_dist], cat_target[column_dist_err] = dist, dist_err
696
697 # Create a matched table, preserving the target catalog's named index (if it has one)
698 cat_left = cat_target.iloc[matched_row]
699 has_index_left = cat_left.index.name is not None
700 cat_right = cat_ref[matched_ref].reset_index()
701 cat_right.columns = [f'{config.column_matched_prefix_ref}{col}' for col in cat_right.columns]
702 cat_matched = pd.concat(objs=(cat_left.reset_index(drop=not has_index_left), cat_right), axis=1)
703
704 if config.include_unmatched:
705 # Create an unmatched table with the same schema as the matched one
706 # ... but only for objects with no matches (for completeness/purity)
707 # and that were selected for matching (or inclusion via config)
708 cat_right = cat_ref[~matched_ref & select_ref].reset_index(drop=False)
709 cat_right.columns = (f'{config.column_matched_prefix_ref}{col}' for col in cat_right.columns)
710 match_row_target = catalog_match_target['match_row'].values
711 cat_left = cat_target[~(match_row_target >= 0) & select_target].reset_index(
712 drop=not has_index_left)
713 # See https://github.com/pandas-dev/pandas/issues/46662
714 # astropy masked columns would handle this much more gracefully
715 # Unfortunately, that would require storageClass migration
716 # So we use pandas "extended" nullable types for now
717 for cat_i in (cat_left, cat_right):
718 for colname in cat_i.columns:
719 column = cat_i[colname]
720 dtype = str(column.dtype)
721 if dtype == "bool":
722 cat_i[colname] = column.astype("boolean")
723 elif dtype.startswith("int"):
724 cat_i[colname] = column.astype(f"Int{dtype[3:]}")
725 elif dtype.startswith("uint"):
726 cat_i[colname] = column.astype(f"UInt{dtype[3:]}")
727 cat_unmatched = pd.concat(objs=(cat_left, cat_right))
728
729 for columns_convert_base, prefix in (
730 (config.columns_ref_mag_to_nJy, config.column_matched_prefix_ref),
731 (config.columns_target_mag_to_nJy, ""),
732 ):
733 if columns_convert_base:
734 columns_convert = {
735 f"{prefix}{k}": f"{prefix}{v}" for k, v in columns_convert_base.items()
736 } if prefix else columns_convert_base
737 for cat_convert in (cat_matched, cat_unmatched):
738 cat_convert.rename(columns=columns_convert, inplace=True)
739 for column_flux in columns_convert.values():
740 cat_convert[column_flux] = u.ABmag.to(u.nJy, cat_convert[column_flux])
741
742 # TODO: Deprecate all matched difference output in DM-43831 (per RFC-1008)
743
744 # Slightly smelly hack for when a column (like distance) is already relative to truth
745 column_dummy = 'dummy'
746 cat_ref[column_dummy] = np.zeros_like(ref.coord1)
747
748 # Add a boolean column for whether a match is classified correctly
749 # TODO: remove the assumption of a boolean column
750 extended_ref = cat_ref[config.column_ref_extended] == (not config.column_ref_extended_inverted)
751
752 extended_target = cat_target[config.column_target_extended].values >= config.extendedness_cut
753
754 # Define difference/chi columns and statistics thereof
755 suffixes = {MeasurementType.DIFF: 'diff', MeasurementType.CHI: 'chi'}
756 # Skip diff for fluxes - covered by mags
757 suffixes_flux = {MeasurementType.CHI: suffixes[MeasurementType.CHI]}
758 # Skip chi for magnitudes, which have strange errors
759 suffixes_mag = {MeasurementType.DIFF: suffixes[MeasurementType.DIFF]}
760 stats = {stat.name_short(): stat() for stat in (Median, SigmaIQR, SigmaMAD)}
761
762 for percentile in self.config.percentiles:
763 stat = Percentile(percentile=float(Decimal(percentile)))
764 stats[stat.name_short()] = stat
765
766 # Get dict of column names
767 columns, n_models = _get_columns(
768 bands_columns=config.columns_flux,
769 suffixes=suffixes,
770 suffixes_flux=suffixes_flux,
771 suffixes_mag=suffixes_mag,
772 stats=stats,
773 target=target,
774 column_dist=column_dist,
775 )
776
777 # Setup numpy table
778 n_bins = config.mag_num_bins
779 data = np.zeros((n_bins,), dtype=[(key, value) for key, value in columns.items()])
780 data['bin'] = np.arange(n_bins)
781
782 # Setup bins
783 bins_mag = np.linspace(start=config.mag_brightest_ref, stop=config.mag_faintest_ref,
784 num=n_bins + 1)
785 data['mag_min'] = bins_mag[:-1]
786 data['mag_max'] = bins_mag[1:]
787 bins_mag = tuple((bins_mag[idx], bins_mag[idx + 1]) for idx in range(n_bins))
788
789 # Define temporary columns for intermediate storage
790 column_mag_temp = 'mag_temp'
791 column_color_temp = 'color_temp'
792 column_color_err_temp = 'colorErr_temp'
793 flux_err_frac_prev = [None]*n_models
794 mag_prev = [None]*n_models
795
796 columns_target = {
797 target.column_coord1: (
798 ref.column_coord1, target.column_coord1, coord1_target_err, False,
799 ),
800 target.column_coord2: (
801 ref.column_coord2, target.column_coord2, coord2_target_err, False,
802 ),
803 column_dist: (column_dummy, column_dist, column_dist_err, False),
804 }
805
806 # Cheat a little and do the first band last so that the color is
807 # based on the last band
808 band_fluxes = [(band, config_flux) for (band, config_flux) in config.columns_flux.items()]
809 n_bands = len(band_fluxes)
810 if n_bands > 0:
811 band_fluxes.append(band_fluxes[0])
812 flux_err_frac_first = None
813 mag_first = None
814 mag_ref_first = None
815
816 band_prev = None
817 for idx_band, (band, config_flux) in enumerate(band_fluxes):
818 if idx_band == n_bands:
819 # These were already computed earlier
820 mag_ref = mag_ref_first
821 flux_err_frac = flux_err_frac_first
822 mag_model = mag_first
823 else:
824 mag_ref = -2.5*np.log10(cat_ref[config_flux.column_ref_flux]) + config.mag_zeropoint_ref
825 flux_err_frac = [None]*n_models
826 mag_model = [None]*n_models
827
828 if idx_band > 0:
829 cat_ref[column_color_temp] = cat_ref[column_mag_temp] - mag_ref
830
831 cat_ref[column_mag_temp] = mag_ref
832
833 select_ref_bins = [select_ref & (mag_ref > mag_lo) & (mag_ref < mag_hi)
834 for idx_bin, (mag_lo, mag_hi) in enumerate(bins_mag)]
835
836 # Iterate over multiple models, compute their mags and colours (if there's a previous band)
837 for idx_model in range(n_models):
838 column_target_flux = config_flux.columns_target_flux[idx_model]
839 column_target_flux_err = config_flux.columns_target_flux_err[idx_model]
840
841 flux_target = cat_target[column_target_flux]
842 mag_target = -2.5*np.log10(flux_target) + config.mag_zeropoint_target
843 if config.mag_ceiling_target is not None:
844 mag_target[mag_target > config.mag_ceiling_target] = config.mag_ceiling_target
845 mag_model[idx_model] = mag_target
846
847 # These are needed for computing magnitude/color "errors" (which are a sketchy concept)
848 flux_err_frac[idx_model] = cat_target[column_target_flux_err]/flux_target
849
850 # Stop if idx == 0: The rest will be picked up at idx == n_bins
851 if idx_band > 0:
852 # Keep these mags tabulated for convenience
853 column_mag_temp_model = f'{column_mag_temp}{idx_model}'
854 cat_target[column_mag_temp_model] = mag_target
855
856 columns_target[f'flux_{column_target_flux}'] = (
857 config_flux.column_ref_flux,
858 column_target_flux,
859 column_target_flux_err,
860 True,
861 )
862 # Note: magnitude errors are generally problematic and not worth aggregating
863 columns_target[f'mag_{column_target_flux}'] = (
864 column_mag_temp, column_mag_temp_model, None, False,
865 )
866
867 # No need for colors if this is the last band and there are only two bands
868 # (because it would just be the negative of the first color)
869 skip_color = (idx_band == n_bands) and (n_bands <= 2)
870 if not skip_color:
871 column_color_temp_model = f'{column_color_temp}{idx_model}'
872 column_color_err_temp_model = f'{column_color_err_temp}{idx_model}'
873
874 # e.g. if order is ugrizy, first color will be u - g
875 cat_target[column_color_temp_model] = mag_prev[idx_model] - mag_model[idx_model]
876
877 # Sum (in quadrature, and admittedly sketchy for faint fluxes) magnitude errors
878 cat_target[column_color_err_temp_model] = 2.5/np.log(10)*np.hypot(
879 flux_err_frac[idx_model], flux_err_frac_prev[idx_model])
880 columns_target[f'color_{band_prev}_m_{band}_{column_target_flux}'] = (
881 column_color_temp,
882 column_color_temp_model,
883 column_color_err_temp_model,
884 False,
885 )
886
887 for idx_bin, (mag_lo, mag_hi) in enumerate(bins_mag):
888 row = data[idx_bin]
889 # Reference sources only need to be counted once
890 if idx_model == 0:
891 select_ref_bin = select_ref_bins[idx_bin]
892 select_target_bin = select_target & (mag_target > mag_lo) & (mag_target < mag_hi)
893
894 for sourcetype in SourceType:
895 sourcetype_info = sourcetype.value
896 is_extended = sourcetype_info.is_extended
897 # Counts filtered by match selection and magnitude bin
898 select_ref_sub = select_ref_bin.copy()
899 select_target_sub = select_target_bin.copy()
900 if is_extended is not None:
901 is_extended_ref = (extended_ref == is_extended)
902 select_ref_sub &= is_extended_ref
903 if idx_model == 0:
904 n_ref_sub = np.count_nonzero(select_ref_sub)
905 row[_get_column_name(band, sourcetype_info.label, 'n_ref',
906 MatchType.ALL.value)] = n_ref_sub
907 select_target_sub &= (extended_target == is_extended)
908 n_target_sub = np.count_nonzero(select_target_sub)
909 row[_get_column_name(band, sourcetype_info.label, 'n_target',
910 MatchType.ALL.value)] = n_target_sub
911
912 # Filter matches by magnitude bin and true class
913 match_row_bin = match_row.copy()
914 match_row_bin[~select_ref_sub] = -1
915 match_good = match_row_bin >= 0
916
917 n_match = np.count_nonzero(match_good)
918
919 # Same for counts of matched target sources (for e.g. purity)
920
921 if n_match > 0:
922 rows_matched = match_row_bin[match_good]
923 subset_target = cat_target.iloc[rows_matched]
924 if (is_extended is not None) and (idx_model == 0):
925 right_type = extended_target[rows_matched] == is_extended
926 n_total = len(right_type)
927 n_right = np.count_nonzero(right_type)
928 row[_get_column_name(band, sourcetype_info.label, 'n_ref',
929 MatchType.MATCH_RIGHT.value)] = n_right
931 band,
932 sourcetype_info.label,
933 'n_ref',
934 MatchType.MATCH_WRONG.value,
935 )] = n_total - n_right
936
937 # compute stats for this bin, for all columns
938 for column, (column_ref, column_target, column_err_target, skip_diff) \
939 in columns_target.items():
940 values_ref = cat_ref[column_ref][match_good].values
941 errors_target = (
942 subset_target[column_err_target].values
943 if column_err_target is not None
944 else None
945 )
947 values_ref,
948 subset_target[column_target].values,
949 errors_target,
950 row,
951 stats,
952 suffixes,
953 prefix=f'{band}_{sourcetype_info.label}_{column}',
954 skip_diff=skip_diff,
955 )
956
957 # Count matched target sources with *measured* mags within bin
958 # Used for e.g. purity calculation
959 # Should be merged with above code if there's ever a need for
960 # measuring stats on this source selection
961 select_target_sub &= matched_target
962
963 if is_extended is not None and (np.count_nonzero(select_target_sub) > 0):
964 n_total = np.count_nonzero(select_target_sub)
965 right_type = np.zeros(n_target, dtype=bool)
966 right_type[match_row[matched_ref & is_extended_ref]] = True
967 right_type &= select_target_sub
968 n_right = np.count_nonzero(right_type)
969 row[_get_column_name(band, sourcetype_info.label, 'n_target',
970 MatchType.MATCH_RIGHT.value)] = n_right
971 row[_get_column_name(band, sourcetype_info.label, 'n_target',
972 MatchType.MATCH_WRONG.value)] = n_total - n_right
973
974 # delete the flux/color columns since they change with each band
975 for prefix in ('flux', 'mag'):
976 del columns_target[f'{prefix}_{column_target_flux}']
977 if not skip_color:
978 del columns_target[f'color_{band_prev}_m_{band}_{column_target_flux}']
979
980 # keep values needed for colors
981 flux_err_frac_prev = flux_err_frac
982 mag_prev = mag_model
983 band_prev = band
984 if idx_band == 0:
985 flux_err_frac_first = flux_err_frac
986 mag_first = mag_model
987 mag_ref_first = mag_ref
988
989 if config.include_unmatched:
990 cat_matched = pd.concat((cat_matched, cat_unmatched))
991
992 retStruct = pipeBase.Struct(cat_matched=cat_matched, diff_matched=pd.DataFrame(data))
993 return retStruct
pipeBase.Struct run(self, pd.DataFrame catalog_ref, pd.DataFrame catalog_target, pd.DataFrame catalog_match_ref, pd.DataFrame catalog_match_target, afwGeom.SkyWcs wcs=None)
compute_stats(values_ref, values_target, errors_target, row, stats, suffixes, prefix, skip_diff=False)
_get_columns(Dict bands_columns, Dict suffixes, Dict suffixes_flux, Dict suffixes_mag, Dict stats, ComparableCatalog target, str column_dist)