Coverage for python/lsst/pipe/tasks/match_tract_catalog.py: 69%
58 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-12 01:56 -0700
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-12 01:56 -0700
1# This file is part of pipe_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = [
23 'MatchTractCatalogSubConfig', 'MatchTractCatalogSubTask',
24 'MatchTractCatalogConfig', 'MatchTractCatalogTask'
25]
27import lsst.afw.geom as afwGeom
28import lsst.pex.config as pexConfig
29import lsst.pipe.base as pipeBase
30import lsst.pipe.base.connectionTypes as cT
31from lsst.skymap import BaseSkyMap
33from abc import ABC, abstractmethod
35import pandas as pd
36from typing import Tuple, Set
39MatchTractCatalogBaseTemplates = {
40 "name_input_cat_ref": "truth_summary",
41 "name_input_cat_target": "objectTable_tract",
42}
45class MatchTractCatalogConnections(
46 pipeBase.PipelineTaskConnections,
47 dimensions=("tract", "skymap"),
48 defaultTemplates=MatchTractCatalogBaseTemplates,
49):
50 cat_ref = cT.Input(
51 doc="Reference object catalog to match from",
52 name="{name_input_cat_ref}",
53 storageClass="DataFrame",
54 dimensions=("tract", "skymap"),
55 deferLoad=True,
56 )
57 cat_target = cT.Input(
58 doc="Target object catalog to match",
59 name="{name_input_cat_target}",
60 storageClass="DataFrame",
61 dimensions=("tract", "skymap"),
62 deferLoad=True,
63 )
64 skymap = cT.Input(
65 doc="Input definition of geometry/bbox and projection/wcs for coadded exposures",
66 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
67 storageClass="SkyMap",
68 dimensions=("skymap",),
69 )
70 # TODO: Change outputs to ArrowAstropy in DM-44159
71 cat_output_ref = cT.Output(
72 doc="Reference matched catalog with indices of target matches",
73 name="match_ref_{name_input_cat_ref}_{name_input_cat_target}",
74 storageClass="DataFrame",
75 dimensions=("tract", "skymap"),
76 )
77 cat_output_target = cT.Output(
78 doc="Target matched catalog with indices of reference matches",
79 name="match_target_{name_input_cat_ref}_{name_input_cat_target}",
80 storageClass="DataFrame",
81 dimensions=("tract", "skymap"),
82 )
85class MatchTractCatalogSubConfig(pexConfig.Config):
86 """Config class for the MatchTractCatalogSubTask to define methods returning
87 values that depend on multiple config settings.
88 """
89 @property
90 @abstractmethod
91 def columns_in_ref(self) -> Set[str]:
92 raise NotImplementedError()
94 @property
95 @abstractmethod
96 def columns_in_target(self) -> Set[str]:
97 raise NotImplementedError()
100class MatchTractCatalogSubTask(pipeBase.Task, ABC):
101 """An abstract interface for subtasks of MatchTractCatalogTask to match
102 two tract object catalogs.
104 Parameters
105 ----------
106 **kwargs
107 Additional arguments to be passed to the `lsst.pipe.base.Task`
108 constructor.
109 """
110 ConfigClass = MatchTractCatalogSubConfig
112 def __init__(self, **kwargs):
113 super().__init__(**kwargs)
115 @abstractmethod
116 def run(
117 self,
118 catalog_ref: pd.DataFrame,
119 catalog_target: pd.DataFrame,
120 wcs: afwGeom.SkyWcs = None,
121 ) -> pipeBase.Struct:
122 """Match sources in a reference tract catalog with a target catalog.
124 Parameters
125 ----------
126 catalog_ref : `pandas.DataFrame`
127 A reference catalog to match objects/sources from.
128 catalog_target : `pandas.DataFrame`
129 A target catalog to match reference objects/sources to.
130 wcs : `lsst.afw.image.SkyWcs`
131 A coordinate system to convert catalog positions to sky coordinates.
133 Returns
134 -------
135 retStruct : `lsst.pipe.base.Struct`
136 A struct with output_ref and output_target attribute containing the
137 output matched catalogs.
138 """
139 raise NotImplementedError()
142class MatchTractCatalogConfig(
143 pipeBase.PipelineTaskConfig,
144 pipelineConnections=MatchTractCatalogConnections,
145):
146 """Configure a MatchTractCatalogTask, including a configurable matching subtask.
147 """
148 match_tract_catalog = pexConfig.ConfigurableField(
149 target=MatchTractCatalogSubTask,
150 doc="Task to match sources in a reference tract catalog with a target catalog",
151 )
153 def get_columns_in(self) -> Tuple[Set, Set]:
154 """Get the set of input columns required for matching.
156 Returns
157 -------
158 columns_ref : `set` [`str`]
159 The set of required input catalog column names.
160 columns_target : `set` [`str`]
161 The set of required target catalog column names.
162 """
163 try:
164 columns_ref, columns_target = (self.match_tract_catalog.columns_in_ref,
165 self.match_tract_catalog.columns_in_target)
166 except AttributeError as err:
167 raise RuntimeError(f'{__class__}.match_tract_catalog must have columns_in_ref and'
168 f' columns_in_target attributes: {err}') from None
169 return set(columns_ref), set(columns_target)
172class MatchTractCatalogTask(pipeBase.PipelineTask):
173 """Match sources in a reference tract catalog with those in a target catalog.
174 """
175 ConfigClass = MatchTractCatalogConfig
176 _DefaultName = "MatchTractCatalog"
178 def __init__(self, initInputs, **kwargs):
179 super().__init__(initInputs=initInputs, **kwargs)
180 self.makeSubtask("match_tract_catalog")
182 def runQuantum(self, butlerQC, inputRefs, outputRefs):
183 inputs = butlerQC.get(inputRefs)
184 columns_ref, columns_target = self.config.get_columns_in()
185 skymap = inputs.pop("skymap")
187 outputs = self.run(
188 catalog_ref=inputs['cat_ref'].get(parameters={'columns': columns_ref}),
189 catalog_target=inputs['cat_target'].get(parameters={'columns': columns_target}),
190 wcs=skymap[butlerQC.quantum.dataId["tract"]].wcs,
191 )
192 butlerQC.put(outputs, outputRefs)
194 def run(
195 self,
196 catalog_ref: pd.DataFrame,
197 catalog_target: pd.DataFrame,
198 wcs: afwGeom.SkyWcs = None,
199 ) -> pipeBase.Struct:
200 """Match sources in a reference tract catalog with a target catalog.
202 Parameters
203 ----------
204 catalog_ref : `pandas.DataFrame`
205 A reference catalog to match objects/sources from.
206 catalog_target : `pandas.DataFrame`
207 A target catalog to match reference objects/sources to.
208 wcs : `lsst.afw.image.SkyWcs`
209 A coordinate system to convert catalog positions to sky coordinates,
210 if necessary.
212 Returns
213 -------
214 retStruct : `lsst.pipe.base.Struct`
215 A struct with output_ref and output_target attribute containing the
216 output matched catalogs.
217 """
218 output = self.match_tract_catalog.run(catalog_ref, catalog_target, wcs=wcs)
219 if output.exceptions:
220 self.log.warn('Exceptions: %s', output.exceptions)
221 retStruct = pipeBase.Struct(cat_output_ref=output.cat_output_ref,
222 cat_output_target=output.cat_output_target)
223 return retStruct