Coverage for python/lsst/analysis/tools/actions/vector/selectors.py: 30%
214 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-28 05:16 -0700
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-28 05:16 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = (
24 "FlagSelector",
25 "CoaddPlotFlagSelector",
26 "RangeSelector",
27 "SnSelector",
28 "ExtendednessSelector",
29 "SkyObjectSelector",
30 "SkySourceSelector",
31 "GoodDiaSourceSelector",
32 "StarSelector",
33 "GalaxySelector",
34 "UnknownSelector",
35 "VectorSelector",
36 "VisitPlotFlagSelector",
37 "ThresholdSelector",
38 "BandSelector",
39)
41import operator
42from typing import Optional, cast
44import numpy as np
45from lsst.pex.config import Field
46from lsst.pex.config.listField import ListField
48from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction
51class SelectorBase(VectorAction):
52 plotLabelKey = Field[str](
53 doc="Key to use when populating plot info, ignored if empty string", optional=True, default=""
54 )
56 def _addValueToPlotInfo(self, value, **kwargs):
57 if "plotInfo" in kwargs and self.plotLabelKey:
58 kwargs["plotInfo"][self.plotLabelKey] = value
61class FlagSelector(VectorAction):
62 """The base flag selector to use to select valid sources for QA"""
64 selectWhenFalse = ListField[str](
65 doc="Names of the flag columns to select on when False", optional=False, default=[]
66 )
68 selectWhenTrue = ListField[str](
69 doc="Names of the flag columns to select on when True", optional=False, default=[]
70 )
72 def getInputSchema(self) -> KeyedDataSchema:
73 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue)
74 return ((col, Vector) for col in allCols)
76 def __call__(self, data: KeyedData, **kwargs) -> Vector:
77 """Select on the given flags
78 Parameters
79 ----------
80 table : `Tabular`
81 Returns
82 -------
83 result : `Vector`
84 A mask of the objects that satisfy the given
85 flag cuts.
86 Notes
87 -----
88 Uses the columns in selectWhenFalse and
89 selectWhenTrue to decide which columns to
90 select on in each circumstance.
91 """
92 if not self.selectWhenFalse and not self.selectWhenTrue:
93 raise RuntimeError("No column keys specified")
94 results: Optional[Vector] = None
96 for flag in self.selectWhenFalse: # type: ignore
97 temp = np.array(data[flag.format(**kwargs)] == 0)
98 if results is not None:
99 results &= temp # type: ignore
100 else:
101 results = temp
103 for flag in self.selectWhenTrue:
104 temp = np.array(data[flag.format(**kwargs)] == 1)
105 if results is not None:
106 results &= temp # type: ignore
107 else:
108 results = temp
109 # The test at the beginning assures this can never be None
110 return cast(Vector, results)
113class CoaddPlotFlagSelector(FlagSelector):
114 """This default setting makes it take the band from
115 the kwargs."""
117 bands = ListField[str](
118 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
119 default=[],
120 )
122 def getInputSchema(self) -> KeyedDataSchema:
123 yield from super().getInputSchema()
125 def __call__(self, data: KeyedData, **kwargs) -> Vector:
126 result: Optional[Vector] = None
127 bands: tuple[str, ...]
128 match kwargs:
129 case {"band": band} if not self.bands and self.bands == []:
130 bands = (band,)
131 case {"bands": bands} if not self.bands and self.bands == []:
132 bands = bands
133 case _ if self.bands:
134 bands = tuple(self.bands)
135 case _:
136 bands = ("",)
137 for band in bands:
138 temp = super().__call__(data, **(kwargs | dict(band=band)))
139 if result is not None:
140 result &= temp # type: ignore
141 else:
142 result = temp
143 return cast(Vector, result)
145 def setDefaults(self):
146 self.selectWhenFalse = [
147 "{band}_psfFlux_flag",
148 "{band}_pixelFlags_saturatedCenter",
149 "{band}_extendedness_flag",
150 "xy_flag",
151 ]
152 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"]
155class VisitPlotFlagSelector(FlagSelector):
156 """Select on a set of flags appropriate for making visit-level plots
157 (i.e., using sourceTable_visit catalogs).
158 """
160 def getInputSchema(self) -> KeyedDataSchema:
161 yield from super().getInputSchema()
163 def __call__(self, data: KeyedData, **kwargs) -> Vector:
164 result: Optional[Vector] = None
165 temp = super().__call__(data, **kwargs)
166 if result is not None:
167 result &= temp # type: ignore
168 else:
169 result = temp
171 return result
173 def setDefaults(self):
174 self.selectWhenFalse = [
175 "psfFlux_flag",
176 "pixelFlags_saturatedCenter",
177 "extendedness_flag",
178 "centroid_flag",
179 ]
182class RangeSelector(VectorAction):
183 """Selects rows within a range, inclusive of min/exclusive of max."""
185 key = Field[str](doc="Key to select from data")
186 maximum = Field[float](doc="The maximum value", default=np.Inf)
187 minimum = Field[float](doc="The minimum value", default=np.nextafter(-np.Inf, 0.0))
189 def getInputSchema(self) -> KeyedDataSchema:
190 yield self.key, Vector
192 def __call__(self, data: KeyedData, **kwargs) -> Vector:
193 """Return a mask of rows with values within the specified range.
195 Parameters
196 ----------
197 data : `KeyedData`
199 Returns
200 -------
201 result : `Vector`
202 A mask of the rows with values within the specified range.
203 """
204 values = cast(Vector, data[self.key])
205 mask = (values >= self.minimum) & (values < self.maximum)
207 return np.array(mask)
210class SnSelector(SelectorBase):
211 """Selects points that have S/N > threshold in the given flux type"""
213 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux")
214 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0)
215 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6)
216 uncertaintySuffix = Field[str](
217 doc="Suffix to add to fluxType to specify uncertainty column", default="Err"
218 )
219 bands = ListField[str](
220 doc="The bands to apply the signal to noise cut in." "Takes precedence if bands passed to call",
221 default=[],
222 )
224 def getInputSchema(self) -> KeyedDataSchema:
225 yield (fluxCol := self.fluxType), Vector
226 yield f"{fluxCol}{self.uncertaintySuffix}", Vector
228 def __call__(self, data: KeyedData, **kwargs) -> Vector:
229 """Makes a mask of objects that have S/N greater than
230 self.threshold in self.fluxType
231 Parameters
232 ----------
233 data : `KeyedData`
234 Returns
235 -------
236 result : `Vector`
237 A mask of the objects that satisfy the given
238 S/N cut.
239 """
240 self._addValueToPlotInfo(self.threshold, **kwargs)
241 mask: Optional[Vector] = None
242 bands: tuple[str, ...]
243 match kwargs:
244 case {"band": band} if not self.bands and self.bands == []:
245 bands = (band,)
246 case {"bands": bands} if not self.bands and self.bands == []:
247 bands = bands
248 case _ if self.bands:
249 bands = tuple(self.bands)
250 case _:
251 bands = ("",)
252 for band in bands:
253 fluxCol = self.fluxType.format(**(kwargs | dict(band=band)))
254 errCol = f"{fluxCol}{self.uncertaintySuffix.format(**kwargs)}"
255 vec = cast(Vector, data[fluxCol]) / cast(Vector, data[errCol])
256 temp = (vec > self.threshold) & (vec < self.maxSN)
257 if mask is not None:
258 mask &= temp # type: ignore
259 else:
260 mask = temp
262 # It should not be possible for mask to be a None now
263 return np.array(cast(Vector, mask))
266class SkyObjectSelector(FlagSelector):
267 """Selects sky objects in the given band(s)"""
269 bands = ListField[str](
270 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
271 default=["i"],
272 )
274 def getInputSchema(self) -> KeyedDataSchema:
275 yield from super().getInputSchema()
277 def __call__(self, data: KeyedData, **kwargs) -> Vector:
278 result: Optional[Vector] = None
279 bands: tuple[str, ...]
280 match kwargs:
281 case {"band": band} if not self.bands and self.bands == []:
282 bands = (band,)
283 case {"bands": bands} if not self.bands and self.bands == []:
284 bands = bands
285 case _ if self.bands:
286 bands = tuple(self.bands)
287 case _:
288 bands = ("",)
289 for band in bands:
290 temp = super().__call__(data, **(kwargs | dict(band=band)))
291 if result is not None:
292 result &= temp # type: ignore
293 else:
294 result = temp
295 return cast(Vector, result)
297 def setDefaults(self):
298 self.selectWhenFalse = [
299 "{band}_pixelFlags_edge",
300 ]
301 self.selectWhenTrue = ["sky_object"]
304class SkySourceSelector(FlagSelector):
305 """Selects sky sources from sourceTables"""
307 def getInputSchema(self) -> KeyedDataSchema:
308 yield from super().getInputSchema()
310 def __call__(self, data: KeyedData, **kwargs) -> Vector:
311 result: Optional[Vector] = None
312 temp = super().__call__(data, **(kwargs))
313 if result is not None:
314 result &= temp # type: ignore
315 else:
316 result = temp
317 return result
319 def setDefaults(self):
320 self.selectWhenFalse = [
321 "pixelFlags_edge",
322 ]
323 self.selectWhenTrue = ["sky_source"]
326class GoodDiaSourceSelector(FlagSelector):
327 """Selects good DIA sources from diaSourceTables"""
329 def getInputSchema(self) -> KeyedDataSchema:
330 yield from super().getInputSchema()
332 def __call__(self, data: KeyedData, **kwargs) -> Vector:
333 result: Optional[Vector] = None
334 temp = super().__call__(data, **(kwargs))
335 if result is not None:
336 result &= temp # type: ignore
337 else:
338 result = temp
339 return result
341 def setDefaults(self):
342 self.selectWhenFalse = [
343 "base_PixelFlags_flag_bad",
344 "base_PixelFlags_flag_suspect",
345 "base_PixelFlags_flag_saturatedCenter",
346 "base_PixelFlags_flag_interpolated",
347 "base_PixelFlags_flag_interpolatedCenter",
348 "base_PixelFlags_flag_edge",
349 ]
352class ExtendednessSelector(VectorAction):
353 vectorKey = Field[str](
354 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness"
355 )
357 def getInputSchema(self) -> KeyedDataSchema:
358 return ((self.vectorKey, Vector),)
360 def __call__(self, data: KeyedData, **kwargs) -> Vector:
361 key = self.vectorKey.format(**kwargs)
362 return cast(Vector, data[key])
365class StarSelector(ExtendednessSelector):
366 extendedness_maximum = Field[float](
367 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float
368 )
370 def __call__(self, data: KeyedData, **kwargs) -> Vector:
371 extendedness = super().__call__(data, **kwargs)
372 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum)))
375class GalaxySelector(ExtendednessSelector):
376 extendedness_minimum = Field[float](
377 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5
378 )
380 def __call__(self, data: KeyedData, **kwargs) -> Vector:
381 extendedness = super().__call__(data, **kwargs)
382 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore
385class UnknownSelector(ExtendednessSelector):
386 def __call__(self, data: KeyedData, **kwargs) -> Vector:
387 extendedness = super().__call__(data, **kwargs)
388 return extendedness == 9
391class VectorSelector(VectorAction):
392 """Load a boolean vector from KeyedData and return it for use as a
393 selector.
394 """
396 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask")
398 def getInputSchema(self) -> KeyedDataSchema:
399 return ((self.vectorKey, Vector),)
401 def __call__(self, data: KeyedData, **kwargs) -> Vector:
402 return cast(Vector, data[self.vectorKey.format(**kwargs)])
405class ThresholdSelector(VectorAction):
406 """Return a mask corresponding to an applied threshold."""
408 op = Field[str](doc="Operator name.")
409 threshold = Field[float](doc="Threshold to apply.")
410 vectorKey = Field[str](doc="Name of column")
412 def getInputSchema(self) -> KeyedDataSchema:
413 return ((self.vectorKey, Vector),)
415 def __call__(self, data: KeyedData, **kwargs) -> Vector:
416 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold)
417 return cast(Vector, mask)
420class BandSelector(VectorAction):
421 """Makes a mask for sources observed in a specified set of bands."""
423 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band")
424 bands = ListField[str](
425 doc="The bands to select. `None` indicates no band selection applied.",
426 default=[],
427 )
429 def getInputSchema(self) -> KeyedDataSchema:
430 return ((self.vectorKey, Vector),)
432 def __call__(self, data: KeyedData, **kwargs) -> Vector:
433 bands: Optional[tuple[str, ...]]
434 match kwargs:
435 case {"band": band} if not self.bands and self.bands == []:
436 bands = (band,)
437 case {"bands": bands} if not self.bands and self.bands == []:
438 bands = bands
439 case _ if self.bands:
440 bands = tuple(self.bands)
441 case _:
442 bands = None
443 if bands:
444 mask = np.in1d(data[self.vectorKey], bands)
445 else:
446 # No band selection is applied, i.e., select all rows
447 mask = np.full(len(data[self.vectorKey]), True) # type: ignore
448 return cast(Vector, mask)