Coverage for python/lsst/pipe/tasks/diff_matched_tract_catalog.py: 26%
409 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-10 03:09 -0700
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-10 03:09 -0700
1# This file is part of pipe_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = [
23 'DiffMatchedTractCatalogConfig', 'DiffMatchedTractCatalogTask', 'MatchedCatalogFluxesConfig',
24 'MatchType', 'MeasurementType', 'SourceType',
25 'Statistic', 'Median', 'SigmaIQR', 'SigmaMAD', 'Percentile',
26]
28import lsst.afw.geom as afwGeom
29from lsst.meas.astrom.matcher_probabilistic import (
30 ComparableCatalog, ConvertCatalogCoordinatesConfig,
31)
32from lsst.meas.astrom.match_probabilistic_task import radec_to_xy
33import lsst.pex.config as pexConfig
34import lsst.pipe.base as pipeBase
35import lsst.pipe.base.connectionTypes as cT
36from lsst.skymap import BaseSkyMap
38from abc import ABCMeta, abstractmethod
39from astropy.stats import mad_std
40import astropy.units as u
41from dataclasses import dataclass
42from decimal import Decimal
43from enum import Enum
44import numpy as np
45import pandas as pd
46from scipy.stats import iqr
47from typing import Dict, Sequence
50def is_sequence_set(x: Sequence):
51 return len(x) == len(set(x))
54def is_percentile(x: str):
55 return 0 <= Decimal(x) <= 100
58DiffMatchedTractCatalogBaseTemplates = {
59 "name_input_cat_ref": "truth_summary",
60 "name_input_cat_target": "objectTable_tract",
61 "name_skymap": BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
62}
65class DiffMatchedTractCatalogConnections(
66 pipeBase.PipelineTaskConnections,
67 dimensions=("tract", "skymap"),
68 defaultTemplates=DiffMatchedTractCatalogBaseTemplates,
69):
70 cat_ref = cT.Input(
71 doc="Reference object catalog to match from",
72 name="{name_input_cat_ref}",
73 storageClass="DataFrame",
74 dimensions=("tract", "skymap"),
75 deferLoad=True,
76 )
77 cat_target = cT.Input(
78 doc="Target object catalog to match",
79 name="{name_input_cat_target}",
80 storageClass="DataFrame",
81 dimensions=("tract", "skymap"),
82 deferLoad=True,
83 )
84 skymap = cT.Input(
85 doc="Input definition of geometry/bbox and projection/wcs for coadded exposures",
86 name="{name_skymap}",
87 storageClass="SkyMap",
88 dimensions=("skymap",),
89 )
90 cat_match_ref = cT.Input(
91 doc="Reference match catalog with indices of target matches",
92 name="match_ref_{name_input_cat_ref}_{name_input_cat_target}",
93 storageClass="DataFrame",
94 dimensions=("tract", "skymap"),
95 deferLoad=True,
96 )
97 cat_match_target = cT.Input(
98 doc="Target match catalog with indices of references matches",
99 name="match_target_{name_input_cat_ref}_{name_input_cat_target}",
100 storageClass="DataFrame",
101 dimensions=("tract", "skymap"),
102 deferLoad=True,
103 )
104 columns_match_target = cT.Input(
105 doc="Target match catalog columns",
106 name="match_target_{name_input_cat_ref}_{name_input_cat_target}.columns",
107 storageClass="DataFrameIndex",
108 dimensions=("tract", "skymap"),
109 )
110 cat_matched = cT.Output(
111 doc="Catalog with reference and target columns for joined sources",
112 name="matched_{name_input_cat_ref}_{name_input_cat_target}",
113 storageClass="DataFrame",
114 dimensions=("tract", "skymap"),
115 )
116 diff_matched = cT.Output(
117 doc="Table with aggregated counts, difference and chi statistics",
118 name="diff_matched_{name_input_cat_ref}_{name_input_cat_target}",
119 storageClass="DataFrame",
120 dimensions=("tract", "skymap"),
121 )
124class MatchedCatalogFluxesConfig(pexConfig.Config):
125 column_ref_flux = pexConfig.Field(
126 dtype=str,
127 doc='Reference catalog flux column name',
128 )
129 columns_target_flux = pexConfig.ListField(
130 dtype=str,
131 listCheck=is_sequence_set,
132 doc="List of target catalog flux column names",
133 )
134 columns_target_flux_err = pexConfig.ListField(
135 dtype=str,
136 listCheck=is_sequence_set,
137 doc="List of target catalog flux error column names",
138 )
140 # this should be an orderedset
141 @property
142 def columns_in_ref(self) -> list[str]:
143 return [self.column_ref_flux]
145 # this should also be an orderedset
146 @property
147 def columns_in_target(self) -> list[str]:
148 columns = [col for col in self.columns_target_flux]
149 columns.extend(col for col in self.columns_target_flux_err if col not in columns)
150 return columns
153class DiffMatchedTractCatalogConfig(
154 pipeBase.PipelineTaskConfig,
155 pipelineConnections=DiffMatchedTractCatalogConnections,
156):
157 column_matched_prefix_ref = pexConfig.Field[str](
158 default='refcat_',
159 doc='The prefix for matched columns copied from the reference catalog',
160 )
161 column_ref_extended = pexConfig.Field[str](
162 default='is_pointsource',
163 doc='The boolean reference table column specifying if the target is extended',
164 )
165 column_ref_extended_inverted = pexConfig.Field[bool](
166 default=True,
167 doc='Whether column_ref_extended specifies if the object is compact, not extended',
168 )
169 column_target_extended = pexConfig.Field[str](
170 default='refExtendedness',
171 doc='The target table column estimating the extendedness of the object (0 <= x <= 1)',
172 )
173 include_unmatched = pexConfig.Field[bool](
174 default=False,
175 doc="Whether to include unmatched rows in the matched table",
176 )
178 @property
179 def columns_in_ref(self) -> list[str]:
180 columns_all = [self.coord_format.column_ref_coord1, self.coord_format.column_ref_coord2,
181 self.column_ref_extended]
182 for column_lists in (
183 (
184 self.columns_ref_copy,
185 ),
186 (x.columns_in_ref for x in self.columns_flux.values()),
187 ):
188 for column_list in column_lists:
189 columns_all.extend(column_list)
191 return list({column: None for column in columns_all}.keys())
193 @property
194 def columns_in_target(self) -> list[str]:
195 columns_all = [self.coord_format.column_target_coord1, self.coord_format.column_target_coord2,
196 self.column_target_extended]
197 if self.coord_format.coords_ref_to_convert is not None:
198 columns_all.extend(col for col in self.coord_format.coords_ref_to_convert.values()
199 if col not in columns_all)
200 for column_lists in (
201 (
202 self.columns_target_coord_err,
203 self.columns_target_select_false,
204 self.columns_target_select_true,
205 self.columns_target_copy,
206 ),
207 (x.columns_in_target for x in self.columns_flux.values()),
208 ):
209 for column_list in column_lists:
210 columns_all.extend(col for col in column_list if col not in columns_all)
211 return columns_all
213 columns_flux = pexConfig.ConfigDictField(
214 doc="Configs for flux columns for each band",
215 keytype=str,
216 itemtype=MatchedCatalogFluxesConfig,
217 default={},
218 )
219 columns_ref_mag_to_nJy = pexConfig.DictField[str, str](
220 doc='Reference table AB mag columns to convert to nJy flux columns with new names',
221 default={},
222 )
223 columns_ref_copy = pexConfig.ListField[str](
224 doc='Reference table columns to copy into cat_matched',
225 default=[],
226 listCheck=is_sequence_set,
227 )
228 columns_target_coord_err = pexConfig.ListField[str]( 228 ↛ exitline 228 didn't jump to the function exit
229 doc='Target table coordinate columns with standard errors (sigma)',
230 listCheck=lambda x: (len(x) == 2) and (x[0] != x[1]),
231 )
232 columns_target_copy = pexConfig.ListField[str](
233 doc='Target table columns to copy into cat_matched',
234 default=('patch',),
235 listCheck=is_sequence_set,
236 )
237 columns_target_mag_to_nJy = pexConfig.DictField[str, str](
238 doc='Target table AB mag columns to convert to nJy flux columns with new names',
239 default={},
240 )
241 columns_target_select_true = pexConfig.ListField[str](
242 doc='Target table columns to require to be True for selecting sources',
243 default=('detect_isPrimary',),
244 listCheck=is_sequence_set,
245 )
246 columns_target_select_false = pexConfig.ListField[str](
247 doc='Target table columns to require to be False for selecting sources',
248 default=('merge_peak_sky',),
249 listCheck=is_sequence_set,
250 )
251 coord_format = pexConfig.ConfigField[ConvertCatalogCoordinatesConfig](
252 doc="Configuration for coordinate conversion",
253 )
254 extendedness_cut = pexConfig.Field[float](
255 dtype=float,
256 default=0.5,
257 doc='Minimum extendedness for a measured source to be considered extended',
258 )
259 mag_num_bins = pexConfig.Field[int](
260 doc='Number of magnitude bins',
261 default=15,
262 )
263 mag_brightest_ref = pexConfig.Field[float](
264 doc='Brightest magnitude cutoff for binning',
265 default=15,
266 )
267 mag_ceiling_target = pexConfig.Field[float](
268 doc='Ceiling (maximum/faint) magnitude for target sources',
269 default=None,
270 optional=True,
271 )
272 mag_faintest_ref = pexConfig.Field[float](
273 doc='Faintest magnitude cutoff for binning',
274 default=30,
275 )
276 mag_zeropoint_ref = pexConfig.Field[float](
277 doc='Magnitude zeropoint for reference sources',
278 default=31.4,
279 )
280 mag_zeropoint_target = pexConfig.Field[float](
281 doc='Magnitude zeropoint for target sources',
282 default=31.4,
283 )
284 percentiles = pexConfig.ListField[str](
285 doc='Percentiles to compute for diff/chi values',
286 # -2, -1, +1, +2 sigma percentiles for normal distribution
287 default=('2.275', '15.866', '84.134', '97.725'),
288 itemCheck=is_percentile,
289 listCheck=is_sequence_set,
290 )
292 def validate(self):
293 super().validate()
295 errors = []
297 for columns_mag, columns_in, name_columns_copy in (
298 (self.columns_ref_mag_to_nJy, self.columns_in_ref, "columns_ref_copy"),
299 (self.columns_target_mag_to_nJy, self.columns_in_target, "columns_target_copy"),
300 ):
301 columns_copy = getattr(self, name_columns_copy)
302 for column_old, column_new in columns_mag.items():
303 if column_old not in columns_in:
304 errors.append(
305 f"{column_old=} key in self.columns_mag_to_nJy not found in {columns_in=}; did you"
306 f" forget to add it to self.{name_columns_copy}={columns_copy}?"
307 )
308 if column_new in columns_copy:
309 errors.append(
310 f"{column_new=} value found in self.{name_columns_copy}={columns_copy}"
311 f" this will cause a collision. Please choose a different name."
312 )
313 if errors:
314 raise ValueError("\n".join(errors))
317@dataclass(frozen=True)
318class MeasurementTypeInfo:
319 doc: str
320 name: str
323class MeasurementType(Enum):
324 DIFF = MeasurementTypeInfo(
325 doc="difference (measured - reference)",
326 name="diff",
327 )
328 CHI = MeasurementTypeInfo(
329 doc="scaled difference (measured - reference)/error",
330 name="chi",
331 )
334class Statistic(metaclass=ABCMeta):
335 """A statistic that can be applied to a set of values.
336 """
337 @abstractmethod
338 def doc(self) -> str:
339 """A description of the statistic"""
340 raise NotImplementedError('Subclasses must implement this method')
342 @abstractmethod
343 def name_short(self) -> str:
344 """A short name for the statistic, e.g. for a table column name"""
345 raise NotImplementedError('Subclasses must implement this method')
347 @abstractmethod
348 def value(self, values):
349 """The value of the statistic for a set of input values.
351 Parameters
352 ----------
353 values : `Collection` [`float`]
354 A set of values to compute the statistic for.
356 Returns
357 -------
358 statistic : `float`
359 The value of the statistic.
360 """
361 raise NotImplementedError('Subclasses must implement this method')
364class Median(Statistic):
365 """The median of a set of values."""
366 @classmethod
367 def doc(cls) -> str:
368 return "Median"
370 @classmethod
371 def name_short(cls) -> str:
372 return "median"
374 def value(self, values):
375 return np.median(values)
378class SigmaIQR(Statistic):
379 """The re-scaled interquartile range (sigma equivalent)."""
380 @classmethod
381 def doc(cls) -> str:
382 return "Interquartile range divided by ~1.349 (sigma-equivalent)"
384 @classmethod
385 def name_short(cls) -> str:
386 return "sig_iqr"
388 def value(self, values):
389 return iqr(values, scale='normal')
392class SigmaMAD(Statistic):
393 """The re-scaled median absolute deviation (sigma equivalent)."""
394 @classmethod
395 def doc(cls) -> str:
396 return "Median absolute deviation multiplied by ~1.4826 (sigma-equivalent)"
398 @classmethod
399 def name_short(cls) -> str:
400 return "sig_mad"
402 def value(self, values):
403 return mad_std(values)
406@dataclass(frozen=True)
407class Percentile(Statistic):
408 """An arbitrary percentile.
410 Parameters
411 ----------
412 percentile : `float`
413 A valid percentile (0 <= p <= 100).
414 """
415 percentile: float
417 def doc(self) -> str:
418 return "Median absolute deviation multiplied by ~1.4826 (sigma-equivalent)"
420 def name_short(self) -> str:
421 return f"pctl{f'{self.percentile/100:.5f}'[2:]}"
423 def value(self, values):
424 return np.percentile(values, self.percentile)
426 def __post_init__(self):
427 if not ((self.percentile >= 0) and (self.percentile <= 100)):
428 raise ValueError(f'percentile={self.percentile} not >=0 and <= 100')
431def _get_stat_name(*args):
432 return '_'.join(args)
435def _get_column_name(band, *args):
436 return f"{band}_{_get_stat_name(*args)}"
439def compute_stats(values_ref, values_target, errors_target, row, stats, suffixes, prefix, skip_diff=False):
440 """Compute statistics on differences and store results in a row.
442 Parameters
443 ----------
444 values_ref : `numpy.ndarray`, (N,)
445 Reference values.
446 values_target : `numpy.ndarray`, (N,)
447 Measured values.
448 errors_target : `numpy.ndarray`, (N,)
449 Errors (standard deviations) on `values_target`.
450 row : `numpy.ndarray`, (1, C)
451 A numpy array with pre-assigned column names.
452 stats : `Dict` [`str`, `Statistic`]
453 A dict of `Statistic` values to measure, keyed by their column suffix.
454 suffixes : `Dict` [`MeasurementType`, `str`]
455 A dict of measurement type column suffixes, keyed by the measurement type.
456 prefix : `str`
457 A prefix for all column names (e.g. band).
458 skip_diff : `bool`
459 Whether to skip computing statistics on differences. Note that
460 differences will still be computed for chi statistics.
462 Returns
463 -------
464 row_with_stats : `numpy.ndarray`, (1, C)
465 The original `row` with statistic values assigned.
466 """
467 n_ref = len(values_ref)
468 if n_ref > 0:
469 n_target = len(values_target)
470 n_target_err = len(errors_target) if errors_target is not None else n_ref
471 if (n_target != n_ref) or (n_target_err != n_ref):
472 raise ValueError(f'lengths of values_ref={n_ref}, values_target={n_target}'
473 f', error_target={n_target_err} must match')
475 do_chi = errors_target is not None
476 diff = values_target - values_ref
477 chi = diff/errors_target if do_chi else diff
478 # Could make this configurable, but non-finite values/errors are not really usable
479 valid = np.isfinite(chi)
480 values_type = {} if skip_diff else {MeasurementType.DIFF: diff[valid]}
481 if do_chi:
482 values_type[MeasurementType.CHI] = chi[valid]
484 for suffix_type, suffix in suffixes.items():
485 values = values_type.get(suffix_type)
486 if values is not None and len(values) > 0:
487 for stat_name, stat in stats.items():
488 row[_get_stat_name(prefix, suffix, stat_name)] = stat.value(values)
489 return row
492@dataclass(frozen=True)
493class SourceTypeInfo:
494 is_extended: bool | None
495 label: str
498class SourceType(Enum):
499 ALL = SourceTypeInfo(is_extended=None, label='all')
500 RESOLVED = SourceTypeInfo(is_extended=True, label='resolved')
501 UNRESOLVED = SourceTypeInfo(is_extended=False, label='unresolved')
504class MatchType(Enum):
505 ALL = 'all'
506 MATCH_RIGHT = 'match_right'
507 MATCH_WRONG = 'match_wrong'
510def _get_columns(bands_columns: Dict, suffixes: Dict, suffixes_flux: Dict, suffixes_mag: Dict,
511 stats: Dict, target: ComparableCatalog, column_dist: str):
512 """Get column names for a table of difference statistics.
514 Parameters
515 ----------
516 bands_columns : `Dict` [`str`,`MatchedCatalogFluxesConfig`]
517 Dict keyed by band of flux column configuration.
518 suffixes, suffixes_flux, suffixes_mag : `Dict` [`MeasurementType`, `str`]
519 Dict of suffixes for each `MeasurementType` type, for general columns (e.g.
520 coordinates), fluxes and magnitudes, respectively.
521 stats : `Dict` [`str`, `Statistic`]
522 Dict of suffixes for each `Statistic` type.
523 target : `ComparableCatalog`
524 A target catalog with coordinate column names.
525 column_dist : `str`
526 The name of the distance column.
528 Returns
529 -------
530 columns : `Dict` [`str`, `type`]
531 Dictionary of column types keyed by name.
532 n_models : `int`
533 The number of models measurements will be made for.
535 Notes
536 -----
537 Presently, models must be identical for each band.
538 """
539 # Initial columns
540 columns = {
541 "bin": int,
542 "mag_min": float,
543 "mag_max": float,
544 }
546 # pre-assign all of the columns with appropriate types
547 n_models = 0
549 bands = list(bands_columns.keys())
550 n_bands = len(bands)
552 for idx, (band, config_flux) in enumerate(bands_columns.items()):
553 columns_suffix = [
554 ('flux', suffixes_flux),
555 ('mag', suffixes_mag),
556 ]
557 if idx == 0:
558 n_models = len(config_flux.columns_target_flux)
559 if (idx > 0) or (n_bands > 2):
560 columns_suffix.append((f'color_{bands[idx - 1]}_m_{band}', suffixes))
561 n_models_flux = len(config_flux.columns_target_flux)
562 n_models_err = len(config_flux.columns_target_flux_err)
564 # TODO: Do equivalent validation earlier, in the config
565 if (n_models_flux != n_models) or (n_models_err != n_models):
566 raise RuntimeError(f'{config_flux} len(columns_target_flux)={n_models_flux} and'
567 f' len(columns_target_flux_err)={n_models_err} must equal {n_models}')
569 for sourcetype in SourceType:
570 label = sourcetype.value.label
571 # Totals would be redundant
572 if sourcetype != SourceType.ALL:
573 for item in (f'n_{itype}_{mtype.value}' for itype in ('ref', 'target')
574 for mtype in MatchType):
575 columns[_get_column_name(band, label, item)] = int
577 for item in (target.column_coord1, target.column_coord2, column_dist):
578 for suffix in suffixes.values():
579 for stat in stats.keys():
580 columns[_get_column_name(band, label, item, suffix, stat)] = float
582 for item in config_flux.columns_target_flux:
583 for prefix_item, suffixes_col in columns_suffix:
584 for suffix in suffixes_col.values():
585 for stat in stats.keys():
586 columns[_get_column_name(band, label, prefix_item, item, suffix, stat)] = float
588 return columns, n_models
591class DiffMatchedTractCatalogTask(pipeBase.PipelineTask):
592 """Load subsets of matched catalogs and output a merged catalog of matched sources.
593 """
594 ConfigClass = DiffMatchedTractCatalogConfig
595 _DefaultName = "DiffMatchedTractCatalog"
597 def runQuantum(self, butlerQC, inputRefs, outputRefs):
598 inputs = butlerQC.get(inputRefs)
599 skymap = inputs.pop("skymap")
601 columns_match_target = ['match_row']
602 if 'match_candidate' in inputs['columns_match_target']:
603 columns_match_target.append('match_candidate')
605 outputs = self.run(
606 catalog_ref=inputs['cat_ref'].get(parameters={'columns': self.config.columns_in_ref}),
607 catalog_target=inputs['cat_target'].get(parameters={'columns': self.config.columns_in_target}),
608 catalog_match_ref=inputs['cat_match_ref'].get(
609 parameters={'columns': ['match_candidate', 'match_row']},
610 ),
611 catalog_match_target=inputs['cat_match_target'].get(
612 parameters={'columns': columns_match_target},
613 ),
614 wcs=skymap[butlerQC.quantum.dataId["tract"]].wcs,
615 )
616 butlerQC.put(outputs, outputRefs)
618 def run(
619 self,
620 catalog_ref: pd.DataFrame,
621 catalog_target: pd.DataFrame,
622 catalog_match_ref: pd.DataFrame,
623 catalog_match_target: pd.DataFrame,
624 wcs: afwGeom.SkyWcs = None,
625 ) -> pipeBase.Struct:
626 """Load matched reference and target (measured) catalogs, measure summary statistics, and output
627 a combined matched catalog with columns from both inputs.
629 Parameters
630 ----------
631 catalog_ref : `pandas.DataFrame`
632 A reference catalog to diff objects/sources from.
633 catalog_target : `pandas.DataFrame`
634 A target catalog to diff reference objects/sources to.
635 catalog_match_ref : `pandas.DataFrame`
636 A catalog with match indices of target sources and selection flags
637 for each reference source.
638 catalog_match_target : `pandas.DataFrame`
639 A catalog with selection flags for each target source.
640 wcs : `lsst.afw.image.SkyWcs`
641 A coordinate system to convert catalog positions to sky coordinates,
642 if necessary.
644 Returns
645 -------
646 retStruct : `lsst.pipe.base.Struct`
647 A struct with output_ref and output_target attribute containing the
648 output matched catalogs.
649 """
650 # Would be nice if this could refer directly to ConfigClass
651 config: DiffMatchedTractCatalogConfig = self.config
653 select_ref = catalog_match_ref['match_candidate'].values
654 # Add additional selection criteria for target sources beyond those for matching
655 # (not recommended, but can be done anyway)
656 select_target = (catalog_match_target['match_candidate'].values
657 if 'match_candidate' in catalog_match_target.columns
658 else np.ones(len(catalog_match_target), dtype=bool))
659 for column in config.columns_target_select_true:
660 select_target &= catalog_target[column].values
661 for column in config.columns_target_select_false:
662 select_target &= ~catalog_target[column].values
664 ref, target = config.coord_format.format_catalogs(
665 catalog_ref=catalog_ref, catalog_target=catalog_target,
666 select_ref=None, select_target=select_target, wcs=wcs, radec_to_xy_func=radec_to_xy,
667 return_converted_columns=config.coord_format.coords_ref_to_convert is not None,
668 )
669 cat_ref = ref.catalog
670 cat_target = target.catalog
671 n_target = len(cat_target)
673 if config.include_unmatched:
674 for cat_add, cat_match in ((cat_ref, catalog_match_ref), (cat_target, catalog_match_target)):
675 cat_add['match_candidate'] = cat_match['match_candidate'].values
677 match_row = catalog_match_ref['match_row'].values
678 matched_ref = match_row >= 0
679 matched_row = match_row[matched_ref]
680 matched_target = np.zeros(n_target, dtype=bool)
681 matched_target[matched_row] = True
683 # Add/compute distance columns
684 coord1_target_err, coord2_target_err = config.columns_target_coord_err
685 column_dist, column_dist_err = 'match_distance', 'match_distanceErr'
686 dist = np.full(n_target, np.nan)
688 dist[matched_row] = np.hypot(
689 target.coord1[matched_row] - ref.coord1[matched_ref],
690 target.coord2[matched_row] - ref.coord2[matched_ref],
691 )
692 dist_err = np.full(n_target, np.nan)
693 dist_err[matched_row] = np.hypot(cat_target.iloc[matched_row][coord1_target_err].values,
694 cat_target.iloc[matched_row][coord2_target_err].values)
695 cat_target[column_dist], cat_target[column_dist_err] = dist, dist_err
697 # Create a matched table, preserving the target catalog's named index (if it has one)
698 cat_left = cat_target.iloc[matched_row]
699 has_index_left = cat_left.index.name is not None
700 cat_right = cat_ref[matched_ref].reset_index()
701 cat_right.columns = [f'{config.column_matched_prefix_ref}{col}' for col in cat_right.columns]
702 cat_matched = pd.concat(objs=(cat_left.reset_index(drop=not has_index_left), cat_right), axis=1)
704 if config.include_unmatched:
705 # Create an unmatched table with the same schema as the matched one
706 # ... but only for objects with no matches (for completeness/purity)
707 # and that were selected for matching (or inclusion via config)
708 cat_right = cat_ref[~matched_ref & select_ref].reset_index(drop=False)
709 cat_right.columns = (f'{config.column_matched_prefix_ref}{col}' for col in cat_right.columns)
710 match_row_target = catalog_match_target['match_row'].values
711 cat_left = cat_target[~(match_row_target >= 0) & select_target].reset_index(
712 drop=not has_index_left)
713 # See https://github.com/pandas-dev/pandas/issues/46662
714 # astropy masked columns would handle this much more gracefully
715 # Unfortunately, that would require storageClass migration
716 # So we use pandas "extended" nullable types for now
717 for cat_i in (cat_left, cat_right):
718 for colname in cat_i.columns:
719 column = cat_i[colname]
720 dtype = str(column.dtype)
721 if dtype == "bool":
722 cat_i[colname] = column.astype("boolean")
723 elif dtype.startswith("int"):
724 cat_i[colname] = column.astype(f"Int{dtype[3:]}")
725 elif dtype.startswith("uint"):
726 cat_i[colname] = column.astype(f"UInt{dtype[3:]}")
727 cat_unmatched = pd.concat(objs=(cat_left, cat_right))
729 for columns_convert_base, prefix in (
730 (config.columns_ref_mag_to_nJy, config.column_matched_prefix_ref),
731 (config.columns_target_mag_to_nJy, ""),
732 ):
733 if columns_convert_base:
734 columns_convert = {
735 f"{prefix}{k}": f"{prefix}{v}" for k, v in columns_convert_base.items()
736 } if prefix else columns_convert_base
737 for cat_convert in (cat_matched, cat_unmatched):
738 cat_convert.rename(columns=columns_convert, inplace=True)
739 for column_flux in columns_convert.values():
740 cat_convert[column_flux] = u.ABmag.to(u.nJy, cat_convert[column_flux])
742 # TODO: Deprecate all matched difference output in DM-43831 (per RFC-1008)
744 # Slightly smelly hack for when a column (like distance) is already relative to truth
745 column_dummy = 'dummy'
746 cat_ref[column_dummy] = np.zeros_like(ref.coord1)
748 # Add a boolean column for whether a match is classified correctly
749 # TODO: remove the assumption of a boolean column
750 extended_ref = cat_ref[config.column_ref_extended] == (not config.column_ref_extended_inverted)
752 extended_target = cat_target[config.column_target_extended].values >= config.extendedness_cut
754 # Define difference/chi columns and statistics thereof
755 suffixes = {MeasurementType.DIFF: 'diff', MeasurementType.CHI: 'chi'}
756 # Skip diff for fluxes - covered by mags
757 suffixes_flux = {MeasurementType.CHI: suffixes[MeasurementType.CHI]}
758 # Skip chi for magnitudes, which have strange errors
759 suffixes_mag = {MeasurementType.DIFF: suffixes[MeasurementType.DIFF]}
760 stats = {stat.name_short(): stat() for stat in (Median, SigmaIQR, SigmaMAD)}
762 for percentile in self.config.percentiles:
763 stat = Percentile(percentile=float(Decimal(percentile)))
764 stats[stat.name_short()] = stat
766 # Get dict of column names
767 columns, n_models = _get_columns(
768 bands_columns=config.columns_flux,
769 suffixes=suffixes,
770 suffixes_flux=suffixes_flux,
771 suffixes_mag=suffixes_mag,
772 stats=stats,
773 target=target,
774 column_dist=column_dist,
775 )
777 # Setup numpy table
778 n_bins = config.mag_num_bins
779 data = np.zeros((n_bins,), dtype=[(key, value) for key, value in columns.items()])
780 data['bin'] = np.arange(n_bins)
782 # Setup bins
783 bins_mag = np.linspace(start=config.mag_brightest_ref, stop=config.mag_faintest_ref,
784 num=n_bins + 1)
785 data['mag_min'] = bins_mag[:-1]
786 data['mag_max'] = bins_mag[1:]
787 bins_mag = tuple((bins_mag[idx], bins_mag[idx + 1]) for idx in range(n_bins))
789 # Define temporary columns for intermediate storage
790 column_mag_temp = 'mag_temp'
791 column_color_temp = 'color_temp'
792 column_color_err_temp = 'colorErr_temp'
793 flux_err_frac_prev = [None]*n_models
794 mag_prev = [None]*n_models
796 columns_target = {
797 target.column_coord1: (
798 ref.column_coord1, target.column_coord1, coord1_target_err, False,
799 ),
800 target.column_coord2: (
801 ref.column_coord2, target.column_coord2, coord2_target_err, False,
802 ),
803 column_dist: (column_dummy, column_dist, column_dist_err, False),
804 }
806 # Cheat a little and do the first band last so that the color is
807 # based on the last band
808 band_fluxes = [(band, config_flux) for (band, config_flux) in config.columns_flux.items()]
809 n_bands = len(band_fluxes)
810 if n_bands > 0:
811 band_fluxes.append(band_fluxes[0])
812 flux_err_frac_first = None
813 mag_first = None
814 mag_ref_first = None
816 band_prev = None
817 for idx_band, (band, config_flux) in enumerate(band_fluxes):
818 if idx_band == n_bands:
819 # These were already computed earlier
820 mag_ref = mag_ref_first
821 flux_err_frac = flux_err_frac_first
822 mag_model = mag_first
823 else:
824 mag_ref = -2.5*np.log10(cat_ref[config_flux.column_ref_flux]) + config.mag_zeropoint_ref
825 flux_err_frac = [None]*n_models
826 mag_model = [None]*n_models
828 if idx_band > 0:
829 cat_ref[column_color_temp] = cat_ref[column_mag_temp] - mag_ref
831 cat_ref[column_mag_temp] = mag_ref
833 select_ref_bins = [select_ref & (mag_ref > mag_lo) & (mag_ref < mag_hi)
834 for idx_bin, (mag_lo, mag_hi) in enumerate(bins_mag)]
836 # Iterate over multiple models, compute their mags and colours (if there's a previous band)
837 for idx_model in range(n_models):
838 column_target_flux = config_flux.columns_target_flux[idx_model]
839 column_target_flux_err = config_flux.columns_target_flux_err[idx_model]
841 flux_target = cat_target[column_target_flux]
842 mag_target = -2.5*np.log10(flux_target) + config.mag_zeropoint_target
843 if config.mag_ceiling_target is not None:
844 mag_target[mag_target > config.mag_ceiling_target] = config.mag_ceiling_target
845 mag_model[idx_model] = mag_target
847 # These are needed for computing magnitude/color "errors" (which are a sketchy concept)
848 flux_err_frac[idx_model] = cat_target[column_target_flux_err]/flux_target
850 # Stop if idx == 0: The rest will be picked up at idx == n_bins
851 if idx_band > 0:
852 # Keep these mags tabulated for convenience
853 column_mag_temp_model = f'{column_mag_temp}{idx_model}'
854 cat_target[column_mag_temp_model] = mag_target
856 columns_target[f'flux_{column_target_flux}'] = (
857 config_flux.column_ref_flux,
858 column_target_flux,
859 column_target_flux_err,
860 True,
861 )
862 # Note: magnitude errors are generally problematic and not worth aggregating
863 columns_target[f'mag_{column_target_flux}'] = (
864 column_mag_temp, column_mag_temp_model, None, False,
865 )
867 # No need for colors if this is the last band and there are only two bands
868 # (because it would just be the negative of the first color)
869 skip_color = (idx_band == n_bands) and (n_bands <= 2)
870 if not skip_color:
871 column_color_temp_model = f'{column_color_temp}{idx_model}'
872 column_color_err_temp_model = f'{column_color_err_temp}{idx_model}'
874 # e.g. if order is ugrizy, first color will be u - g
875 cat_target[column_color_temp_model] = mag_prev[idx_model] - mag_model[idx_model]
877 # Sum (in quadrature, and admittedly sketchy for faint fluxes) magnitude errors
878 cat_target[column_color_err_temp_model] = 2.5/np.log(10)*np.hypot(
879 flux_err_frac[idx_model], flux_err_frac_prev[idx_model])
880 columns_target[f'color_{band_prev}_m_{band}_{column_target_flux}'] = (
881 column_color_temp,
882 column_color_temp_model,
883 column_color_err_temp_model,
884 False,
885 )
887 for idx_bin, (mag_lo, mag_hi) in enumerate(bins_mag):
888 row = data[idx_bin]
889 # Reference sources only need to be counted once
890 if idx_model == 0:
891 select_ref_bin = select_ref_bins[idx_bin]
892 select_target_bin = select_target & (mag_target > mag_lo) & (mag_target < mag_hi)
894 for sourcetype in SourceType:
895 sourcetype_info = sourcetype.value
896 is_extended = sourcetype_info.is_extended
897 # Counts filtered by match selection and magnitude bin
898 select_ref_sub = select_ref_bin.copy()
899 select_target_sub = select_target_bin.copy()
900 if is_extended is not None:
901 is_extended_ref = (extended_ref == is_extended)
902 select_ref_sub &= is_extended_ref
903 if idx_model == 0:
904 n_ref_sub = np.count_nonzero(select_ref_sub)
905 row[_get_column_name(band, sourcetype_info.label, 'n_ref',
906 MatchType.ALL.value)] = n_ref_sub
907 select_target_sub &= (extended_target == is_extended)
908 n_target_sub = np.count_nonzero(select_target_sub)
909 row[_get_column_name(band, sourcetype_info.label, 'n_target',
910 MatchType.ALL.value)] = n_target_sub
912 # Filter matches by magnitude bin and true class
913 match_row_bin = match_row.copy()
914 match_row_bin[~select_ref_sub] = -1
915 match_good = match_row_bin >= 0
917 n_match = np.count_nonzero(match_good)
919 # Same for counts of matched target sources (for e.g. purity)
921 if n_match > 0:
922 rows_matched = match_row_bin[match_good]
923 subset_target = cat_target.iloc[rows_matched]
924 if (is_extended is not None) and (idx_model == 0):
925 right_type = extended_target[rows_matched] == is_extended
926 n_total = len(right_type)
927 n_right = np.count_nonzero(right_type)
928 row[_get_column_name(band, sourcetype_info.label, 'n_ref',
929 MatchType.MATCH_RIGHT.value)] = n_right
930 row[_get_column_name(
931 band,
932 sourcetype_info.label,
933 'n_ref',
934 MatchType.MATCH_WRONG.value,
935 )] = n_total - n_right
937 # compute stats for this bin, for all columns
938 for column, (column_ref, column_target, column_err_target, skip_diff) \
939 in columns_target.items():
940 values_ref = cat_ref[column_ref][match_good].values
941 errors_target = (
942 subset_target[column_err_target].values
943 if column_err_target is not None
944 else None
945 )
946 compute_stats(
947 values_ref,
948 subset_target[column_target].values,
949 errors_target,
950 row,
951 stats,
952 suffixes,
953 prefix=f'{band}_{sourcetype_info.label}_{column}',
954 skip_diff=skip_diff,
955 )
957 # Count matched target sources with *measured* mags within bin
958 # Used for e.g. purity calculation
959 # Should be merged with above code if there's ever a need for
960 # measuring stats on this source selection
961 select_target_sub &= matched_target
963 if is_extended is not None and (np.count_nonzero(select_target_sub) > 0):
964 n_total = np.count_nonzero(select_target_sub)
965 right_type = np.zeros(n_target, dtype=bool)
966 right_type[match_row[matched_ref & is_extended_ref]] = True
967 right_type &= select_target_sub
968 n_right = np.count_nonzero(right_type)
969 row[_get_column_name(band, sourcetype_info.label, 'n_target',
970 MatchType.MATCH_RIGHT.value)] = n_right
971 row[_get_column_name(band, sourcetype_info.label, 'n_target',
972 MatchType.MATCH_WRONG.value)] = n_total - n_right
974 # delete the flux/color columns since they change with each band
975 for prefix in ('flux', 'mag'):
976 del columns_target[f'{prefix}_{column_target_flux}']
977 if not skip_color:
978 del columns_target[f'color_{band_prev}_m_{band}_{column_target_flux}']
980 # keep values needed for colors
981 flux_err_frac_prev = flux_err_frac
982 mag_prev = mag_model
983 band_prev = band
984 if idx_band == 0:
985 flux_err_frac_first = flux_err_frac
986 mag_first = mag_model
987 mag_ref_first = mag_ref
989 if config.include_unmatched:
990 cat_matched = pd.concat((cat_matched, cat_unmatched))
992 retStruct = pipeBase.Struct(cat_matched=cat_matched, diff_matched=pd.DataFrame(data))
993 return retStruct