Coverage for python/lsst/analysis/tools/actions/vector/selectors.py: 27%
283 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-06 12:42 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-06 12:42 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = (
24 "FlagSelector",
25 "CoaddPlotFlagSelector",
26 "RangeSelector",
27 "SnSelector",
28 "ExtendednessSelector",
29 "SkyObjectSelector",
30 "SkySourceSelector",
31 "GoodDiaSourceSelector",
32 "StarSelector",
33 "GalaxySelector",
34 "UnknownSelector",
35 "VectorSelector",
36 "VisitPlotFlagSelector",
37 "ThresholdSelector",
38 "BandSelector",
39 "MatchingFlagSelector",
40 "MagSelector",
41)
43import operator
44from typing import Optional, cast
46import numpy as np
47from lsst.pex.config import Field
48from lsst.pex.config.listField import ListField
50from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction
51from ...math import divide, fluxToMag
54class SelectorBase(VectorAction):
55 plotLabelKey = Field[str](
56 doc="Key to use when populating plot info, ignored if empty string", optional=True, default=""
57 )
59 def _addValueToPlotInfo(self, value, plotLabelKey=None, **kwargs):
60 if "plotInfo" in kwargs:
61 if plotLabelKey is not None:
62 kwargs["plotInfo"][plotLabelKey] = value
63 elif self.plotLabelKey:
64 kwargs["plotInfo"][self.plotLabelKey] = value
65 else:
66 raise RuntimeError(f"No plotLabelKey provided for value {value}, so can't add to plotInfo")
69class FlagSelector(VectorAction):
70 """The base flag selector to use to select valid sources for QA."""
72 selectWhenFalse = ListField[str](
73 doc="Names of the flag columns to select on when False", optional=False, default=[]
74 )
76 selectWhenTrue = ListField[str](
77 doc="Names of the flag columns to select on when True", optional=False, default=[]
78 )
80 def getInputSchema(self) -> KeyedDataSchema:
81 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue)
82 return ((col, Vector) for col in allCols)
84 def __call__(self, data: KeyedData, **kwargs) -> Vector:
85 """Select on the given flags
87 Parameters
88 ----------
89 data : `KeyedData`
91 Returns
92 -------
93 result : `Vector`
94 A mask of the objects that satisfy the given
95 flag cuts.
97 Notes
98 -----
99 Uses the columns in selectWhenFalse and
100 selectWhenTrue to decide which columns to
101 select on in each circumstance.
102 """
104 if not self.selectWhenFalse and not self.selectWhenTrue:
105 raise RuntimeError("No column keys specified")
106 results: Optional[Vector] = None
108 for flag in self.selectWhenFalse: # type: ignore
109 temp = np.array(data[flag.format(**kwargs)] == 0)
110 if results is not None:
111 results &= temp # type: ignore
112 else:
113 results = temp
115 for flag in self.selectWhenTrue:
116 temp = np.array(data[flag.format(**kwargs)] == 1)
117 if results is not None:
118 results &= temp # type: ignore
119 else:
120 results = temp
121 # The test at the beginning assures this can never be None
122 return cast(Vector, results)
125class CoaddPlotFlagSelector(FlagSelector):
126 """This default setting makes it take the band from
127 the kwargs.
128 """
130 bands = ListField[str](
131 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
132 default=[],
133 )
135 def getInputSchema(self) -> KeyedDataSchema:
136 yield from super().getInputSchema()
138 def refMatchContext(self):
139 self.selectWhenFalse = [
140 "{band}_psfFlux_flag_target",
141 "{band}_pixelFlags_saturatedCenter_target",
142 "{band}_extendedness_flag_target",
143 "xy_flag_target",
144 ]
145 self.selectWhenTrue = ["detect_isPatchInner_target", "detect_isDeblendedSource_target"]
147 def __call__(self, data: KeyedData, **kwargs) -> Vector:
148 result: Optional[Vector] = None
149 bands: tuple[str, ...]
150 match kwargs:
151 case {"band": band} if not self.bands and self.bands == []:
152 bands = (band,)
153 case {"bands": bands} if not self.bands and self.bands == []:
154 bands = bands
155 case _ if self.bands:
156 bands = tuple(self.bands)
157 case _:
158 bands = ("",)
159 for band in bands:
160 temp = super().__call__(data, **(kwargs | dict(band=band)))
161 if result is not None:
162 result &= temp # type: ignore
163 else:
164 result = temp
165 return cast(Vector, result)
167 def setDefaults(self):
168 self.selectWhenFalse = [
169 "{band}_psfFlux_flag",
170 "{band}_pixelFlags_saturatedCenter",
171 "{band}_extendedness_flag",
172 "xy_flag",
173 "sky_object",
174 ]
175 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"]
178class MatchingFlagSelector(CoaddPlotFlagSelector):
179 """The default flag selector to apply pre matching.
180 The sources are cut down to remove duplicates but
181 not on quality.
182 """
184 def setDefaults(self):
185 self.selectWhenFalse = []
186 self.selectWhenTrue = ["detect_isPrimary"]
189class VisitPlotFlagSelector(FlagSelector):
190 """Select on a set of flags appropriate for making visit-level plots
191 (i.e., using sourceTable_visit catalogs).
192 """
194 catalogSuffix = Field[str](doc="The suffix to apply to all the keys.", default="")
196 def getInputSchema(self) -> KeyedDataSchema:
197 yield from super().getInputSchema()
199 def refMatchContext(self):
200 self.selectWhenFalse = [
201 "psfFlux_flag_target",
202 "pixelFlags_saturatedCenter_target",
203 "extendedness_flag_target",
204 "centroid_flag_target",
205 ]
207 def __call__(self, data: KeyedData, **kwargs) -> Vector:
208 result: Optional[Vector] = None
209 temp = super().__call__(data, **kwargs)
210 if result is not None:
211 result &= temp # type: ignore
212 else:
213 result = temp
215 return result
217 def setDefaults(self):
218 self.selectWhenFalse = [
219 "psfFlux_flag",
220 "pixelFlags_saturatedCenter",
221 "extendedness_flag",
222 "centroid_flag",
223 "sky_source",
224 ]
227class RangeSelector(VectorAction):
228 """Selects rows within a range, inclusive of min/exclusive of max."""
230 vectorKey = Field[str](doc="Key to select from data")
231 maximum = Field[float](doc="The maximum value", default=np.Inf)
232 minimum = Field[float](doc="The minimum value", default=np.nextafter(-np.Inf, 0.0))
234 def getInputSchema(self) -> KeyedDataSchema:
235 yield self.vectorKey, Vector
237 def __call__(self, data: KeyedData, **kwargs) -> Vector:
238 """Return a mask of rows with values within the specified range.
240 Parameters
241 ----------
242 data : `KeyedData`
244 Returns
245 -------
246 result : `Vector`
247 A mask of the rows with values within the specified range.
248 """
249 values = cast(Vector, data[self.vectorKey])
250 mask = (values >= self.minimum) & (values < self.maximum)
252 return cast(Vector, mask)
255class SnSelector(SelectorBase):
256 """Selects points that have S/N > threshold in the given flux type."""
258 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux")
259 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0)
260 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6)
261 uncertaintySuffix = Field[str](
262 doc="Suffix to add to fluxType to specify uncertainty column", default="Err"
263 )
264 bands = ListField[str](
265 doc="The bands to apply the signal to noise cut in." "Takes precedence if bands passed to call",
266 default=[],
267 )
269 def getInputSchema(self) -> KeyedDataSchema:
270 fluxCol = self.fluxType
271 fluxInd = fluxCol.find("lux") + len("lux")
272 errCol = f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix}" + f"{fluxCol}"[fluxInd:]
273 yield fluxCol, Vector
274 yield errCol, Vector
276 def __call__(self, data: KeyedData, **kwargs) -> Vector:
277 """Makes a mask of objects that have S/N greater than
278 self.threshold in self.fluxType
280 Parameters
281 ----------
282 data : `KeyedData`
283 The data to perform the selection on.
285 Returns
286 -------
287 result : `Vector`
288 A mask of the objects that satisfy the given
289 S/N cut.
290 """
291 mask: Optional[Vector] = None
292 bands: tuple[str, ...]
293 match kwargs:
294 case {"band": band} if not self.bands and self.bands == []:
295 bands = (band,)
296 case {"bands": bands} if not self.bands and self.bands == []:
297 bands = bands
298 case _ if self.bands:
299 bands = tuple(self.bands)
300 case _:
301 bands = ("",)
302 bandStr = ",".join(bands)
303 for band in bands:
304 fluxCol = self.fluxType.format(**(kwargs | dict(band=band)))
305 fluxInd = fluxCol.find("lux") + len("lux")
306 errCol = (
307 f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix.format(**kwargs)}" + f"{fluxCol}"[fluxInd:]
308 )
309 vec = divide(cast(Vector, data[fluxCol]), cast(Vector, data[errCol]))
310 temp = (vec > self.threshold) & (vec < self.maxSN)
311 if mask is not None:
312 mask &= temp # type: ignore
313 else:
314 mask = temp
316 plotLabelStr = "({}) > {:.1f}".format(bandStr, self.threshold)
317 if self.maxSN < 1e5:
318 plotLabelStr += " & < {:.1f}".format(self.maxSN)
320 if self.plotLabelKey == "" or self.plotLabelKey is None:
321 self._addValueToPlotInfo(plotLabelStr, plotLabelKey="S/N", **kwargs)
322 else:
323 self._addValueToPlotInfo(plotLabelStr, **kwargs)
325 # It should not be possible for mask to be a None now
326 return np.array(cast(Vector, mask))
329class SkyObjectSelector(FlagSelector):
330 """Selects sky objects in the given band(s)."""
332 bands = ListField[str](
333 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
334 default=["i"],
335 )
337 def getInputSchema(self) -> KeyedDataSchema:
338 yield from super().getInputSchema()
340 def __call__(self, data: KeyedData, **kwargs) -> Vector:
341 result: Optional[Vector] = None
342 bands: tuple[str, ...]
343 match kwargs:
344 case {"band": band} if not self.bands and self.bands == []:
345 bands = (band,)
346 case {"bands": bands} if not self.bands and self.bands == []:
347 bands = bands
348 case _ if self.bands:
349 bands = tuple(self.bands)
350 case _:
351 bands = ("",)
352 for band in bands:
353 temp = super().__call__(data, **(kwargs | dict(band=band)))
354 if result is not None:
355 result &= temp # type: ignore
356 else:
357 result = temp
358 return cast(Vector, result)
360 def setDefaults(self):
361 self.selectWhenFalse = [
362 "{band}_pixelFlags_edge",
363 ]
364 self.selectWhenTrue = ["sky_object"]
367class SkySourceSelector(FlagSelector):
368 """Selects sky sources from sourceTables."""
370 def getInputSchema(self) -> KeyedDataSchema:
371 yield from super().getInputSchema()
373 def __call__(self, data: KeyedData, **kwargs) -> Vector:
374 result: Optional[Vector] = None
375 temp = super().__call__(data, **(kwargs))
376 if result is not None:
377 result &= temp # type: ignore
378 else:
379 result = temp
380 return result
382 def setDefaults(self):
383 self.selectWhenFalse = [
384 "pixelFlags_edge",
385 ]
386 self.selectWhenTrue = ["sky_source"]
389class GoodDiaSourceSelector(FlagSelector):
390 """Selects good DIA sources from diaSourceTables."""
392 def getInputSchema(self) -> KeyedDataSchema:
393 yield from super().getInputSchema()
395 def __call__(self, data: KeyedData, **kwargs) -> Vector:
396 result: Optional[Vector] = None
397 temp = super().__call__(data, **(kwargs))
398 if result is not None:
399 result &= temp # type: ignore
400 else:
401 result = temp
402 return result
404 def setDefaults(self):
405 self.selectWhenFalse = [
406 "base_PixelFlags_flag_bad",
407 "base_PixelFlags_flag_suspect",
408 "base_PixelFlags_flag_saturatedCenter",
409 "base_PixelFlags_flag_interpolated",
410 "base_PixelFlags_flag_interpolatedCenter",
411 "base_PixelFlags_flag_edge",
412 ]
415class ExtendednessSelector(VectorAction):
416 """A selector that picks between extended and point sources."""
418 vectorKey = Field[str](
419 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness"
420 )
422 def getInputSchema(self) -> KeyedDataSchema:
423 return ((self.vectorKey, Vector),)
425 def __call__(self, data: KeyedData, **kwargs) -> Vector:
426 key = self.vectorKey.format(**kwargs)
427 return cast(Vector, data[key])
430class StarSelector(ExtendednessSelector):
431 """A selector that picks out stars based off of their
432 extendedness values.
433 """
435 extendedness_maximum = Field[float](
436 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float
437 )
439 def __call__(self, data: KeyedData, **kwargs) -> Vector:
440 extendedness = super().__call__(data, **kwargs)
441 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum)))
444class GalaxySelector(ExtendednessSelector):
445 """A selector that picks out galaxies based off of their
446 extendedness values.
447 """
449 extendedness_minimum = Field[float](
450 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5
451 )
453 def __call__(self, data: KeyedData, **kwargs) -> Vector:
454 extendedness = super().__call__(data, **kwargs)
455 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore
458class UnknownSelector(ExtendednessSelector):
459 """A selector that picks out unclassified objects based off of their
460 extendedness values.
461 """
463 def __call__(self, data: KeyedData, **kwargs) -> Vector:
464 extendedness = super().__call__(data, **kwargs)
465 return extendedness == 9
468class VectorSelector(VectorAction):
469 """Load a boolean vector from KeyedData and return it for use as a
470 selector.
471 """
473 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask")
475 def getInputSchema(self) -> KeyedDataSchema:
476 return ((self.vectorKey, Vector),)
478 def __call__(self, data: KeyedData, **kwargs) -> Vector:
479 return cast(Vector, data[self.vectorKey.format(**kwargs)])
482class ThresholdSelector(VectorAction):
483 """Return a mask corresponding to an applied threshold."""
485 op = Field[str](doc="Operator name.")
486 threshold = Field[float](doc="Threshold to apply.")
487 vectorKey = Field[str](doc="Name of column")
489 def getInputSchema(self) -> KeyedDataSchema:
490 return ((self.vectorKey, Vector),)
492 def __call__(self, data: KeyedData, **kwargs) -> Vector:
493 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold)
494 return cast(Vector, mask)
497class BandSelector(VectorAction):
498 """Makes a mask for sources observed in a specified set of bands."""
500 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band")
501 bands = ListField[str](
502 doc="The bands to select. `None` indicates no band selection applied.",
503 default=[],
504 )
506 def getInputSchema(self) -> KeyedDataSchema:
507 return ((self.vectorKey, Vector),)
509 def __call__(self, data: KeyedData, **kwargs) -> Vector:
510 bands: Optional[tuple[str, ...]]
511 match kwargs:
512 case {"band": band} if not self.bands and self.bands == []:
513 bands = (band,)
514 case {"bands": bands} if not self.bands and self.bands == []:
515 bands = bands
516 case _ if self.bands:
517 bands = tuple(self.bands)
518 case _:
519 bands = None
520 if bands:
521 mask = np.in1d(data[self.vectorKey], bands)
522 else:
523 # No band selection is applied, i.e., select all rows
524 mask = np.full(len(data[self.vectorKey]), True) # type: ignore
525 return cast(Vector, mask)
528class ParentObjectSelector(FlagSelector):
529 """Select only parent objects that are not sky objects."""
531 def setDefaults(self):
532 # This selects all of the parents
533 self.selectWhenFalse = [
534 "detect_isDeblendedModelSource",
535 "sky_object",
536 ]
537 self.selectWhenTrue = ["detect_isPatchInner"]
540class MagSelector(SelectorBase):
541 """Selects points that have minMag < mag (AB) < maxMag.
543 The magnitude is based on the given fluxType.
544 """
546 fluxType = Field[str](doc="Flux type to calculate the magnitude in.", default="{band}_psfFlux")
547 minMag = Field[float](doc="Minimum mag to include in the sample.", default=-1e6)
548 maxMag = Field[float](doc="Maximum mag to include in the sample.", default=1e6)
549 fluxUnit = Field[str](doc="Astropy unit of flux vector", default="nJy")
550 returnMillimags = Field[bool](doc="Use millimags or not?", default=False)
551 bands = ListField[str](
552 doc="The band(s) to apply the magnitude cut in. Takes precedence if bands passed to call.",
553 default=[],
554 )
556 def getInputSchema(self) -> KeyedDataSchema:
557 fluxCol = self.fluxType
558 yield fluxCol, Vector
560 def __call__(self, data: KeyedData, **kwargs) -> Vector:
561 """Make a mask of that satisfies self.minMag < mag < self.maxMag.
563 The magnitude is based on the flux in self.fluxType.
565 Parameters
566 ----------
567 data : `KeyedData`
568 The data to perform the magnitude selection on.
570 Returns
571 -------
572 result : `Vector`
573 A mask of the objects that satisfy the given magnitude cut.
574 """
575 mask: Optional[Vector] = None
576 bands: tuple[str, ...]
577 match kwargs:
578 case {"band": band} if not self.bands and self.bands == []:
579 bands = (band,)
580 case {"bands": bands} if not self.bands and self.bands == []:
581 bands = bands
582 case _ if self.bands:
583 bands = tuple(self.bands)
584 case _:
585 bands = ("",)
586 bandStr = ",".join(bands)
587 for band in bands:
588 fluxCol = self.fluxType.format(**(kwargs | dict(band=band)))
589 vec = fluxToMag(
590 cast(Vector, data[fluxCol]),
591 flux_unit=self.fluxUnit,
592 return_millimags=self.returnMillimags,
593 )
594 temp = (vec > self.minMag) & (vec < self.maxMag)
595 if mask is not None:
596 mask &= temp # type: ignore
597 else:
598 mask = temp
600 plotLabelStr = ""
601 if self.maxMag < 100:
602 plotLabelStr += "({}) < {:.1f}".format(bandStr, self.maxMag)
603 if self.minMag > -100:
604 if bandStr in plotLabelStr:
605 plotLabelStr += " & < {:.1f}".format(self.minMag)
606 else:
607 plotLabelStr += "({}) < {:.1f}".format(bandStr, self.minMag)
608 if self.plotLabelKey == "" or self.plotLabelKey is None:
609 self._addValueToPlotInfo(plotLabelStr, plotLabelKey="Mag", **kwargs)
610 else:
611 self._addValueToPlotInfo(plotLabelStr, **kwargs)
613 # It should not be possible for mask to be a None now
614 return np.array(cast(Vector, mask))