Coverage for python / lsst / ap / association / filterDiaSourceCatalog.py: 30%

119 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-06 08:53 +0000

1# This file is part of ap_association 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ( 

23 "FilterDiaSourceCatalogConfig", 

24 "FilterDiaSourceCatalogTask", 

25 "FilterDiaSourceReliabilityConfig", 

26 "FilterDiaSourceReliabilityTask" 

27) 

28 

29import numpy as np 

30 

31import lsst.pex.config as pexConfig 

32import lsst.pipe.base as pipeBase 

33import lsst.pipe.base.connectionTypes as connTypes 

34from lsst.utils.timer import timeMethod 

35 

36 

37class FilterDiaSourceCatalogConnections( 

38 pipeBase.PipelineTaskConnections, 

39 dimensions=("instrument", "visit", "detector"), 

40 defaultTemplates={"coaddName": "deep", "fakesType": ""}, 

41): 

42 """Connections class for FilterDiaSourceCatalogTask.""" 

43 

44 diaSourceCat = connTypes.Input( 

45 doc="Catalog of DiaSources produced during image differencing.", 

46 name="{fakesType}{coaddName}Diff_diaSrc", 

47 storageClass="SourceCatalog", 

48 dimensions=("instrument", "visit", "detector"), 

49 ) 

50 

51 diffImVisitInfo = connTypes.Input( 

52 doc="VisitInfo of diffIm.", 

53 name="{fakesType}{coaddName}Diff_differenceExp.visitInfo", 

54 storageClass="VisitInfo", 

55 dimensions=("instrument", "visit", "detector"), 

56 ) 

57 

58 filteredDiaSourceCat = connTypes.Output( 

59 doc="Output catalog of DiaSources after filtering.", 

60 name="{fakesType}{coaddName}Diff_candidateDiaSrc", 

61 storageClass="SourceCatalog", 

62 dimensions=("instrument", "visit", "detector"), 

63 ) 

64 

65 rejectedDiaSources = connTypes.Output( 

66 doc="Optional output storing all the rejected DiaSources.", 

67 name="{fakesType}{coaddName}Diff_rejectedDiaSrc", 

68 storageClass="SourceCatalog", 

69 dimensions={"instrument", "visit", "detector"}, 

70 ) 

71 

72 longTrailedSources = connTypes.Output( 

73 doc="Optional output temporarily storing long trailed diaSources.", 

74 dimensions=("instrument", "visit", "detector"), 

75 storageClass="ArrowAstropy", 

76 name="{fakesType}{coaddName}Diff_longTrailedSrc", 

77 ) 

78 

79 def __init__(self, *, config=None): 

80 super().__init__(config=config) 

81 if not self.config.doWriteRejectedSkySources: 

82 self.outputs.remove("rejectedDiaSources") 

83 if not self.config.doTrailedSourceFilter: 

84 self.outputs.remove("longTrailedSources") 

85 if not self.config.doWriteTrailedSources: 

86 self.outputs.remove("longTrailedSources") 

87 

88 

89class FilterDiaSourceCatalogConfig( 

90 pipeBase.PipelineTaskConfig, pipelineConnections=FilterDiaSourceCatalogConnections 

91): 

92 """Config class for FilterDiaSourceCatalogTask.""" 

93 

94 doRemoveSkySources = pexConfig.Field( 

95 dtype=bool, 

96 default=False, 

97 doc="Input DiaSource catalog contains SkySources that should be " 

98 "removed before storing the output DiaSource catalog.", 

99 ) 

100 

101 doWriteRejectedSkySources = pexConfig.Field( 

102 dtype=bool, 

103 default=True, 

104 doc="Store the output DiaSource catalog containing all the rejected " 

105 "sky sources." 

106 ) 

107 

108 badFlagList = pexConfig.ListField( 

109 dtype=str, 

110 default=[ 

111 "slot_Centroid_flag", 

112 "base_PixelFlags_flag_crCenter", 

113 "base_PixelFlags_flag_high_varianceCenterAll" 

114 ], 

115 doc="List of flags which cause a source to be removed.", 

116 ) 

117 

118 doRemoveNegativeDirectImageSources = pexConfig.Field( 

119 dtype=bool, 

120 default=True, 

121 doc="Remove DIASources with negative scienceFlux/scienceFluxErr " 

122 "according to a configurable threshold.", 

123 ) 

124 

125 minAllowedDirectSnr = pexConfig.Field( 

126 dtype=float, 

127 doc="Minimum allowed ratio of scienceFlux/scienceFluxErr.", 

128 default=-2.0, 

129 ) 

