Coverage for python/lsst/analysis/tools/actions/vector/selectors.py: 34%
157 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-08-04 03:18 -0700
« prev ^ index » next coverage.py v6.4.2, created at 2022-08-04 03:18 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = (
24 "FlagSelector",
25 "CoaddPlotFlagSelector",
26 "SnSelector",
27 "ExtendednessSelector",
28 "SkyObjectSelector",
29 "StarSelector",
30 "GalaxySelector",
31 "UnknownSelector",
32 "VectorSelector",
33 "VisitPlotFlagSelector",
34)
36from typing import Optional, cast
38import numpy as np
39from lsst.pex.config import Field
40from lsst.pex.config.listField import ListField
42from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction
45class FlagSelector(VectorAction):
46 """The base flag selector to use to select valid sources for QA"""
48 selectWhenFalse = ListField[str](
49 doc="Names of the flag columns to select on when False", optional=False, default=[]
50 )
52 selectWhenTrue = ListField[str](
53 doc="Names of the flag columns to select on when True", optional=False, default=[]
54 )
56 def getInputSchema(self) -> KeyedDataSchema:
57 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue)
58 return ((col, Vector) for col in allCols)
60 def __call__(self, data: KeyedData, **kwargs) -> Vector:
61 """Select on the given flags
62 Parameters
63 ----------
64 table : `Tabular`
65 Returns
66 -------
67 result : `Vector`
68 A mask of the objects that satisfy the given
69 flag cuts.
70 Notes
71 -----
72 Uses the columns in selectWhenFalse and
73 selectWhenTrue to decide which columns to
74 select on in each circumstance.
75 """
76 if not self.selectWhenFalse and not self.selectWhenTrue:
77 raise RuntimeError("No column keys specified")
78 results: Optional[Vector] = None
80 for flag in self.selectWhenFalse: # type: ignore
81 temp = np.array(data[flag.format(**kwargs)] == 0)
82 if results is not None:
83 results &= temp # type: ignore
84 else:
85 results = temp
87 for flag in self.selectWhenTrue:
88 temp = np.array(data[flag.format(**kwargs)] == 1)
89 if results is not None:
90 results &= temp # type: ignore
91 else:
92 results = temp
93 # The test at the beginning assures this can never be None
94 return cast(Vector, results)
97class CoaddPlotFlagSelector(FlagSelector):
98 bands = ListField[str](
99 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
100 default=["i"],
101 )
103 def getInputSchema(self) -> KeyedDataSchema:
104 yield from super().getInputSchema()
106 def __call__(self, data: KeyedData, **kwargs) -> Vector:
107 result: Optional[Vector] = None
108 bands: tuple[str, ...]
109 match kwargs:
110 case {"band": band}:
111 bands = (band,)
112 case {"bands": bands} if not self.bands:
113 bands = bands
114 case _ if self.bands:
115 bands = tuple(self.bands)
116 case _:
117 bands = ("",)
118 for band in bands:
119 temp = super().__call__(data, **(kwargs | dict(band=band)))
120 if result is not None:
121 result &= temp # type: ignore
122 else:
123 result = temp
124 return cast(Vector, result)
126 def setDefaults(self):
127 self.selectWhenFalse = [
128 "{band}_psfFlux_flag",
129 "{band}_pixelFlags_saturatedCenter",
130 "{band}_extendedness_flag",
131 "xy_flag",
132 ]
133 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"]
136class VisitPlotFlagSelector(FlagSelector):
137 """Select on a set of flags appropriate for making visit-level plots
138 (i.e., using sourceTable_visit catalogs).
139 """
141 def getInputSchema(self) -> KeyedDataSchema:
142 yield from super().getInputSchema()
144 def __call__(self, data: KeyedData, **kwargs) -> Vector:
145 result: Optional[Vector] = None
146 temp = super().__call__(data, **kwargs)
147 if result is not None:
148 result &= temp # type: ignore
149 else:
150 result = temp
152 return result
154 def setDefaults(self):
155 self.selectWhenFalse = [
156 "psfFlux_flag",
157 "pixelFlags_saturatedCenter",
158 "extendedness_flag",
159 "centroid_flag",
160 ]
163class SnSelector(VectorAction):
164 """Selects points that have S/N > threshold in the given flux type"""
166 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux")
167 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0)
168 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6)
169 uncertaintySuffix = Field[str](
170 doc="Suffix to add to fluxType to specify uncertainty column", default="Err"
171 )
172 bands = ListField[str](
173 doc="The bands to apply the signal to noise cut in." "Takes precedence if bands passed to call",
174 default=[],
175 )
177 def getInputSchema(self) -> KeyedDataSchema:
178 yield (fluxCol := self.fluxType), Vector
179 yield f"{fluxCol}{self.uncertaintySuffix}", Vector
181 def __call__(self, data: KeyedData, **kwargs) -> Vector:
182 """Makes a mask of objects that have S/N greater than
183 self.threshold in self.fluxType
184 Parameters
185 ----------
186 df : `Tabular`
187 Returns
188 -------
189 result : `Vector`
190 A mask of the objects that satisfy the given
191 S/N cut.
192 """
193 mask: Optional[Vector] = None
194 bands: tuple[str, ...]
195 match kwargs:
196 case {"band": band}:
197 bands = (band,)
198 case {"bands": bands} if not self.bands:
199 bands = bands
200 case _ if self.bands:
201 bands = tuple(self.bands)
202 case _:
203 bands = ("",)
204 for band in bands:
205 fluxCol = self.fluxType.format(**(kwargs | dict(band=band)))
206 errCol = f"{fluxCol}{self.uncertaintySuffix.format(**kwargs)}"
207 vec = cast(Vector, data[fluxCol]) / data[errCol]
208 temp = (vec > self.threshold) & (vec < self.maxSN)
209 if mask is not None:
210 mask &= temp # type: ignore
211 else:
212 mask = temp
214 # It should not be possible for mask to be a None now
215 return np.array(cast(Vector, mask))
218class SkyObjectSelector(FlagSelector):
219 """Selects sky objects in the given band(s)"""
221 bands = ListField[str](
222 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
223 default=["i"],
224 )
226 def getInputSchema(self) -> KeyedDataSchema:
227 yield from super().getInputSchema()
229 def __call__(self, data: KeyedData, **kwargs) -> Vector:
230 result: Optional[Vector] = None
231 bands: tuple[str, ...]
232 match kwargs:
233 case {"band": band}:
234 bands = (band,)
235 case {"bands": bands} if not self.bands:
236 bands = bands
237 case _ if self.bands:
238 bands = tuple(self.bands)
239 case _:
240 bands = ("",)
241 for band in bands:
242 temp = super().__call__(data, **(kwargs | dict(band=band)))
243 if result is not None:
244 result &= temp # type: ignore
245 else:
246 result = temp
247 return cast(Vector, result)
249 def setDefaults(self):
250 self.selectWhenFalse = [
251 "{band}_pixelFlags_edge",
252 ]
253 self.selectWhenTrue = ["sky_object"]
256class SkySourceSelector(FlagSelector):
257 """Selects sky sources from sourceTables"""
259 def getInputSchema(self) -> KeyedDataSchema:
260 yield from super().getInputSchema()
262 def __call__(self, data: KeyedData, **kwargs) -> Vector:
263 result: Optional[Vector] = None
264 temp = super().__call__(data, **(kwargs))
265 if result is not None:
266 result &= temp # type: ignore
267 else:
268 result = temp
269 return result
271 def setDefaults(self):
272 self.selectWhenFalse = [
273 "pixelFlags_edge",
274 ]
275 self.selectWhenTrue = ["sky_source"]
278class ExtendednessSelector(VectorAction):
279 vectorKey = Field[str](
280 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness"
281 )
283 def getInputSchema(self) -> KeyedDataSchema:
284 return ((self.vectorKey, Vector),)
286 def __call__(self, data: KeyedData, **kwargs) -> Vector:
287 key = self.vectorKey.format(**kwargs)
288 return cast(Vector, data[key])
291class StarSelector(ExtendednessSelector):
292 extendedness_maximum = Field[float](
293 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float
294 )
296 def __call__(self, data: KeyedData, **kwargs) -> Vector:
297 extendedness = super().__call__(data, **kwargs)
298 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum)))
301class GalaxySelector(ExtendednessSelector):
302 extendedness_minimum = Field[float](
303 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5
304 )
306 def __call__(self, data: KeyedData, **kwargs) -> Vector:
307 extendedness = super().__call__(data, **kwargs)
308 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore
311class UnknownSelector(ExtendednessSelector):
312 def __call__(self, data: KeyedData, **kwargs) -> Vector:
313 extendedness = super().__call__(data, **kwargs)
314 return extendedness == 9
317class VectorSelector(VectorAction):
318 """Load a boolean vector from KeyedData and return it for use as a
319 selector.
320 """
322 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask")
324 def getInputSchema(self) -> KeyedDataSchema:
325 return ((self.vectorKey, Vector),)
327 def __call__(self, data: KeyedData, **kwargs) -> Vector:
328 return cast(Vector, data[self.vectorKey.format(**kwargs)])