Coverage for python/lsst/analysis/tools/actions/vector/selectors.py: 28%
294 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-23 04:51 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-23 04:51 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = (
24 "FlagSelector",
25 "CoaddPlotFlagSelector",
26 "RangeSelector",
27 "SnSelector",
28 "ExtendednessSelector",
29 "SkyObjectSelector",
30 "SkySourceSelector",
31 "GoodDiaSourceSelector",
32 "StarSelector",
33 "GalaxySelector",
34 "UnknownSelector",
35 "VectorSelector",
36 "VisitPlotFlagSelector",
37 "ThresholdSelector",
38 "BandSelector",
39 "MatchingFlagSelector",
40 "MagSelector",
41)
43import operator
44from typing import Optional, cast
46import numpy as np
47from lsst.pex.config import Field
48from lsst.pex.config.listField import ListField
50from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction
51from ...math import divide, fluxToMag
54class SelectorBase(VectorAction):
55 plotLabelKey = Field[str](
56 doc="Key to use when populating plot info, ignored if empty string", optional=True, default=""
57 )
59 def _addValueToPlotInfo(self, value, plotLabelKey=None, **kwargs):
60 if "plotInfo" in kwargs:
61 if plotLabelKey is not None:
62 kwargs["plotInfo"][plotLabelKey] = value
63 elif self.plotLabelKey:
64 kwargs["plotInfo"][self.plotLabelKey] = value
65 else:
66 raise RuntimeError(f"No plotLabelKey provided for value {value}, so can't add to plotInfo")
69class FlagSelector(VectorAction):
70 """The base flag selector to use to select valid sources for QA."""
72 selectWhenFalse = ListField[str](
73 doc="Names of the flag columns to select on when False", optional=False, default=[]
74 )
76 selectWhenTrue = ListField[str](
77 doc="Names of the flag columns to select on when True", optional=False, default=[]
78 )
80 def getInputSchema(self) -> KeyedDataSchema:
81 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue)
82 return ((col, Vector) for col in allCols)
84 def __call__(self, data: KeyedData, **kwargs) -> Vector:
85 """Select on the given flags
87 Parameters
88 ----------
89 data : `KeyedData`
91 Returns
92 -------
93 result : `Vector`
94 A mask of the objects that satisfy the given
95 flag cuts.
97 Notes
98 -----
99 Uses the columns in selectWhenFalse and
100 selectWhenTrue to decide which columns to
101 select on in each circumstance.
102 """
104 if not self.selectWhenFalse and not self.selectWhenTrue:
105 raise RuntimeError("No column keys specified")
106 results: Optional[Vector] = None
108 for flag in self.selectWhenFalse: # type: ignore
109 temp = np.array(data[flag.format(**kwargs)] == 0)
110 if results is not None:
111 results &= temp # type: ignore
112 else:
113 results = temp
115 for flag in self.selectWhenTrue:
116 temp = np.array(data[flag.format(**kwargs)] == 1)
117 if results is not None:
118 results &= temp # type: ignore
119 else:
120 results = temp
121 # The test at the beginning assures this can never be None
122 return cast(Vector, results)
125class CoaddPlotFlagSelector(FlagSelector):
126 """This default setting makes it take the band from
127 the kwargs.
128 """
130 bands = ListField[str](
131 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
132 default=[],
133 )
135 def getInputSchema(self) -> KeyedDataSchema:
136 yield from super().getInputSchema()
138 def refMatchContext(self):
139 self.selectWhenFalse = [
140 "{band}_psfFlux_flag_target",
141 "{band}_pixelFlags_saturatedCenter_target",
142 "{band}_extendedness_flag_target",
143 "xy_flag_target",
144 ]
145 self.selectWhenTrue = ["detect_isPatchInner_target", "detect_isDeblendedSource_target"]
147 def __call__(self, data: KeyedData, **kwargs) -> Vector:
148 result: Optional[Vector] = None
149 bands: tuple[str, ...]
150 match kwargs:
151 case {"band": band} if not self.bands and self.bands == []:
152 bands = (band,)
153 case {"bands": bands} if not self.bands and self.bands == []:
154 bands = bands
155 case _ if self.bands:
156 bands = tuple(self.bands)
157 case _:
158 bands = ("",)
159 for band in bands:
160 temp = super().__call__(data, **(kwargs | dict(band=band)))
161 if result is not None:
162 result &= temp # type: ignore
163 else:
164 result = temp
165 return cast(Vector, result)
167 def setDefaults(self):
168 self.selectWhenFalse = [
169 "{band}_psfFlux_flag",
170 "{band}_pixelFlags_saturatedCenter",
171 "{band}_extendedness_flag",
172 "xy_flag",
173 "sky_object",
174 ]
175 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"]
178class MatchingFlagSelector(CoaddPlotFlagSelector):
179 """The default flag selector to apply pre matching.
180 The sources are cut down to remove duplicates but
181 not on quality.
182 """
184 def setDefaults(self):
185 self.selectWhenFalse = []
186 self.selectWhenTrue = ["detect_isPrimary"]
189class VisitPlotFlagSelector(FlagSelector):
190 """Select on a set of flags appropriate for making visit-level plots
191 (i.e., using sourceTable_visit catalogs).
192 """
194 catalogSuffix = Field[str](doc="The suffix to apply to all the keys.", default="")
196 def getInputSchema(self) -> KeyedDataSchema:
197 yield from super().getInputSchema()
199 def refMatchContext(self):
200 self.selectWhenFalse = [
201 "psfFlux_flag_target",
202 "pixelFlags_saturatedCenter_target",
203 "extendedness_flag_target",
204 "centroid_flag_target",
205 ]
207 def __call__(self, data: KeyedData, **kwargs) -> Vector:
208 result: Optional[Vector] = None
209 temp = super().__call__(data, **kwargs)
210 if result is not None:
211 result &= temp # type: ignore
212 else:
213 result = temp
215 return result
217 def setDefaults(self):
218 self.selectWhenFalse = [
219 "psfFlux_flag",
220 "pixelFlags_saturatedCenter",
221 "extendedness_flag",
222 "centroid_flag",
223 "sky_source",
224 ]
227class RangeSelector(SelectorBase):
228 """Selects rows within a range, inclusive of min/exclusive of max."""
230 vectorKey = Field[str](doc="Key to select from data")
231 maximum = Field[float](doc="The maximum value", default=np.Inf)
232 minimum = Field[float](doc="The minimum value", default=np.nextafter(-np.Inf, 0.0))
234 def getInputSchema(self) -> KeyedDataSchema:
235 yield self.vectorKey, Vector
237 def __call__(self, data: KeyedData, **kwargs) -> Vector:
238 """Return a mask of rows with values within the specified range.
240 Parameters
241 ----------
242 data : `KeyedData`
244 Returns
245 -------
246 result : `Vector`
247 A mask of the rows with values within the specified range.
248 """
249 values = cast(Vector, data[self.vectorKey])
250 mask = (values >= self.minimum) & (values < self.maximum)
252 return cast(Vector, mask)
255class SnSelector(SelectorBase):
256 """Selects points that have S/N > threshold in the given flux type."""
258 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux")
259 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0)
260 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6)
261 uncertaintySuffix = Field[str](
262 doc="Suffix to add to fluxType to specify uncertainty column", default="Err"
263 )
264 bands = ListField[str](
265 doc="The bands to apply the signal to noise cut in." "Takes precedence if bands passed to call",
266 default=[],
267 )
269 def getInputSchema(self) -> KeyedDataSchema:
270 fluxCol = self.fluxType
271 fluxInd = fluxCol.find("lux") + len("lux")
272 errCol = f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix}" + f"{fluxCol}"[fluxInd:]
273 yield fluxCol, Vector
274 yield errCol, Vector
276 def __call__(self, data: KeyedData, **kwargs) -> Vector:
277 """Makes a mask of objects that have S/N greater than
278 self.threshold in self.fluxType
280 Parameters
281 ----------
282 data : `KeyedData`
283 The data to perform the selection on.
285 Returns
286 -------
287 result : `Vector`
288 A mask of the objects that satisfy the given
289 S/N cut.
290 """
291 mask: Optional[Vector] = None
292 bands: tuple[str, ...]
293 match kwargs:
294 case {"band": band} if not self.bands and self.bands == []:
295 bands = (band,)
296 case {"bands": bands} if not self.bands and self.bands == []:
297 bands = bands
298 case _ if self.bands:
299 bands = tuple(self.bands)
300 case _:
301 bands = ("",)
302 bandStr = ",".join(bands)
303 for band in bands:
304 fluxCol = self.fluxType.format(**(kwargs | dict(band=band)))
305 fluxInd = fluxCol.find("lux") + len("lux")
306 errCol = (
307 f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix.format(**kwargs)}" + f"{fluxCol}"[fluxInd:]
308 )
309 vec = divide(cast(Vector, data[fluxCol]), cast(Vector, data[errCol]))
310 temp = (vec > self.threshold) & (vec < self.maxSN)
311 if mask is not None:
312 mask &= temp # type: ignore
313 else:
314 mask = temp
316 plotLabelStr = "({}) > {:.1f}".format(bandStr, self.threshold)
317 if self.maxSN < 1e5:
318 plotLabelStr += " & < {:.1f}".format(self.maxSN)
320 if self.plotLabelKey == "" or self.plotLabelKey is None:
321 self._addValueToPlotInfo(plotLabelStr, plotLabelKey="S/N", **kwargs)
322 else:
323 self._addValueToPlotInfo(plotLabelStr, **kwargs)
325 # It should not be possible for mask to be a None now
326 return np.array(cast(Vector, mask))
329class SkyObjectSelector(FlagSelector):
330 """Selects sky objects in the given band(s)."""
332 bands = ListField[str](
333 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
334 default=["i"],
335 )
337 def getInputSchema(self) -> KeyedDataSchema:
338 yield from super().getInputSchema()
340 def __call__(self, data: KeyedData, **kwargs) -> Vector:
341 result: Optional[Vector] = None
342 bands: tuple[str, ...]
343 match kwargs:
344 case {"band": band} if not self.bands and self.bands == []:
345 bands = (band,)
346 case {"bands": bands} if not self.bands and self.bands == []:
347 bands = bands
348 case _ if self.bands:
349 bands = tuple(self.bands)
350 case _:
351 bands = ("",)
352 for band in bands:
353 temp = super().__call__(data, **(kwargs | dict(band=band)))
354 if result is not None:
355 result &= temp # type: ignore
356 else:
357 result = temp
358 return cast(Vector, result)
360 def setDefaults(self):
361 self.selectWhenFalse = [
362 "{band}_pixelFlags_edge",
363 ]
364 self.selectWhenTrue = ["sky_object"]
367class SkySourceSelector(FlagSelector):
368 """Selects sky sources from sourceTables."""
370 def getInputSchema(self) -> KeyedDataSchema:
371 yield from super().getInputSchema()
373 def __call__(self, data: KeyedData, **kwargs) -> Vector:
374 result: Optional[Vector] = None
375 temp = super().__call__(data, **(kwargs))
376 if result is not None:
377 result &= temp # type: ignore
378 else:
379 result = temp
380 return result
382 def setDefaults(self):
383 self.selectWhenFalse = [
384 "pixelFlags_edge",
385 ]
386 self.selectWhenTrue = ["sky_source"]
389class GoodDiaSourceSelector(FlagSelector):
390 """Selects good DIA sources from diaSourceTables."""
392 def getInputSchema(self) -> KeyedDataSchema:
393 yield from super().getInputSchema()
395 def __call__(self, data: KeyedData, **kwargs) -> Vector:
396 result: Optional[Vector] = None
397 temp = super().__call__(data, **(kwargs))
398 if result is not None:
399 result &= temp # type: ignore
400 else:
401 result = temp
402 return result
404 def setDefaults(self):
405 # These default flag names are correct for AP data products
406 self.selectWhenFalse = [
407 "base_PixelFlags_flag_bad",
408 "base_PixelFlags_flag_suspect",
409 "base_PixelFlags_flag_saturatedCenter",
410 "base_PixelFlags_flag_interpolated",
411 "base_PixelFlags_flag_interpolatedCenter",
412 "base_PixelFlags_flag_edge",
413 ]
415 def drpContext(self):
416 # These flag names are correct for DRP data products
417 newSelectWhenFalse = [
418 flag.replace("base_PixelFlags_flag", "pixelFlags") for flag in self.selectWhenFalse
419 ]
420 self.selectWhenFalse = newSelectWhenFalse
423class ExtendednessSelector(VectorAction):
424 """A selector that picks between extended and point sources."""
426 vectorKey = Field[str](
427 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness"
428 )
430 def getInputSchema(self) -> KeyedDataSchema:
431 return ((self.vectorKey, Vector),)
433 def __call__(self, data: KeyedData, **kwargs) -> Vector:
434 key = self.vectorKey.format(**kwargs)
435 return cast(Vector, data[key])
438class StarSelector(ExtendednessSelector):
439 """A selector that picks out stars based off of their
440 extendedness values.
441 """
443 extendedness_maximum = Field[float](
444 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float
445 )
447 def __call__(self, data: KeyedData, **kwargs) -> Vector:
448 extendedness = super().__call__(data, **kwargs)
449 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum)))
452class GalaxySelector(ExtendednessSelector):
453 """A selector that picks out galaxies based off of their
454 extendedness values.
455 """
457 extendedness_minimum = Field[float](
458 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5
459 )
461 def __call__(self, data: KeyedData, **kwargs) -> Vector:
462 extendedness = super().__call__(data, **kwargs)
463 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore
466class UnknownSelector(ExtendednessSelector):
467 """A selector that picks out unclassified objects based off of their
468 extendedness values.
469 """
471 def __call__(self, data: KeyedData, **kwargs) -> Vector:
472 extendedness = super().__call__(data, **kwargs)
473 return extendedness == 9
476class VectorSelector(VectorAction):
477 """Load a boolean vector from KeyedData and return it for use as a
478 selector.
479 """
481 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask")
483 def getInputSchema(self) -> KeyedDataSchema:
484 return ((self.vectorKey, Vector),)
486 def __call__(self, data: KeyedData, **kwargs) -> Vector:
487 return cast(Vector, data[self.vectorKey.format(**kwargs)])
490class ThresholdSelector(SelectorBase):
491 """Return a mask corresponding to an applied threshold."""
493 op = Field[str](doc="Operator name.")
494 threshold = Field[float](doc="Threshold to apply.")
495 vectorKey = Field[str](doc="Name of column")
497 def getInputSchema(self) -> KeyedDataSchema:
498 return ((self.vectorKey, Vector),)
500 def __call__(self, data: KeyedData, **kwargs) -> Vector:
501 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold)
502 return cast(Vector, mask)
505class BandSelector(VectorAction):
506 """Makes a mask for sources observed in a specified set of bands."""
508 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band")
509 bands = ListField[str](
510 doc="The bands to select. `None` indicates no band selection applied.",
511 default=[],
512 )
514 def getInputSchema(self) -> KeyedDataSchema:
515 return ((self.vectorKey, Vector),)
517 def __call__(self, data: KeyedData, **kwargs) -> Vector:
518 bands: Optional[tuple[str, ...]]
519 match kwargs:
520 case {"band": band} if not self.bands and self.bands == []:
521 bands = (band,)
522 case {"bands": bands} if not self.bands and self.bands == []:
523 bands = bands
524 case _ if self.bands:
525 bands = tuple(self.bands)
526 case _:
527 bands = None
528 if bands:
529 mask = np.in1d(data[self.vectorKey], bands)
530 else:
531 # No band selection is applied, i.e., select all rows
532 mask = np.full(len(data[self.vectorKey]), True) # type: ignore
533 return cast(Vector, mask)
536class ParentObjectSelector(FlagSelector):
537 """Select only parent objects that are not sky objects."""
539 def setDefaults(self):
540 # This selects all of the parents
541 self.selectWhenFalse = [
542 "detect_isDeblendedModelSource",
543 "sky_object",
544 ]
545 self.selectWhenTrue = ["detect_isPatchInner"]
548class ChildObjectSelector(RangeSelector):
549 """Select only children from deblended parents"""
551 vectorKey = Field[str](doc="Key to select from data", default="parentSourceId")
553 def getInputSchema(self) -> KeyedDataSchema:
554 yield self.vectorKey, Vector
556 def __call__(self, data: KeyedData, **kwargs) -> Vector:
557 """Return a mask of rows with values within the specified range.
559 Parameters
560 ----------
561 data : `KeyedData`
563 Returns
564 -------
565 result : `Vector`
566 A mask of the rows with values within the specified range.
567 """
568 values = cast(Vector, data[self.vectorKey])
569 mask = values > 0
571 return cast(Vector, mask)
574class MagSelector(SelectorBase):
575 """Selects points that have minMag < mag (AB) < maxMag.
577 The magnitude is based on the given fluxType.
578 """
580 fluxType = Field[str](doc="Flux type to calculate the magnitude in.", default="{band}_psfFlux")
581 minMag = Field[float](doc="Minimum mag to include in the sample.", default=-1e6)
582 maxMag = Field[float](doc="Maximum mag to include in the sample.", default=1e6)
583 fluxUnit = Field[str](doc="Astropy unit of flux vector", default="nJy")
584 returnMillimags = Field[bool](doc="Use millimags or not?", default=False)
585 bands = ListField[str](
586 doc="The band(s) to apply the magnitude cut in. Takes precedence if bands passed to call.",
587 default=[],
588 )
590 def getInputSchema(self) -> KeyedDataSchema:
591 fluxCol = self.fluxType
592 yield fluxCol, Vector
594 def __call__(self, data: KeyedData, **kwargs) -> Vector:
595 """Make a mask of that satisfies self.minMag < mag < self.maxMag.
597 The magnitude is based on the flux in self.fluxType.
599 Parameters
600 ----------
601 data : `KeyedData`
602 The data to perform the magnitude selection on.
604 Returns
605 -------
606 result : `Vector`
607 A mask of the objects that satisfy the given magnitude cut.
608 """
609 mask: Optional[Vector] = None
610 bands: tuple[str, ...]
611 match kwargs:
612 case {"band": band} if not self.bands and self.bands == []:
613 bands = (band,)
614 case {"bands": bands} if not self.bands and self.bands == []:
615 bands = bands
616 case _ if self.bands:
617 bands = tuple(self.bands)
618 case _:
619 bands = ("",)
620 bandStr = ",".join(bands)
621 for band in bands:
622 fluxCol = self.fluxType.format(**(kwargs | dict(band=band)))
623 vec = fluxToMag(
624 cast(Vector, data[fluxCol]),
625 flux_unit=self.fluxUnit,
626 return_millimags=self.returnMillimags,
627 )
628 temp = (vec > self.minMag) & (vec < self.maxMag)
629 if mask is not None:
630 mask &= temp # type: ignore
631 else:
632 mask = temp
634 plotLabelStr = ""
635 if self.maxMag < 100:
636 plotLabelStr += "({}) < {:.1f}".format(bandStr, self.maxMag)
637 if self.minMag > -100:
638 if bandStr in plotLabelStr:
639 plotLabelStr += " & < {:.1f}".format(self.minMag)
640 else:
641 plotLabelStr += "({}) < {:.1f}".format(bandStr, self.minMag)
642 if self.plotLabelKey == "" or self.plotLabelKey is None:
643 self._addValueToPlotInfo(plotLabelStr, plotLabelKey="Mag", **kwargs)
644 else:
645 self._addValueToPlotInfo(plotLabelStr, **kwargs)
647 # It should not be possible for mask to be a None now
648 return np.array(cast(Vector, mask))