Coverage for python/lsst/analysis/tools/actions/vector/selectors.py: 29%
196 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-04 10:49 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-04 10:49 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = (
24 "FlagSelector",
25 "CoaddPlotFlagSelector",
26 "RangeSelector",
27 "SnSelector",
28 "ExtendednessSelector",
29 "SkyObjectSelector",
30 "StarSelector",
31 "GalaxySelector",
32 "UnknownSelector",
33 "VectorSelector",
34 "VisitPlotFlagSelector",
35 "ThresholdSelector",
36 "BandSelector",
37)
39import operator
40from typing import Optional, cast
42import numpy as np
43from lsst.pex.config import Field
44from lsst.pex.config.listField import ListField
46from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction
49class FlagSelector(VectorAction):
50 """The base flag selector to use to select valid sources for QA"""
52 selectWhenFalse = ListField[str](
53 doc="Names of the flag columns to select on when False", optional=False, default=[]
54 )
56 selectWhenTrue = ListField[str](
57 doc="Names of the flag columns to select on when True", optional=False, default=[]
58 )
60 def getInputSchema(self) -> KeyedDataSchema:
61 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue)
62 return ((col, Vector) for col in allCols)
64 def __call__(self, data: KeyedData, **kwargs) -> Vector:
65 """Select on the given flags
66 Parameters
67 ----------
68 table : `Tabular`
69 Returns
70 -------
71 result : `Vector`
72 A mask of the objects that satisfy the given
73 flag cuts.
74 Notes
75 -----
76 Uses the columns in selectWhenFalse and
77 selectWhenTrue to decide which columns to
78 select on in each circumstance.
79 """
80 if not self.selectWhenFalse and not self.selectWhenTrue:
81 raise RuntimeError("No column keys specified")
82 results: Optional[Vector] = None
84 for flag in self.selectWhenFalse: # type: ignore
85 temp = np.array(data[flag.format(**kwargs)] == 0)
86 if results is not None:
87 results &= temp # type: ignore
88 else:
89 results = temp
91 for flag in self.selectWhenTrue:
92 temp = np.array(data[flag.format(**kwargs)] == 1)
93 if results is not None:
94 results &= temp # type: ignore
95 else:
96 results = temp
97 # The test at the beginning assures this can never be None
98 return cast(Vector, results)
101class CoaddPlotFlagSelector(FlagSelector):
102 bands = ListField[str](
103 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
104 default=["i"],
105 )
107 def getInputSchema(self) -> KeyedDataSchema:
108 yield from super().getInputSchema()
110 def __call__(self, data: KeyedData, **kwargs) -> Vector:
111 result: Optional[Vector] = None
112 bands: tuple[str, ...]
113 match kwargs:
114 case {"band": band}:
115 bands = (band,)
116 case {"bands": bands} if not self.bands:
117 bands = bands
118 case _ if self.bands:
119 bands = tuple(self.bands)
120 case _:
121 bands = ("",)
122 for band in bands:
123 temp = super().__call__(data, **(kwargs | dict(band=band)))
124 if result is not None:
125 result &= temp # type: ignore
126 else:
127 result = temp
128 return cast(Vector, result)
130 def setDefaults(self):
131 self.selectWhenFalse = [
132 "{band}_psfFlux_flag",
133 "{band}_pixelFlags_saturatedCenter",
134 "{band}_extendedness_flag",
135 "xy_flag",
136 ]
137 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"]
140class VisitPlotFlagSelector(FlagSelector):
141 """Select on a set of flags appropriate for making visit-level plots
142 (i.e., using sourceTable_visit catalogs).
143 """
145 def getInputSchema(self) -> KeyedDataSchema:
146 yield from super().getInputSchema()
148 def __call__(self, data: KeyedData, **kwargs) -> Vector:
149 result: Optional[Vector] = None
150 temp = super().__call__(data, **kwargs)
151 if result is not None:
152 result &= temp # type: ignore
153 else:
154 result = temp
156 return result
158 def setDefaults(self):
159 self.selectWhenFalse = [
160 "psfFlux_flag",
161 "pixelFlags_saturatedCenter",
162 "extendedness_flag",
163 "centroid_flag",
164 ]
167class RangeSelector(VectorAction):
168 """Selects rows within a range, inclusive of min/exclusive of max."""
170 column = Field[str](doc="Column to select from")
171 maximum = Field[float](doc="The maximum value", default=np.Inf)
172 minimum = Field[float](doc="The minimum value", default=np.nextafter(-np.Inf, 0.0))
174 def getInputSchema(self) -> KeyedDataSchema:
175 yield self.column, Vector
177 def __call__(self, data: KeyedData, **kwargs) -> Vector:
178 """Return a mask of rows with values within the specified range.
180 Parameters
181 ----------
182 data : `KeyedData`
184 Returns
185 -------
186 result : `Vector`
187 A mask of the rows with values within the specified range.
188 """
189 values = cast(Vector, data[self.column])
190 mask = (values >= self.minimum) & (values < self.maximum)
192 return np.array(mask)
195class SnSelector(VectorAction):
196 """Selects points that have S/N > threshold in the given flux type"""
198 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux")
199 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0)
200 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6)
201 uncertaintySuffix = Field[str](
202 doc="Suffix to add to fluxType to specify uncertainty column", default="Err"
203 )
204 bands = ListField[str](
205 doc="The bands to apply the signal to noise cut in." "Takes precedence if bands passed to call",
206 default=[],
207 )
209 def getInputSchema(self) -> KeyedDataSchema:
210 yield (fluxCol := self.fluxType), Vector
211 yield f"{fluxCol}{self.uncertaintySuffix}", Vector
213 def __call__(self, data: KeyedData, **kwargs) -> Vector:
214 """Makes a mask of objects that have S/N greater than
215 self.threshold in self.fluxType
216 Parameters
217 ----------
218 data : `KeyedData`
219 Returns
220 -------
221 result : `Vector`
222 A mask of the objects that satisfy the given
223 S/N cut.
224 """
225 mask: Optional[Vector] = None
226 bands: tuple[str, ...]
227 match kwargs:
228 case {"band": band}:
229 bands = (band,)
230 case {"bands": bands} if not self.bands:
231 bands = bands
232 case _ if self.bands:
233 bands = tuple(self.bands)
234 case _:
235 bands = ("",)
236 for band in bands:
237 fluxCol = self.fluxType.format(**(kwargs | dict(band=band)))
238 errCol = f"{fluxCol}{self.uncertaintySuffix.format(**kwargs)}"
239 vec = cast(Vector, data[fluxCol]) / data[errCol]
240 temp = (vec > self.threshold) & (vec < self.maxSN)
241 if mask is not None:
242 mask &= temp # type: ignore
243 else:
244 mask = temp
246 # It should not be possible for mask to be a None now
247 return np.array(cast(Vector, mask))
250class SkyObjectSelector(FlagSelector):
251 """Selects sky objects in the given band(s)"""
253 bands = ListField[str](
254 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
255 default=["i"],
256 )
258 def getInputSchema(self) -> KeyedDataSchema:
259 yield from super().getInputSchema()
261 def __call__(self, data: KeyedData, **kwargs) -> Vector:
262 result: Optional[Vector] = None
263 bands: tuple[str, ...]
264 match kwargs:
265 case {"band": band}:
266 bands = (band,)
267 case {"bands": bands} if not self.bands:
268 bands = bands
269 case _ if self.bands:
270 bands = tuple(self.bands)
271 case _:
272 bands = ("",)
273 for band in bands:
274 temp = super().__call__(data, **(kwargs | dict(band=band)))
275 if result is not None:
276 result &= temp # type: ignore
277 else:
278 result = temp
279 return cast(Vector, result)
281 def setDefaults(self):
282 self.selectWhenFalse = [
283 "{band}_pixelFlags_edge",
284 ]
285 self.selectWhenTrue = ["sky_object"]
288class SkySourceSelector(FlagSelector):
289 """Selects sky sources from sourceTables"""
291 def getInputSchema(self) -> KeyedDataSchema:
292 yield from super().getInputSchema()
294 def __call__(self, data: KeyedData, **kwargs) -> Vector:
295 result: Optional[Vector] = None
296 temp = super().__call__(data, **(kwargs))
297 if result is not None:
298 result &= temp # type: ignore
299 else:
300 result = temp
301 return result
303 def setDefaults(self):
304 self.selectWhenFalse = [
305 "pixelFlags_edge",
306 ]
307 self.selectWhenTrue = ["sky_source"]
310class ExtendednessSelector(VectorAction):
311 vectorKey = Field[str](
312 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness"
313 )
315 def getInputSchema(self) -> KeyedDataSchema:
316 return ((self.vectorKey, Vector),)
318 def __call__(self, data: KeyedData, **kwargs) -> Vector:
319 key = self.vectorKey.format(**kwargs)
320 return cast(Vector, data[key])
323class StarSelector(ExtendednessSelector):
324 extendedness_maximum = Field[float](
325 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float
326 )
328 def __call__(self, data: KeyedData, **kwargs) -> Vector:
329 extendedness = super().__call__(data, **kwargs)
330 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum)))
333class GalaxySelector(ExtendednessSelector):
334 extendedness_minimum = Field[float](
335 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5
336 )
338 def __call__(self, data: KeyedData, **kwargs) -> Vector:
339 extendedness = super().__call__(data, **kwargs)
340 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore
343class UnknownSelector(ExtendednessSelector):
344 def __call__(self, data: KeyedData, **kwargs) -> Vector:
345 extendedness = super().__call__(data, **kwargs)
346 return extendedness == 9
349class VectorSelector(VectorAction):
350 """Load a boolean vector from KeyedData and return it for use as a
351 selector.
352 """
354 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask")
356 def getInputSchema(self) -> KeyedDataSchema:
357 return ((self.vectorKey, Vector),)
359 def __call__(self, data: KeyedData, **kwargs) -> Vector:
360 return cast(Vector, data[self.vectorKey.format(**kwargs)])
363class ThresholdSelector(VectorAction):
364 """Return a mask corresponding to an applied threshold."""
366 op = Field[str](doc="Operator name.")
367 threshold = Field[float](doc="Threshold to apply.")
368 vectorKey = Field[str](doc="Name of column")
370 def getInputSchema(self) -> KeyedDataSchema:
371 return ((self.vectorKey, Vector),)
373 def __call__(self, data: KeyedData, **kwargs) -> Vector:
374 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold)
375 return cast(Vector, mask)
378class BandSelector(VectorAction):
379 """Makes a mask for sources observed in a specified set of bands."""
381 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band")
382 bands = ListField[str](
383 doc="The bands to select. `None` indicates no band selection applied.",
384 default=[],
385 )
387 def getInputSchema(self) -> KeyedDataSchema:
388 return ((self.vectorKey, Vector),)
390 def __call__(self, data: KeyedData, **kwargs) -> Vector:
391 bands: Optional[tuple[str, ...]]
392 match kwargs:
393 case {"band": band}:
394 bands = (band,)
395 case {"bands": bands} if not self.bands:
396 bands = bands
397 case _ if self.bands:
398 bands = tuple(self.bands)
399 case _:
400 bands = None
401 if bands:
402 mask = np.in1d(data[self.vectorKey], bands)
403 else:
404 # No band selection is applied, i.e., select all rows
405 mask = np.full(len(data[self.vectorKey]), True) # type: ignore
406 return cast(Vector, mask)