Coverage for python/lsst/pipe/tasks/diff_matched_tract_catalog.py: 26%
412 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 11:06 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 11:06 +0000
1# This file is part of pipe_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = [
23 'DiffMatchedTractCatalogConfig', 'DiffMatchedTractCatalogTask', 'MatchedCatalogFluxesConfig',
24 'MatchType', 'MeasurementType', 'SourceType',
25 'Statistic', 'Median', 'SigmaIQR', 'SigmaMAD', 'Percentile',
26]
28import lsst.afw.geom as afwGeom
29from lsst.meas.astrom.matcher_probabilistic import (
30 ComparableCatalog, ConvertCatalogCoordinatesConfig,
31)
32from lsst.meas.astrom.match_probabilistic_task import radec_to_xy
33import lsst.pex.config as pexConfig
34import lsst.pipe.base as pipeBase
35import lsst.pipe.base.connectionTypes as cT
36from lsst.skymap import BaseSkyMap
38from abc import ABCMeta, abstractmethod
39from astropy.stats import mad_std
40import astropy.units as u
41from dataclasses import dataclass
42from decimal import Decimal
43from enum import Enum
44import numpy as np
45import pandas as pd
46from scipy.stats import iqr
47from smatch.matcher import sphdist
48from typing import Sequence
51def is_sequence_set(x: Sequence):
52 return len(x) == len(set(x))
55def is_percentile(x: str):
56 return 0 <= Decimal(x) <= 100
59DiffMatchedTractCatalogBaseTemplates = {
60 "name_input_cat_ref": "truth_summary",
61 "name_input_cat_target": "objectTable_tract",
62 "name_skymap": BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
63}
66class DiffMatchedTractCatalogConnections(
67 pipeBase.PipelineTaskConnections,
68 dimensions=("tract", "skymap"),
69 defaultTemplates=DiffMatchedTractCatalogBaseTemplates,
70):
71 cat_ref = cT.Input(
72 doc="Reference object catalog to match from",
73 name="{name_input_cat_ref}",
74 storageClass="DataFrame",
75 dimensions=("tract", "skymap"),
76 deferLoad=True,
77 )
78 cat_target = cT.Input(
79 doc="Target object catalog to match",
80 name="{name_input_cat_target}",
81 storageClass="DataFrame",
82 dimensions=("tract", "skymap"),
83 deferLoad=True,
84 )
85 skymap = cT.Input(
86 doc="Input definition of geometry/bbox and projection/wcs for coadded exposures",
87 name="{name_skymap}",
88 storageClass="SkyMap",
89 dimensions=("skymap",),
90 )
91 cat_match_ref = cT.Input(
92 doc="Reference match catalog with indices of target matches",
93 name="match_ref_{name_input_cat_ref}_{name_input_cat_target}",
94 storageClass="DataFrame",
95 dimensions=("tract", "skymap"),
96 deferLoad=True,
97 )
98 cat_match_target = cT.Input(
99 doc="Target match catalog with indices of references matches",
100 name="match_target_{name_input_cat_ref}_{name_input_cat_target}",
101 storageClass="DataFrame",
102 dimensions=("tract", "skymap"),
103 deferLoad=True,
104 )
105 columns_match_target = cT.Input(
106 doc="Target match catalog columns",
107 name="match_target_{name_input_cat_ref}_{name_input_cat_target}.columns",
108 storageClass="DataFrameIndex",
109 dimensions=("tract", "skymap"),
110 )
111 cat_matched = cT.Output(
112 doc="Catalog with reference and target columns for joined sources",
113 name="matched_{name_input_cat_ref}_{name_input_cat_target}",
114 storageClass="DataFrame",
115 dimensions=("tract", "skymap"),
116 )
117 diff_matched = cT.Output(
118 doc="Table with aggregated counts, difference and chi statistics",
119 name="diff_matched_{name_input_cat_ref}_{name_input_cat_target}",
120 storageClass="DataFrame",
121 dimensions=("tract", "skymap"),
122 )
125class MatchedCatalogFluxesConfig(pexConfig.Config):
126 column_ref_flux = pexConfig.Field(
127 dtype=str,
128 doc='Reference catalog flux column name',
129 )
130 columns_target_flux = pexConfig.ListField(
131 dtype=str,
132 listCheck=is_sequence_set,
133 doc="List of target catalog flux column names",
134 )
135 columns_target_flux_err = pexConfig.ListField(
136 dtype=str,
137 listCheck=is_sequence_set,
138 doc="List of target catalog flux error column names",
139 )
141 # this should be an orderedset
142 @property
143 def columns_in_ref(self) -> list[str]:
144 return [self.column_ref_flux]
146 # this should also be an orderedset
147 @property
148 def columns_in_target(self) -> list[str]:
149 columns = [col for col in self.columns_target_flux]
150 columns.extend(col for col in self.columns_target_flux_err if col not in columns)
151 return columns
154class DiffMatchedTractCatalogConfig(
155 pipeBase.PipelineTaskConfig,
156 pipelineConnections=DiffMatchedTractCatalogConnections,
157):
158 column_matched_prefix_ref = pexConfig.Field[str](
159 default='refcat_',
160 doc='The prefix for matched columns copied from the reference catalog',
161 )
162 column_ref_extended = pexConfig.Field[str](
163 default='is_pointsource',
164 doc='The boolean reference table column specifying if the target is extended',
165 )
166 column_ref_extended_inverted = pexConfig.Field[bool](
167 default=True,
168 doc='Whether column_ref_extended specifies if the object is compact, not extended',
169 )
170 column_target_extended = pexConfig.Field[str](
171 default='refExtendedness',
172 doc='The target table column estimating the extendedness of the object (0 <= x <= 1)',
173 )
174 include_unmatched = pexConfig.Field[bool](
175 default=False,
176 doc="Whether to include unmatched rows in the matched table",
177 )
179 @property
180 def columns_in_ref(self) -> list[str]:
181 columns_all = [self.coord_format.column_ref_coord1, self.coord_format.column_ref_coord2,
182 self.column_ref_extended]
183 for column_lists in (
184 (
185 self.columns_ref_copy,
186 ),
187 (x.columns_in_ref for x in self.columns_flux.values()),
188 ):
189 for column_list in column_lists:
190 columns_all.extend(column_list)
192 return list({column: None for column in columns_all}.keys())
194 @property
195 def columns_in_target(self) -> list[str]:
196 columns_all = [self.coord_format.column_target_coord1, self.coord_format.column_target_coord2,
197 self.column_target_extended]
198 if self.coord_format.coords_ref_to_convert is not None:
199 columns_all.extend(col for col in self.coord_format.coords_ref_to_convert.values()
200 if col not in columns_all)
201 for column_lists in (
202 (
203 self.columns_target_coord_err,
204 self.columns_target_select_false,
205 self.columns_target_select_true,
206 self.columns_target_copy,
207 ),
208 (x.columns_in_target for x in self.columns_flux.values()),
209 ):
210 for column_list in column_lists:
211 columns_all.extend(col for col in column_list if col not in columns_all)
212 return columns_all
214 columns_flux = pexConfig.ConfigDictField(
215 doc="Configs for flux columns for each band",
216 keytype=str,
217 itemtype=MatchedCatalogFluxesConfig,
218 default={},
219 )
220 columns_ref_mag_to_nJy = pexConfig.DictField[str, str](
221 doc='Reference table AB mag columns to convert to nJy flux columns with new names',
222 default={},
223 )
224 columns_ref_copy = pexConfig.ListField[str](
225 doc='Reference table columns to copy into cat_matched',
226 default=[],
227 listCheck=is_sequence_set,
228 )
229 columns_target_coord_err = pexConfig.ListField[str]( 229 ↛ exitline 229 didn't jump to the function exit
230 doc='Target table coordinate columns with standard errors (sigma)',
231 listCheck=lambda x: (len(x) == 2) and (x[0] != x[1]),
232 )
233 columns_target_copy = pexConfig.ListField[str](
234 doc='Target table columns to copy into cat_matched',
235 default=('patch',),
236 listCheck=is_sequence_set,
237 )
238 columns_target_mag_to_nJy = pexConfig.DictField[str, str](
239 doc='Target table AB mag columns to convert to nJy flux columns with new names',
240 default={},
241 )
242 columns_target_select_true = pexConfig.ListField[str](
243 doc='Target table columns to require to be True for selecting sources',
244 default=('detect_isPrimary',),
245 listCheck=is_sequence_set,
246 )
247 columns_target_select_false = pexConfig.ListField[str](
248 doc='Target table columns to require to be False for selecting sources',
249 default=('merge_peak_sky',),
250 listCheck=is_sequence_set,
251 )
252 coord_format = pexConfig.ConfigField[ConvertCatalogCoordinatesConfig](
253 doc="Configuration for coordinate conversion",
254 )
255 extendedness_cut = pexConfig.Field[float](
256 dtype=float,
257 default=0.5,
258 doc='Minimum extendedness for a measured source to be considered extended',
259 )
260 mag_num_bins = pexConfig.Field[int](
261 doc='Number of magnitude bins',
262 default=15,
263 )
264 mag_brightest_ref = pexConfig.Field[float](
265 doc='Brightest magnitude cutoff for binning',
266 default=15,
267 )
268 mag_ceiling_target = pexConfig.Field[float](
269 doc='Ceiling (maximum/faint) magnitude for target sources',
270 default=None,
271 optional=True,
272 )
273 mag_faintest_ref = pexConfig.Field[float](
274 doc='Faintest magnitude cutoff for binning',
275 default=30,
276 )
277 mag_zeropoint_ref = pexConfig.Field[float](
278 doc='Magnitude zeropoint for reference sources',
279 default=31.4,
280 )
281 mag_zeropoint_target = pexConfig.Field[float](
282 doc='Magnitude zeropoint for target sources',
283 default=31.4,
284 )
285 percentiles = pexConfig.ListField[str](
286 doc='Percentiles to compute for diff/chi values',
287 # -2, -1, +1, +2 sigma percentiles for normal distribution
288 default=('2.275', '15.866', '84.134', '97.725'),
289 itemCheck=is_percentile,
290 listCheck=is_sequence_set,
291 )
293 def validate(self):
294 super().validate()
296 errors = []
298 for columns_mag, columns_in, name_columns_copy in (
299 (self.columns_ref_mag_to_nJy, self.columns_in_ref, "columns_ref_copy"),
300 (self.columns_target_mag_to_nJy, self.columns_in_target, "columns_target_copy"),
301 ):
302 columns_copy = getattr(self, name_columns_copy)
303 for column_old, column_new in columns_mag.items():
304 if column_old not in columns_in:
305 errors.append(
306 f"{column_old=} key in self.columns_mag_to_nJy not found in {columns_in=}; did you"
307 f" forget to add it to self.{name_columns_copy}={columns_copy}?"
308 )
309 if column_new in columns_copy:
310 errors.append(
311 f"{column_new=} value found in self.{name_columns_copy}={columns_copy}"
312 f" this will cause a collision. Please choose a different name."
313 )
314 if errors:
315 raise ValueError("\n".join(errors))
318@dataclass(frozen=True)
319class MeasurementTypeInfo:
320 doc: str
321 name: str
324class MeasurementType(Enum):
325 DIFF = MeasurementTypeInfo(
326 doc="difference (measured - reference)",
327 name="diff",
328 )
329 CHI = MeasurementTypeInfo(
330 doc="scaled difference (measured - reference)/error",
331 name="chi",
332 )
335class Statistic(metaclass=ABCMeta):
336 """A statistic that can be applied to a set of values.
337 """
338 @abstractmethod
339 def doc(self) -> str:
340 """A description of the statistic"""
341 raise NotImplementedError('Subclasses must implement this method')
343 @abstractmethod
344 def name_short(self) -> str:
345 """A short name for the statistic, e.g. for a table column name"""
346 raise NotImplementedError('Subclasses must implement this method')
348 @abstractmethod
349 def value(self, values):
350 """The value of the statistic for a set of input values.
352 Parameters
353 ----------
354 values : `Collection` [`float`]
355 A set of values to compute the statistic for.
357 Returns
358 -------
359 statistic : `float`
360 The value of the statistic.
361 """
362 raise NotImplementedError('Subclasses must implement this method')
365class Median(Statistic):
366 """The median of a set of values."""
367 @classmethod
368 def doc(cls) -> str:
369 return "Median"
371 @classmethod
372 def name_short(cls) -> str:
373 return "median"
375 def value(self, values):
376 return np.median(values)
379class SigmaIQR(Statistic):
380 """The re-scaled interquartile range (sigma equivalent)."""
381 @classmethod
382 def doc(cls) -> str:
383 return "Interquartile range divided by ~1.349 (sigma-equivalent)"
385 @classmethod
386 def name_short(cls) -> str:
387 return "sig_iqr"
389 def value(self, values):
390 return iqr(values, scale='normal')
393class SigmaMAD(Statistic):
394 """The re-scaled median absolute deviation (sigma equivalent)."""
395 @classmethod
396 def doc(cls) -> str:
397 return "Median absolute deviation multiplied by ~1.4826 (sigma-equivalent)"
399 @classmethod
400 def name_short(cls) -> str:
401 return "sig_mad"
403 def value(self, values):
404 return mad_std(values)
407@dataclass(frozen=True)
408class Percentile(Statistic):
409 """An arbitrary percentile.
411 Parameters
412 ----------
413 percentile : `float`
414 A valid percentile (0 <= p <= 100).
415 """
416 percentile: float
418 def doc(self) -> str:
419 return "Median absolute deviation multiplied by ~1.4826 (sigma-equivalent)"
421 def name_short(self) -> str:
422 return f"pctl{f'{self.percentile/100:.5f}'[2:]}"
424 def value(self, values):
425 return np.percentile(values, self.percentile)
427 def __post_init__(self):
428 if not ((self.percentile >= 0) and (self.percentile <= 100)):
429 raise ValueError(f'percentile={self.percentile} not >=0 and <= 100')
432def _get_stat_name(*args):
433 return '_'.join(args)
436def _get_column_name(band, *args):
437 return f"{band}_{_get_stat_name(*args)}"
440def compute_stats(values_ref, values_target, errors_target, row, stats, suffixes, prefix, skip_diff=False):
441 """Compute statistics on differences and store results in a row.
443 Parameters
444 ----------
445 values_ref : `numpy.ndarray`, (N,)
446 Reference values.
447 values_target : `numpy.ndarray`, (N,)
448 Measured values.
449 errors_target : `numpy.ndarray`, (N,)
450 Errors (standard deviations) on `values_target`.
451 row : `numpy.ndarray`, (1, C)
452 A numpy array with pre-assigned column names.
453 stats : `Dict` [`str`, `Statistic`]
454 A dict of `Statistic` values to measure, keyed by their column suffix.
455 suffixes : `Dict` [`MeasurementType`, `str`]
456 A dict of measurement type column suffixes, keyed by the measurement type.
457 prefix : `str`
458 A prefix for all column names (e.g. band).
459 skip_diff : `bool`
460 Whether to skip computing statistics on differences. Note that
461 differences will still be computed for chi statistics.
463 Returns
464 -------
465 row_with_stats : `numpy.ndarray`, (1, C)
466 The original `row` with statistic values assigned.
467 """
468 n_ref = len(values_ref)
469 if n_ref > 0:
470 n_target = len(values_target)
471 n_target_err = len(errors_target) if errors_target is not None else n_ref
472 if (n_target != n_ref) or (n_target_err != n_ref):
473 raise ValueError(f'lengths of values_ref={n_ref}, values_target={n_target}'
474 f', error_target={n_target_err} must match')
476 do_chi = errors_target is not None
477 diff = values_target - values_ref
478 chi = diff/errors_target if do_chi else diff
479 # Could make this configurable, but non-finite values/errors are not really usable
480 valid = np.isfinite(chi)
481 values_type = {} if skip_diff else {MeasurementType.DIFF: diff[valid]}
482 if do_chi:
483 values_type[MeasurementType.CHI] = chi[valid]
485 for suffix_type, suffix in suffixes.items():
486 values = values_type.get(suffix_type)
487 if values is not None and len(values) > 0:
488 for stat_name, stat in stats.items():
489 row[_get_stat_name(prefix, suffix, stat_name)] = stat.value(values)
490 return row
493@dataclass(frozen=True)
494class SourceTypeInfo:
495 is_extended: bool | None
496 label: str
499class SourceType(Enum):
500 ALL = SourceTypeInfo(is_extended=None, label='all')
501 RESOLVED = SourceTypeInfo(is_extended=True, label='resolved')
502 UNRESOLVED = SourceTypeInfo(is_extended=False, label='unresolved')
505class MatchType(Enum):
506 ALL = 'all'
507 MATCH_RIGHT = 'match_right'
508 MATCH_WRONG = 'match_wrong'
511def _get_columns(bands_columns: dict, suffixes: dict, suffixes_flux: dict, suffixes_mag: dict,
512 stats: dict, target: ComparableCatalog, column_dist: str):
513 """Get column names for a table of difference statistics.
515 Parameters
516 ----------
517 bands_columns : `Dict` [`str`,`MatchedCatalogFluxesConfig`]
518 Dict keyed by band of flux column configuration.
519 suffixes, suffixes_flux, suffixes_mag : `Dict` [`MeasurementType`, `str`]
520 Dict of suffixes for each `MeasurementType` type, for general columns (e.g.
521 coordinates), fluxes and magnitudes, respectively.
522 stats : `Dict` [`str`, `Statistic`]
523 Dict of suffixes for each `Statistic` type.
524 target : `ComparableCatalog`
525 A target catalog with coordinate column names.
526 column_dist : `str`
527 The name of the distance column.
529 Returns
530 -------
531 columns : `Dict` [`str`, `type`]
532 Dictionary of column types keyed by name.
533 n_models : `int`
534 The number of models measurements will be made for.
536 Notes
537 -----
538 Presently, models must be identical for each band.
539 """
540 # Initial columns
541 columns = {
542 "bin": int,
543 "mag_min": float,
544 "mag_max": float,
545 }
547 # pre-assign all of the columns with appropriate types
548 n_models = 0
550 bands = list(bands_columns.keys())
551 n_bands = len(bands)
553 for idx, (band, config_flux) in enumerate(bands_columns.items()):
554 columns_suffix = [
555 ('flux', suffixes_flux),
556 ('mag', suffixes_mag),
557 ]
558 if idx == 0:
559 n_models = len(config_flux.columns_target_flux)
560 if (idx > 0) or (n_bands > 2):
561 columns_suffix.append((f'color_{bands[idx - 1]}_m_{band}', suffixes))
562 n_models_flux = len(config_flux.columns_target_flux)
563 n_models_err = len(config_flux.columns_target_flux_err)
565 # TODO: Do equivalent validation earlier, in the config
566 if (n_models_flux != n_models) or (n_models_err != n_models):
567 raise RuntimeError(f'{config_flux} len(columns_target_flux)={n_models_flux} and'
568 f' len(columns_target_flux_err)={n_models_err} must equal {n_models}')
570 for sourcetype in SourceType:
571 label = sourcetype.value.label
572 # Totals would be redundant
573 if sourcetype != SourceType.ALL:
574 for item in (f'n_{itype}_{mtype.value}' for itype in ('ref', 'target')
575 for mtype in MatchType):
576 columns[_get_column_name(band, label, item)] = int
578 for item in (target.column_coord1, target.column_coord2, column_dist):
579 for suffix in suffixes.values():
580 for stat in stats.keys():
581 columns[_get_column_name(band, label, item, suffix, stat)] = float
583 for item in config_flux.columns_target_flux:
584 for prefix_item, suffixes_col in columns_suffix:
585 for suffix in suffixes_col.values():
586 for stat in stats.keys():
587 columns[_get_column_name(band, label, prefix_item, item, suffix, stat)] = float
589 return columns, n_models
592class DiffMatchedTractCatalogTask(pipeBase.PipelineTask):
593 """Load subsets of matched catalogs and output a merged catalog of matched sources.
594 """
595 ConfigClass = DiffMatchedTractCatalogConfig
596 _DefaultName = "DiffMatchedTractCatalog"
598 def runQuantum(self, butlerQC, inputRefs, outputRefs):
599 inputs = butlerQC.get(inputRefs)
600 skymap = inputs.pop("skymap")
602 columns_match_target = ['match_row']
603 if 'match_candidate' in inputs['columns_match_target']:
604 columns_match_target.append('match_candidate')
606 outputs = self.run(
607 catalog_ref=inputs['cat_ref'].get(parameters={'columns': self.config.columns_in_ref}),
608 catalog_target=inputs['cat_target'].get(parameters={'columns': self.config.columns_in_target}),
609 catalog_match_ref=inputs['cat_match_ref'].get(
610 parameters={'columns': ['match_candidate', 'match_row']},
611 ),
612 catalog_match_target=inputs['cat_match_target'].get(
613 parameters={'columns': columns_match_target},
614 ),
615 wcs=skymap[butlerQC.quantum.dataId["tract"]].wcs,
616 )
617 butlerQC.put(outputs, outputRefs)
619 def run(
620 self,
621 catalog_ref: pd.DataFrame,
622 catalog_target: pd.DataFrame,
623 catalog_match_ref: pd.DataFrame,
624 catalog_match_target: pd.DataFrame,
625 wcs: afwGeom.SkyWcs = None,
626 ) -> pipeBase.Struct:
627 """Load matched reference and target (measured) catalogs, measure summary statistics, and output
628 a combined matched catalog with columns from both inputs.
630 Parameters
631 ----------
632 catalog_ref : `pandas.DataFrame`
633 A reference catalog to diff objects/sources from.
634 catalog_target : `pandas.DataFrame`
635 A target catalog to diff reference objects/sources to.
636 catalog_match_ref : `pandas.DataFrame`
637 A catalog with match indices of target sources and selection flags
638 for each reference source.
639 catalog_match_target : `pandas.DataFrame`
640 A catalog with selection flags for each target source.
641 wcs : `lsst.afw.image.SkyWcs`
642 A coordinate system to convert catalog positions to sky coordinates,
643 if necessary.
645 Returns
646 -------
647 retStruct : `lsst.pipe.base.Struct`
648 A struct with output_ref and output_target attribute containing the
649 output matched catalogs.
650 """
651 # Would be nice if this could refer directly to ConfigClass
652 config: DiffMatchedTractCatalogConfig = self.config
654 select_ref = catalog_match_ref['match_candidate'].values
655 # Add additional selection criteria for target sources beyond those for matching
656 # (not recommended, but can be done anyway)
657 select_target = (catalog_match_target['match_candidate'].values
658 if 'match_candidate' in catalog_match_target.columns
659 else np.ones(len(catalog_match_target), dtype=bool))
660 for column in config.columns_target_select_true:
661 select_target &= catalog_target[column].values
662 for column in config.columns_target_select_false:
663 select_target &= ~catalog_target[column].values
665 ref, target = config.coord_format.format_catalogs(
666 catalog_ref=catalog_ref, catalog_target=catalog_target,
667 select_ref=None, select_target=select_target, wcs=wcs, radec_to_xy_func=radec_to_xy,
668 )
669 cat_ref = ref.catalog
670 cat_target = target.catalog
671 n_target = len(cat_target)
673 if config.include_unmatched:
674 for cat_add, cat_match in ((cat_ref, catalog_match_ref), (cat_target, catalog_match_target)):
675 cat_add['match_candidate'] = cat_match['match_candidate'].values
677 match_row = catalog_match_ref['match_row'].values
678 matched_ref = match_row >= 0
679 matched_row = match_row[matched_ref]
680 matched_target = np.zeros(n_target, dtype=bool)
681 matched_target[matched_row] = True
683 # Add/compute distance columns
684 coord1_target_err, coord2_target_err = config.columns_target_coord_err
685 column_dist, column_dist_err = 'match_distance', 'match_distanceErr'
686 dist = np.full(n_target, np.nan)
688 target_match_c1, target_match_c2 = (coord[matched_row] for coord in (target.coord1, target.coord2))
689 target_ref_c1, target_ref_c2 = (coord[matched_ref] for coord in (ref.coord1, ref.coord2))
691 dist_err = np.full(n_target, np.nan)
692 dist[matched_row] = sphdist(
693 target_match_c1, target_match_c2, target_ref_c1, target_ref_c2
694 ) if config.coord_format.coords_spherical else np.hypot(
695 target_match_c1 - target_ref_c1, target_match_c2 - target_ref_c2,
696 )
697 # Should probably explicitly add cosine terms if ref has errors too
698 dist_err[matched_row] = sphdist(
699 target_match_c1, target_match_c2,
700 target_match_c1 + cat_target.iloc[matched_row][coord1_target_err].values,
701 target_match_c2 + cat_target.iloc[matched_row][coord2_target_err].values,
702 ) if config.coord_format.coords_spherical else np.hypot(
703 cat_target.iloc[matched_row][coord1_target_err].values,
704 cat_target.iloc[matched_row][coord2_target_err].values
705 )
706 cat_target[column_dist], cat_target[column_dist_err] = dist, dist_err
708 # Create a matched table, preserving the target catalog's named index (if it has one)
709 cat_left = cat_target.iloc[matched_row]
710 has_index_left = cat_left.index.name is not None
711 cat_right = cat_ref[matched_ref].reset_index()
712 cat_right.columns = [f'{config.column_matched_prefix_ref}{col}' for col in cat_right.columns]
713 cat_matched = pd.concat(objs=(cat_left.reset_index(drop=not has_index_left), cat_right), axis=1)
715 if config.include_unmatched:
716 # Create an unmatched table with the same schema as the matched one
717 # ... but only for objects with no matches (for completeness/purity)
718 # and that were selected for matching (or inclusion via config)
719 cat_right = cat_ref[~matched_ref & select_ref].reset_index(drop=False)
720 cat_right.columns = (f'{config.column_matched_prefix_ref}{col}' for col in cat_right.columns)
721 match_row_target = catalog_match_target['match_row'].values
722 cat_left = cat_target[~(match_row_target >= 0) & select_target].reset_index(
723 drop=not has_index_left)
724 # See https://github.com/pandas-dev/pandas/issues/46662
725 # astropy masked columns would handle this much more gracefully
726 # Unfortunately, that would require storageClass migration
727 # So we use pandas "extended" nullable types for now
728 for cat_i in (cat_left, cat_right):
729 for colname in cat_i.columns:
730 column = cat_i[colname]
731 dtype = str(column.dtype)
732 if dtype == "bool":
733 cat_i[colname] = column.astype("boolean")
734 elif dtype.startswith("int"):
735 cat_i[colname] = column.astype(f"Int{dtype[3:]}")
736 elif dtype.startswith("uint"):
737 cat_i[colname] = column.astype(f"UInt{dtype[3:]}")
738 cat_unmatched = pd.concat(objs=(cat_left, cat_right))
740 for columns_convert_base, prefix in (
741 (config.columns_ref_mag_to_nJy, config.column_matched_prefix_ref),
742 (config.columns_target_mag_to_nJy, ""),
743 ):
744 if columns_convert_base:
745 columns_convert = {
746 f"{prefix}{k}": f"{prefix}{v}" for k, v in columns_convert_base.items()
747 } if prefix else columns_convert_base
748 for cat_convert in (cat_matched, cat_unmatched):
749 cat_convert.rename(columns=columns_convert, inplace=True)
750 for column_flux in columns_convert.values():
751 cat_convert[column_flux] = u.ABmag.to(u.nJy, cat_convert[column_flux])
753 # TODO: Deprecate all matched difference output in DM-43831 (per RFC-1008)
755 # Slightly smelly hack for when a column (like distance) is already relative to truth
756 column_dummy = 'dummy'
757 cat_ref[column_dummy] = np.zeros_like(ref.coord1)
759 # Add a boolean column for whether a match is classified correctly
760 # TODO: remove the assumption of a boolean column
761 extended_ref = cat_ref[config.column_ref_extended] == (not config.column_ref_extended_inverted)
763 extended_target = cat_target[config.column_target_extended].values >= config.extendedness_cut
765 # Define difference/chi columns and statistics thereof
766 suffixes = {MeasurementType.DIFF: 'diff', MeasurementType.CHI: 'chi'}
767 # Skip diff for fluxes - covered by mags
768 suffixes_flux = {MeasurementType.CHI: suffixes[MeasurementType.CHI]}
769 # Skip chi for magnitudes, which have strange errors
770 suffixes_mag = {MeasurementType.DIFF: suffixes[MeasurementType.DIFF]}
771 stats = {stat.name_short(): stat() for stat in (Median, SigmaIQR, SigmaMAD)}
773 for percentile in self.config.percentiles:
774 stat = Percentile(percentile=float(Decimal(percentile)))
775 stats[stat.name_short()] = stat
777 # Get dict of column names
778 columns, n_models = _get_columns(
779 bands_columns=config.columns_flux,
780 suffixes=suffixes,
781 suffixes_flux=suffixes_flux,
782 suffixes_mag=suffixes_mag,
783 stats=stats,
784 target=target,
785 column_dist=column_dist,
786 )
788 # Setup numpy table
789 n_bins = config.mag_num_bins
790 data = np.zeros((n_bins,), dtype=[(key, value) for key, value in columns.items()])
791 data['bin'] = np.arange(n_bins)
793 # Setup bins
794 bins_mag = np.linspace(start=config.mag_brightest_ref, stop=config.mag_faintest_ref,
795 num=n_bins + 1)
796 data['mag_min'] = bins_mag[:-1]
797 data['mag_max'] = bins_mag[1:]
798 bins_mag = tuple((bins_mag[idx], bins_mag[idx + 1]) for idx in range(n_bins))
800 # Define temporary columns for intermediate storage
801 column_mag_temp = 'mag_temp'
802 column_color_temp = 'color_temp'
803 column_color_err_temp = 'colorErr_temp'
804 flux_err_frac_prev = [None]*n_models
805 mag_prev = [None]*n_models
807 columns_target = {
808 target.column_coord1: (
809 ref.column_coord1, target.column_coord1, coord1_target_err, False,
810 ),
811 target.column_coord2: (
812 ref.column_coord2, target.column_coord2, coord2_target_err, False,
813 ),
814 column_dist: (column_dummy, column_dist, column_dist_err, False),
815 }
817 # Cheat a little and do the first band last so that the color is
818 # based on the last band
819 band_fluxes = [(band, config_flux) for (band, config_flux) in config.columns_flux.items()]
820 n_bands = len(band_fluxes)
821 if n_bands > 0:
822 band_fluxes.append(band_fluxes[0])
823 flux_err_frac_first = None
824 mag_first = None
825 mag_ref_first = None
827 band_prev = None
828 for idx_band, (band, config_flux) in enumerate(band_fluxes):
829 if idx_band == n_bands:
830 # These were already computed earlier
831 mag_ref = mag_ref_first
832 flux_err_frac = flux_err_frac_first
833 mag_model = mag_first
834 else:
835 mag_ref = -2.5*np.log10(cat_ref[config_flux.column_ref_flux]) + config.mag_zeropoint_ref
836 flux_err_frac = [None]*n_models
837 mag_model = [None]*n_models
839 if idx_band > 0:
840 cat_ref[column_color_temp] = cat_ref[column_mag_temp] - mag_ref
842 cat_ref[column_mag_temp] = mag_ref
844 select_ref_bins = [select_ref & (mag_ref > mag_lo) & (mag_ref < mag_hi)
845 for idx_bin, (mag_lo, mag_hi) in enumerate(bins_mag)]
847 # Iterate over multiple models, compute their mags and colours (if there's a previous band)
848 for idx_model in range(n_models):
849 column_target_flux = config_flux.columns_target_flux[idx_model]
850 column_target_flux_err = config_flux.columns_target_flux_err[idx_model]
852 flux_target = cat_target[column_target_flux]
853 mag_target = -2.5*np.log10(flux_target) + config.mag_zeropoint_target
854 if config.mag_ceiling_target is not None:
855 mag_target[mag_target > config.mag_ceiling_target] = config.mag_ceiling_target
856 mag_model[idx_model] = mag_target
858 # These are needed for computing magnitude/color "errors" (which are a sketchy concept)
859 flux_err_frac[idx_model] = cat_target[column_target_flux_err]/flux_target
861 # Stop if idx == 0: The rest will be picked up at idx == n_bins
862 if idx_band > 0:
863 # Keep these mags tabulated for convenience
864 column_mag_temp_model = f'{column_mag_temp}{idx_model}'
865 cat_target[column_mag_temp_model] = mag_target
867 columns_target[f'flux_{column_target_flux}'] = (
868 config_flux.column_ref_flux,
869 column_target_flux,
870 column_target_flux_err,
871 True,
872 )
873 # Note: magnitude errors are generally problematic and not worth aggregating
874 columns_target[f'mag_{column_target_flux}'] = (
875 column_mag_temp, column_mag_temp_model, None, False,
876 )
878 # No need for colors if this is the last band and there are only two bands
879 # (because it would just be the negative of the first color)
880 skip_color = (idx_band == n_bands) and (n_bands <= 2)
881 if not skip_color:
882 column_color_temp_model = f'{column_color_temp}{idx_model}'
883 column_color_err_temp_model = f'{column_color_err_temp}{idx_model}'
885 # e.g. if order is ugrizy, first color will be u - g
886 cat_target[column_color_temp_model] = mag_prev[idx_model] - mag_model[idx_model]
888 # Sum (in quadrature, and admittedly sketchy for faint fluxes) magnitude errors
889 cat_target[column_color_err_temp_model] = 2.5/np.log(10)*np.hypot(
890 flux_err_frac[idx_model], flux_err_frac_prev[idx_model])
891 columns_target[f'color_{band_prev}_m_{band}_{column_target_flux}'] = (
892 column_color_temp,
893 column_color_temp_model,
894 column_color_err_temp_model,
895 False,
896 )
898 for idx_bin, (mag_lo, mag_hi) in enumerate(bins_mag):
899 row = data[idx_bin]
900 # Reference sources only need to be counted once
901 if idx_model == 0:
902 select_ref_bin = select_ref_bins[idx_bin]
903 select_target_bin = select_target & (mag_target > mag_lo) & (mag_target < mag_hi)
905 for sourcetype in SourceType:
906 sourcetype_info = sourcetype.value
907 is_extended = sourcetype_info.is_extended
908 # Counts filtered by match selection and magnitude bin
909 select_ref_sub = select_ref_bin.copy()
910 select_target_sub = select_target_bin.copy()
911 if is_extended is not None:
912 is_extended_ref = (extended_ref == is_extended)
913 select_ref_sub &= is_extended_ref
914 if idx_model == 0:
915 n_ref_sub = np.count_nonzero(select_ref_sub)
916 row[_get_column_name(band, sourcetype_info.label, 'n_ref',
917 MatchType.ALL.value)] = n_ref_sub
918 select_target_sub &= (extended_target == is_extended)
919 n_target_sub = np.count_nonzero(select_target_sub)
920 row[_get_column_name(band, sourcetype_info.label, 'n_target',
921 MatchType.ALL.value)] = n_target_sub
923 # Filter matches by magnitude bin and true class
924 match_row_bin = match_row.copy()
925 match_row_bin[~select_ref_sub] = -1
926 match_good = match_row_bin >= 0
928 n_match = np.count_nonzero(match_good)
930 # Same for counts of matched target sources (for e.g. purity)
932 if n_match > 0:
933 rows_matched = match_row_bin[match_good]
934 subset_target = cat_target.iloc[rows_matched]
935 if (is_extended is not None) and (idx_model == 0):
936 right_type = extended_target[rows_matched] == is_extended
937 n_total = len(right_type)
938 n_right = np.count_nonzero(right_type)
939 row[_get_column_name(band, sourcetype_info.label, 'n_ref',
940 MatchType.MATCH_RIGHT.value)] = n_right
941 row[_get_column_name(
942 band,
943 sourcetype_info.label,
944 'n_ref',
945 MatchType.MATCH_WRONG.value,
946 )] = n_total - n_right
948 # compute stats for this bin, for all columns
949 for column, (column_ref, column_target, column_err_target, skip_diff) \
950 in columns_target.items():
951 values_ref = cat_ref[column_ref][match_good].values
952 errors_target = (
953 subset_target[column_err_target].values
954 if column_err_target is not None
955 else None
956 )
957 compute_stats(
958 values_ref,
959 subset_target[column_target].values,
960 errors_target,
961 row,
962 stats,
963 suffixes,
964 prefix=f'{band}_{sourcetype_info.label}_{column}',
965 skip_diff=skip_diff,
966 )
968 # Count matched target sources with *measured* mags within bin
969 # Used for e.g. purity calculation
970 # Should be merged with above code if there's ever a need for
971 # measuring stats on this source selection
972 select_target_sub &= matched_target
974 if is_extended is not None and (np.count_nonzero(select_target_sub) > 0):
975 n_total = np.count_nonzero(select_target_sub)
976 right_type = np.zeros(n_target, dtype=bool)
977 right_type[match_row[matched_ref & is_extended_ref]] = True
978 right_type &= select_target_sub
979 n_right = np.count_nonzero(right_type)
980 row[_get_column_name(band, sourcetype_info.label, 'n_target',
981 MatchType.MATCH_RIGHT.value)] = n_right
982 row[_get_column_name(band, sourcetype_info.label, 'n_target',
983 MatchType.MATCH_WRONG.value)] = n_total - n_right
985 # delete the flux/color columns since they change with each band
986 for prefix in ('flux', 'mag'):
987 del columns_target[f'{prefix}_{column_target_flux}']
988 if not skip_color:
989 del columns_target[f'color_{band_prev}_m_{band}_{column_target_flux}']
991 # keep values needed for colors
992 flux_err_frac_prev = flux_err_frac
993 mag_prev = mag_model
994 band_prev = band
995 if idx_band == 0:
996 flux_err_frac_first = flux_err_frac
997 mag_first = mag_model
998 mag_ref_first = mag_ref
1000 if config.include_unmatched:
1001 cat_matched = pd.concat((cat_matched, cat_unmatched))
1003 retStruct = pipeBase.Struct(cat_matched=cat_matched, diff_matched=pd.DataFrame(data))
1004 return retStruct