Coverage for python/lsst/pipe/tasks/diff_matched_tract_catalog.py: 26%
366 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-23 02:25 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-23 02:25 -0800
1# This file is part of pipe_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = [
23 'DiffMatchedTractCatalogConfig', 'DiffMatchedTractCatalogTask', 'MatchedCatalogFluxesConfig',
24 'MatchType', 'MeasurementType', 'SourceType',
25 'Statistic', 'Median', 'SigmaIQR', 'SigmaMAD', 'Percentile',
26]
28import lsst.afw.geom as afwGeom
29from lsst.meas.astrom.matcher_probabilistic import (
30 ComparableCatalog, ConvertCatalogCoordinatesConfig,
31)
32from lsst.meas.astrom.match_probabilistic_task import radec_to_xy
33import lsst.pex.config as pexConfig
34import lsst.pipe.base as pipeBase
35import lsst.pipe.base.connectionTypes as cT
36from lsst.skymap import BaseSkyMap
38from abc import ABCMeta, abstractmethod
39from astropy.stats import mad_std
40from dataclasses import dataclass
41from decimal import Decimal
42from enum import Enum
43import numpy as np
44import pandas as pd
45from scipy.stats import iqr
46from typing import Dict, Sequence, Set
49def is_sequence_set(x: Sequence):
50 return len(x) == len(set(x))
53def is_percentile(x: str):
54 return 0 <= Decimal(x) <= 100
57DiffMatchedTractCatalogBaseTemplates = {
58 "name_input_cat_ref": "truth_summary",
59 "name_input_cat_target": "objectTable_tract",
60 "name_skymap": BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
61}
64class DiffMatchedTractCatalogConnections(
65 pipeBase.PipelineTaskConnections,
66 dimensions=("tract", "skymap"),
67 defaultTemplates=DiffMatchedTractCatalogBaseTemplates,
68):
69 cat_ref = cT.Input(
70 doc="Reference object catalog to match from",
71 name="{name_input_cat_ref}",
72 storageClass="DataFrame",
73 dimensions=("tract", "skymap"),
74 deferLoad=True,
75 )
76 cat_target = cT.Input(
77 doc="Target object catalog to match",
78 name="{name_input_cat_target}",
79 storageClass="DataFrame",
80 dimensions=("tract", "skymap"),
81 deferLoad=True,
82 )
83 skymap = cT.Input(
84 doc="Input definition of geometry/bbox and projection/wcs for coadded exposures",
85 name="{name_skymap}",
86 storageClass="SkyMap",
87 dimensions=("skymap",),
88 )
89 cat_match_ref = cT.Input(
90 doc="Reference match catalog with indices of target matches",
91 name="match_ref_{name_input_cat_ref}_{name_input_cat_target}",
92 storageClass="DataFrame",
93 dimensions=("tract", "skymap"),
94 deferLoad=True,
95 )
96 cat_match_target = cT.Input(
97 doc="Target match catalog with indices of references matches",
98 name="match_target_{name_input_cat_ref}_{name_input_cat_target}",
99 storageClass="DataFrame",
100 dimensions=("tract", "skymap"),
101 deferLoad=True,
102 )
103 columns_match_target = cT.Input(
104 doc="Target match catalog columns",
105 name="match_target_{name_input_cat_ref}_{name_input_cat_target}.columns",
106 storageClass="DataFrameIndex",
107 dimensions=("tract", "skymap"),
108 )
109 cat_matched = cT.Output(
110 doc="Catalog with reference and target columns for matched sources only",
111 name="matched_{name_input_cat_ref}_{name_input_cat_target}",
112 storageClass="DataFrame",
113 dimensions=("tract", "skymap"),
114 )
115 diff_matched = cT.Output(
116 doc="Table with aggregated counts, difference and chi statistics",
117 name="diff_matched_{name_input_cat_ref}_{name_input_cat_target}",
118 storageClass="DataFrame",
119 dimensions=("tract", "skymap"),
120 )
123class MatchedCatalogFluxesConfig(pexConfig.Config):
124 column_ref_flux = pexConfig.Field(
125 dtype=str,
126 doc='Reference catalog flux column name',
127 )
128 columns_target_flux = pexConfig.ListField(
129 dtype=str,
130 listCheck=is_sequence_set,
131 doc="List of target catalog flux column names",
132 )
133 columns_target_flux_err = pexConfig.ListField(
134 dtype=str,
135 listCheck=is_sequence_set,
136 doc="List of target catalog flux error column names",
137 )
139 @property
140 def columns_in_ref(self) -> Set[str]:
141 return {self.column_ref_flux}
143 @property
144 def columns_in_target(self) -> Set[str]:
145 return set(self.columns_target_flux).union(set(self.columns_target_flux_err))
148class DiffMatchedTractCatalogConfig(
149 pipeBase.PipelineTaskConfig,
150 pipelineConnections=DiffMatchedTractCatalogConnections,
151):
152 column_matched_prefix_ref = pexConfig.Field(
153 dtype=str,
154 default='refcat_',
155 doc='The prefix for matched columns copied from the reference catalog',
156 )
157 column_ref_extended = pexConfig.Field(
158 dtype=str,
159 default='is_pointsource',
160 doc='The boolean reference table column specifying if the target is extended',
161 )
162 column_ref_extended_inverted = pexConfig.Field(
163 dtype=bool,
164 default=True,
165 doc='Whether column_ref_extended specifies if the object is compact, not extended',
166 )
167 column_target_extended = pexConfig.Field(
168 dtype=str,
169 default='refExtendedness',
170 doc='The target table column estimating the extendedness of the object (0 <= x <= 1)',
171 )
173 @property
174 def columns_in_ref(self) -> Set[str]:
175 columns_all = [self.coord_format.column_ref_coord1, self.coord_format.column_ref_coord2,
176 self.column_ref_extended]
177 for column_lists in (
178 (
179 self.columns_ref_copy,
180 ),
181 (x.columns_in_ref for x in self.columns_flux.values()),
182 ):
183 for column_list in column_lists:
184 columns_all.extend(column_list)
186 return set(columns_all)
188 @property
189 def columns_in_target(self) -> Set[str]:
190 columns_all = [self.coord_format.column_target_coord1, self.coord_format.column_target_coord2,
191 self.column_target_extended]
192 if self.coord_format.coords_ref_to_convert is not None:
193 columns_all.extend(self.coord_format.coords_ref_to_convert.values())
194 for column_lists in (
195 (
196 self.columns_target_coord_err,
197 self.columns_target_select_false,
198 self.columns_target_select_true,
199 self.columns_target_copy,
200 ),
201 (x.columns_in_target for x in self.columns_flux.values()),
202 ):
203 for column_list in column_lists:
204 columns_all.extend(column_list)
205 return set(columns_all)
207 columns_flux = pexConfig.ConfigDictField(
208 keytype=str,
209 itemtype=MatchedCatalogFluxesConfig,
210 doc="Configs for flux columns for each band",
211 )
212 columns_ref_copy = pexConfig.ListField(
213 dtype=str,
214 default=set(),
215 doc='Reference table columns to copy to copy into cat_matched',
216 )
217 columns_target_coord_err = pexConfig.ListField( 217 ↛ exitline 217 didn't jump to the function exit
218 dtype=str,
219 listCheck=lambda x: (len(x) == 2) and (x[0] != x[1]),
220 doc='Target table coordinate columns with standard errors (sigma)',
221 )
222 columns_target_copy = pexConfig.ListField(
223 dtype=str,
224 default=('patch',),
225 doc='Target table columns to copy to copy into cat_matched',
226 )
227 columns_target_select_true = pexConfig.ListField(
228 dtype=str,
229 default=('detect_isPrimary',),
230 doc='Target table columns to require to be True for selecting sources',
231 )
232 columns_target_select_false = pexConfig.ListField(
233 dtype=str,
234 default=('merge_peak_sky',),
235 doc='Target table columns to require to be False for selecting sources',
236 )
237 coord_format = pexConfig.ConfigField(
238 dtype=ConvertCatalogCoordinatesConfig,
239 doc="Configuration for coordinate conversion",
240 )
241 extendedness_cut = pexConfig.Field(
242 dtype=float,
243 default=0.5,
244 doc='Minimum extendedness for a measured source to be considered extended',
245 )
246 mag_num_bins = pexConfig.Field(
247 doc='Number of magnitude bins',
248 default=15,
249 dtype=int,
250 )
251 mag_brightest_ref = pexConfig.Field(
252 dtype=float,
253 default=15,
254 doc='Brightest magnitude cutoff for binning',
255 )
256 mag_ceiling_target = pexConfig.Field(
257 dtype=float,
258 default=None,
259 optional=True,
260 doc='Ceiling (maximum/faint) magnitude for target sources',
261 )
262 mag_faintest_ref = pexConfig.Field(
263 dtype=float,
264 default=30,
265 doc='Faintest magnitude cutoff for binning',
266 )
267 mag_zeropoint_ref = pexConfig.Field(
268 dtype=float,
269 default=31.4,
270 doc='Magnitude zeropoint for reference sources',
271 )
272 mag_zeropoint_target = pexConfig.Field(
273 dtype=float,
274 default=31.4,
275 doc='Magnitude zeropoint for target sources',
276 )
277 percentiles = pexConfig.ListField(
278 dtype=str,
279 # -2, -1, +1, +2 sigma percentiles for normal distribution
280 default=('2.275', '15.866', '84.134', '97.725'),
281 doc='Percentiles to compute for diff/chi values',
282 itemCheck=is_percentile,
283 listCheck=is_sequence_set,
284 )
287@dataclass(frozen=True)
288class MeasurementTypeInfo:
289 doc: str
290 name: str
293class MeasurementType(Enum):
294 DIFF = MeasurementTypeInfo(
295 doc="difference (measured - reference)",
296 name="diff",
297 )
298 CHI = MeasurementTypeInfo(
299 doc="scaled difference (measured - reference)/error",
300 name="chi",
301 )
304class Statistic(metaclass=ABCMeta):
305 """A statistic that can be applied to a set of values.
306 """
307 @abstractmethod
308 def doc(self) -> str:
309 """A description of the statistic"""
310 raise NotImplementedError('Subclasses must implement this method')
312 @abstractmethod
313 def name_short(self) -> str:
314 """A short name for the statistic, e.g. for a table column name"""
315 raise NotImplementedError('Subclasses must implement this method')
317 @abstractmethod
318 def value(self, values):
319 """The value of the statistic for a set of input values.
321 Parameters
322 ----------
323 values : `Collection` [`float`]
324 A set of values to compute the statistic for.
326 Returns
327 -------
328 statistic : `float`
329 The value of the statistic.
330 """
331 raise NotImplementedError('Subclasses must implement this method')
334class Median(Statistic):
335 """The median of a set of values."""
336 @classmethod
337 def doc(cls) -> str:
338 return "Median"
340 @classmethod
341 def name_short(cls) -> str:
342 return "median"
344 def value(self, values):
345 return np.median(values)
348class SigmaIQR(Statistic):
349 """The re-scaled interquartile range (sigma equivalent)."""
350 @classmethod
351 def doc(cls) -> str:
352 return "Interquartile range divided by ~1.349 (sigma-equivalent)"
354 @classmethod
355 def name_short(cls) -> str:
356 return "sig_iqr"
358 def value(self, values):
359 return iqr(values, scale='normal')
362class SigmaMAD(Statistic):
363 """The re-scaled median absolute deviation (sigma equivalent)."""
364 @classmethod
365 def doc(cls) -> str:
366 return "Median absolute deviation multiplied by ~1.4826 (sigma-equivalent)"
368 @classmethod
369 def name_short(cls) -> str:
370 return "sig_mad"
372 def value(self, values):
373 return mad_std(values)
376@dataclass(frozen=True)
377class Percentile(Statistic):
378 """An arbitrary percentile.
380 Parameters
381 ----------
382 percentile : `float`
383 A valid percentile (0 <= p <= 100).
384 """
385 percentile: float
387 def doc(self) -> str:
388 return "Median absolute deviation multiplied by ~1.4826 (sigma-equivalent)"
390 def name_short(self) -> str:
391 return f"pctl{f'{self.percentile/100:.5f}'[2:]}"
393 def value(self, values):
394 return np.percentile(values, self.percentile)
396 def __post_init__(self):
397 if not ((self.percentile >= 0) and (self.percentile <= 100)):
398 raise ValueError(f'percentile={self.percentile} not >=0 and <= 100')
401def _get_stat_name(*args):
402 return '_'.join(args)
405def _get_column_name(band, *args):
406 return f"{band}_{_get_stat_name(*args)}"
409def compute_stats(values_ref, values_target, errors_target, row, stats, suffixes, prefix, skip_diff=False):
410 """Compute statistics on differences and store results in a row.
412 Parameters
413 ----------
414 values_ref : `numpy.ndarray`, (N,)
415 Reference values.
416 values_target : `numpy.ndarray`, (N,)
417 Measured values.
418 errors_target : `numpy.ndarray`, (N,)
419 Errors (standard deviations) on `values_target`.
420 row : `numpy.ndarray`, (1, C)
421 A numpy array with pre-assigned column names.
422 stats : `Dict` [`str`, `Statistic`]
423 A dict of `Statistic` values to measure, keyed by their column suffix.
424 suffixes : `Dict` [`MeasurementType`, `str`]
425 A dict of measurement type column suffixes, keyed by the measurement type.
426 prefix : `str`
427 A prefix for all column names (e.g. band).
428 skip_diff : `bool`
429 Whether to skip computing statistics on differences. Note that
430 differences will still be computed for chi statistics.
432 Returns
433 -------
434 row_with_stats : `numpy.ndarray`, (1, C)
435 The original `row` with statistic values assigned.
436 """
437 n_ref = len(values_ref)
438 if n_ref > 0:
439 n_target = len(values_target)
440 n_target_err = len(errors_target) if errors_target is not None else n_ref
441 if (n_target != n_ref) or (n_target_err != n_ref):
442 raise ValueError(f'lengths of values_ref={n_ref}, values_target={n_target}'
443 f', error_target={n_target_err} must match')
445 do_chi = errors_target is not None
446 diff = values_target - values_ref
447 chi = diff/errors_target if do_chi else diff
448 # Could make this configurable, but non-finite values/errors are not really usable
449 valid = np.isfinite(chi)
450 values_type = {} if skip_diff else {MeasurementType.DIFF: diff[valid]}
451 if do_chi:
452 values_type[MeasurementType.CHI] = chi[valid]
454 for suffix_type, suffix in suffixes.items():
455 values = values_type.get(suffix_type)
456 if values is not None and len(values) > 0:
457 for stat_name, stat in stats.items():
458 row[_get_stat_name(prefix, suffix, stat_name)] = stat.value(values)
459 return row
462@dataclass(frozen=True)
463class SourceTypeInfo:
464 is_extended: bool | None
465 label: str
468class SourceType(Enum):
469 ALL = SourceTypeInfo(is_extended=None, label='all')
470 RESOLVED = SourceTypeInfo(is_extended=True, label='resolved')
471 UNRESOLVED = SourceTypeInfo(is_extended=False, label='unresolved')
474class MatchType(Enum):
475 ALL = 'all'
476 MATCH_RIGHT = 'match_right'
477 MATCH_WRONG = 'match_wrong'
480def _get_columns(bands_columns: Dict, suffixes: Dict, suffixes_flux: Dict, suffixes_mag: Dict,
481 stats: Dict, target: ComparableCatalog, column_dist: str):
482 """Get column names for a table of difference statistics.
484 Parameters
485 ----------
486 bands_columns : `Dict` [`str`,`MatchedCatalogFluxesConfig`]
487 Dict keyed by band of flux column configuration.
488 suffixes, suffixes_flux, suffixes_mag : `Dict` [`MeasurementType`, `str`]
489 Dict of suffixes for each `MeasurementType` type, for general columns (e.g.
490 coordinates), fluxes and magnitudes, respectively.
491 stats : `Dict` [`str`, `Statistic`]
492 Dict of suffixes for each `Statistic` type.
493 target : `ComparableCatalog`
494 A target catalog with coordinate column names.
495 column_dist : `str`
496 The name of the distance column.
498 Returns
499 -------
500 columns : `Dict` [`str`, `type`]
501 Dictionary of column types keyed by name.
502 n_models : `int`
503 The number of models measurements will be made for.
505 Notes
506 -----
507 Presently, models must be identical for each band.
508 """
509 # Initial columns
510 columns = {
511 "bin": int,
512 "mag_min": float,
513 "mag_max": float,
514 }
516 # pre-assign all of the columns with appropriate types
517 n_models = 0
519 bands = list(bands_columns.keys())
520 n_bands = len(bands)
522 for idx, (band, config_flux) in enumerate(bands_columns.items()):
523 columns_suffix = [
524 ('flux', suffixes_flux),
525 ('mag', suffixes_mag),
526 ]
527 if idx == 0:
528 n_models = len(config_flux.columns_target_flux)
529 if (idx > 0) or (n_bands > 2):
530 columns_suffix.append((f'color_{bands[idx - 1]}_m_{band}', suffixes))
531 n_models_flux = len(config_flux.columns_target_flux)
532 n_models_err = len(config_flux.columns_target_flux_err)
534 # TODO: Do equivalent validation earlier, in the config
535 if (n_models_flux != n_models) or (n_models_err != n_models):
536 raise RuntimeError(f'{config_flux} len(columns_target_flux)={n_models_flux} and'
537 f' len(columns_target_flux_err)={n_models_err} must equal {n_models}')
539 for sourcetype in SourceType:
540 label = sourcetype.value.label
541 # Totals would be redundant
542 if sourcetype != SourceType.ALL:
543 for item in (f'n_{itype}_{mtype.value}' for itype in ('ref', 'target')
544 for mtype in MatchType):
545 columns[_get_column_name(band, label, item)] = int
547 for item in (target.column_coord1, target.column_coord2, column_dist):
548 for suffix in suffixes.values():
549 for stat in stats.keys():
550 columns[_get_column_name(band, label, item, suffix, stat)] = float
552 for item in config_flux.columns_target_flux:
553 for prefix_item, suffixes_col in columns_suffix:
554 for suffix in suffixes_col.values():
555 for stat in stats.keys():
556 columns[_get_column_name(band, label, prefix_item, item, suffix, stat)] = float
558 return columns, n_models
561class DiffMatchedTractCatalogTask(pipeBase.PipelineTask):
562 """Load subsets of matched catalogs and output a merged catalog of matched sources.
563 """
564 ConfigClass = DiffMatchedTractCatalogConfig
565 _DefaultName = "DiffMatchedTractCatalog"
567 def runQuantum(self, butlerQC, inputRefs, outputRefs):
568 inputs = butlerQC.get(inputRefs)
569 skymap = inputs.pop("skymap")
571 columns_match_target = ['match_row']
572 if 'match_candidate' in inputs['columns_match_target']:
573 columns_match_target.append('match_candidate')
575 outputs = self.run(
576 catalog_ref=inputs['cat_ref'].get(parameters={'columns': self.config.columns_in_ref}),
577 catalog_target=inputs['cat_target'].get(parameters={'columns': self.config.columns_in_target}),
578 catalog_match_ref=inputs['cat_match_ref'].get(
579 parameters={'columns': ['match_candidate', 'match_row']},
580 ),
581 catalog_match_target=inputs['cat_match_target'].get(
582 parameters={'columns': columns_match_target},
583 ),
584 wcs=skymap[butlerQC.quantum.dataId["tract"]].wcs,
585 )
586 butlerQC.put(outputs, outputRefs)
588 def run(
589 self,
590 catalog_ref: pd.DataFrame,
591 catalog_target: pd.DataFrame,
592 catalog_match_ref: pd.DataFrame,
593 catalog_match_target: pd.DataFrame,
594 wcs: afwGeom.SkyWcs = None,
595 ) -> pipeBase.Struct:
596 """Load matched reference and target (measured) catalogs, measure summary statistics, and output
597 a combined matched catalog with columns from both inputs.
599 Parameters
600 ----------
601 catalog_ref : `pandas.DataFrame`
602 A reference catalog to diff objects/sources from.
603 catalog_target : `pandas.DataFrame`
604 A target catalog to diff reference objects/sources to.
605 catalog_match_ref : `pandas.DataFrame`
606 A catalog with match indices of target sources and selection flags
607 for each reference source.
608 catalog_match_target : `pandas.DataFrame`
609 A catalog with selection flags for each target source.
610 wcs : `lsst.afw.image.SkyWcs`
611 A coordinate system to convert catalog positions to sky coordinates,
612 if necessary.
614 Returns
615 -------
616 retStruct : `lsst.pipe.base.Struct`
617 A struct with output_ref and output_target attribute containing the
618 output matched catalogs.
619 """
620 config = self.config
622 select_ref = catalog_match_ref['match_candidate'].values
623 # Add additional selection criteria for target sources beyond those for matching
624 # (not recommended, but can be done anyway)
625 select_target = (catalog_match_target['match_candidate'].values
626 if 'match_candidate' in catalog_match_target.columns
627 else np.ones(len(catalog_match_target), dtype=bool))
628 for column in config.columns_target_select_true:
629 select_target &= catalog_target[column].values
630 for column in config.columns_target_select_false:
631 select_target &= ~catalog_target[column].values
633 ref, target = config.coord_format.format_catalogs(
634 catalog_ref=catalog_ref, catalog_target=catalog_target,
635 select_ref=None, select_target=select_target, wcs=wcs, radec_to_xy_func=radec_to_xy,
636 return_converted_columns=config.coord_format.coords_ref_to_convert is not None,
637 )
638 cat_ref = ref.catalog
639 cat_target = target.catalog
640 n_target = len(cat_target)
642 match_row = catalog_match_ref['match_row'].values
643 matched_ref = match_row >= 0
644 matched_row = match_row[matched_ref]
645 matched_target = np.zeros(n_target, dtype=bool)
646 matched_target[matched_row] = True
648 # Create a matched table, preserving the target catalog's named index (if it has one)
649 cat_left = cat_target.iloc[matched_row]
650 has_index_left = cat_left.index.name is not None
651 cat_right = cat_ref[matched_ref].reset_index()
652 cat_matched = pd.concat(objs=(cat_left.reset_index(drop=True), cat_right), axis=1)
653 if has_index_left:
654 cat_matched.index = cat_left.index
655 cat_matched.columns.values[len(cat_target.columns):] = [f'refcat_{col}' for col in cat_right.columns]
657 # Add/compute distance columns
658 coord1_target_err, coord2_target_err = config.columns_target_coord_err
659 column_dist, column_dist_err = 'distance', 'distanceErr'
660 dist = np.full(n_target, np.Inf)
662 dist[matched_row] = np.hypot(
663 target.coord1[matched_row] - ref.coord1[matched_ref],
664 target.coord2[matched_row] - ref.coord2[matched_ref],
665 )
666 dist_err = np.full(n_target, np.Inf)
667 dist_err[matched_row] = np.hypot(cat_target.iloc[matched_row][coord1_target_err].values,
668 cat_target.iloc[matched_row][coord2_target_err].values)
669 cat_target[column_dist], cat_target[column_dist_err] = dist, dist_err
671 # Slightly smelly hack for when a column (like distance) is already relative to truth
672 column_dummy = 'dummy'
673 cat_ref[column_dummy] = np.zeros_like(ref.coord1)
675 # Add a boolean column for whether a match is classified correctly
676 extended_ref = cat_ref[config.column_ref_extended]
677 if config.column_ref_extended_inverted:
678 extended_ref = 1 - extended_ref
680 extended_target = cat_target[config.column_target_extended].values >= config.extendedness_cut
682 # Define difference/chi columns and statistics thereof
683 suffixes = {MeasurementType.DIFF: 'diff', MeasurementType.CHI: 'chi'}
684 # Skip diff for fluxes - covered by mags
685 suffixes_flux = {MeasurementType.CHI: suffixes[MeasurementType.CHI]}
686 # Skip chi for magnitudes, which have strange errors
687 suffixes_mag = {MeasurementType.DIFF: suffixes[MeasurementType.DIFF]}
688 stats = {stat.name_short(): stat() for stat in (Median, SigmaIQR, SigmaMAD)}
690 for percentile in self.config.percentiles:
691 stat = Percentile(percentile=float(Decimal(percentile)))
692 stats[stat.name_short()] = stat
694 # Get dict of column names
695 columns, n_models = _get_columns(
696 bands_columns=config.columns_flux,
697 suffixes=suffixes,
698 suffixes_flux=suffixes_flux,
699 suffixes_mag=suffixes_mag,
700 stats=stats,
701 target=target,
702 column_dist=column_dist,
703 )
705 # Setup numpy table
706 n_bins = config.mag_num_bins
707 data = np.zeros((n_bins,), dtype=[(key, value) for key, value in columns.items()])
708 data['bin'] = np.arange(n_bins)
710 # Setup bins
711 bins_mag = np.linspace(start=config.mag_brightest_ref, stop=config.mag_faintest_ref,
712 num=n_bins + 1)
713 data['mag_min'] = bins_mag[:-1]
714 data['mag_max'] = bins_mag[1:]
715 bins_mag = tuple((bins_mag[idx], bins_mag[idx + 1]) for idx in range(n_bins))
717 # Define temporary columns for intermediate storage
718 column_mag_temp = 'mag_temp'
719 column_color_temp = 'color_temp'
720 column_color_err_temp = 'colorErr_temp'
721 flux_err_frac_prev = [None]*n_models
722 mag_prev = [None]*n_models
724 columns_target = {
725 target.column_coord1: (
726 ref.column_coord1, target.column_coord1, coord1_target_err, False,
727 ),
728 target.column_coord2: (
729 ref.column_coord2, target.column_coord2, coord2_target_err, False,
730 ),
731 column_dist: (column_dummy, column_dist, column_dist_err, False),
732 }
734 # Cheat a little and do the first band last so that the color is
735 # based on the last band
736 band_fluxes = [(band, config_flux) for (band, config_flux) in config.columns_flux.items()]
737 n_bands = len(band_fluxes)
738 band_fluxes.append(band_fluxes[0])
739 flux_err_frac_first = None
740 mag_first = None
741 mag_ref_first = None
743 band_prev = None
744 for idx_band, (band, config_flux) in enumerate(band_fluxes):
745 if idx_band == n_bands:
746 # These were already computed earlier
747 mag_ref = mag_ref_first
748 flux_err_frac = flux_err_frac_first
749 mag_model = mag_first
750 else:
751 mag_ref = -2.5*np.log10(cat_ref[config_flux.column_ref_flux]) + config.mag_zeropoint_ref
752 flux_err_frac = [None]*n_models
753 mag_model = [None]*n_models
755 if idx_band > 0:
756 cat_ref[column_color_temp] = cat_ref[column_mag_temp] - mag_ref
758 cat_ref[column_mag_temp] = mag_ref
760 select_ref_bins = [select_ref & (mag_ref > mag_lo) & (mag_ref < mag_hi)
761 for idx_bin, (mag_lo, mag_hi) in enumerate(bins_mag)]
763 # Iterate over multiple models, compute their mags and colours (if there's a previous band)
764 for idx_model in range(n_models):
765 column_target_flux = config_flux.columns_target_flux[idx_model]
766 column_target_flux_err = config_flux.columns_target_flux_err[idx_model]
768 flux_target = cat_target[column_target_flux]
769 mag_target = -2.5*np.log10(flux_target) + config.mag_zeropoint_target
770 if config.mag_ceiling_target is not None:
771 mag_target[mag_target > config.mag_ceiling_target] = config.mag_ceiling_target
772 mag_model[idx_model] = mag_target
774 # These are needed for computing magnitude/color "errors" (which are a sketchy concept)
775 flux_err_frac[idx_model] = cat_target[column_target_flux_err]/flux_target
777 # Stop if idx == 0: The rest will be picked up at idx == n_bins
778 if idx_band > 0:
779 # Keep these mags tabulated for convenience
780 column_mag_temp_model = f'{column_mag_temp}{idx_model}'
781 cat_target[column_mag_temp_model] = mag_target
783 columns_target[f'flux_{column_target_flux}'] = (
784 config_flux.column_ref_flux,
785 column_target_flux,
786 column_target_flux_err,
787 True,
788 )
789 # Note: magnitude errors are generally problematic and not worth aggregating
790 columns_target[f'mag_{column_target_flux}'] = (
791 column_mag_temp, column_mag_temp_model, None, False,
792 )
794 # No need for colors if this is the last band and there are only two bands
795 # (because it would just be the negative of the first color)
796 skip_color = (idx_band == n_bands) and (n_bands <= 2)
797 if not skip_color:
798 column_color_temp_model = f'{column_color_temp}{idx_model}'
799 column_color_err_temp_model = f'{column_color_err_temp}{idx_model}'
801 # e.g. if order is ugrizy, first color will be u - g
802 cat_target[column_color_temp_model] = mag_prev[idx_model] - mag_model[idx_model]
804 # Sum (in quadrature, and admittedly sketchy for faint fluxes) magnitude errors
805 cat_target[column_color_err_temp_model] = 2.5/np.log(10)*np.hypot(
806 flux_err_frac[idx_model], flux_err_frac_prev[idx_model])
807 columns_target[f'color_{band_prev}_m_{band}_{column_target_flux}'] = (
808 column_color_temp,
809 column_color_temp_model,
810 column_color_err_temp_model,
811 False,
812 )
814 for idx_bin, (mag_lo, mag_hi) in enumerate(bins_mag):
815 row = data[idx_bin]
816 # Reference sources only need to be counted once
817 if idx_model == 0:
818 select_ref_bin = select_ref_bins[idx_bin]
819 select_target_bin = select_target & (mag_target > mag_lo) & (mag_target < mag_hi)
821 for sourcetype in SourceType:
822 sourcetype_info = sourcetype.value
823 is_extended = sourcetype_info.is_extended
824 # Counts filtered by match selection and magnitude bin
825 select_ref_sub = select_ref_bin.copy()
826 select_target_sub = select_target_bin.copy()
827 if is_extended is not None:
828 is_extended_ref = (extended_ref == is_extended)
829 select_ref_sub &= is_extended_ref
830 if idx_model == 0:
831 n_ref_sub = np.count_nonzero(select_ref_sub)
832 row[_get_column_name(band, sourcetype_info.label, 'n_ref',
833 MatchType.ALL.value)] = n_ref_sub
834 select_target_sub &= (extended_target == is_extended)
835 n_target_sub = np.count_nonzero(select_target_sub)
836 row[_get_column_name(band, sourcetype_info.label, 'n_target',
837 MatchType.ALL.value)] = n_target_sub
839 # Filter matches by magnitude bin and true class
840 match_row_bin = match_row.copy()
841 match_row_bin[~select_ref_sub] = -1
842 match_good = match_row_bin >= 0
844 n_match = np.count_nonzero(match_good)
846 # Same for counts of matched target sources (for e.g. purity)
848 if n_match > 0:
849 rows_matched = match_row_bin[match_good]
850 subset_target = cat_target.iloc[rows_matched]
851 if (is_extended is not None) and (idx_model == 0):
852 right_type = extended_target[rows_matched] == is_extended
853 n_total = len(right_type)
854 n_right = np.count_nonzero(right_type)
855 row[_get_column_name(band, sourcetype_info.label, 'n_ref',
856 MatchType.MATCH_RIGHT.value)] = n_right
857 row[_get_column_name(
858 band, sourcetype_info.label, 'n_ref', MatchType.MATCH_WRONG.value,
859 )] = n_total - n_right
861 # compute stats for this bin, for all columns
862 for column, (column_ref, column_target, column_err_target, skip_diff) \
863 in columns_target.items():
864 values_ref = cat_ref[column_ref][match_good].values
865 errors_target = (
866 subset_target[column_err_target].values
867 if column_err_target is not None
868 else None
869 )
870 compute_stats(
871 values_ref,
872 subset_target[column_target].values,
873 errors_target,
874 row,
875 stats,
876 suffixes,
877 prefix=f'{band}_{sourcetype_info.label}_{column}',
878 skip_diff=skip_diff,
879 )
881 # Count matched target sources with *measured* mags within bin
882 # Used for e.g. purity calculation
883 # Should be merged with above code if there's ever a need for
884 # measuring stats on this source selection
885 select_target_sub &= matched_target
887 if is_extended is not None and (np.count_nonzero(select_target_sub) > 0):
888 n_total = np.count_nonzero(select_target_sub)
889 right_type = np.zeros(n_target, dtype=bool)
890 right_type[match_row[matched_ref & is_extended_ref]] = True
891 right_type &= select_target_sub
892 n_right = np.count_nonzero(right_type)
893 row[_get_column_name(band, sourcetype_info.label, 'n_target',
894 MatchType.MATCH_RIGHT.value)] = n_right
895 row[_get_column_name(band, sourcetype_info.label, 'n_target',
896 MatchType.MATCH_WRONG.value)] = n_total - n_right
898 # delete the flux/color columns since they change with each band
899 for prefix in ('flux', 'mag'):
900 del columns_target[f'{prefix}_{column_target_flux}']
901 if not skip_color:
902 del columns_target[f'color_{band_prev}_m_{band}_{column_target_flux}']
904 # keep values needed for colors
905 flux_err_frac_prev = flux_err_frac
906 mag_prev = mag_model
907 band_prev = band
908 if idx_band == 0:
909 flux_err_frac_first = flux_err_frac
910 mag_first = mag_model
911 mag_ref_first = mag_ref
913 retStruct = pipeBase.Struct(cat_matched=cat_matched, diff_matched=pd.DataFrame(data))
914 return retStruct