Coverage for python/lsst/analysis/tools/actions/vector/selectors.py: 31%
233 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-20 13:17 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-20 13:17 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = (
24 "FlagSelector",
25 "CoaddPlotFlagSelector",
26 "RangeSelector",
27 "SnSelector",
28 "ExtendednessSelector",
29 "SkyObjectSelector",
30 "SkySourceSelector",
31 "GoodDiaSourceSelector",
32 "StarSelector",
33 "GalaxySelector",
34 "UnknownSelector",
35 "VectorSelector",
36 "VisitPlotFlagSelector",
37 "ThresholdSelector",
38 "BandSelector",
39 "MatchingFlagSelector",
40)
42import operator
43from typing import Optional, cast
45import numpy as np
46from lsst.pex.config import Field
47from lsst.pex.config.listField import ListField
49from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction
50from ...math import divide
53class SelectorBase(VectorAction):
54 plotLabelKey = Field[str](
55 doc="Key to use when populating plot info, ignored if empty string", optional=True, default=""
56 )
58 def _addValueToPlotInfo(self, value, **kwargs):
59 if "plotInfo" in kwargs and self.plotLabelKey:
60 kwargs["plotInfo"][self.plotLabelKey] = value
63class FlagSelector(VectorAction):
64 """The base flag selector to use to select valid sources for QA."""
66 selectWhenFalse = ListField[str](
67 doc="Names of the flag columns to select on when False", optional=False, default=[]
68 )
70 selectWhenTrue = ListField[str](
71 doc="Names of the flag columns to select on when True", optional=False, default=[]
72 )
74 def getInputSchema(self) -> KeyedDataSchema:
75 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue)
76 return ((col, Vector) for col in allCols)
78 def __call__(self, data: KeyedData, **kwargs) -> Vector:
79 """Select on the given flags
81 Parameters
82 ----------
83 data : `KeyedData`
85 Returns
86 -------
87 result : `Vector`
88 A mask of the objects that satisfy the given
89 flag cuts.
91 Notes
92 -----
93 Uses the columns in selectWhenFalse and
94 selectWhenTrue to decide which columns to
95 select on in each circumstance.
96 """
98 if not self.selectWhenFalse and not self.selectWhenTrue:
99 raise RuntimeError("No column keys specified")
100 results: Optional[Vector] = None
102 for flag in self.selectWhenFalse: # type: ignore
103 temp = np.array(data[flag.format(**kwargs)] == 0)
104 if results is not None:
105 results &= temp # type: ignore
106 else:
107 results = temp
109 for flag in self.selectWhenTrue:
110 temp = np.array(data[flag.format(**kwargs)] == 1)
111 if results is not None:
112 results &= temp # type: ignore
113 else:
114 results = temp
115 # The test at the beginning assures this can never be None
116 return cast(Vector, results)
119class CoaddPlotFlagSelector(FlagSelector):
120 """This default setting makes it take the band from
121 the kwargs.
122 """
124 bands = ListField[str](
125 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
126 default=[],
127 )
129 def getInputSchema(self) -> KeyedDataSchema:
130 yield from super().getInputSchema()
132 def refMatchContext(self):
133 self.selectWhenFalse = [
134 "{band}_psfFlux_flag_target",
135 "{band}_pixelFlags_saturatedCenter_target",
136 "{band}_extendedness_flag_target",
137 "xy_flag_target",
138 ]
139 self.selectWhenTrue = ["detect_isPatchInner_target", "detect_isDeblendedSource_target"]
141 def __call__(self, data: KeyedData, **kwargs) -> Vector:
142 result: Optional[Vector] = None
143 bands: tuple[str, ...]
144 match kwargs:
145 case {"band": band} if not self.bands and self.bands == []:
146 bands = (band,)
147 case {"bands": bands} if not self.bands and self.bands == []:
148 bands = bands
149 case _ if self.bands:
150 bands = tuple(self.bands)
151 case _:
152 bands = ("",)
153 for band in bands:
154 temp = super().__call__(data, **(kwargs | dict(band=band)))
155 if result is not None:
156 result &= temp # type: ignore
157 else:
158 result = temp
159 return cast(Vector, result)
161 def setDefaults(self):
162 self.selectWhenFalse = [
163 "{band}_psfFlux_flag",
164 "{band}_pixelFlags_saturatedCenter",
165 "{band}_extendedness_flag",
166 "xy_flag",
167 ]
168 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"]
171class MatchingFlagSelector(CoaddPlotFlagSelector):
172 """The default flag selector to apply pre matching.
173 The sources are cut down to remove duplicates but
174 not on quality.
175 """
177 def setDefaults(self):
178 self.selectWhenFalse = []
179 self.selectWhenTrue = ["detect_isPrimary"]
182class VisitPlotFlagSelector(FlagSelector):
183 """Select on a set of flags appropriate for making visit-level plots
184 (i.e., using sourceTable_visit catalogs).
185 """
187 catalogSuffix = Field[str](doc="The suffix to apply to all the keys.", default="")
189 def getInputSchema(self) -> KeyedDataSchema:
190 yield from super().getInputSchema()
192 def refMatchContext(self):
193 self.selectWhenFalse = [
194 "psfFlux_flag_target",
195 "pixelFlags_saturatedCenter_target",
196 "extendedness_flag_target",
197 "centroid_flag_target",
198 ]
200 def __call__(self, data: KeyedData, **kwargs) -> Vector:
201 result: Optional[Vector] = None
202 temp = super().__call__(data, **kwargs)
203 if result is not None:
204 result &= temp # type: ignore
205 else:
206 result = temp
208 return result
210 def setDefaults(self):
211 self.selectWhenFalse = [
212 "psfFlux_flag",
213 "pixelFlags_saturatedCenter",
214 "extendedness_flag",
215 "centroid_flag",
216 ]
219class RangeSelector(VectorAction):
220 """Selects rows within a range, inclusive of min/exclusive of max."""
222 vectorKey = Field[str](doc="Key to select from data")
223 maximum = Field[float](doc="The maximum value", default=np.Inf)
224 minimum = Field[float](doc="The minimum value", default=np.nextafter(-np.Inf, 0.0))
226 def getInputSchema(self) -> KeyedDataSchema:
227 yield self.vectorKey, Vector
229 def __call__(self, data: KeyedData, **kwargs) -> Vector:
230 """Return a mask of rows with values within the specified range.
232 Parameters
233 ----------
234 data : `KeyedData`
236 Returns
237 -------
238 result : `Vector`
239 A mask of the rows with values within the specified range.
240 """
241 values = cast(Vector, data[self.vectorKey])
242 mask = (values >= self.minimum) & (values < self.maximum)
244 return cast(Vector, mask)
247class SnSelector(SelectorBase):
248 """Selects points that have S/N > threshold in the given flux type."""
250 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux")
251 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0)
252 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6)
253 uncertaintySuffix = Field[str](
254 doc="Suffix to add to fluxType to specify uncertainty column", default="Err"
255 )
256 bands = ListField[str](
257 doc="The bands to apply the signal to noise cut in." "Takes precedence if bands passed to call",
258 default=[],
259 )
261 def getInputSchema(self) -> KeyedDataSchema:
262 fluxCol = self.fluxType
263 fluxInd = fluxCol.find("lux") + len("lux")
264 errCol = f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix}" + f"{fluxCol}"[fluxInd:]
265 yield fluxCol, Vector
266 yield errCol, Vector
268 def __call__(self, data: KeyedData, **kwargs) -> Vector:
269 """Makes a mask of objects that have S/N greater than
270 self.threshold in self.fluxType
272 Parameters
273 ----------
274 data : `KeyedData`
275 The data to perform the selection on.
277 Returns
278 -------
279 result : `Vector`
280 A mask of the objects that satisfy the given
281 S/N cut.
282 """
284 self._addValueToPlotInfo(self.threshold, **kwargs)
285 mask: Optional[Vector] = None
286 bands: tuple[str, ...]
287 match kwargs:
288 case {"band": band} if not self.bands and self.bands == []:
289 bands = (band,)
290 case {"bands": bands} if not self.bands and self.bands == []:
291 bands = bands
292 case _ if self.bands:
293 bands = tuple(self.bands)
294 case _:
295 bands = ("",)
296 for band in bands:
297 fluxCol = self.fluxType.format(**(kwargs | dict(band=band)))
298 fluxInd = fluxCol.find("lux") + len("lux")
299 errCol = (
300 f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix.format(**kwargs)}" + f"{fluxCol}"[fluxInd:]
301 )
302 vec = divide(cast(Vector, data[fluxCol]), cast(Vector, data[errCol]))
303 temp = (vec > self.threshold) & (vec < self.maxSN)
304 if mask is not None:
305 mask &= temp # type: ignore
306 else:
307 mask = temp
309 # It should not be possible for mask to be a None now
310 return np.array(cast(Vector, mask))
313class SkyObjectSelector(FlagSelector):
314 """Selects sky objects in the given band(s)."""
316 bands = ListField[str](
317 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
318 default=["i"],
319 )
321 def getInputSchema(self) -> KeyedDataSchema:
322 yield from super().getInputSchema()
324 def __call__(self, data: KeyedData, **kwargs) -> Vector:
325 result: Optional[Vector] = None
326 bands: tuple[str, ...]
327 match kwargs:
328 case {"band": band} if not self.bands and self.bands == []:
329 bands = (band,)
330 case {"bands": bands} if not self.bands and self.bands == []:
331 bands = bands
332 case _ if self.bands:
333 bands = tuple(self.bands)
334 case _:
335 bands = ("",)
336 for band in bands:
337 temp = super().__call__(data, **(kwargs | dict(band=band)))
338 if result is not None:
339 result &= temp # type: ignore
340 else:
341 result = temp
342 return cast(Vector, result)
344 def setDefaults(self):
345 self.selectWhenFalse = [
346 "{band}_pixelFlags_edge",
347 ]
348 self.selectWhenTrue = ["sky_object"]
351class SkySourceSelector(FlagSelector):
352 """Selects sky sources from sourceTables."""
354 def getInputSchema(self) -> KeyedDataSchema:
355 yield from super().getInputSchema()
357 def __call__(self, data: KeyedData, **kwargs) -> Vector:
358 result: Optional[Vector] = None
359 temp = super().__call__(data, **(kwargs))
360 if result is not None:
361 result &= temp # type: ignore
362 else:
363 result = temp
364 return result
366 def setDefaults(self):
367 self.selectWhenFalse = [
368 "pixelFlags_edge",
369 ]
370 self.selectWhenTrue = ["sky_source"]
373class GoodDiaSourceSelector(FlagSelector):
374 """Selects good DIA sources from diaSourceTables."""
376 def getInputSchema(self) -> KeyedDataSchema:
377 yield from super().getInputSchema()
379 def __call__(self, data: KeyedData, **kwargs) -> Vector:
380 result: Optional[Vector] = None
381 temp = super().__call__(data, **(kwargs))
382 if result is not None:
383 result &= temp # type: ignore
384 else:
385 result = temp
386 return result
388 def setDefaults(self):
389 self.selectWhenFalse = [
390 "base_PixelFlags_flag_bad",
391 "base_PixelFlags_flag_suspect",
392 "base_PixelFlags_flag_saturatedCenter",
393 "base_PixelFlags_flag_interpolated",
394 "base_PixelFlags_flag_interpolatedCenter",
395 "base_PixelFlags_flag_edge",
396 ]
399class ExtendednessSelector(VectorAction):
400 """A selector that picks between extended and point sources."""
402 vectorKey = Field[str](
403 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness"
404 )
406 def getInputSchema(self) -> KeyedDataSchema:
407 return ((self.vectorKey, Vector),)
409 def __call__(self, data: KeyedData, **kwargs) -> Vector:
410 key = self.vectorKey.format(**kwargs)
411 return cast(Vector, data[key])
414class StarSelector(ExtendednessSelector):
415 """A selector that picks out stars based off of their
416 extendedness values.
417 """
419 extendedness_maximum = Field[float](
420 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float
421 )
423 def __call__(self, data: KeyedData, **kwargs) -> Vector:
424 extendedness = super().__call__(data, **kwargs)
425 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum)))
428class GalaxySelector(ExtendednessSelector):
429 """A selector that picks out galaxies based off of their
430 extendedness values.
431 """
433 extendedness_minimum = Field[float](
434 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5
435 )
437 def __call__(self, data: KeyedData, **kwargs) -> Vector:
438 extendedness = super().__call__(data, **kwargs)
439 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore
442class UnknownSelector(ExtendednessSelector):
443 """A selector that picks out unclassified objects based off of their
444 extendedness values.
445 """
447 def __call__(self, data: KeyedData, **kwargs) -> Vector:
448 extendedness = super().__call__(data, **kwargs)
449 return extendedness == 9
452class VectorSelector(VectorAction):
453 """Load a boolean vector from KeyedData and return it for use as a
454 selector.
455 """
457 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask")
459 def getInputSchema(self) -> KeyedDataSchema:
460 return ((self.vectorKey, Vector),)
462 def __call__(self, data: KeyedData, **kwargs) -> Vector:
463 return cast(Vector, data[self.vectorKey.format(**kwargs)])
466class ThresholdSelector(VectorAction):
467 """Return a mask corresponding to an applied threshold."""
469 op = Field[str](doc="Operator name.")
470 threshold = Field[float](doc="Threshold to apply.")
471 vectorKey = Field[str](doc="Name of column")
473 def getInputSchema(self) -> KeyedDataSchema:
474 return ((self.vectorKey, Vector),)
476 def __call__(self, data: KeyedData, **kwargs) -> Vector:
477 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold)
478 return cast(Vector, mask)
481class BandSelector(VectorAction):
482 """Makes a mask for sources observed in a specified set of bands."""
484 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band")
485 bands = ListField[str](
486 doc="The bands to select. `None` indicates no band selection applied.",
487 default=[],
488 )
490 def getInputSchema(self) -> KeyedDataSchema:
491 return ((self.vectorKey, Vector),)
493 def __call__(self, data: KeyedData, **kwargs) -> Vector:
494 bands: Optional[tuple[str, ...]]
495 match kwargs:
496 case {"band": band} if not self.bands and self.bands == []:
497 bands = (band,)
498 case {"bands": bands} if not self.bands and self.bands == []:
499 bands = bands
500 case _ if self.bands:
501 bands = tuple(self.bands)
502 case _:
503 bands = None
504 if bands:
505 mask = np.in1d(data[self.vectorKey], bands)
506 else:
507 # No band selection is applied, i.e., select all rows
508 mask = np.full(len(data[self.vectorKey]), True) # type: ignore
509 return cast(Vector, mask)
512class ParentObjectSelector(FlagSelector):
513 """Select only parent objects that are not sky objects."""
515 def setDefaults(self):
516 # This selects all of the parents
517 self.selectWhenFalse = [
518 "detect_isDeblendedModelSource",
519 "sky_object",
520 ]
521 self.selectWhenTrue = ["detect_isPatchInner"]