Coverage for python/lsst/pipe/tasks/match_tract_catalog.py: 69%

58 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-16 11:05 +0000

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = [ 

23 'MatchTractCatalogSubConfig', 'MatchTractCatalogSubTask', 

24 'MatchTractCatalogConfig', 'MatchTractCatalogTask' 

25] 

26 

27import lsst.afw.geom as afwGeom 

28import lsst.pex.config as pexConfig 

29import lsst.pipe.base as pipeBase 

30import lsst.pipe.base.connectionTypes as cT 

31from lsst.skymap import BaseSkyMap 

32 

33from abc import ABC, abstractmethod 

34 

35import pandas as pd 

36from typing import Tuple, Set 

37 

38 

39MatchTractCatalogBaseTemplates = { 

40 "name_input_cat_ref": "truth_summary", 

41 "name_input_cat_target": "objectTable_tract", 

42} 

43 

44 

45class MatchTractCatalogConnections( 

46 pipeBase.PipelineTaskConnections, 

47 dimensions=("tract", "skymap"), 

48 defaultTemplates=MatchTractCatalogBaseTemplates, 

49): 

50 cat_ref = cT.Input( 

51 doc="Reference object catalog to match from", 

52 name="{name_input_cat_ref}", 

53 storageClass="DataFrame", 

54 dimensions=("tract", "skymap"), 

55 deferLoad=True, 

56 ) 

57 cat_target = cT.Input( 

58 doc="Target object catalog to match", 

59 name="{name_input_cat_target}", 

60 storageClass="DataFrame", 

61 dimensions=("tract", "skymap"), 

62 deferLoad=True, 

63 ) 

64 skymap = cT.Input( 

65 doc="Input definition of geometry/bbox and projection/wcs for coadded exposures", 

66 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

67 storageClass="SkyMap", 

68 dimensions=("skymap",), 

69 ) 

70 # TODO: Change outputs to ArrowAstropy in DM-44159 

71 cat_output_ref = cT.Output( 

72 doc="Reference matched catalog with indices of target matches", 

73 name="match_ref_{name_input_cat_ref}_{name_input_cat_target}", 

74 storageClass="DataFrame", 

75 dimensions=("tract", "skymap"), 

76 ) 

77 cat_output_target = cT.Output( 

78 doc="Target matched catalog with indices of reference matches", 

79 name="match_target_{name_input_cat_ref}_{name_input_cat_target}", 

80 storageClass="DataFrame", 

81 dimensions=("tract", "skymap"), 

82 ) 

83 

84 

85class MatchTractCatalogSubConfig(pexConfig.Config): 

86 """Config class for the MatchTractCatalogSubTask to define methods returning 

87 values that depend on multiple config settings. 

88 """ 

89 @property 

90 @abstractmethod 

91 def columns_in_ref(self) -> Set[str]: 

92 raise NotImplementedError() 

93 

94 @property 

95 @abstractmethod 

96 def columns_in_target(self) -> Set[str]: 

97 raise NotImplementedError() 

98 

99 

100class MatchTractCatalogSubTask(pipeBase.Task, ABC): 

101 """An abstract interface for subtasks of MatchTractCatalogTask to match 

102 two tract object catalogs. 

103 

104 Parameters 

105 ---------- 

106 **kwargs 

107 Additional arguments to be passed to the `lsst.pipe.base.Task` 

108 constructor. 

109 """ 

110 ConfigClass = MatchTractCatalogSubConfig 

111 

112 def __init__(self, **kwargs): 

113 super().__init__(**kwargs) 

114 

115 @abstractmethod 

116 def run( 

117 self, 

118 catalog_ref: pd.DataFrame, 

119 catalog_target: pd.DataFrame, 

120 wcs: afwGeom.SkyWcs = None, 

121 ) -> pipeBase.Struct: 

