Coverage for python/lsst/analysis/tools/actions/vector/selectors.py: 30%
214 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 03:18 -0700
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 03:18 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = (
24 "FlagSelector",
25 "CoaddPlotFlagSelector",
26 "RangeSelector",
27 "SnSelector",
28 "ExtendednessSelector",
29 "SkyObjectSelector",
30 "SkySourceSelector",
31 "GoodDiaSourceSelector",
32 "StarSelector",
33 "GalaxySelector",
34 "UnknownSelector",
35 "VectorSelector",
36 "VisitPlotFlagSelector",
37 "ThresholdSelector",
38 "BandSelector",
39)
41import operator
42from typing import Optional, cast
44import numpy as np
45from lsst.pex.config import Field
46from lsst.pex.config.listField import ListField
48from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction
51class SelectorBase(VectorAction):
52 plotLabelKey = Field[str](
53 doc="Key to use when populating plot info, ignored if empty string", optional=True, default=""
54 )
56 def _addValueToPlotInfo(self, value, **kwargs):
57 if "plotInfo" in kwargs and self.plotLabelKey:
58 kwargs["plotInfo"][self.plotLabelKey] = value
61class FlagSelector(VectorAction):
62 """The base flag selector to use to select valid sources for QA."""
64 selectWhenFalse = ListField[str](
65 doc="Names of the flag columns to select on when False", optional=False, default=[]
66 )
68 selectWhenTrue = ListField[str](
69 doc="Names of the flag columns to select on when True", optional=False, default=[]
70 )
72 def getInputSchema(self) -> KeyedDataSchema:
73 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue)
74 return ((col, Vector) for col in allCols)
76 def __call__(self, data: KeyedData, **kwargs) -> Vector:
77 """Select on the given flags
79 Parameters
80 ----------
81 data : `KeyedData`
83 Returns
84 -------
85 result : `Vector`
86 A mask of the objects that satisfy the given
87 flag cuts.
89 Notes
90 -----
91 Uses the columns in selectWhenFalse and
92 selectWhenTrue to decide which columns to
93 select on in each circumstance.
94 """
96 if not self.selectWhenFalse and not self.selectWhenTrue:
97 raise RuntimeError("No column keys specified")
98 results: Optional[Vector] = None
100 for flag in self.selectWhenFalse: # type: ignore
101 temp = np.array(data[flag.format(**kwargs)] == 0)
102 if results is not None:
103 results &= temp # type: ignore
104 else:
105 results = temp
107 for flag in self.selectWhenTrue:
108 temp = np.array(data[flag.format(**kwargs)] == 1)
109 if results is not None:
110 results &= temp # type: ignore
111 else:
112 results = temp
113 # The test at the beginning assures this can never be None
114 return cast(Vector, results)
117class CoaddPlotFlagSelector(FlagSelector):
118 """This default setting makes it take the band from
119 the kwargs.
120 """
122 bands = ListField[str](
123 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
124 default=[],
125 )
127 def getInputSchema(self) -> KeyedDataSchema:
128 yield from super().getInputSchema()
130 def __call__(self, data: KeyedData, **kwargs) -> Vector:
131 result: Optional[Vector] = None
132 bands: tuple[str, ...]
133 match kwargs:
134 case {"band": band} if not self.bands and self.bands == []:
135 bands = (band,)
136 case {"bands": bands} if not self.bands and self.bands == []:
137 bands = bands
138 case _ if self.bands:
139 bands = tuple(self.bands)
140 case _:
141 bands = ("",)
142 for band in bands:
143 temp = super().__call__(data, **(kwargs | dict(band=band)))
144 if result is not None:
145 result &= temp # type: ignore
146 else:
147 result = temp
148 return cast(Vector, result)
150 def setDefaults(self):
151 self.selectWhenFalse = [
152 "{band}_psfFlux_flag",
153 "{band}_pixelFlags_saturatedCenter",
154 "{band}_extendedness_flag",
155 "xy_flag",
156 ]
157 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"]
160class VisitPlotFlagSelector(FlagSelector):
161 """Select on a set of flags appropriate for making visit-level plots
162 (i.e., using sourceTable_visit catalogs).
163 """
165 def getInputSchema(self) -> KeyedDataSchema:
166 yield from super().getInputSchema()
168 def __call__(self, data: KeyedData, **kwargs) -> Vector:
169 result: Optional[Vector] = None
170 temp = super().__call__(data, **kwargs)
171 if result is not None:
172 result &= temp # type: ignore
173 else:
174 result = temp
176 return result
178 def setDefaults(self):
179 self.selectWhenFalse = [
180 "psfFlux_flag",
181 "pixelFlags_saturatedCenter",
182 "extendedness_flag",
183 "centroid_flag",
184 ]
187class RangeSelector(VectorAction):
188 """Selects rows within a range, inclusive of min/exclusive of max."""
190 key = Field[str](doc="Key to select from data")
191 maximum = Field[float](doc="The maximum value", default=np.Inf)
192 minimum = Field[float](doc="The minimum value", default=np.nextafter(-np.Inf, 0.0))
194 def getInputSchema(self) -> KeyedDataSchema:
195 yield self.key, Vector
197 def __call__(self, data: KeyedData, **kwargs) -> Vector:
198 """Return a mask of rows with values within the specified range.
200 Parameters
201 ----------
202 data : `KeyedData`
204 Returns
205 -------
206 result : `Vector`
207 A mask of the rows with values within the specified range.
208 """
209 values = cast(Vector, data[self.key])
210 mask = (values >= self.minimum) & (values < self.maximum)
212 return np.array(mask)
215class SnSelector(SelectorBase):
216 """Selects points that have S/N > threshold in the given flux type."""
218 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux")
219 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0)
220 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6)
221 uncertaintySuffix = Field[str](
222 doc="Suffix to add to fluxType to specify uncertainty column", default="Err"
223 )
224 bands = ListField[str](
225 doc="The bands to apply the signal to noise cut in." "Takes precedence if bands passed to call",
226 default=[],
227 )
229 def getInputSchema(self) -> KeyedDataSchema:
230 yield (fluxCol := self.fluxType), Vector
231 yield f"{fluxCol}{self.uncertaintySuffix}", Vector
233 def __call__(self, data: KeyedData, **kwargs) -> Vector:
234 """Makes a mask of objects that have S/N greater than
235 self.threshold in self.fluxType
237 Parameters
238 ----------
239 data : `KeyedData`
240 The data to perform the selection on.
242 Returns
243 -------
244 result : `Vector`
245 A mask of the objects that satisfy the given
246 S/N cut.
247 """
249 self._addValueToPlotInfo(self.threshold, **kwargs)
250 mask: Optional[Vector] = None
251 bands: tuple[str, ...]
252 match kwargs:
253 case {"band": band} if not self.bands and self.bands == []:
254 bands = (band,)
255 case {"bands": bands} if not self.bands and self.bands == []:
256 bands = bands
257 case _ if self.bands:
258 bands = tuple(self.bands)
259 case _:
260 bands = ("",)
261 for band in bands:
262 fluxCol = self.fluxType.format(**(kwargs | dict(band=band)))
263 errCol = f"{fluxCol}{self.uncertaintySuffix.format(**kwargs)}"
264 vec = cast(Vector, data[fluxCol]) / cast(Vector, data[errCol])
265 temp = (vec > self.threshold) & (vec < self.maxSN)
266 if mask is not None:
267 mask &= temp # type: ignore
268 else:
269 mask = temp
271 # It should not be possible for mask to be a None now
272 return np.array(cast(Vector, mask))
275class SkyObjectSelector(FlagSelector):
276 """Selects sky objects in the given band(s)."""
278 bands = ListField[str](
279 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
280 default=["i"],
281 )
283 def getInputSchema(self) -> KeyedDataSchema:
284 yield from super().getInputSchema()
286 def __call__(self, data: KeyedData, **kwargs) -> Vector:
287 result: Optional[Vector] = None
288 bands: tuple[str, ...]
289 match kwargs:
290 case {"band": band} if not self.bands and self.bands == []:
291 bands = (band,)
292 case {"bands": bands} if not self.bands and self.bands == []:
293 bands = bands
294 case _ if self.bands:
295 bands = tuple(self.bands)
296 case _:
297 bands = ("",)
298 for band in bands:
299 temp = super().__call__(data, **(kwargs | dict(band=band)))
300 if result is not None:
301 result &= temp # type: ignore
302 else:
303 result = temp
304 return cast(Vector, result)
306 def setDefaults(self):
307 self.selectWhenFalse = [
308 "{band}_pixelFlags_edge",
309 ]
310 self.selectWhenTrue = ["sky_object"]
313class SkySourceSelector(FlagSelector):
314 """Selects sky sources from sourceTables."""
316 def getInputSchema(self) -> KeyedDataSchema:
317 yield from super().getInputSchema()
319 def __call__(self, data: KeyedData, **kwargs) -> Vector:
320 result: Optional[Vector] = None
321 temp = super().__call__(data, **(kwargs))
322 if result is not None:
323 result &= temp # type: ignore
324 else:
325 result = temp
326 return result
328 def setDefaults(self):
329 self.selectWhenFalse = [
330 "pixelFlags_edge",
331 ]
332 self.selectWhenTrue = ["sky_source"]
335class GoodDiaSourceSelector(FlagSelector):
336 """Selects good DIA sources from diaSourceTables."""
338 def getInputSchema(self) -> KeyedDataSchema:
339 yield from super().getInputSchema()
341 def __call__(self, data: KeyedData, **kwargs) -> Vector:
342 result: Optional[Vector] = None
343 temp = super().__call__(data, **(kwargs))
344 if result is not None:
345 result &= temp # type: ignore
346 else:
347 result = temp
348 return result
350 def setDefaults(self):
351 self.selectWhenFalse = [
352 "base_PixelFlags_flag_bad",
353 "base_PixelFlags_flag_suspect",
354 "base_PixelFlags_flag_saturatedCenter",
355 "base_PixelFlags_flag_interpolated",
356 "base_PixelFlags_flag_interpolatedCenter",
357 "base_PixelFlags_flag_edge",
358 ]
361class ExtendednessSelector(VectorAction):
362 """A selector that picks between extended and point sources."""
364 vectorKey = Field[str](
365 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness"
366 )
368 def getInputSchema(self) -> KeyedDataSchema:
369 return ((self.vectorKey, Vector),)
371 def __call__(self, data: KeyedData, **kwargs) -> Vector:
372 key = self.vectorKey.format(**kwargs)
373 return cast(Vector, data[key])
376class StarSelector(ExtendednessSelector):
377 """A selector that picks out stars based off of their
378 extendedness values.
379 """
381 extendedness_maximum = Field[float](
382 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float
383 )
385 def __call__(self, data: KeyedData, **kwargs) -> Vector:
386 extendedness = super().__call__(data, **kwargs)
387 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum)))
390class GalaxySelector(ExtendednessSelector):
391 """A selector that picks out galaxies based off of their
392 extendedness values.
393 """
395 extendedness_minimum = Field[float](
396 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5
397 )
399 def __call__(self, data: KeyedData, **kwargs) -> Vector:
400 extendedness = super().__call__(data, **kwargs)
401 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore
404class UnknownSelector(ExtendednessSelector):
405 """A selector that picks out unclassified objects based off of their
406 extendedness values.
407 """
409 def __call__(self, data: KeyedData, **kwargs) -> Vector:
410 extendedness = super().__call__(data, **kwargs)
411 return extendedness == 9
414class VectorSelector(VectorAction):
415 """Load a boolean vector from KeyedData and return it for use as a
416 selector.
417 """
419 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask")
421 def getInputSchema(self) -> KeyedDataSchema:
422 return ((self.vectorKey, Vector),)
424 def __call__(self, data: KeyedData, **kwargs) -> Vector:
425 return cast(Vector, data[self.vectorKey.format(**kwargs)])
428class ThresholdSelector(VectorAction):
429 """Return a mask corresponding to an applied threshold."""
431 op = Field[str](doc="Operator name.")
432 threshold = Field[float](doc="Threshold to apply.")
433 vectorKey = Field[str](doc="Name of column")
435 def getInputSchema(self) -> KeyedDataSchema:
436 return ((self.vectorKey, Vector),)
438 def __call__(self, data: KeyedData, **kwargs) -> Vector:
439 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold)
440 return cast(Vector, mask)
443class BandSelector(VectorAction):
444 """Makes a mask for sources observed in a specified set of bands."""
446 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band")
447 bands = ListField[str](
448 doc="The bands to select. `None` indicates no band selection applied.",
449 default=[],
450 )
452 def getInputSchema(self) -> KeyedDataSchema:
453 return ((self.vectorKey, Vector),)
455 def __call__(self, data: KeyedData, **kwargs) -> Vector:
456 bands: Optional[tuple[str, ...]]
457 match kwargs:
458 case {"band": band} if not self.bands and self.bands == []:
459 bands = (band,)
460 case {"bands": bands} if not self.bands and self.bands == []:
461 bands = bands
462 case _ if self.bands:
463 bands = tuple(self.bands)
464 case _:
465 bands = None
466 if bands:
467 mask = np.in1d(data[self.vectorKey], bands)
468 else:
469 # No band selection is applied, i.e., select all rows
470 mask = np.full(len(data[self.vectorKey]), True) # type: ignore
471 return cast(Vector, mask)