Coverage for python / lsst / ap / association / filterDiaSourceCatalog.py: 30%
119 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-18 09:05 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-18 09:05 +0000
1# This file is part of ap_association
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = (
23 "FilterDiaSourceCatalogConfig",
24 "FilterDiaSourceCatalogTask",
25 "FilterDiaSourceReliabilityConfig",
26 "FilterDiaSourceReliabilityTask"
27)
29import numpy as np
31import lsst.pex.config as pexConfig
32import lsst.pipe.base as pipeBase
33import lsst.pipe.base.connectionTypes as connTypes
34from lsst.utils.timer import timeMethod
37class FilterDiaSourceCatalogConnections(
38 pipeBase.PipelineTaskConnections,
39 dimensions=("instrument", "visit", "detector"),
40 defaultTemplates={"coaddName": "deep", "fakesType": ""},
41):
42 """Connections class for FilterDiaSourceCatalogTask."""
44 diaSourceCat = connTypes.Input(
45 doc="Catalog of DiaSources produced during image differencing.",
46 name="{fakesType}{coaddName}Diff_diaSrc",
47 storageClass="SourceCatalog",
48 dimensions=("instrument", "visit", "detector"),
49 )
51 diffImVisitInfo = connTypes.Input(
52 doc="VisitInfo of diffIm.",
53 name="{fakesType}{coaddName}Diff_differenceExp.visitInfo",
54 storageClass="VisitInfo",
55 dimensions=("instrument", "visit", "detector"),
56 )
58 filteredDiaSourceCat = connTypes.Output(
59 doc="Output catalog of DiaSources after filtering.",
60 name="{fakesType}{coaddName}Diff_candidateDiaSrc",
61 storageClass="SourceCatalog",
62 dimensions=("instrument", "visit", "detector"),
63 )
65 rejectedDiaSources = connTypes.Output(
66 doc="Optional output storing all the rejected DiaSources.",
67 name="{fakesType}{coaddName}Diff_rejectedDiaSrc",
68 storageClass="SourceCatalog",
69 dimensions={"instrument", "visit", "detector"},
70 )
72 longTrailedSources = connTypes.Output(
73 doc="Optional output temporarily storing long trailed diaSources.",
74 dimensions=("instrument", "visit", "detector"),
75 storageClass="ArrowAstropy",
76 name="{fakesType}{coaddName}Diff_longTrailedSrc",
77 )
79 def __init__(self, *, config=None):
80 super().__init__(config=config)
81 if not self.config.doWriteRejectedSkySources:
82 self.outputs.remove("rejectedDiaSources")
83 if not self.config.doTrailedSourceFilter:
84 self.outputs.remove("longTrailedSources")
85 if not self.config.doWriteTrailedSources:
86 self.outputs.remove("longTrailedSources")
89class FilterDiaSourceCatalogConfig(
90 pipeBase.PipelineTaskConfig, pipelineConnections=FilterDiaSourceCatalogConnections
91):
92 """Config class for FilterDiaSourceCatalogTask."""
94 doRemoveSkySources = pexConfig.Field(
95 dtype=bool,
96 default=False,
97 doc="Input DiaSource catalog contains SkySources that should be "
98 "removed before storing the output DiaSource catalog.",
99 )
101 doWriteRejectedSkySources = pexConfig.Field(
102 dtype=bool,
103 default=True,
104 doc="Store the output DiaSource catalog containing all the rejected "
105 "sky sources."
106 )
108 badFlagList = pexConfig.ListField(
109 dtype=str,
110 default=[
111 "slot_Centroid_flag",
112 "base_PixelFlags_flag_crCenter",
113 "base_PixelFlags_flag_high_varianceCenterAll"
114 ],
115 doc="List of flags which cause a source to be removed.",
116 )
118 doRemoveNegativeDirectImageSources = pexConfig.Field(
119 dtype=bool,
120 default=True,
121 doc="Remove DIASources with negative scienceFlux/scienceFluxErr "
122 "according to a configurable threshold.",
123 )
125 minAllowedDirectSnr = pexConfig.Field(
126 dtype=float,
127 doc="Minimum allowed ratio of scienceFlux/scienceFluxErr.",
128 default=-2.0,
129 )
131 doTrailedSourceFilter = pexConfig.Field(
132 doc="Run trailedSourceFilter to remove long trailed sources from the"
133 "diaSource output catalog.",
134 dtype=bool,
135 default=True,
136 )
138 doWriteTrailedSources = pexConfig.Field(
139 doc="Write trailed diaSources sources to a table.",
140 dtype=bool,
141 default=True,
142 deprecated="Trailed sources will not be written out during production."
143 )
145 max_trail_length = pexConfig.Field(
146 dtype=float,
147 doc="Length of long trailed sources to remove from the input catalog, "
148 "in arcseconds per second. Default comes from DMTN-199, which "
149 "requires removal of sources with trails longer than 10 "
150 "degrees/day, which is 36000/3600/24 arcsec/second, or roughly"
151 "0.416 arcseconds per second.",
152 default=36000/3600.0/24.0,
153 )
154 estimatedPixelScale = pexConfig.Field(
155 dtype=float,
156 doc="Approximate plate scale, in arcseconds/pixel."
157 "Used to convert trail length if the catalog calculation fails.",
158 default=0.2,
159 )
162class FilterDiaSourceCatalogTask(pipeBase.PipelineTask):
163 """Filter sources from a DiaSource catalog."""
165 ConfigClass = FilterDiaSourceCatalogConfig
166 _DefaultName = "filterDiaSourceCatalog"
168 @timeMethod
169 def run(self, diaSourceCat, diffImVisitInfo):
170 """Filter sources from the supplied DiaSource catalog.
172 Parameters
173 ----------
174 diaSourceCat : `lsst.afw.table.SourceCatalog`
175 Catalog of sources measured on the difference image.
176 diffImVisitInfo: `lsst.afw.image.VisitInfo`
177 VisitInfo for the difference image corresponding to diaSourceCat.
179 Returns
180 -------
181 filterResults : `lsst.pipe.base.Struct`
183 ``filteredDiaSourceCat`` : `lsst.afw.table.SourceCatalog`
184 The catalog of filtered sources.
185 ``rejectedDiaSources`` : `lsst.afw.table.SourceCatalog`
186 The catalog of rejected sources.
187 ``longTrailedDiaSources`` : `astropy.table.Table`
188 DiaSources which have trail lengths greater than
189 max_trail_length*exposure_time.
190 """
191 rejectedSources = None
192 exposure_time = diffImVisitInfo.exposureTime
193 rejected_mask = np.zeros(len(diaSourceCat), dtype=bool)
194 if self.config.doRemoveSkySources:
195 sky_source_column = diaSourceCat["sky_source"]
196 self.log.info(f"Filtered {np.sum(sky_source_column & ~rejected_mask)} sky sources.")
197 rejected_mask |= sky_source_column
199 for flag in self.config.badFlagList:
200 flag_mask = diaSourceCat[flag]
201 self.log.info(f"Filtered {np.sum(flag_mask & ~rejected_mask)} sources with flag {flag}.")
202 rejected_mask |= flag_mask
204 if self.config.doRemoveNegativeDirectImageSources:
205 direct_snr = (diaSourceCat["ip_diffim_forced_PsfFlux_instFlux"]
206 / diaSourceCat["ip_diffim_forced_PsfFlux_instFluxErr"])
207 too_negative = direct_snr < self.config.minAllowedDirectSnr
208 self.log.info(f"Filtered {np.sum(too_negative & ~rejected_mask)} negative direct sources.")
209 rejected_mask |= too_negative
211 if self.config.doTrailedSourceFilter:
212 trail_mask = self._check_dia_source_trail(diaSourceCat, exposure_time)
213 longTrailedDiaSources = diaSourceCat[trail_mask].copy(deep=True)
214 rejectedSources = diaSourceCat[rejected_mask].copy(deep=True)
215 rejected_mask |= trail_mask
216 diaSourceCat = diaSourceCat[~rejected_mask].copy(deep=True)
218 self.log.info("%i DiaSources exceed max_trail_length %f arcseconds per second, "
219 "dropping from source catalog."
220 % (len(longTrailedDiaSources), self.config.max_trail_length))
221 self.metadata.add("num_filtered", len(longTrailedDiaSources))
223 if self.config.doWriteTrailedSources:
224 filterResults = pipeBase.Struct(filteredDiaSourceCat=diaSourceCat,
225 rejectedDiaSources=rejectedSources,
226 longTrailedSources=longTrailedDiaSources.asAstropy())
227 else:
228 filterResults = pipeBase.Struct(filteredDiaSourceCat=diaSourceCat,
229 rejectedDiaSources=rejectedSources)
230 else:
231 rejectedSources = diaSourceCat[rejected_mask].copy(deep=True)
232 diaSourceCat = diaSourceCat[~rejected_mask].copy(deep=True)
233 filterResults = pipeBase.Struct(filteredDiaSourceCat=diaSourceCat,
234 rejectedDiaSources=rejectedSources)
235 return filterResults
237 def _check_dia_source_trail(self, dia_sources, exposure_time):
238 """Find DiaSources that have long trails or trails with indeterminant
239 end points.
241 Return a mask of sources with lengths greater than
242 (``config.max_trail_length`` multiplied by the exposure time)
243 arcseconds.
244 Additionally, set mask if
245 ``ext_trailedSources_Naive_flag_off_image`` is set or if
246 ``ext_trailedSources_Naive_flag_suspect_long_trail`` and
247 ``ext_trailedSources_Naive_flag_edge`` are both set.
249 Parameters
250 ----------
251 dia_sources : `pandas.DataFrame`
252 Input diaSources to check for trail lengths.
253 exposure_time : `float`
254 Exposure time from difference image.
256 Returns
257 -------
258 trail_mask : `pandas.DataFrame`
259 Boolean mask for diaSources which are greater than the
260 Boolean mask for diaSources which are greater than the
261 cutoff length or have trails which extend beyond the edge of the
262 detector (off_image set). Also checks if both
263 suspect_long_trail and edge are set and masks those sources out.
264 """
265 pixelScale = self._estimate_pixel_scale(dia_sources)
266 trail_mask = (dia_sources["ext_trailedSources_Naive_length"]
267 >= (self.config.max_trail_length*exposure_time/pixelScale))
268 trail_mask |= dia_sources['ext_trailedSources_Naive_flag_off_image']
269 trail_mask |= (dia_sources['ext_trailedSources_Naive_flag_suspect_long_trail']
270 & dia_sources['ext_trailedSources_Naive_flag_edge'])
272 return trail_mask
274 def _estimate_pixel_scale(self, catalog):
275 """Quickly calculate the pixel scale from catalog values
277 Will return a fallback value from the task config if there is any error
279 Parameters
280 ----------
281 catalog : `lsst.afw.table.SourceCatalog`
282 Catalog of sources measured on the difference image.
284 Returns
285 -------
286 scale : `float`
287 Pixel scale of the catalog, in arcseconds/pixel
288 """
289 nSrc = len(catalog)
290 if nSrc < 2:
291 return self.config.estimatedPixelScale
292 try:
293 coordKey = catalog.getCoordKey()
294 decVals = catalog[coordKey.getDec()] # in radians
295 raVals = catalog[coordKey.getRa()] # in radians
296 xVals = catalog.getX()
297 yVals = catalog.getY()
298 # Find two points that are well separated for the calculation
299 # Start with a point near one edge, and find the furthest point
300 # from there
301 iMin = np.argmin(xVals)
302 dist = np.sqrt((xVals[iMin] - xVals)**2 + (yVals[iMin] - yVals)**2)
303 iMax = np.argmax(dist)
304 # Use the spherical law of cosines:
305 t1 = np.sin(decVals[iMin])*np.sin(decVals[iMax])
306 t2 = np.cos(decVals[iMin])*np.cos(decVals[iMax])*np.cos(raVals[iMin] - raVals[iMax])
307 separation = np.arccos(t1 + t2) # in radians
308 scale = separation/max(dist)*3600*180/np.pi # convert to arcseconds/pixel
309 except Exception as e:
310 self.log.warning("Error encountered estimating the pixel scale from the catalog: %s", e)
311 return self.config.estimatedPixelScale
312 else:
313 if abs(scale - self.config.estimatedPixelScale)/self.config.estimatedPixelScale > 0.1:
314 self.log.warning(f"Calculated pixel scale of {scale} too different from estimated value "
315 f"{self.config.estimatedPixelScale}. Falling back on estimate.")
316 return self.config.estimatedPixelScale
317 else:
318 return scale
321class FilterDiaSourceReliabilityConnections(
322 pipeBase.PipelineTaskConnections,
323 dimensions=("instrument", "visit", "detector"),
324 defaultTemplates={"coaddName": "deep", "fakesType": ""}
325):
326 diaSourceCat = connTypes.Input(
327 doc="Catalog of DiaSources produced during image differencing.",
328 name="{fakesType}{coaddName}Diff_candidateDiaSrc",
329 storageClass="SourceCatalog",
330 dimensions=("instrument", "visit", "detector"),
331 )
332 reliability = connTypes.Input(
333 doc="Reliability (e.g. real/bogus) classificiation of diaSourceCat sources.",
334 name="{fakesType}{coaddName}RealBogusSources",
335 storageClass="Catalog",
336 dimensions=("instrument", "visit", "detector"),
337 )
338 filteredDiaSources = connTypes.Output(
339 doc="Accepted diaSource catalog filtered by reliability score.",
340 name="dia_source_high_reliability",
341 storageClass="SourceCatalog",
342 dimensions=("instrument", "visit", "detector"),
343 )
344 rejectedDiaSources = connTypes.Output(
345 doc="Rejected diaSource catalog with low reliability scores.",
346 name="dia_source_low_reliability",
347 storageClass="SourceCatalog",
348 dimensions=("instrument", "visit", "detector"),
349 )
352class FilterDiaSourceReliabilityConfig(
353 pipeBase.PipelineTaskConfig, pipelineConnections=FilterDiaSourceReliabilityConnections
354):
355 """Configuration for the FilterDiaSourceReliabilityTask."""
356 minReliability = pexConfig.Field(
357 doc="Minimum reliability score to keep a source in the DiaSource catalog.",
358 dtype=float,
359 default=0.0,
360 )
363class FilterDiaSourceReliabilityTask(pipeBase.PipelineTask):
364 """Filter DiaSource catalog by reliability score.
366 Parameters
367 ----------
368 diaSourceSchema: `lsst.afw.table.Schema`
369 Schema for the input DiaSource catalog.
370 diaSourceCat : `lsst.afw.table.SourceCatalog`
371 Catalog of DiaSources produced during image differencing.
372 reliability : `lsst.afw.table.SourceCatalog`, optional
373 Reliability (e.g. real/bogus) classification of the sources in `diaSourceCat`.
375 Returns
376 -------
377 filteredResults : `lsst.pipe.base.Struct`
379 ``filteredDiaSources`` : `lsst.afw.table.SourceCatalog`
380 Catalog of unstandardized DiaSources filtered by reliability score.
381 ``rejectedDiaSources`` : `lsst.afw.table.SourceCatalog`
382 Catalog of unstandardized DiaSources that were rejected due to low
383 reliability scores.
384 """
386 ConfigClass = FilterDiaSourceReliabilityConfig
387 _DefaultName = "filterDiaSourceReliability"
389 def runQuantum(self, butlerQC, inputRefs, outputRefs):
390 inputs = butlerQC.get(inputRefs)
391 outputs = self.run(**inputs)
392 butlerQC.put(outputs, outputRefs)
394 def run(self, diaSourceCat, reliability):
395 """Run the task to filter DiaSources by reliability."""
397 # Copy the scores in the output catalog
398 if np.all(diaSourceCat['id'] == reliability['id']):
399 diaSourceCat['reliability'] = reliability['score']
400 else:
401 # If the identifiers do not match, we cannot filter reliably.
402 raise ValueError(
403 "Reliability ids do not match DiaSource ids.")
405 # Filter the DiaSource catalog by reliability score
406 low_reliability = diaSourceCat["reliability"] < self.config.minReliability
407 rejectedDiaSources = diaSourceCat[low_reliability].copy(deep=True)
408 filteredDiaSources = diaSourceCat[~low_reliability].copy(deep=True)
410 self.log.info(f"Filtered {np.sum(low_reliability)} sources with low reliability.")
412 return pipeBase.Struct(
413 filteredDiaSources=filteredDiaSources,
414 rejectedDiaSources=rejectedDiaSources
415 )