Coverage for python / lsst / pipe / tasks / propagateSourceFlags.py: 0%

100 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-22 08:53 +0000

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["PropagateSourceFlagsConfig", "PropagateSourceFlagsTask"] 

23 

24import numpy as np 

25 

26from smatch.matcher import Matcher 

27 

28import lsst.pex.config as pexConfig 

29import lsst.pipe.base as pipeBase 

30from lsst.afw.table import ExposureCatalog 

31 

32 

33class PropagateSourceFlagsConfig(pexConfig.Config): 

34 """Configuration for propagating source flags to coadd objects.""" 

35 source_flags = pexConfig.DictField( 

36 keytype=str, 

37 itemtype=float, 

38 default={ 

39 "calib_astrometry_used": 0.2, 

40 "calib_photometry_used": 0.2, 

41 "calib_photometry_reserved": 0.2 

42 }, 

43 doc=("Source flags to propagate, with the threshold of relative occurrence " 

44 "(valid range: [0-1]). Coadd object will have flag set if fraction " 

45 "of input visits in which it is flagged is greater than the threshold."), 

46 ) 

47 finalized_source_flags = pexConfig.DictField( 

48 keytype=str, 

49 itemtype=float, 

50 default={ 

51 "calib_psf_candidate": 0.2, 

52 "calib_psf_used": 0.2, 

53 "calib_psf_reserved": 0.2 

54 }, 

55 doc=("Finalized source flags to propagate, with the threshold of relative " 

56 "occurrence (valid range: [0-1]). Coadd object will have flag set if " 

57 "fraction of input visits in which it is flagged is greater than the " 

58 "threshold."), 

59 ) 

60 x_column = pexConfig.Field( 

61 doc="Name of column with source x position (sourceTable_visit).", 

62 dtype=str, 

63 default="x", 

64 ) 

65 y_column = pexConfig.Field( 

66 doc="Name of column with source y position (sourceTable_visit).", 

67 dtype=str, 

68 default="y", 

69 ) 

70 finalized_x_column = pexConfig.Field( 

71 doc="Name of column with source x position (finalized_src_table).", 

72 dtype=str, 

73 default="slot_Centroid_x", 

74 ) 

75 finalized_y_column = pexConfig.Field( 

76 doc="Name of column with source y position (finalized_src_table).", 

77 dtype=str, 

78 default="slot_Centroid_y", 

79 ) 

80 match_radius = pexConfig.Field( 

81 dtype=float, 

82 default=0.2, 

83 doc="Source matching radius (arcsec)" 

84 ) 

85 

86 def validate(self): 

87 super().validate() 

88 

89 if set(self.source_flags).intersection(set(self.finalized_source_flags)): 

90 source_flags = self.source_flags.keys() 

91 finalized_source_flags = self.finalized_source_flags.keys() 

92 raise ValueError(f"The set of source_flags {source_flags} must not overlap " 

93 f"with the finalized_source_flags {finalized_source_flags}") 

94 

95 

96class PropagateSourceFlagsTask(pipeBase.Task): 

97 """Task to propagate source flags to coadd objects. 

98 

99 Flagged sources may come from a mix of two different types of source catalogs. 

100 The source_table catalogs from ``CalibrateTask`` contain flags for the first 

101 round of astromety/photometry/psf fits. 

102 The finalized_source_table catalogs from ``FinalizeCalibrationTask`` contain 

103 flags from the second round of psf fitting. 

104 """ 

105 ConfigClass = PropagateSourceFlagsConfig 

106 

107 def __init__(self, schema, **kwargs): 

108 pipeBase.Task.__init__(self, **kwargs) 

109 

110 self.schema = schema 

111 for f in self.config.source_flags: 

112 self.schema.addField(f, type="Flag", doc="Propagated from sources") 

113 for f in self.config.finalized_source_flags: 

114 self.schema.addField(f, type="Flag", doc="Propagated from finalized sources") 

115 

116 def run(self, coadd_object_cat, ccd_inputs, 

117 source_table_handle_dict=None, finalized_source_table_handle_dict=None, 

118 visit_summary_handle_dict=None): 

119 """Propagate flags from single-frame sources to coadd objects. 

120 

121 Flags are only propagated if a configurable percentage of the sources 

122 are matched to the coadd objects. This task will match both "plain" 

123 source flags and "finalized" source flags. 

124 

125 Parameters 

126 ---------- 

127 coadd_object_cat : `lsst.afw.table.SourceCatalog` 

128 Table of coadd objects. 

129 ccd_inputs : `lsst.afw.table.ExposureCatalog` 

130 Table of single-frame inputs to coadd. 

131 source_table_handle_dict : `dict` [`int`: `lsst.daf.butler.DeferredDatasetHandle`], optional 

132 Dict for sourceTable_visit handles (key is visit). May be None if 

133 ``config.source_flags`` has no entries. 

134 finalized_source_table_handle_dict : `dict` [`int`: 

135 `lsst.daf.butler.DeferredDatasetHandle`], optional 

136 Dict for finalized_src_table handles (key is visit). May be None if 

137 ``config.finalized_source_flags`` has no entries. 

138 visit_summary_handle_dict : `dict` [`int`: `lsst.daf.butler.DeferredDatasetHandle`], optional 

139 Dict for visitSummary handles (key is visit). If None, using WCS 

140 from the ccd_inputs will be attempted. 

141 """ 

142 if len(self.config.source_flags) == 0 and len(self.config.finalized_source_flags) == 0: 

143 return 

144 

145 source_columns = self._get_source_table_column_names( 

146 self.config.x_column, 

147 self.config.y_column, 

148 self.config.source_flags.keys() 

149 ) 