130 

131 doTrailedSourceFilter = pexConfig.Field( 

132 doc="Run trailedSourceFilter to remove long trailed sources from the" 

133 "diaSource output catalog.", 

134 dtype=bool, 

135 default=True, 

136 ) 

137 

138 doWriteTrailedSources = pexConfig.Field( 

139 doc="Write trailed diaSources sources to a table.", 

140 dtype=bool, 

141 default=True, 

142 deprecated="Trailed sources will not be written out during production." 

143 ) 

144 

145 max_trail_length = pexConfig.Field( 

146 dtype=float, 

147 doc="Length of long trailed sources to remove from the input catalog, " 

148 "in arcseconds per second. Default comes from DMTN-199, which " 

149 "requires removal of sources with trails longer than 10 " 

150 "degrees/day, which is 36000/3600/24 arcsec/second, or roughly" 

151 "0.416 arcseconds per second.", 

152 default=36000/3600.0/24.0, 

153 ) 

154 estimatedPixelScale = pexConfig.Field( 

155 dtype=float, 

156 doc="Approximate plate scale, in arcseconds/pixel." 

157 "Used to convert trail length if the catalog calculation fails.", 

158 default=0.2, 

159 ) 

160 

161 

162class FilterDiaSourceCatalogTask(pipeBase.PipelineTask): 

163 """Filter sources from a DiaSource catalog.""" 

164 

165 ConfigClass = FilterDiaSourceCatalogConfig 

166 _DefaultName = "filterDiaSourceCatalog" 

167 

168 @timeMethod 

169 def run(self, diaSourceCat, diffImVisitInfo): 

170 """Filter sources from the supplied DiaSource catalog. 

171 

172 Parameters 

173 ---------- 

174 diaSourceCat : `lsst.afw.table.SourceCatalog` 

175 Catalog of sources measured on the difference image. 

176 diffImVisitInfo: `lsst.afw.image.VisitInfo` 

177 VisitInfo for the difference image corresponding to diaSourceCat. 

178 

179 Returns 

180 ------- 

181 filterResults : `lsst.pipe.base.Struct` 

182 

183 ``filteredDiaSourceCat`` : `lsst.afw.table.SourceCatalog` 

184 The catalog of filtered sources. 

185 ``rejectedDiaSources`` : `lsst.afw.table.SourceCatalog` 

186 The catalog of rejected sources. 

187 ``longTrailedDiaSources`` : `astropy.table.Table` 

188 DiaSources which have trail lengths greater than 

189 max_trail_length*exposure_time. 

190 """ 

191 rejectedSources = None 

192 exposure_time = diffImVisitInfo.exposureTime 

193 rejected_mask = np.zeros(len(diaSourceCat), dtype=bool) 

194 if self.config.doRemoveSkySources: 

195 sky_source_column = diaSourceCat["sky_source"] 

196 self.log.info(f"Filtered {np.sum(sky_source_column & ~rejected_mask)} sky sources.") 

197 rejected_mask |= sky_source_column 

198 

199 for flag in self.config.badFlagList: 

200 flag_mask = diaSourceCat[flag] 

201 self.log.info(f"Filtered {np.sum(flag_mask & ~rejected_mask)} sources with flag {flag}.") 

202 rejected_mask |= flag_mask 

203 

204 if self.config.doRemoveNegativeDirectImageSources: 

205 direct_snr = (diaSourceCat["ip_diffim_forced_PsfFlux_instFlux"] 

206 / diaSourceCat["ip_diffim_forced_PsfFlux_instFluxErr"]) 

207 too_negative = direct_snr < self.config.minAllowedDirectSnr 

208 self.log.info(f"Filtered {np.sum(too_negative & ~rejected_mask)} negative direct sources.") 

209 rejected_mask |= too_negative 

210 

211 if self.config.doTrailedSourceFilter: 

212 trail_mask = self._check_dia_source_trail(diaSourceCat, exposure_time) 

213 longTrailedDiaSources = diaSourceCat[trail_mask].copy(deep=True) 

214 rejectedSources = diaSourceCat[rejected_mask].copy(deep=True) 

215 rejected_mask |= trail_mask 

216 diaSourceCat = diaSourceCat[~rejected_mask].copy(deep=True) 

217 

218 self.log.info("%i DiaSources exceed max_trail_length %f arcseconds per second, " 

219 "dropping from source catalog." 

220 % (len(longTrailedDiaSources), self.config.max_trail_length)) 

221 self.metadata.add("num_filtered", len(longTrailedDiaSources)) 

222 

