Coverage for python/lsst/analysis/tools/actions/vector/selectors.py: 30%
229 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-05 14:05 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-05 14:05 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = (
24 "FlagSelector",
25 "CoaddPlotFlagSelector",
26 "RangeSelector",
27 "SnSelector",
28 "ExtendednessSelector",
29 "SkyObjectSelector",
30 "SkySourceSelector",
31 "GoodDiaSourceSelector",
32 "StarSelector",
33 "GalaxySelector",
34 "UnknownSelector",
35 "VectorSelector",
36 "VisitPlotFlagSelector",
37 "ThresholdSelector",
38 "BandSelector",
39)
41import operator
42from typing import Optional, cast
44import numpy as np
45from lsst.pex.config import Field
46from lsst.pex.config.listField import ListField
48from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction
49from ...math import divide
52class SelectorBase(VectorAction):
53 plotLabelKey = Field[str](
54 doc="Key to use when populating plot info, ignored if empty string", optional=True, default=""
55 )
57 def _addValueToPlotInfo(self, value, **kwargs):
58 if "plotInfo" in kwargs and self.plotLabelKey:
59 kwargs["plotInfo"][self.plotLabelKey] = value
62class FlagSelector(VectorAction):
63 """The base flag selector to use to select valid sources for QA."""
65 selectWhenFalse = ListField[str](
66 doc="Names of the flag columns to select on when False", optional=False, default=[]
67 )
69 selectWhenTrue = ListField[str](
70 doc="Names of the flag columns to select on when True", optional=False, default=[]
71 )
73 def getInputSchema(self) -> KeyedDataSchema:
74 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue)
75 return ((col, Vector) for col in allCols)
77 def __call__(self, data: KeyedData, **kwargs) -> Vector:
78 """Select on the given flags
80 Parameters
81 ----------
82 data : `KeyedData`
84 Returns
85 -------
86 result : `Vector`
87 A mask of the objects that satisfy the given
88 flag cuts.
90 Notes
91 -----
92 Uses the columns in selectWhenFalse and
93 selectWhenTrue to decide which columns to
94 select on in each circumstance.
95 """
97 if not self.selectWhenFalse and not self.selectWhenTrue:
98 raise RuntimeError("No column keys specified")
99 results: Optional[Vector] = None
101 for flag in self.selectWhenFalse: # type: ignore
102 temp = np.array(data[flag.format(**kwargs)] == 0)
103 if results is not None:
104 results &= temp # type: ignore
105 else:
106 results = temp
108 for flag in self.selectWhenTrue:
109 temp = np.array(data[flag.format(**kwargs)] == 1)
110 if results is not None:
111 results &= temp # type: ignore
112 else:
113 results = temp
114 # The test at the beginning assures this can never be None
115 return cast(Vector, results)
118class CoaddPlotFlagSelector(FlagSelector):
119 """This default setting makes it take the band from
120 the kwargs.
121 """
123 bands = ListField[str](
124 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
125 default=[],
126 )
128 def getInputSchema(self) -> KeyedDataSchema:
129 yield from super().getInputSchema()
131 def refMatchContext(self):
132 self.selectWhenFalse = [
133 "{band}_psfFlux_flag_target",
134 "{band}_pixelFlags_saturatedCenter_target",
135 "{band}_extendedness_flag_target",
136 "xy_flag_target",
137 ]
138 self.selectWhenTrue = ["detect_isPatchInner_target", "detect_isDeblendedSource_target"]
140 def __call__(self, data: KeyedData, **kwargs) -> Vector:
141 result: Optional[Vector] = None
142 bands: tuple[str, ...]
143 match kwargs:
144 case {"band": band} if not self.bands and self.bands == []:
145 bands = (band,)
146 case {"bands": bands} if not self.bands and self.bands == []:
147 bands = bands
148 case _ if self.bands:
149 bands = tuple(self.bands)
150 case _:
151 bands = ("",)
152 for band in bands:
153 temp = super().__call__(data, **(kwargs | dict(band=band)))
154 if result is not None:
155 result &= temp # type: ignore
156 else:
157 result = temp
158 return cast(Vector, result)
160 def setDefaults(self):
161 self.selectWhenFalse = [
162 "{band}_psfFlux_flag",
163 "{band}_pixelFlags_saturatedCenter",
164 "{band}_extendedness_flag",
165 "xy_flag",
166 ]
167 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"]
170class VisitPlotFlagSelector(FlagSelector):
171 """Select on a set of flags appropriate for making visit-level plots
172 (i.e., using sourceTable_visit catalogs).
173 """
175 catalogSuffix = Field[str](doc="The suffix to apply to all the keys.", default="")
177 def getInputSchema(self) -> KeyedDataSchema:
178 yield from super().getInputSchema()
180 def refMatchContext(self):
181 self.selectWhenFalse = [
182 "psfFlux_flag_target",
183 "pixelFlags_saturatedCenter_target",
184 "extendedness_flag_target",
185 "centroid_flag_target",
186 ]
188 def __call__(self, data: KeyedData, **kwargs) -> Vector:
189 result: Optional[Vector] = None
190 temp = super().__call__(data, **kwargs)
191 if result is not None:
192 result &= temp # type: ignore
193 else:
194 result = temp
196 return result
198 def setDefaults(self):
199 self.selectWhenFalse = [
200 "psfFlux_flag",
201 "pixelFlags_saturatedCenter",
202 "extendedness_flag",
203 "centroid_flag",
204 ]
207class RangeSelector(VectorAction):
208 """Selects rows within a range, inclusive of min/exclusive of max."""
210 vectorKey = Field[str](doc="Key to select from data")
211 maximum = Field[float](doc="The maximum value", default=np.Inf)
212 minimum = Field[float](doc="The minimum value", default=np.nextafter(-np.Inf, 0.0))
214 def getInputSchema(self) -> KeyedDataSchema:
215 yield self.vectorKey, Vector
217 def __call__(self, data: KeyedData, **kwargs) -> Vector:
218 """Return a mask of rows with values within the specified range.
220 Parameters
221 ----------
222 data : `KeyedData`
224 Returns
225 -------
226 result : `Vector`
227 A mask of the rows with values within the specified range.
228 """
229 values = cast(Vector, data[self.vectorKey])
230 mask = (values >= self.minimum) & (values < self.maximum)
232 return cast(Vector, mask)
235class SnSelector(SelectorBase):
236 """Selects points that have S/N > threshold in the given flux type."""
238 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux")
239 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0)
240 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6)
241 uncertaintySuffix = Field[str](
242 doc="Suffix to add to fluxType to specify uncertainty column", default="Err"
243 )
244 bands = ListField[str](
245 doc="The bands to apply the signal to noise cut in." "Takes precedence if bands passed to call",
246 default=[],
247 )
249 def getInputSchema(self) -> KeyedDataSchema:
250 fluxCol = self.fluxType
251 fluxInd = fluxCol.find("lux") + len("lux")
252 errCol = f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix}" + f"{fluxCol}"[fluxInd:]
253 yield fluxCol, Vector
254 yield errCol, Vector
256 def __call__(self, data: KeyedData, **kwargs) -> Vector:
257 """Makes a mask of objects that have S/N greater than
258 self.threshold in self.fluxType
260 Parameters
261 ----------
262 data : `KeyedData`
263 The data to perform the selection on.
265 Returns
266 -------
267 result : `Vector`
268 A mask of the objects that satisfy the given
269 S/N cut.
270 """
272 self._addValueToPlotInfo(self.threshold, **kwargs)
273 mask: Optional[Vector] = None
274 bands: tuple[str, ...]
275 match kwargs:
276 case {"band": band} if not self.bands and self.bands == []:
277 bands = (band,)
278 case {"bands": bands} if not self.bands and self.bands == []:
279 bands = bands
280 case _ if self.bands:
281 bands = tuple(self.bands)
282 case _:
283 bands = ("",)
284 for band in bands:
285 fluxCol = self.fluxType.format(**(kwargs | dict(band=band)))
286 fluxInd = fluxCol.find("lux") + len("lux")
287 errCol = (
288 f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix.format(**kwargs)}" + f"{fluxCol}"[fluxInd:]
289 )
290 vec = divide(cast(Vector, data[fluxCol]), cast(Vector, data[errCol]))
291 temp = (vec > self.threshold) & (vec < self.maxSN)
292 if mask is not None:
293 mask &= temp # type: ignore
294 else:
295 mask = temp
297 # It should not be possible for mask to be a None now
298 return np.array(cast(Vector, mask))
301class SkyObjectSelector(FlagSelector):
302 """Selects sky objects in the given band(s)."""
304 bands = ListField[str](
305 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
306 default=["i"],
307 )
309 def getInputSchema(self) -> KeyedDataSchema:
310 yield from super().getInputSchema()
312 def __call__(self, data: KeyedData, **kwargs) -> Vector:
313 result: Optional[Vector] = None
314 bands: tuple[str, ...]
315 match kwargs:
316 case {"band": band} if not self.bands and self.bands == []:
317 bands = (band,)
318 case {"bands": bands} if not self.bands and self.bands == []:
319 bands = bands
320 case _ if self.bands:
321 bands = tuple(self.bands)
322 case _:
323 bands = ("",)
324 for band in bands:
325 temp = super().__call__(data, **(kwargs | dict(band=band)))
326 if result is not None:
327 result &= temp # type: ignore
328 else:
329 result = temp
330 return cast(Vector, result)
332 def setDefaults(self):
333 self.selectWhenFalse = [
334 "{band}_pixelFlags_edge",
335 ]
336 self.selectWhenTrue = ["sky_object"]
339class SkySourceSelector(FlagSelector):
340 """Selects sky sources from sourceTables."""
342 def getInputSchema(self) -> KeyedDataSchema:
343 yield from super().getInputSchema()
345 def __call__(self, data: KeyedData, **kwargs) -> Vector:
346 result: Optional[Vector] = None
347 temp = super().__call__(data, **(kwargs))
348 if result is not None:
349 result &= temp # type: ignore
350 else:
351 result = temp
352 return result
354 def setDefaults(self):
355 self.selectWhenFalse = [
356 "pixelFlags_edge",
357 ]
358 self.selectWhenTrue = ["sky_source"]
361class GoodDiaSourceSelector(FlagSelector):
362 """Selects good DIA sources from diaSourceTables."""
364 def getInputSchema(self) -> KeyedDataSchema:
365 yield from super().getInputSchema()
367 def __call__(self, data: KeyedData, **kwargs) -> Vector:
368 result: Optional[Vector] = None
369 temp = super().__call__(data, **(kwargs))
370 if result is not None:
371 result &= temp # type: ignore
372 else:
373 result = temp
374 return result
376 def setDefaults(self):
377 self.selectWhenFalse = [
378 "base_PixelFlags_flag_bad",
379 "base_PixelFlags_flag_suspect",
380 "base_PixelFlags_flag_saturatedCenter",
381 "base_PixelFlags_flag_interpolated",
382 "base_PixelFlags_flag_interpolatedCenter",
383 "base_PixelFlags_flag_edge",
384 ]
387class ExtendednessSelector(VectorAction):
388 """A selector that picks between extended and point sources."""
390 vectorKey = Field[str](
391 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness"
392 )
394 def getInputSchema(self) -> KeyedDataSchema:
395 return ((self.vectorKey, Vector),)
397 def __call__(self, data: KeyedData, **kwargs) -> Vector:
398 key = self.vectorKey.format(**kwargs)
399 return cast(Vector, data[key])
402class StarSelector(ExtendednessSelector):
403 """A selector that picks out stars based off of their
404 extendedness values.
405 """
407 extendedness_maximum = Field[float](
408 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float
409 )
411 def __call__(self, data: KeyedData, **kwargs) -> Vector:
412 extendedness = super().__call__(data, **kwargs)
413 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum)))
416class GalaxySelector(ExtendednessSelector):
417 """A selector that picks out galaxies based off of their
418 extendedness values.
419 """
421 extendedness_minimum = Field[float](
422 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5
423 )
425 def __call__(self, data: KeyedData, **kwargs) -> Vector:
426 extendedness = super().__call__(data, **kwargs)
427 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore
430class UnknownSelector(ExtendednessSelector):
431 """A selector that picks out unclassified objects based off of their
432 extendedness values.
433 """
435 def __call__(self, data: KeyedData, **kwargs) -> Vector:
436 extendedness = super().__call__(data, **kwargs)
437 return extendedness == 9
440class VectorSelector(VectorAction):
441 """Load a boolean vector from KeyedData and return it for use as a
442 selector.
443 """
445 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask")
447 def getInputSchema(self) -> KeyedDataSchema:
448 return ((self.vectorKey, Vector),)
450 def __call__(self, data: KeyedData, **kwargs) -> Vector:
451 return cast(Vector, data[self.vectorKey.format(**kwargs)])
454class ThresholdSelector(VectorAction):
455 """Return a mask corresponding to an applied threshold."""
457 op = Field[str](doc="Operator name.")
458 threshold = Field[float](doc="Threshold to apply.")
459 vectorKey = Field[str](doc="Name of column")
461 def getInputSchema(self) -> KeyedDataSchema:
462 return ((self.vectorKey, Vector),)
464 def __call__(self, data: KeyedData, **kwargs) -> Vector:
465 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold)
466 return cast(Vector, mask)
469class BandSelector(VectorAction):
470 """Makes a mask for sources observed in a specified set of bands."""
472 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band")
473 bands = ListField[str](
474 doc="The bands to select. `None` indicates no band selection applied.",
475 default=[],
476 )
478 def getInputSchema(self) -> KeyedDataSchema:
479 return ((self.vectorKey, Vector),)
481 def __call__(self, data: KeyedData, **kwargs) -> Vector:
482 bands: Optional[tuple[str, ...]]
483 match kwargs:
484 case {"band": band} if not self.bands and self.bands == []:
485 bands = (band,)
486 case {"bands": bands} if not self.bands and self.bands == []:
487 bands = bands
488 case _ if self.bands:
489 bands = tuple(self.bands)
490 case _:
491 bands = None
492 if bands:
493 mask = np.in1d(data[self.vectorKey], bands)
494 else:
495 # No band selection is applied, i.e., select all rows
496 mask = np.full(len(data[self.vectorKey]), True) # type: ignore
497 return cast(Vector, mask)
500class ParentObjectSelector(FlagSelector):
501 """Select only parent objects that are not sky objects."""
503 def setDefaults(self):
504 # This selects all of the parents
505 self.selectWhenFalse = [
506 "detect_isDeblendedModelSource",
507 "sky_object",
508 ]
509 self.selectWhenTrue = ["detect_isPatchInner"]