122 """Match sources in a reference tract catalog with a target catalog. 

123 

124 Parameters 

125 ---------- 

126 catalog_ref : `pandas.DataFrame` 

127 A reference catalog to match objects/sources from. 

128 catalog_target : `pandas.DataFrame` 

129 A target catalog to match reference objects/sources to. 

130 wcs : `lsst.afw.image.SkyWcs` 

131 A coordinate system to convert catalog positions to sky coordinates. 

132 

133 Returns 

134 ------- 

135 retStruct : `lsst.pipe.base.Struct` 

136 A struct with output_ref and output_target attribute containing the 

137 output matched catalogs. 

138 """ 

139 raise NotImplementedError() 

140 

141 

142class MatchTractCatalogConfig( 

143 pipeBase.PipelineTaskConfig, 

144 pipelineConnections=MatchTractCatalogConnections, 

145): 

146 """Configure a MatchTractCatalogTask, including a configurable matching subtask. 

147 """ 

148 match_tract_catalog = pexConfig.ConfigurableField( 

149 target=MatchTractCatalogSubTask, 

150 doc="Task to match sources in a reference tract catalog with a target catalog", 

151 ) 

152 

153 def get_columns_in(self) -> Tuple[Set, Set]: 

154 """Get the set of input columns required for matching. 

155 

156 Returns 

157 ------- 

158 columns_ref : `set` [`str`] 

159 The set of required input catalog column names. 

160 columns_target : `set` [`str`] 

161 The set of required target catalog column names. 

162 """ 

163 try: 

164 columns_ref, columns_target = (self.match_tract_catalog.columns_in_ref, 

165 self.match_tract_catalog.columns_in_target) 

166 except AttributeError as err: 

167 raise RuntimeError(f'{__class__}.match_tract_catalog must have columns_in_ref and' 

168 f' columns_in_target attributes: {err}') from None 

169 return set(columns_ref), set(columns_target) 

170 

171 

172class MatchTractCatalogTask(pipeBase.PipelineTask): 

173 """Match sources in a reference tract catalog with those in a target catalog. 

174 """ 

175 ConfigClass = MatchTractCatalogConfig 

176 _DefaultName = "MatchTractCatalog" 

177 

178 def __init__(self, initInputs, **kwargs): 

179 super().__init__(initInputs=initInputs, **kwargs) 

180 self.makeSubtask("match_tract_catalog") 

181 

182 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

183 inputs = butlerQC.get(inputRefs) 

184 columns_ref, columns_target = self.config.get_columns_in() 

185 skymap = inputs.pop("skymap") 

186 

187 outputs = self.run( 

188 catalog_ref=inputs['cat_ref'].get(parameters={'columns': columns_ref}), 

189 catalog_target=inputs['cat_target'].get(parameters={'columns': columns_target}), 

190 wcs=skymap[butlerQC.quantum.dataId["tract"]].wcs, 

191 ) 

192 butlerQC.put(outputs, outputRefs) 

193 

194 def run( 

195 self, 

196 catalog_ref: pd.DataFrame, 

197 catalog_target: pd.DataFrame, 

198 wcs: afwGeom.SkyWcs = None, 

199 ) -> pipeBase.Struct: 

200 """Match sources in a reference tract catalog with a target catalog. 

201 

202 Parameters 

203 ---------- 

204 catalog_ref : `pandas.DataFrame` 

205 A reference catalog to match objects/sources from. 

206 catalog_target : `pandas.DataFrame` 

207 A target catalog to match reference objects/sources to. 

208 wcs : `lsst.afw.image.SkyWcs` 

209 A coordinate system to convert catalog positions to sky coordinates, 

210 if necessary. 

211 

212 Returns 

213 ------- 

214 retStruct : `lsst.pipe.base.Struct` 

215 A struct with output_ref and output_target attribute containing the 

216 output matched catalogs. 

217 """ 

218 output = self.match_tract_catalog.run(catalog_ref, catalog_target, wcs=wcs) 

219 if output.exceptions: 

220 self.log.warn('Exceptions: %s', output.exceptions) 

221 retStruct = pipeBase.Struct(cat_output_ref=output.cat_output_ref, 

222 cat_output_target=output.cat_output_target) 

223 return retStruct