223 if self.config.doWriteTrailedSources: 

224 filterResults = pipeBase.Struct(filteredDiaSourceCat=diaSourceCat, 

225 rejectedDiaSources=rejectedSources, 

226 longTrailedSources=longTrailedDiaSources.asAstropy()) 

227 else: 

228 filterResults = pipeBase.Struct(filteredDiaSourceCat=diaSourceCat, 

229 rejectedDiaSources=rejectedSources) 

230 else: 

231 rejectedSources = diaSourceCat[rejected_mask].copy(deep=True) 

232 diaSourceCat = diaSourceCat[~rejected_mask].copy(deep=True) 

233 filterResults = pipeBase.Struct(filteredDiaSourceCat=diaSourceCat, 

234 rejectedDiaSources=rejectedSources) 

235 return filterResults 

236 

237 def _check_dia_source_trail(self, dia_sources, exposure_time): 

238 """Find DiaSources that have long trails or trails with indeterminant 

239 end points. 

240 

241 Return a mask of sources with lengths greater than 

242 (``config.max_trail_length`` multiplied by the exposure time) 

243 arcseconds. 

244 Additionally, set mask if 

245 ``ext_trailedSources_Naive_flag_off_image`` is set or if 

246 ``ext_trailedSources_Naive_flag_suspect_long_trail`` and 

247 ``ext_trailedSources_Naive_flag_edge`` are both set. 

248 

249 Parameters 

250 ---------- 

251 dia_sources : `pandas.DataFrame` 

252 Input diaSources to check for trail lengths. 

253 exposure_time : `float` 

254 Exposure time from difference image. 

255 

256 Returns 

257 ------- 

258 trail_mask : `pandas.DataFrame` 

259 Boolean mask for diaSources which are greater than the 

260 Boolean mask for diaSources which are greater than the 

261 cutoff length or have trails which extend beyond the edge of the 

262 detector (off_image set). Also checks if both 

263 suspect_long_trail and edge are set and masks those sources out. 

264 """ 

265 pixelScale = self._estimate_pixel_scale(dia_sources) 

266 trail_mask = (dia_sources["ext_trailedSources_Naive_length"] 

267 >= (self.config.max_trail_length*exposure_time/pixelScale)) 

268 trail_mask |= dia_sources['ext_trailedSources_Naive_flag_off_image'] 

269 trail_mask |= (dia_sources['ext_trailedSources_Naive_flag_suspect_long_trail'] 

270 & dia_sources['ext_trailedSources_Naive_flag_edge']) 

271 

272 return trail_mask 

273 

274 def _estimate_pixel_scale(self, catalog): 

275 """Quickly calculate the pixel scale from catalog values 

276 

277 Will return a fallback value from the task config if there is any error 

278 

279 Parameters 

280 ---------- 

281 catalog : `lsst.afw.table.SourceCatalog` 

282 Catalog of sources measured on the difference image. 

283 

284 Returns 

285 ------- 

286 scale : `float` 

287 Pixel scale of the catalog, in arcseconds/pixel 

288 """ 

289 nSrc = len(catalog) 

290 if nSrc < 2: 

291 return self.config.estimatedPixelScale 

292 try: 

293 coordKey = catalog.getCoordKey() 

294 decVals = catalog[coordKey.getDec()] # in radians 

295 raVals = catalog[coordKey.getRa()] # in radians 

296 xVals = catalog.getX() 

297 yVals = catalog.getY() 

298 # Find two points that are well separated for the calculation 

299 # Start with a point near one edge, and find the furthest point 

300 # from there 

301 iMin = np.argmin(xVals) 

302 dist = np.sqrt((xVals[iMin] - xVals)**2 + (yVals[iMin] - yVals)**2) 

303 iMax = np.argmax(dist) 

304 # Use the spherical law of cosines: 

305 t1 = np.sin(decVals[iMin])*np.sin(decVals[iMax]) 

306 t2 = np.cos(decVals[iMin])*np.cos(decVals[iMax])*np.cos(raVals[iMin] - raVals[iMax]) 

307 separation = np.arccos(t1 + t2) # in radians 

308 scale = separation/max(dist)*3600*180/np.pi # convert to arcseconds/pixel 

309 except Exception as e: 

310 self.log.warning("Error encountered estimating the pixel scale from the catalog: %s", e) 

311 return self.config.estimatedPixelScale 

312 else: 

313 if abs(scale - self.config.estimatedPixelScale)/self.config.estimatedPixelScale > 0.1: 