150 finalized_columns = self._get_source_table_column_names( 

151 self.config.finalized_x_column, 

152 self.config.finalized_y_column, 

153 self.config.finalized_source_flags.keys(), 

154 ) 

155 

156 # We need the number of overlaps of individual detectors for each coadd source. 

157 # The following code is slow and inefficient, but can be made simpler in the 

158 # case of cell-based coadds and so optimizing usage in afw is not a priority. 

159 num_overlaps = np.zeros(len(coadd_object_cat), dtype=np.int32) 

160 if isinstance(ccd_inputs, ExposureCatalog): 

161 for i, obj in enumerate(coadd_object_cat): 

162 num_overlaps[i] = len(ccd_inputs.subsetContaining(obj.getCoord(), True)) 

163 

164 visits = np.unique(ccd_inputs["visit"]) 

165 else: # StitchedExposureCatalog 

166 for i, obj in enumerate(coadd_object_cat): 

167 # The cell-based coadd inputs can be queried by centroid 

168 # on the coadd instead of sky coordinates. 

169 num_overlaps[i] = len(ccd_inputs.subsetContaining(obj.getCentroid())) 

170 

171 visits = np.unique([ccd_input.visit for ccd_input in ccd_inputs]) 

172 

173 matcher = Matcher(np.rad2deg(coadd_object_cat["coord_ra"]), 

174 np.rad2deg(coadd_object_cat["coord_dec"])) 

175 

176 source_flag_counts = {f: np.zeros(len(coadd_object_cat), dtype=np.int32) 

177 for f in self.config.source_flags} 

178 finalized_source_flag_counts = {f: np.zeros(len(coadd_object_cat), dtype=np.int32) 

179 for f in self.config.finalized_source_flags} 

180 

181 handles_list = [source_table_handle_dict, finalized_source_table_handle_dict] 

182 columns_list = [source_columns, finalized_columns] 

183 counts_list = [source_flag_counts, finalized_source_flag_counts] 

184 x_column_list = [self.config.x_column, self.config.finalized_x_column] 

185 y_column_list = [self.config.y_column, self.config.finalized_y_column] 

186 name_list = ["sources", "finalized_sources"] 

187 

188 for handle_dict, columns, flag_counts, x_col, y_col, name in zip(handles_list, 

189 columns_list, 

190 counts_list, 

191 x_column_list, 

192 y_column_list, 

193 name_list): 

194 if handle_dict is not None and len(columns) > 0: 

195 for visit in visits: 

196 if visit not in handle_dict: 

197 self.log.info("Visit %d not in input handle dict for %s", visit, name) 

198 continue 

199 handle = handle_dict[visit] 

200 tbl = handle.get(parameters={"columns": columns}) 

201 if visit_summary_handle_dict is not None: 

202 visit_summary = visit_summary_handle_dict[visit].get() 

203 

204 # Loop over all ccd_inputs rows for this visit. 

205 for row in ccd_inputs: 

206 if row["visit"] != visit: 

207 continue 

208 detector = row["ccd"] 

209 if visit_summary_handle_dict is None: 

210 wcs = row.getWcs() 

211 else: 

212 wcs = visit_summary.find(detector).getWcs() 

213 if wcs is None: 

214 self.log.info("No WCS for visit %d detector %d, so can't match sources to " 

215 "propagate flags. Skipping...", visit, detector) 

216 continue 

217 

218 tbl_det = tbl[tbl["detector"] == detector] 

219 

220 if len(tbl_det) == 0: 

221 continue 

222 

223 ra, dec = wcs.pixelToSkyArray(np.asarray(tbl_det[x_col]), 

224 np.asarray(tbl_det[y_col]), 

225 degrees=True) 

226 

227 try: 

228 # The output from the matcher links 

229 # coadd_object_cat[i1] <-> df_det[i2] 

230 # All objects within the match radius are matched. 

231 idx, i1, i2, d = matcher.query_radius( 

232 ra, 

233 dec, 

234 self.config.match_radius/3600., 

235 return_indices=True 

236 ) 

237 except IndexError: 

238 # No matches. Workaround a bug in older version of smatch. 

239 self.log.info("Visit %d has no overlapping objects", visit) 

240 continue 

241 

242 if len(i1) == 0: 

243 # No matches (usually because detector does not overlap patch). 

244 self.log.info("Visit %d has no overlapping objects", visit) 

245 continue 

246 

247 for flag in flag_counts: 

248 flag_values = np.asarray(tbl_det[flag]) 

249 flag_counts[flag][i1] += flag_values[i2].astype(np.int32) 

250 

251 for flag in source_flag_counts: 

252 thresh = num_overlaps*self.config.source_flags[flag] 

253 object_flag = (source_flag_counts[flag] > thresh) 

254 coadd_object_cat[flag] = object_flag 

255 self.log.info("Propagated %d sources with flag %s", object_flag.sum(), flag) 

256 

257 for flag in finalized_source_flag_counts: 

258 thresh = num_overlaps*self.config.finalized_source_flags[flag] 

259 object_flag = (finalized_source_flag_counts[flag] > thresh) 

260 coadd_object_cat[flag] = object_flag 

261 self.log.info("Propagated %d finalized sources with flag %s", object_flag.sum(), flag) 

262 

263 def _get_source_table_column_names(self, x_column, y_column, flags): 

264 """Get the list of source table columns from the config. 

265 

266 Parameters 

267 ---------- 

268 x_column : `str` 

269 Name of column with x centroid. 

270 y_column : `str` 

271 Name of column with y centroid. 

272 flags : `list` [`str`] 

273 List of flags to retrieve. 

274 

275 Returns 

276 ------- 

277 columns : [`list`] [`str`] 

278 Columns to read. 

279 """ 

280 columns = ["visit", "detector", 

281 x_column, y_column] 

282 columns.extend(flags) 

283 

284 return columns