Coverage for python/lsst/pipe/tasks/propagateSourceFlags.py: 15%

89 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-17 10:14 +0000

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["PropagateSourceFlagsConfig", "PropagateSourceFlagsTask"] 

23 

24import numpy as np 

25 

26from smatch.matcher import Matcher 

27 

28import lsst.pex.config as pexConfig 

29import lsst.pipe.base as pipeBase 

30 

31 

32class PropagateSourceFlagsConfig(pexConfig.Config): 

33 """Configuration for propagating source flags to coadd objects.""" 

34 source_flags = pexConfig.DictField( 

35 keytype=str, 

36 itemtype=float, 

37 default={ 

38 "calib_astrometry_used": 0.2, 

39 "calib_photometry_used": 0.2, 

40 "calib_photometry_reserved": 0.2 

41 }, 

42 doc=("Source flags to propagate, with the threshold of relative occurrence " 

43 "(valid range: [0-1]). Coadd object will have flag set if fraction " 

44 "of input visits in which it is flagged is greater than the threshold."), 

45 ) 

46 finalized_source_flags = pexConfig.DictField( 

47 keytype=str, 

48 itemtype=float, 

49 default={ 

50 "calib_psf_candidate": 0.2, 

51 "calib_psf_used": 0.2, 

52 "calib_psf_reserved": 0.2 

53 }, 

54 doc=("Finalized source flags to propagate, with the threshold of relative " 

55 "occurrence (valid range: [0-1]). Coadd object will have flag set if " 

56 "fraction of input visits in which it is flagged is greater than the " 

57 "threshold."), 

58 ) 

59 x_column = pexConfig.Field( 

60 doc="Name of column with source x position (sourceTable_visit).", 

61 dtype=str, 

62 default="x", 

63 ) 

64 y_column = pexConfig.Field( 

65 doc="Name of column with source y position (sourceTable_visit).", 

66 dtype=str, 

67 default="y", 

68 ) 

69 finalized_x_column = pexConfig.Field( 

70 doc="Name of column with source x position (finalized_src_table).", 

71 dtype=str, 

72 default="slot_Centroid_x", 

73 ) 

74 finalized_y_column = pexConfig.Field( 

75 doc="Name of column with source y position (finalized_src_table).", 

76 dtype=str, 

77 default="slot_Centroid_y", 

78 ) 

79 match_radius = pexConfig.Field( 

80 dtype=float, 

81 default=0.2, 

82 doc="Source matching radius (arcsec)" 

83 ) 

84 

85 def validate(self): 

86 super().validate() 

87 

88 if set(self.source_flags).intersection(set(self.finalized_source_flags)): 

89 source_flags = self.source_flags.keys() 

90 finalized_source_flags = self.finalized_source_flags.keys() 

91 raise ValueError(f"The set of source_flags {source_flags} must not overlap " 

92 f"with the finalized_source_flags {finalized_source_flags}") 

93 

94 

95class PropagateSourceFlagsTask(pipeBase.Task): 

96 """Task to propagate source flags to coadd objects. 

97 

98 Flagged sources may come from a mix of two different types of source catalogs. 

99 The source_table catalogs from ``CalibrateTask`` contain flags for the first 

100 round of astromety/photometry/psf fits. 

101 The finalized_source_table catalogs from ``FinalizeCalibrationTask`` contain 

102 flags from the second round of psf fitting. 

103 """ 

104 ConfigClass = PropagateSourceFlagsConfig 

105 

106 def __init__(self, schema, **kwargs): 

107 pipeBase.Task.__init__(self, **kwargs) 

108 

109 self.schema = schema 

110 for f in self.config.source_flags: 

111 self.schema.addField(f, type="Flag", doc="Propagated from sources") 

112 for f in self.config.finalized_source_flags: 

113 self.schema.addField(f, type="Flag", doc="Propagated from finalized sources") 

114 

115 def run(self, coadd_object_cat, ccd_inputs, 

116 source_table_handle_dict=None, finalized_source_table_handle_dict=None): 

117 """Propagate flags from single-frame sources to coadd objects. 

118 

119 Flags are only propagated if a configurable percentage of the sources 

120 are matched to the coadd objects. This task will match both "plain" 

121 source flags and "finalized" source flags. 

122 

123 Parameters 

124 ---------- 

125 coadd_object_cat : `lsst.afw.table.SourceCatalog` 

126 Table of coadd objects. 

127 ccd_inputs : `lsst.afw.table.ExposureCatalog` 

128 Table of single-frame inputs to coadd. 

129 source_table_handle_dict : `dict` [`int`: `lsst.daf.butler.DeferredDatasetHandle`] 

130 Dict for sourceTable_visit handles (key is visit). May be None if 

131 ``config.source_flags`` has no entries. 

132 finalized_source_table_handle_dict : `dict` [`int`: 

133 `lsst.daf.butler.DeferredDatasetHandle`] 

134 Dict for finalized_src_table handles (key is visit). May be None if 

135 ``config.finalized_source_flags`` has no entries. 

136 """ 

137 if len(self.config.source_flags) == 0 and len(self.config.finalized_source_flags) == 0: 

138 return 

139 