314 self.log.warning(f"Calculated pixel scale of {scale} too different from estimated value " 

315 f"{self.config.estimatedPixelScale}. Falling back on estimate.") 

316 return self.config.estimatedPixelScale 

317 else: 

318 return scale 

319 

320 

321class FilterDiaSourceReliabilityConnections( 

322 pipeBase.PipelineTaskConnections, 

323 dimensions=("instrument", "visit", "detector"), 

324 defaultTemplates={"coaddName": "deep", "fakesType": ""} 

325): 

326 diaSourceCat = connTypes.Input( 

327 doc="Catalog of DiaSources produced during image differencing.", 

328 name="{fakesType}{coaddName}Diff_candidateDiaSrc", 

329 storageClass="SourceCatalog", 

330 dimensions=("instrument", "visit", "detector"), 

331 ) 

332 reliability = connTypes.Input( 

333 doc="Reliability (e.g. real/bogus) classificiation of diaSourceCat sources.", 

334 name="{fakesType}{coaddName}RealBogusSources", 

335 storageClass="Catalog", 

336 dimensions=("instrument", "visit", "detector"), 

337 ) 

338 filteredDiaSources = connTypes.Output( 

339 doc="Accepted diaSource catalog filtered by reliability score.", 

340 name="dia_source_high_reliability", 

341 storageClass="SourceCatalog", 

342 dimensions=("instrument", "visit", "detector"), 

343 ) 

344 rejectedDiaSources = connTypes.Output( 

345 doc="Rejected diaSource catalog with low reliability scores.", 

346 name="dia_source_low_reliability", 

347 storageClass="SourceCatalog", 

348 dimensions=("instrument", "visit", "detector"), 

349 ) 

350 

351 

352class FilterDiaSourceReliabilityConfig( 

353 pipeBase.PipelineTaskConfig, pipelineConnections=FilterDiaSourceReliabilityConnections 

354): 

355 """Configuration for the FilterDiaSourceReliabilityTask.""" 

356 minReliability = pexConfig.Field( 

357 doc="Minimum reliability score to keep a source in the DiaSource catalog.", 

358 dtype=float, 

359 default=0.0, 

360 ) 

361 

362 

363class FilterDiaSourceReliabilityTask(pipeBase.PipelineTask): 

364 """Filter DiaSource catalog by reliability score. 

365 

366 Parameters 

367 ---------- 

368 diaSourceSchema: `lsst.afw.table.Schema` 

369 Schema for the input DiaSource catalog. 

370 diaSourceCat : `lsst.afw.table.SourceCatalog` 

371 Catalog of DiaSources produced during image differencing. 

372 reliability : `lsst.afw.table.SourceCatalog`, optional 

373 Reliability (e.g. real/bogus) classification of the sources in `diaSourceCat`. 

374 

375 Returns 

376 ------- 

377 filteredResults : `lsst.pipe.base.Struct` 

378 

379 ``filteredDiaSources`` : `lsst.afw.table.SourceCatalog` 

380 Catalog of unstandardized DiaSources filtered by reliability score. 

381 ``rejectedDiaSources`` : `lsst.afw.table.SourceCatalog` 

382 Catalog of unstandardized DiaSources that were rejected due to low 

383 reliability scores. 

384 """ 

385 

386 ConfigClass = FilterDiaSourceReliabilityConfig 

387 _DefaultName = "filterDiaSourceReliability" 

388 

389 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

390 inputs = butlerQC.get(inputRefs) 

391 outputs = self.run(**inputs) 

392 butlerQC.put(outputs, outputRefs) 

393 

394 def run(self, diaSourceCat, reliability): 

395 """Run the task to filter DiaSources by reliability.""" 

396 

397 # Copy the scores in the output catalog 

398 if np.all(diaSourceCat['id'] == reliability['id']): 

399 diaSourceCat['reliability'] = reliability['score'] 

400 else: 

401 # If the identifiers do not match, we cannot filter reliably. 

402 raise ValueError( 

403 "Reliability ids do not match DiaSource ids.") 

404 

405 # Filter the DiaSource catalog by reliability score 

406 low_reliability = diaSourceCat["reliability"] < self.config.minReliability 

407 rejectedDiaSources = diaSourceCat[low_reliability].copy(deep=True) 

408 filteredDiaSources = diaSourceCat[~low_reliability].copy(deep=True) 

409 

410 self.log.info(f"Filtered {np.sum(low_reliability)} sources with low reliability.") 

411 

412 return pipeBase.Struct( 

413 filteredDiaSources=filteredDiaSources, 

414 rejectedDiaSources=rejectedDiaSources 

415 )