Coverage for python/lsst/analysis/tools/actions/vector/selectors.py: 29%
208 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-28 03:17 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-28 03:17 -0800
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = (
24 "FlagSelector",
25 "CoaddPlotFlagSelector",
26 "RangeSelector",
27 "SnSelector",
28 "ExtendednessSelector",
29 "SkyObjectSelector",
30 "SkySourceSelector",
31 "GoodDiaSourceSelector",
32 "StarSelector",
33 "GalaxySelector",
34 "UnknownSelector",
35 "VectorSelector",
36 "VisitPlotFlagSelector",
37 "ThresholdSelector",
38 "BandSelector",
39)
41import operator
42from typing import Optional, cast
44import numpy as np
45from lsst.pex.config import Field
46from lsst.pex.config.listField import ListField
48from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction
51class FlagSelector(VectorAction):
52 """The base flag selector to use to select valid sources for QA"""
54 selectWhenFalse = ListField[str](
55 doc="Names of the flag columns to select on when False", optional=False, default=[]
56 )
58 selectWhenTrue = ListField[str](
59 doc="Names of the flag columns to select on when True", optional=False, default=[]
60 )
62 def getInputSchema(self) -> KeyedDataSchema:
63 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue)
64 return ((col, Vector) for col in allCols)
66 def __call__(self, data: KeyedData, **kwargs) -> Vector:
67 """Select on the given flags
68 Parameters
69 ----------
70 table : `Tabular`
71 Returns
72 -------
73 result : `Vector`
74 A mask of the objects that satisfy the given
75 flag cuts.
76 Notes
77 -----
78 Uses the columns in selectWhenFalse and
79 selectWhenTrue to decide which columns to
80 select on in each circumstance.
81 """
82 if not self.selectWhenFalse and not self.selectWhenTrue:
83 raise RuntimeError("No column keys specified")
84 results: Optional[Vector] = None
86 for flag in self.selectWhenFalse: # type: ignore
87 temp = np.array(data[flag.format(**kwargs)] == 0)
88 if results is not None:
89 results &= temp # type: ignore
90 else:
91 results = temp
93 for flag in self.selectWhenTrue:
94 temp = np.array(data[flag.format(**kwargs)] == 1)
95 if results is not None:
96 results &= temp # type: ignore
97 else:
98 results = temp
99 # The test at the beginning assures this can never be None
100 return cast(Vector, results)
103class CoaddPlotFlagSelector(FlagSelector):
104 """This default setting makes it take the band from
105 the kwargs."""
107 bands = ListField[str](
108 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
109 default=[],
110 )
112 def getInputSchema(self) -> KeyedDataSchema:
113 yield from super().getInputSchema()
115 def __call__(self, data: KeyedData, **kwargs) -> Vector:
116 result: Optional[Vector] = None
117 bands: tuple[str, ...]
118 match kwargs:
119 case {"band": band} if not self.bands and self.bands == []:
120 bands = (band,)
121 case {"bands": bands} if not self.bands and self.bands == []:
122 bands = bands
123 case _ if self.bands:
124 bands = tuple(self.bands)
125 case _:
126 bands = ("",)
127 for band in bands:
128 temp = super().__call__(data, **(kwargs | dict(band=band)))
129 if result is not None:
130 result &= temp # type: ignore
131 else:
132 result = temp
133 return cast(Vector, result)
135 def setDefaults(self):
136 self.selectWhenFalse = [
137 "{band}_psfFlux_flag",
138 "{band}_pixelFlags_saturatedCenter",
139 "{band}_extendedness_flag",
140 "xy_flag",
141 ]
142 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"]
145class VisitPlotFlagSelector(FlagSelector):
146 """Select on a set of flags appropriate for making visit-level plots
147 (i.e., using sourceTable_visit catalogs).
148 """
150 def getInputSchema(self) -> KeyedDataSchema:
151 yield from super().getInputSchema()
153 def __call__(self, data: KeyedData, **kwargs) -> Vector:
154 result: Optional[Vector] = None
155 temp = super().__call__(data, **kwargs)
156 if result is not None:
157 result &= temp # type: ignore
158 else:
159 result = temp
161 return result
163 def setDefaults(self):
164 self.selectWhenFalse = [
165 "psfFlux_flag",
166 "pixelFlags_saturatedCenter",
167 "extendedness_flag",
168 "centroid_flag",
169 ]
172class RangeSelector(VectorAction):
173 """Selects rows within a range, inclusive of min/exclusive of max."""
175 key = Field[str](doc="Key to select from data")
176 maximum = Field[float](doc="The maximum value", default=np.Inf)
177 minimum = Field[float](doc="The minimum value", default=np.nextafter(-np.Inf, 0.0))
179 def getInputSchema(self) -> KeyedDataSchema:
180 yield self.key, Vector
182 def __call__(self, data: KeyedData, **kwargs) -> Vector:
183 """Return a mask of rows with values within the specified range.
185 Parameters
186 ----------
187 data : `KeyedData`
189 Returns
190 -------
191 result : `Vector`
192 A mask of the rows with values within the specified range.
193 """
194 values = cast(Vector, data[self.key])
195 mask = (values >= self.minimum) & (values < self.maximum)
197 return np.array(mask)
200class SnSelector(VectorAction):
201 """Selects points that have S/N > threshold in the given flux type"""
203 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux")
204 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0)
205 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6)
206 uncertaintySuffix = Field[str](
207 doc="Suffix to add to fluxType to specify uncertainty column", default="Err"
208 )
209 bands = ListField[str](
210 doc="The bands to apply the signal to noise cut in." "Takes precedence if bands passed to call",
211 default=[],
212 )
214 def getInputSchema(self) -> KeyedDataSchema:
215 yield (fluxCol := self.fluxType), Vector
216 yield f"{fluxCol}{self.uncertaintySuffix}", Vector
218 def __call__(self, data: KeyedData, **kwargs) -> Vector:
219 """Makes a mask of objects that have S/N greater than
220 self.threshold in self.fluxType
221 Parameters
222 ----------
223 data : `KeyedData`
224 Returns
225 -------
226 result : `Vector`
227 A mask of the objects that satisfy the given
228 S/N cut.
229 """
230 mask: Optional[Vector] = None
231 bands: tuple[str, ...]
232 match kwargs:
233 case {"band": band} if not self.bands and self.bands == []:
234 bands = (band,)
235 case {"bands": bands} if not self.bands and self.bands == []:
236 bands = bands
237 case _ if self.bands:
238 bands = tuple(self.bands)
239 case _:
240 bands = ("",)
241 for band in bands:
242 fluxCol = self.fluxType.format(**(kwargs | dict(band=band)))
243 errCol = f"{fluxCol}{self.uncertaintySuffix.format(**kwargs)}"
244 vec = cast(Vector, data[fluxCol]) / data[errCol]
245 temp = (vec > self.threshold) & (vec < self.maxSN)
246 if mask is not None:
247 mask &= temp # type: ignore
248 else:
249 mask = temp
251 # It should not be possible for mask to be a None now
252 return np.array(cast(Vector, mask))
255class SkyObjectSelector(FlagSelector):
256 """Selects sky objects in the given band(s)"""
258 bands = ListField[str](
259 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
260 default=["i"],
261 )
263 def getInputSchema(self) -> KeyedDataSchema:
264 yield from super().getInputSchema()
266 def __call__(self, data: KeyedData, **kwargs) -> Vector:
267 result: Optional[Vector] = None
268 bands: tuple[str, ...]
269 match kwargs:
270 case {"band": band} if not self.bands and self.bands == []:
271 bands = (band,)
272 case {"bands": bands} if not self.bands and self.bands == []:
273 bands = bands
274 case _ if self.bands:
275 bands = tuple(self.bands)
276 case _:
277 bands = ("",)
278 for band in bands:
279 temp = super().__call__(data, **(kwargs | dict(band=band)))
280 if result is not None:
281 result &= temp # type: ignore
282 else:
283 result = temp
284 return cast(Vector, result)
286 def setDefaults(self):
287 self.selectWhenFalse = [
288 "{band}_pixelFlags_edge",
289 ]
290 self.selectWhenTrue = ["sky_object"]
293class SkySourceSelector(FlagSelector):
294 """Selects sky sources from sourceTables"""
296 def getInputSchema(self) -> KeyedDataSchema:
297 yield from super().getInputSchema()
299 def __call__(self, data: KeyedData, **kwargs) -> Vector:
300 result: Optional[Vector] = None
301 temp = super().__call__(data, **(kwargs))
302 if result is not None:
303 result &= temp # type: ignore
304 else:
305 result = temp
306 return result
308 def setDefaults(self):
309 self.selectWhenFalse = [
310 "pixelFlags_edge",
311 ]
312 self.selectWhenTrue = ["sky_source"]
315class GoodDiaSourceSelector(FlagSelector):
316 """Selects good DIA sources from diaSourceTables"""
318 def getInputSchema(self) -> KeyedDataSchema:
319 yield from super().getInputSchema()
321 def __call__(self, data: KeyedData, **kwargs) -> Vector:
322 result: Optional[Vector] = None
323 temp = super().__call__(data, **(kwargs))
324 if result is not None:
325 result &= temp # type: ignore
326 else:
327 result = temp
328 return result
330 def setDefaults(self):
331 self.selectWhenFalse = [
332 "base_PixelFlags_flag_bad",
333 "base_PixelFlags_flag_suspect",
334 "base_PixelFlags_flag_saturatedCenter",
335 "base_PixelFlags_flag_interpolated",
336 "base_PixelFlags_flag_interpolatedCenter",
337 "base_PixelFlags_flag_edge",
338 ]
341class ExtendednessSelector(VectorAction):
342 vectorKey = Field[str](
343 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness"
344 )
346 def getInputSchema(self) -> KeyedDataSchema:
347 return ((self.vectorKey, Vector),)
349 def __call__(self, data: KeyedData, **kwargs) -> Vector:
350 key = self.vectorKey.format(**kwargs)
351 return cast(Vector, data[key])
354class StarSelector(ExtendednessSelector):
355 extendedness_maximum = Field[float](
356 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float
357 )
359 def __call__(self, data: KeyedData, **kwargs) -> Vector:
360 extendedness = super().__call__(data, **kwargs)
361 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum)))
364class GalaxySelector(ExtendednessSelector):
365 extendedness_minimum = Field[float](
366 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5
367 )
369 def __call__(self, data: KeyedData, **kwargs) -> Vector:
370 extendedness = super().__call__(data, **kwargs)
371 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore
374class UnknownSelector(ExtendednessSelector):
375 def __call__(self, data: KeyedData, **kwargs) -> Vector:
376 extendedness = super().__call__(data, **kwargs)
377 return extendedness == 9
380class VectorSelector(VectorAction):
381 """Load a boolean vector from KeyedData and return it for use as a
382 selector.
383 """
385 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask")
387 def getInputSchema(self) -> KeyedDataSchema:
388 return ((self.vectorKey, Vector),)
390 def __call__(self, data: KeyedData, **kwargs) -> Vector:
391 return cast(Vector, data[self.vectorKey.format(**kwargs)])
394class ThresholdSelector(VectorAction):
395 """Return a mask corresponding to an applied threshold."""
397 op = Field[str](doc="Operator name.")
398 threshold = Field[float](doc="Threshold to apply.")
399 vectorKey = Field[str](doc="Name of column")
401 def getInputSchema(self) -> KeyedDataSchema:
402 return ((self.vectorKey, Vector),)
404 def __call__(self, data: KeyedData, **kwargs) -> Vector:
405 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold)
406 return cast(Vector, mask)
409class BandSelector(VectorAction):
410 """Makes a mask for sources observed in a specified set of bands."""
412 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band")
413 bands = ListField[str](
414 doc="The bands to select. `None` indicates no band selection applied.",
415 default=[],
416 )
418 def getInputSchema(self) -> KeyedDataSchema:
419 return ((self.vectorKey, Vector),)
421 def __call__(self, data: KeyedData, **kwargs) -> Vector:
422 bands: Optional[tuple[str, ...]]
423 match kwargs:
424 case {"band": band} if not self.bands and self.bands == []:
425 bands = (band,)
426 case {"bands": bands} if not self.bands and self.bands == []:
427 bands = bands
428 case _ if self.bands:
429 bands = tuple(self.bands)
430 case _:
431 bands = None
432 if bands:
433 mask = np.in1d(data[self.vectorKey], bands)
434 else:
435 # No band selection is applied, i.e., select all rows
436 mask = np.full(len(data[self.vectorKey]), True) # type: ignore
437 return cast(Vector, mask)