Coverage for python/lsst/analysis/tools/actions/vector/selectors.py: 35%
186 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-08-06 02:06 -0700
« prev ^ index » next coverage.py v6.4.2, created at 2022-08-06 02:06 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = (
24 "FlagSelector",
25 "CoaddPlotFlagSelector",
26 "SnSelector",
27 "ExtendednessSelector",
28 "SkyObjectSelector",
29 "StarSelector",
30 "GalaxySelector",
31 "UnknownSelector",
32 "VectorSelector",
33 "VisitPlotFlagSelector",
34 "ThresholdSelector",
35 "BandSelector",
36)
38import operator
39from typing import Optional, cast
41import numpy as np
42from lsst.pex.config import Field
43from lsst.pex.config.listField import ListField
45from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction
48class FlagSelector(VectorAction):
49 """The base flag selector to use to select valid sources for QA"""
51 selectWhenFalse = ListField[str](
52 doc="Names of the flag columns to select on when False", optional=False, default=[]
53 )
55 selectWhenTrue = ListField[str](
56 doc="Names of the flag columns to select on when True", optional=False, default=[]
57 )
59 def getInputSchema(self) -> KeyedDataSchema:
60 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue)
61 return ((col, Vector) for col in allCols)
63 def __call__(self, data: KeyedData, **kwargs) -> Vector:
64 """Select on the given flags
65 Parameters
66 ----------
67 table : `Tabular`
68 Returns
69 -------
70 result : `Vector`
71 A mask of the objects that satisfy the given
72 flag cuts.
73 Notes
74 -----
75 Uses the columns in selectWhenFalse and
76 selectWhenTrue to decide which columns to
77 select on in each circumstance.
78 """
79 if not self.selectWhenFalse and not self.selectWhenTrue:
80 raise RuntimeError("No column keys specified")
81 results: Optional[Vector] = None
83 for flag in self.selectWhenFalse: # type: ignore
84 temp = np.array(data[flag.format(**kwargs)] == 0)
85 if results is not None:
86 results &= temp # type: ignore
87 else:
88 results = temp
90 for flag in self.selectWhenTrue:
91 temp = np.array(data[flag.format(**kwargs)] == 1)
92 if results is not None:
93 results &= temp # type: ignore
94 else:
95 results = temp
96 # The test at the beginning assures this can never be None
97 return cast(Vector, results)
100class CoaddPlotFlagSelector(FlagSelector):
101 bands = ListField[str](
102 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
103 default=["i"],
104 )
106 def getInputSchema(self) -> KeyedDataSchema:
107 yield from super().getInputSchema()
109 def __call__(self, data: KeyedData, **kwargs) -> Vector:
110 result: Optional[Vector] = None
111 bands: tuple[str, ...]
112 match kwargs:
113 case {"band": band}:
114 bands = (band,)
115 case {"bands": bands} if not self.bands:
116 bands = bands
117 case _ if self.bands:
118 bands = tuple(self.bands)
119 case _:
120 bands = ("",)
121 for band in bands:
122 temp = super().__call__(data, **(kwargs | dict(band=band)))
123 if result is not None:
124 result &= temp # type: ignore
125 else:
126 result = temp
127 return cast(Vector, result)
129 def setDefaults(self):
130 self.selectWhenFalse = [
131 "{band}_psfFlux_flag",
132 "{band}_pixelFlags_saturatedCenter",
133 "{band}_extendedness_flag",
134 "xy_flag",
135 ]
136 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"]
139class VisitPlotFlagSelector(FlagSelector):
140 """Select on a set of flags appropriate for making visit-level plots
141 (i.e., using sourceTable_visit catalogs).
142 """
144 def getInputSchema(self) -> KeyedDataSchema:
145 yield from super().getInputSchema()
147 def __call__(self, data: KeyedData, **kwargs) -> Vector:
148 result: Optional[Vector] = None
149 temp = super().__call__(data, **kwargs)
150 if result is not None:
151 result &= temp # type: ignore
152 else:
153 result = temp
155 return result
157 def setDefaults(self):
158 self.selectWhenFalse = [
159 "psfFlux_flag",
160 "pixelFlags_saturatedCenter",
161 "extendedness_flag",
162 "centroid_flag",
163 ]
166class SnSelector(VectorAction):
167 """Selects points that have S/N > threshold in the given flux type"""
169 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux")
170 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0)
171 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6)
172 uncertaintySuffix = Field[str](
173 doc="Suffix to add to fluxType to specify uncertainty column", default="Err"
174 )
175 bands = ListField[str](
176 doc="The bands to apply the signal to noise cut in." "Takes precedence if bands passed to call",
177 default=[],
178 )
180 def getInputSchema(self) -> KeyedDataSchema:
181 yield (fluxCol := self.fluxType), Vector
182 yield f"{fluxCol}{self.uncertaintySuffix}", Vector
184 def __call__(self, data: KeyedData, **kwargs) -> Vector:
185 """Makes a mask of objects that have S/N greater than
186 self.threshold in self.fluxType
187 Parameters
188 ----------
189 df : `Tabular`
190 Returns
191 -------
192 result : `Vector`
193 A mask of the objects that satisfy the given
194 S/N cut.
195 """
196 mask: Optional[Vector] = None
197 bands: tuple[str, ...]
198 match kwargs:
199 case {"band": band}:
200 bands = (band,)
201 case {"bands": bands} if not self.bands:
202 bands = bands
203 case _ if self.bands:
204 bands = tuple(self.bands)
205 case _:
206 bands = ("",)
207 for band in bands:
208 fluxCol = self.fluxType.format(**(kwargs | dict(band=band)))
209 errCol = f"{fluxCol}{self.uncertaintySuffix.format(**kwargs)}"
210 vec = cast(Vector, data[fluxCol]) / data[errCol]
211 temp = (vec > self.threshold) & (vec < self.maxSN)
212 if mask is not None:
213 mask &= temp # type: ignore
214 else:
215 mask = temp
217 # It should not be possible for mask to be a None now
218 return np.array(cast(Vector, mask))
221class SkyObjectSelector(FlagSelector):
222 """Selects sky objects in the given band(s)"""
224 bands = ListField[str](
225 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
226 default=["i"],
227 )
229 def getInputSchema(self) -> KeyedDataSchema:
230 yield from super().getInputSchema()
232 def __call__(self, data: KeyedData, **kwargs) -> Vector:
233 result: Optional[Vector] = None
234 bands: tuple[str, ...]
235 match kwargs:
236 case {"band": band}:
237 bands = (band,)
238 case {"bands": bands} if not self.bands:
239 bands = bands
240 case _ if self.bands:
241 bands = tuple(self.bands)
242 case _:
243 bands = ("",)
244 for band in bands:
245 temp = super().__call__(data, **(kwargs | dict(band=band)))
246 if result is not None:
247 result &= temp # type: ignore
248 else:
249 result = temp
250 return cast(Vector, result)
252 def setDefaults(self):
253 self.selectWhenFalse = [
254 "{band}_pixelFlags_edge",
255 ]
256 self.selectWhenTrue = ["sky_object"]
259class SkySourceSelector(FlagSelector):
260 """Selects sky sources from sourceTables"""
262 def getInputSchema(self) -> KeyedDataSchema:
263 yield from super().getInputSchema()
265 def __call__(self, data: KeyedData, **kwargs) -> Vector:
266 result: Optional[Vector] = None
267 temp = super().__call__(data, **(kwargs))
268 if result is not None:
269 result &= temp # type: ignore
270 else:
271 result = temp
272 return result
274 def setDefaults(self):
275 self.selectWhenFalse = [
276 "pixelFlags_edge",
277 ]
278 self.selectWhenTrue = ["sky_source"]
281class ExtendednessSelector(VectorAction):
282 vectorKey = Field[str](
283 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness"
284 )
286 def getInputSchema(self) -> KeyedDataSchema:
287 return ((self.vectorKey, Vector),)
289 def __call__(self, data: KeyedData, **kwargs) -> Vector:
290 key = self.vectorKey.format(**kwargs)
291 return cast(Vector, data[key])
294class StarSelector(ExtendednessSelector):
295 extendedness_maximum = Field[float](
296 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float
297 )
299 def __call__(self, data: KeyedData, **kwargs) -> Vector:
300 extendedness = super().__call__(data, **kwargs)
301 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum)))
304class GalaxySelector(ExtendednessSelector):
305 extendedness_minimum = Field[float](
306 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5
307 )
309 def __call__(self, data: KeyedData, **kwargs) -> Vector:
310 extendedness = super().__call__(data, **kwargs)
311 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore
314class UnknownSelector(ExtendednessSelector):
315 def __call__(self, data: KeyedData, **kwargs) -> Vector:
316 extendedness = super().__call__(data, **kwargs)
317 return extendedness == 9
320class VectorSelector(VectorAction):
321 """Load a boolean vector from KeyedData and return it for use as a
322 selector.
323 """
325 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask")
327 def getInputSchema(self) -> KeyedDataSchema:
328 return ((self.vectorKey, Vector),)
330 def __call__(self, data: KeyedData, **kwargs) -> Vector:
331 return cast(Vector, data[self.vectorKey.format(**kwargs)])
334class ThresholdSelector(VectorAction):
335 """Return a mask corresponding to an applied threshold."""
337 op = Field[str](doc="Operator name.")
338 threshold = Field[float](doc="Threshold to apply.")
339 vectorKey = Field[str](doc="Name of column")
341 def getInputSchema(self) -> KeyedDataSchema:
342 return ((self.vectorKey, Vector),)
344 def __call__(self, data: KeyedData, **kwargs) -> Vector:
345 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold)
346 return cast(Vector, mask)
349class BandSelector(VectorAction):
350 """Makes a mask for sources observed in a specified set of bands."""
352 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band")
353 bands = ListField[str](
354 doc="The bands to select. `None` indicates no band selection applied.",
355 default=[],
356 )
358 def getInputSchema(self) -> KeyedDataSchema:
359 return ((self.vectorKey, Vector),)
361 def __call__(self, data: KeyedData, **kwargs) -> Vector:
362 bands: Optional[tuple[str, ...]]
363 match kwargs:
364 case {"band": band}:
365 bands = (band,)
366 case {"bands": bands} if not self.bands:
367 bands = bands
368 case _ if self.bands:
369 bands = tuple(self.bands)
370 case _:
371 bands = None
372 if bands:
373 mask = np.in1d(data[self.vectorKey], bands)
374 else:
375 # No band selection is applied, i.e., select all rows
376 mask = np.full(len(data[self.vectorKey]), True) # type: ignore
377 return cast(Vector, mask)