140 source_columns = self._get_source_table_column_names( 

141 self.config.x_column, 

142 self.config.y_column, 

143 self.config.source_flags.keys() 

144 ) 

145 finalized_columns = self._get_source_table_column_names( 

146 self.config.finalized_x_column, 

147 self.config.finalized_y_column, 

148 self.config.finalized_source_flags.keys(), 

149 ) 

150 

151 # We need the number of overlaps of individual detectors for each coadd source. 

152 # The following code is slow and inefficient, but can be made simpler in the future 

153 # case of cell-based coadds and so optimizing usage in afw is not a priority. 

154 num_overlaps = np.zeros(len(coadd_object_cat), dtype=np.int32) 

155 for i, obj in enumerate(coadd_object_cat): 

156 num_overlaps[i] = len(ccd_inputs.subsetContaining(obj.getCoord(), True)) 

157 

158 visits = np.unique(ccd_inputs["visit"]) 

159 

160 matcher = Matcher(np.rad2deg(coadd_object_cat["coord_ra"]), 

161 np.rad2deg(coadd_object_cat["coord_dec"])) 

162 

163 source_flag_counts = {f: np.zeros(len(coadd_object_cat), dtype=np.int32) 

164 for f in self.config.source_flags} 

165 finalized_source_flag_counts = {f: np.zeros(len(coadd_object_cat), dtype=np.int32) 

166 for f in self.config.finalized_source_flags} 

167 

168 handles_list = [source_table_handle_dict, finalized_source_table_handle_dict] 

169 columns_list = [source_columns, finalized_columns] 

170 counts_list = [source_flag_counts, finalized_source_flag_counts] 

171 x_column_list = [self.config.x_column, self.config.finalized_x_column] 

172 y_column_list = [self.config.y_column, self.config.finalized_y_column] 

173 name_list = ["sources", "finalized_sources"] 

174 

175 for handle_dict, columns, flag_counts, x_col, y_col, name in zip(handles_list, 

176 columns_list, 

177 counts_list, 

178 x_column_list, 

179 y_column_list, 

180 name_list): 

181 if handle_dict is not None and len(columns) > 0: 

182 for visit in visits: 

183 if visit not in handle_dict: 

184 self.log.info("Visit %d not in input handle dict for %s", visit, name) 

185 continue 

186 handle = handle_dict[visit] 

187 df = handle.get(parameters={"columns": columns}) 

188 

189 # Loop over all ccd_inputs rows for this visit. 

190 for row in ccd_inputs[ccd_inputs["visit"] == visit]: 

191 detector = row["ccd"] 

192 wcs = row.getWcs() 

193 if wcs is None: 

194 self.log.info("No WCS for visit %d detector %d, so can't match sources to " 

195 "propagate flags. Skipping...", visit, detector) 

196 continue 

197 

198 df_det = df[df["detector"] == detector] 

199 

200 if len(df_det) == 0: 

201 continue 

202 

203 ra, dec = wcs.pixelToSkyArray(df_det[x_col].values, 

204 df_det[y_col].values, 

205 degrees=True) 

206 

207 try: 

208 # The output from the matcher links 

209 # coadd_object_cat[i1] <-> df_det[i2] 

210 # All objects within the match radius are matched. 

211 idx, i1, i2, d = matcher.query_radius( 

212 ra, 

213 dec, 

214 self.config.match_radius/3600., 

215 return_indices=True 

216 ) 

217 except IndexError: 

218 # No matches. Workaround a bug in older version of smatch. 

219 self.log.info("Visit %d has no overlapping objects", visit) 

220 continue 

221 

222 if len(i1) == 0: 

223 # No matches (usually because detector does not overlap patch). 

224 self.log.info("Visit %d has no overlapping objects", visit) 

225 continue 

226 

227 for flag in flag_counts: 

228 flag_values = df_det[flag].values 

229 flag_counts[flag][i1] += flag_values[i2].astype(np.int32) 

230 

231 for flag in source_flag_counts: 

232 thresh = num_overlaps*self.config.source_flags[flag] 

233 object_flag = (source_flag_counts[flag] > thresh) 

234 coadd_object_cat[flag] = object_flag 

235 self.log.info("Propagated %d sources with flag %s", object_flag.sum(), flag) 

236 

237 for flag in finalized_source_flag_counts: 

238 thresh = num_overlaps*self.config.finalized_source_flags[flag] 

239 object_flag = (finalized_source_flag_counts[flag] > thresh) 

240 coadd_object_cat[flag] = object_flag 

241 self.log.info("Propagated %d finalized sources with flag %s", object_flag.sum(), flag) 

242 

243 def _get_source_table_column_names(self, x_column, y_column, flags): 

244 """Get the list of source table columns from the config. 

245 

246 Parameters 

247 ---------- 

248 x_column : `str` 

249 Name of column with x centroid. 

250 y_column : `str` 

251 Name of column with y centroid. 

252 flags : `list` [`str`] 

253 List of flags to retrieve. 

254 

255 Returns 

256 ------- 

257 columns : [`list`] [`str`] 

258 Columns to read. 

259 """ 

260 columns = ["visit", "detector", 

261 x_column, y_column] 

262 columns.extend(flags) 

263 

264 return columns