Coverage for python/lsst/analysis/tools/actions/vector/selectors.py: 29%
196 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-14 15:50 -0700
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-14 15:50 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = (
24 "FlagSelector",
25 "CoaddPlotFlagSelector",
26 "RangeSelector",
27 "SnSelector",
28 "ExtendednessSelector",
29 "SkyObjectSelector",
30 "StarSelector",
31 "GalaxySelector",
32 "UnknownSelector",
33 "VectorSelector",
34 "VisitPlotFlagSelector",
35 "ThresholdSelector",
36 "BandSelector",
37)
39import operator
40from typing import Optional, cast
42import numpy as np
43from lsst.pex.config import Field
44from lsst.pex.config.listField import ListField
46from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction
49class FlagSelector(VectorAction):
50 """The base flag selector to use to select valid sources for QA"""
52 selectWhenFalse = ListField[str](
53 doc="Names of the flag columns to select on when False", optional=False, default=[]
54 )
56 selectWhenTrue = ListField[str](
57 doc="Names of the flag columns to select on when True", optional=False, default=[]
58 )
60 def getInputSchema(self) -> KeyedDataSchema:
61 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue)
62 return ((col, Vector) for col in allCols)
64 def __call__(self, data: KeyedData, **kwargs) -> Vector:
65 """Select on the given flags
66 Parameters
67 ----------
68 table : `Tabular`
69 Returns
70 -------
71 result : `Vector`
72 A mask of the objects that satisfy the given
73 flag cuts.
74 Notes
75 -----
76 Uses the columns in selectWhenFalse and
77 selectWhenTrue to decide which columns to
78 select on in each circumstance.
79 """
80 if not self.selectWhenFalse and not self.selectWhenTrue:
81 raise RuntimeError("No column keys specified")
82 results: Optional[Vector] = None
84 for flag in self.selectWhenFalse: # type: ignore
85 temp = np.array(data[flag.format(**kwargs)] == 0)
86 if results is not None:
87 results &= temp # type: ignore
88 else:
89 results = temp
91 for flag in self.selectWhenTrue:
92 temp = np.array(data[flag.format(**kwargs)] == 1)
93 if results is not None:
94 results &= temp # type: ignore
95 else:
96 results = temp
97 # The test at the beginning assures this can never be None
98 return cast(Vector, results)
101class CoaddPlotFlagSelector(FlagSelector):
102 """This default setting makes it take the band from
103 the kwargs."""
105 bands = ListField[str](
106 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
107 default=[],
108 )
110 def getInputSchema(self) -> KeyedDataSchema:
111 yield from super().getInputSchema()
113 def __call__(self, data: KeyedData, **kwargs) -> Vector:
114 result: Optional[Vector] = None
115 bands: tuple[str, ...]
116 match kwargs:
117 case {"band": band} if not self.bands and self.bands == []:
118 bands = (band,)
119 case {"bands": bands} if not self.bands and self.bands == []:
120 bands = bands
121 case _ if self.bands:
122 bands = tuple(self.bands)
123 case _:
124 bands = ("",)
125 for band in bands:
126 temp = super().__call__(data, **(kwargs | dict(band=band)))
127 if result is not None:
128 result &= temp # type: ignore
129 else:
130 result = temp
131 return cast(Vector, result)
133 def setDefaults(self):
134 self.selectWhenFalse = [
135 "{band}_psfFlux_flag",
136 "{band}_pixelFlags_saturatedCenter",
137 "{band}_extendedness_flag",
138 "xy_flag",
139 ]
140 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"]
143class VisitPlotFlagSelector(FlagSelector):
144 """Select on a set of flags appropriate for making visit-level plots
145 (i.e., using sourceTable_visit catalogs).
146 """
148 def getInputSchema(self) -> KeyedDataSchema:
149 yield from super().getInputSchema()
151 def __call__(self, data: KeyedData, **kwargs) -> Vector:
152 result: Optional[Vector] = None
153 temp = super().__call__(data, **kwargs)
154 if result is not None:
155 result &= temp # type: ignore
156 else:
157 result = temp
159 return result
161 def setDefaults(self):
162 self.selectWhenFalse = [
163 "psfFlux_flag",
164 "pixelFlags_saturatedCenter",
165 "extendedness_flag",
166 "centroid_flag",
167 ]
170class RangeSelector(VectorAction):
171 """Selects rows within a range, inclusive of min/exclusive of max."""
173 key = Field[str](doc="Key to select from data")
174 maximum = Field[float](doc="The maximum value", default=np.Inf)
175 minimum = Field[float](doc="The minimum value", default=np.nextafter(-np.Inf, 0.0))
177 def getInputSchema(self) -> KeyedDataSchema:
178 yield self.key, Vector
180 def __call__(self, data: KeyedData, **kwargs) -> Vector:
181 """Return a mask of rows with values within the specified range.
183 Parameters
184 ----------
185 data : `KeyedData`
187 Returns
188 -------
189 result : `Vector`
190 A mask of the rows with values within the specified range.
191 """
192 values = cast(Vector, data[self.key])
193 mask = (values >= self.minimum) & (values < self.maximum)
195 return np.array(mask)
198class SnSelector(VectorAction):
199 """Selects points that have S/N > threshold in the given flux type"""
201 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux")
202 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0)
203 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6)
204 uncertaintySuffix = Field[str](
205 doc="Suffix to add to fluxType to specify uncertainty column", default="Err"
206 )
207 bands = ListField[str](
208 doc="The bands to apply the signal to noise cut in." "Takes precedence if bands passed to call",
209 default=[],
210 )
212 def getInputSchema(self) -> KeyedDataSchema:
213 yield (fluxCol := self.fluxType), Vector
214 yield f"{fluxCol}{self.uncertaintySuffix}", Vector
216 def __call__(self, data: KeyedData, **kwargs) -> Vector:
217 """Makes a mask of objects that have S/N greater than
218 self.threshold in self.fluxType
219 Parameters
220 ----------
221 data : `KeyedData`
222 Returns
223 -------
224 result : `Vector`
225 A mask of the objects that satisfy the given
226 S/N cut.
227 """
228 mask: Optional[Vector] = None
229 bands: tuple[str, ...]
230 match kwargs:
231 case {"band": band} if not self.bands and self.bands == []:
232 bands = (band,)
233 case {"bands": bands} if not self.bands and self.bands == []:
234 bands = bands
235 case _ if self.bands:
236 bands = tuple(self.bands)
237 case _:
238 bands = ("",)
239 for band in bands:
240 fluxCol = self.fluxType.format(**(kwargs | dict(band=band)))
241 errCol = f"{fluxCol}{self.uncertaintySuffix.format(**kwargs)}"
242 vec = cast(Vector, data[fluxCol]) / data[errCol]
243 temp = (vec > self.threshold) & (vec < self.maxSN)
244 if mask is not None:
245 mask &= temp # type: ignore
246 else:
247 mask = temp
249 # It should not be possible for mask to be a None now
250 return np.array(cast(Vector, mask))
253class SkyObjectSelector(FlagSelector):
254 """Selects sky objects in the given band(s)"""
256 bands = ListField[str](
257 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
258 default=["i"],
259 )
261 def getInputSchema(self) -> KeyedDataSchema:
262 yield from super().getInputSchema()
264 def __call__(self, data: KeyedData, **kwargs) -> Vector:
265 result: Optional[Vector] = None
266 bands: tuple[str, ...]
267 match kwargs:
268 case {"band": band} if not self.bands and self.bands == []:
269 bands = (band,)
270 case {"bands": bands} if not self.bands and self.bands == []:
271 bands = bands
272 case _ if self.bands:
273 bands = tuple(self.bands)
274 case _:
275 bands = ("",)
276 for band in bands:
277 temp = super().__call__(data, **(kwargs | dict(band=band)))
278 if result is not None:
279 result &= temp # type: ignore
280 else:
281 result = temp
282 return cast(Vector, result)
284 def setDefaults(self):
285 self.selectWhenFalse = [
286 "{band}_pixelFlags_edge",
287 ]
288 self.selectWhenTrue = ["sky_object"]
291class SkySourceSelector(FlagSelector):
292 """Selects sky sources from sourceTables"""
294 def getInputSchema(self) -> KeyedDataSchema:
295 yield from super().getInputSchema()
297 def __call__(self, data: KeyedData, **kwargs) -> Vector:
298 result: Optional[Vector] = None
299 temp = super().__call__(data, **(kwargs))
300 if result is not None:
301 result &= temp # type: ignore
302 else:
303 result = temp
304 return result
306 def setDefaults(self):
307 self.selectWhenFalse = [
308 "pixelFlags_edge",
309 ]
310 self.selectWhenTrue = ["sky_source"]
313class ExtendednessSelector(VectorAction):
314 vectorKey = Field[str](
315 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness"
316 )
318 def getInputSchema(self) -> KeyedDataSchema:
319 return ((self.vectorKey, Vector),)
321 def __call__(self, data: KeyedData, **kwargs) -> Vector:
322 key = self.vectorKey.format(**kwargs)
323 return cast(Vector, data[key])
326class StarSelector(ExtendednessSelector):
327 extendedness_maximum = Field[float](
328 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float
329 )
331 def __call__(self, data: KeyedData, **kwargs) -> Vector:
332 extendedness = super().__call__(data, **kwargs)
333 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum)))
336class GalaxySelector(ExtendednessSelector):
337 extendedness_minimum = Field[float](
338 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5
339 )
341 def __call__(self, data: KeyedData, **kwargs) -> Vector:
342 extendedness = super().__call__(data, **kwargs)
343 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore
346class UnknownSelector(ExtendednessSelector):
347 def __call__(self, data: KeyedData, **kwargs) -> Vector:
348 extendedness = super().__call__(data, **kwargs)
349 return extendedness == 9
352class VectorSelector(VectorAction):
353 """Load a boolean vector from KeyedData and return it for use as a
354 selector.
355 """
357 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask")
359 def getInputSchema(self) -> KeyedDataSchema:
360 return ((self.vectorKey, Vector),)
362 def __call__(self, data: KeyedData, **kwargs) -> Vector:
363 return cast(Vector, data[self.vectorKey.format(**kwargs)])
366class ThresholdSelector(VectorAction):
367 """Return a mask corresponding to an applied threshold."""
369 op = Field[str](doc="Operator name.")
370 threshold = Field[float](doc="Threshold to apply.")
371 vectorKey = Field[str](doc="Name of column")
373 def getInputSchema(self) -> KeyedDataSchema:
374 return ((self.vectorKey, Vector),)
376 def __call__(self, data: KeyedData, **kwargs) -> Vector:
377 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold)
378 return cast(Vector, mask)
381class BandSelector(VectorAction):
382 """Makes a mask for sources observed in a specified set of bands."""
384 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band")
385 bands = ListField[str](
386 doc="The bands to select. `None` indicates no band selection applied.",
387 default=[],
388 )
390 def getInputSchema(self) -> KeyedDataSchema:
391 return ((self.vectorKey, Vector),)
393 def __call__(self, data: KeyedData, **kwargs) -> Vector:
394 bands: Optional[tuple[str, ...]]
395 match kwargs:
396 case {"band": band} if not self.bands and self.bands == []:
397 bands = (band,)
398 case {"bands": bands} if not self.bands and self.bands == []:
399 bands = bands
400 case _ if self.bands:
401 bands = tuple(self.bands)
402 case _:
403 bands = None
404 if bands:
405 mask = np.in1d(data[self.vectorKey], bands)
406 else:
407 # No band selection is applied, i.e., select all rows
408 mask = np.full(len(data[self.vectorKey]), True) # type: ignore
409 return cast(Vector, mask)