Coverage for python / lsst / analysis / tools / actions / vector / selectors.py: 29%
389 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 00:23 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 00:23 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = (
24 "SelectorBase",
25 "FlagSelector",
26 "CoaddPlotFlagSelector",
27 "RangeSelector",
28 "SetSelector",
29 "SnSelector",
30 "ExtendednessSelector",
31 "SkyObjectSelector",
32 "SkySourceSelector",
33 "GoodDiaSourceSelector",
34 "StarSelector",
35 "GalaxySelector",
36 "UnknownSelector",
37 "VectorSelector",
38 "FiniteSelector",
39 "VisitPlotFlagSelector",
40 "ThresholdSelector",
41 "BandSelector",
42 "MatchingFlagSelector",
43 "MagSelector",
44 "InjectedClassSelector",
45 "InjectedGalaxySelector",
46 "InjectedObjectSelector",
47 "InjectedStarSelector",
48 "MatchedObjectSelector",
49 "ReferenceGalaxySelector",
50 "ReferenceObjectSelector",
51 "ReferenceStarSelector",
52)
54import operator
55from typing import Optional, cast
57import numpy as np
58from lsst.pex.config import Field
59from lsst.pex.config.listField import ListField
61from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction
62from ...math import divide, fluxToMag
65class SelectorBase(VectorAction):
66 plotLabelKey = Field[str](
67 doc="Key to use when populating plot info, ignored if empty string", optional=True, default=""
68 )
70 def _addValueToPlotInfo(self, value, plotLabelKey=None, **kwargs):
71 if "plotInfo" in kwargs:
72 if plotLabelKey is not None:
73 kwargs["plotInfo"][plotLabelKey] = value
74 elif self.plotLabelKey:
75 kwargs["plotInfo"][self.plotLabelKey] = value
76 else:
77 raise RuntimeError(f"No plotLabelKey provided for value {value}, so can't add to plotInfo")
80class FlagSelector(VectorAction):
81 """The base flag selector to use to select valid sources for QA."""
83 selectWhenFalse = ListField[str](
84 doc="Names of the flag columns to select on when False", optional=False, default=[]
85 )
87 selectWhenTrue = ListField[str](
88 doc="Names of the flag columns to select on when True", optional=False, default=[]
89 )
91 def getInputSchema(self) -> KeyedDataSchema:
92 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue)
93 return ((col, Vector) for col in allCols)
95 def __call__(self, data: KeyedData, **kwargs) -> Vector:
96 """Select on the given flags
98 Parameters
99 ----------
100 data : `KeyedData`
102 Returns
103 -------
104 result : `Vector`
105 A mask of the objects that satisfy the given
106 flag cuts.
108 Notes
109 -----
110 Uses the columns in selectWhenFalse and
111 selectWhenTrue to decide which columns to
112 select on in each circumstance.
113 """
115 if not self.selectWhenFalse and not self.selectWhenTrue:
116 raise RuntimeError("No column keys specified")
117 results: Optional[Vector] = None
119 for flag in self.selectWhenFalse: # type: ignore
120 temp = np.array(data[flag.format(**kwargs)] == 0)
121 if results is not None:
122 results &= temp # type: ignore
123 else:
124 results = temp
126 for flag in self.selectWhenTrue:
127 temp = np.array(data[flag.format(**kwargs)] == 1)
128 if results is not None:
129 results &= temp # type: ignore
130 else:
131 results = temp
132 # The test at the beginning assures this can never be None
133 return cast(Vector, results)
136class CoaddPlotFlagSelector(FlagSelector):
137 """This default setting makes it take the band from
138 the kwargs.
139 """
141 bands = ListField[str](
142 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
143 default=[],
144 )
146 def getInputSchema(self) -> KeyedDataSchema:
147 yield from super().getInputSchema()
149 def refMatchContext(self):
150 self.selectWhenFalse = [
151 "{band}_psfFlux_flag_target",
152 "{band}_pixelFlags_saturatedCenter_target",
153 "{band}_extendedness_flag_target",
154 "coord_flag_target",
155 ]
156 self.selectWhenTrue = ["detect_isPatchInner_target", "detect_isDeblendedSource_target"]
158 def __call__(self, data: KeyedData, **kwargs) -> Vector:
159 result: Optional[Vector] = None
160 bands: tuple[str, ...]
161 match kwargs:
162 case {"band": band} if not self.bands and self.bands == []:
163 bands = (band,)
164 case {"bands": bands} if not self.bands and self.bands == []:
165 bands = bands
166 case _ if self.bands:
167 bands = tuple(self.bands)
168 case _:
169 bands = ("",)
170 for band in bands:
171 temp = super().__call__(data, **(kwargs | dict(band=band)))
172 if result is not None:
173 result &= temp # type: ignore
174 else:
175 result = temp
176 return cast(Vector, result)
178 def setDefaults(self):
179 self.selectWhenFalse = [
180 "{band}_psfFlux_flag",
181 "{band}_pixelFlags_saturatedCenter",
182 "{band}_extendedness_flag",
183 "coord_flag",
184 "sky_object",
185 ]
186 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"]
189class MatchingFlagSelector(CoaddPlotFlagSelector):
190 """The default flag selector to apply pre matching.
191 The sources are cut down to remove duplicates but
192 not on quality.
193 """
195 def setDefaults(self):
196 self.selectWhenFalse = []
197 self.selectWhenTrue = ["detect_isPrimary"]
200class VisitPlotFlagSelector(FlagSelector):
201 """Select on a set of flags appropriate for making visit-level plots
202 (i.e., using sourceTable_visit catalogs).
203 """
205 catalogSuffix = Field[str](doc="The suffix to apply to all the keys.", default="")
207 def getInputSchema(self) -> KeyedDataSchema:
208 yield from super().getInputSchema()
210 def refMatchContext(self):
211 self.selectWhenFalse = [
212 "psfFlux_flag_target",
213 "pixelFlags_saturatedCenter_target",
214 "extendedness_flag_target",
215 "centroid_flag_target",
216 ]
218 def __call__(self, data: KeyedData, **kwargs) -> Vector:
219 result: Optional[Vector] = None
220 temp = super().__call__(data, **kwargs)
221 if result is not None:
222 result &= temp # type: ignore
223 else:
224 result = temp
226 return result
228 def setDefaults(self):
229 self.selectWhenFalse = [
230 "psfFlux_flag",
231 "pixelFlags_saturatedCenter",
232 "extendedness_flag",
233 "centroid_flag",
234 "sky_source",
235 ]
238class RangeSelector(SelectorBase):
239 """Selects rows within a range, inclusive of min/exclusive of max."""
241 vectorKey = Field[str](doc="Key to select from data")
242 maximum = Field[float](doc="The maximum value (exclusive)", default=np.inf)
243 minimum = Field[float](doc="The minimum value (inclusive)", default=np.nextafter(-np.inf, 0.0))
245 def getInputSchema(self) -> KeyedDataSchema:
246 yield self.vectorKey, Vector
248 def __call__(self, data: KeyedData, **kwargs) -> Vector:
249 """Return a mask of rows with values within the specified range.
251 Parameters
252 ----------
253 data : `KeyedData`
255 Returns
256 -------
257 result : `Vector`
258 A mask of the rows with values within the specified range.
259 """
260 values = cast(Vector, data[self.vectorKey])
261 mask = (values >= self.minimum) & (values < self.maximum)
263 return cast(Vector, mask)
266class SetSelector(SelectorBase):
267 """Selects rows with any number of column values within a given set.
269 For example, given a set of patches (1, 2, 3), and a set of columns
270 (index_1, index_2), return all rows with either index_1 or index_2
271 in the set (1, 2, 3).
273 Notes
274 -----
275 The values are given as floats for flexibility. Integers above
276 the floating point limit (2^53 + 1 = 9,007,199,254,740,993 for 64 bits)
277 will not compare exactly with their float representations.
278 """
280 vectorKeys = ListField[str](
281 doc="Keys to select from data",
282 default=[],
283 listCheck=lambda x: (len(x) > 0) & (len(x) == len(set(x))),
284 )
285 values = ListField[float](
286 doc="The set of acceptable values",
287 default=[],
288 listCheck=lambda x: (len(x) > 0) & (len(x) == len(set(x))),
289 )
291 def getInputSchema(self) -> KeyedDataSchema:
292 yield from ((key, Vector) for key in self.vectorKeys)
294 def __call__(self, data: KeyedData, **kwargs) -> Vector:
295 """Return a mask of rows with values in the specified set.
297 Parameters
298 ----------
299 data : `KeyedData`
301 Returns
302 -------
303 result : `Vector`
304 A mask of the rows with values in the specified set.
305 """
306 mask = np.zeros_like(data[self.vectorKeys[0]], dtype=bool)
307 for key in self.vectorKeys:
308 values = cast(Vector, data[key])
309 for compare in self.values:
310 mask |= values == compare
312 return cast(Vector, mask)
315class PatchSelector(SetSelector):
316 """Select rows within a set of patches."""
318 def setDefaults(self):
319 super().setDefaults()
320 self.vectorKeys = ["patch"]
323class SnSelector(SelectorBase):
324 """Selects points that have S/N > threshold in the given flux type."""
326 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux")
327 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0)
328 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6)
329 uncertaintySuffix = Field[str](
330 doc="Suffix to add to fluxType to specify uncertainty column", default="Err"
331 )
332 bands = ListField[str](
333 doc="The bands to apply the signal to noise cut in. Takes precedence if bands passed to call",
334 default=[],
335 )
337 def getInputSchema(self) -> KeyedDataSchema:
338 fluxCol = self.fluxType
339 fluxInd = fluxCol.find("lux") + len("lux")
340 errCol = f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix}" + f"{fluxCol}"[fluxInd:]
341 yield fluxCol, Vector
342 yield errCol, Vector
344 def __call__(self, data: KeyedData, **kwargs) -> Vector:
345 """Makes a mask of objects that have S/N greater than
346 self.threshold in self.fluxType
348 Parameters
349 ----------
350 data : `KeyedData`
351 The data to perform the selection on.
353 Returns
354 -------
355 result : `Vector`
356 A mask of the objects that satisfy the given
357 S/N cut.
358 """
359 mask: Optional[Vector] = None
360 bands: tuple[str, ...]
361 match kwargs:
362 case {"band": band} if not self.bands and self.bands == []:
363 bands = (band,)
364 case {"bands": bands} if not self.bands and self.bands == []:
365 bands = bands
366 case _ if self.bands:
367 bands = tuple(self.bands)
368 case _:
369 bands = ("",)
370 bandStr = ",".join(bands)
371 for band in bands:
372 fluxCol = self.fluxType.format(**(kwargs | dict(band=band)))
373 fluxInd = fluxCol.find("lux") + len("lux")
374 errCol = (
375 f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix.format(**kwargs)}" + f"{fluxCol}"[fluxInd:]
376 )
377 vec = divide(cast(Vector, data[fluxCol]), cast(Vector, data[errCol]))
378 temp = (vec > self.threshold) & (vec < self.maxSN)
379 if mask is not None:
380 mask &= temp # type: ignore
381 else:
382 mask = temp
384 plotLabelStr = "({}) > {:.1f}".format(bandStr, self.threshold)
385 if self.maxSN < 1e5:
386 plotLabelStr += " & < {:.1f}".format(self.maxSN)
388 if self.plotLabelKey == "" or self.plotLabelKey is None:
389 self._addValueToPlotInfo(plotLabelStr, plotLabelKey="S/N", **kwargs)
390 else:
391 self._addValueToPlotInfo(plotLabelStr, **kwargs)
393 # It should not be possible for mask to be a None now
394 return np.array(cast(Vector, mask))
397class SkyObjectSelector(FlagSelector):
398 """Selects sky objects in the given band(s)."""
400 bands = ListField[str](
401 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
402 default=[],
403 )
405 def getInputSchema(self) -> KeyedDataSchema:
406 yield from super().getInputSchema()
408 def __call__(self, data: KeyedData, **kwargs) -> Vector:
409 result: Optional[Vector] = None
410 bands: tuple[str, ...]
411 match kwargs:
412 case {"band": band} if not self.bands and self.bands == []:
413 bands = (band,)
414 case {"bands": bands} if not self.bands and self.bands == []:
415 bands = bands
416 case _ if self.bands:
417 bands = tuple(self.bands)
418 case _:
419 bands = ("",)
420 for band in bands:
421 temp = super().__call__(data, **(kwargs | dict(band=band)))
422 if result is not None:
423 result &= temp # type: ignore
424 else:
425 result = temp
426 return cast(Vector, result)
428 def setDefaults(self):
429 super().setDefaults()
430 self.selectWhenFalse = [
431 "{band}_pixelFlags_edge",
432 "{band}_pixelFlags_nodata",
433 ]
434 self.selectWhenTrue = ["sky_object"]
437class SkySourceSelector(FlagSelector):
438 """Selects sky sources from sourceTables."""
440 def getInputSchema(self) -> KeyedDataSchema:
441 yield from super().getInputSchema()
443 def __call__(self, data: KeyedData, **kwargs) -> Vector:
444 result: Optional[Vector] = None
445 temp = super().__call__(data, **(kwargs))
446 if result is not None:
447 result &= temp # type: ignore
448 else:
449 result = temp
450 return result
452 def setDefaults(self):
453 super().setDefaults()
454 self.selectWhenFalse = [
455 "pixelFlags_edge",
456 "pixelFlags_nodata",
457 ]
458 self.selectWhenTrue = ["sky_source"]
461class GoodDiaSourceSelector(FlagSelector):
462 """Selects good DIA sources from diaSourceTables."""
464 def getInputSchema(self) -> KeyedDataSchema:
465 yield from super().getInputSchema()
467 def __call__(self, data: KeyedData, **kwargs) -> Vector:
468 result: Optional[Vector] = None
469 temp = super().__call__(data, **(kwargs))
470 if result is not None:
471 result &= temp # type: ignore
472 else:
473 result = temp
474 return result
476 def setDefaults(self):
477 super().setDefaults()
478 # These default flag names are correct for AP data products
479 self.selectWhenFalse = [
480 "pixelFlags_bad",
481 "pixelFlags_saturatedCenter",
482 "pixelFlags_interpolatedCenter",
483 "pixelFlags_edge",
484 "pixelFlags_nodata",
485 ]
488class ExtendednessSelector(SelectorBase):
489 """A selector that picks between extended and point sources."""
491 vectorKey = Field[str](
492 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness"
493 )
495 def getInputSchema(self) -> KeyedDataSchema:
496 return ((self.vectorKey, Vector),)
498 def __call__(self, data: KeyedData, **kwargs) -> Vector:
499 key = self.vectorKey.format(**kwargs)
500 return cast(Vector, data[key])
503class StarSelector(ExtendednessSelector):
504 """A selector that picks out stars based off of their
505 extendedness values.
506 """
508 extendedness_maximum = Field[float](
509 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float
510 )
512 def __call__(self, data: KeyedData, **kwargs) -> Vector:
513 extendedness = super().__call__(data, **kwargs)
514 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum)))
517class GalaxySelector(ExtendednessSelector):
518 """A selector that picks out galaxies based off of their
519 extendedness values.
520 """
522 extendedness_minimum = Field[float](
523 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5
524 )
526 def __call__(self, data: KeyedData, **kwargs) -> Vector:
527 extendedness = super().__call__(data, **kwargs)
528 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore
531class UnknownSelector(ExtendednessSelector):
532 """A selector that picks out unclassified objects based off of their
533 extendedness values.
534 """
536 def __call__(self, data: KeyedData, **kwargs) -> Vector:
537 extendedness = super().__call__(data, **kwargs)
538 return extendedness == 9
541class FiniteSelector(VectorAction):
542 """Return a mask of finite values for a vector key"""
544 vectorKey = Field[str](doc="Key to make a mask of finite values for.")
546 def getInputSchema(self) -> KeyedDataSchema:
547 return ((self.vectorKey, Vector),)
549 def __call__(self, data: KeyedData, **kwargs) -> Vector:
550 return cast(Vector, np.isfinite(data[self.vectorKey.format(**kwargs)]))
553class VectorSelector(VectorAction):
554 """Load a boolean vector from KeyedData and return it for use as a
555 selector.
556 """
558 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask")
560 def getInputSchema(self) -> KeyedDataSchema:
561 return ((self.vectorKey, Vector),)
563 def __call__(self, data: KeyedData, **kwargs) -> Vector:
564 return cast(Vector, data[self.vectorKey.format(**kwargs)])
567class ThresholdSelector(SelectorBase):
568 """Return a mask corresponding to an applied threshold."""
570 op = Field[str](doc="Operator name.")
571 threshold = Field[float](doc="Threshold to apply.")
572 vectorKey = Field[str](doc="Name of column")
574 def getInputSchema(self) -> KeyedDataSchema:
575 return ((self.vectorKey, Vector),)
577 def __call__(self, data: KeyedData, **kwargs) -> Vector:
578 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold)
579 return cast(Vector, mask)
582class BandSelector(VectorAction):
583 """Makes a mask for sources observed in a specified set of bands."""
585 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band")
586 bands = ListField[str](
587 doc="The bands to select. `None` indicates no band selection applied.",
588 default=[],
589 )
591 def getInputSchema(self) -> KeyedDataSchema:
592 return ((self.vectorKey, Vector),)
594 def __call__(self, data: KeyedData, **kwargs) -> Vector:
595 bands: Optional[tuple[str, ...]]
596 match kwargs:
597 case {"band": band} if not self.bands and self.bands == []:
598 bands = (band,)
599 case {"bands": bands} if not self.bands and self.bands == []:
600 bands = bands
601 case _ if self.bands:
602 bands = tuple(self.bands)
603 case _:
604 bands = None
605 if bands:
606 mask = np.isin(data[self.vectorKey], bands).flatten()
607 else:
608 # No band selection is applied, i.e., select all rows
609 mask = np.full(len(data[self.vectorKey]), True) # type: ignore
610 return cast(Vector, mask)
613class ParentObjectSelector(FlagSelector):
614 """Select only parent objects that are not sky objects."""
616 def setDefaults(self):
617 # This selects all of the parents
618 self.selectWhenFalse = [
619 "sky_object",
620 ]
623class ChildObjectSelector(RangeSelector):
624 """Select only children from deblended parents"""
626 vectorKey = Field[str](doc="Key to select from data", default="parentSourceId")
628 def getInputSchema(self) -> KeyedDataSchema:
629 yield self.vectorKey, Vector
631 def __call__(self, data: KeyedData, **kwargs) -> Vector:
632 """Return a mask of rows with values within the specified range.
634 Parameters
635 ----------
636 data : `KeyedData`
638 Returns
639 -------
640 result : `Vector`
641 A mask of the rows with values within the specified range.
642 """
643 values = cast(Vector, data[self.vectorKey])
644 mask = values > 0
646 return cast(Vector, mask)
649class MagSelector(SelectorBase):
650 """Selects points that have minMag < mag (AB) < maxMag.
652 The magnitude is based on the given fluxType.
653 """
655 fluxType = Field[str](doc="Flux type to calculate the magnitude in.", default="{band}_psfFlux")
656 minMag = Field[float](doc="Minimum mag to include in the sample.", default=-1e6)
657 maxMag = Field[float](doc="Maximum mag to include in the sample.", default=1e6)
658 fluxUnit = Field[str](doc="Astropy unit of flux vector", default="nJy")
659 returnMillimags = Field[bool](doc="Use millimags or not?", default=False)
660 bands = ListField[str](
661 doc="The band(s) to apply the magnitude cut in. Takes precedence if bands passed to call.",
662 default=[],
663 )
665 def getInputSchema(self) -> KeyedDataSchema:
666 fluxCol = self.fluxType
667 yield fluxCol, Vector
669 def __call__(self, data: KeyedData, **kwargs) -> Vector:
670 """Make a mask of that satisfies self.minMag < mag < self.maxMag.
672 The magnitude is based on the flux in self.fluxType.
674 Parameters
675 ----------
676 data : `KeyedData`
677 The data to perform the magnitude selection on.
679 Returns
680 -------
681 result : `Vector`
682 A mask of the objects that satisfy the given magnitude cut.
683 """
684 mask: Optional[Vector] = None
685 bands: tuple[str, ...]
686 match kwargs:
687 case {"band": band} if not self.bands and self.bands == []:
688 bands = (band,)
689 case {"bands": bands} if not self.bands and self.bands == []:
690 bands = bands
691 case _ if self.bands:
692 bands = tuple(self.bands)
693 case _:
694 bands = ("",)
695 bandStr = ",".join(bands)
696 for band in bands:
697 fluxCol = self.fluxType.format(**(kwargs | dict(band=band)))
698 vec = fluxToMag(
699 cast(Vector, data[fluxCol]),
700 flux_unit=self.fluxUnit,
701 return_millimags=self.returnMillimags,
702 )
703 temp = (vec > self.minMag) & (vec < self.maxMag)
704 if mask is not None:
705 mask &= temp # type: ignore
706 else:
707 mask = temp
709 plotLabelStr = ""
710 if self.maxMag < 100:
711 plotLabelStr += "({}) < {:.1f}".format(bandStr, self.maxMag)
712 if self.minMag > -100:
713 if bandStr in plotLabelStr:
714 plotLabelStr += " & < {:.1f}".format(self.minMag)
715 else:
716 plotLabelStr += "({}) < {:.1f}".format(bandStr, self.minMag)
717 if self.plotLabelKey == "" or self.plotLabelKey is None:
718 self._addValueToPlotInfo(plotLabelStr, plotLabelKey="Mag", **kwargs)
719 else:
720 self._addValueToPlotInfo(plotLabelStr, **kwargs)
722 # It should not be possible for mask to be a None now
723 return np.array(cast(Vector, mask))
726class InjectedObjectSelector(SelectorBase):
727 """A selector for injected objects."""
729 vectorKey = Field[str](doc="Key to select from data", default="ref_injected_isPrimary")
731 def __call__(self, data: KeyedData, **kwargs) -> Vector:
732 key = self.vectorKey.format(**kwargs)
733 result = cast(Vector, data[key] == 1)
734 return result
736 def getInputSchema(self) -> KeyedDataSchema:
737 yield self.vectorKey, Vector
740class InjectedClassSelector(InjectedObjectSelector):
741 """A selector for injected objects of a given class."""
743 key_class = Field[str](
744 doc="Key for the field indicating the class of the object",
745 default="ref_source_type",
746 )
747 key_injection_flag = Field[str](
748 doc="Key for the field indicating that the object was not injected (per band)",
749 default="ref_{band}_injection_flag",
750 )
751 name_class = Field[str](
752 doc="Name of the class of objects",
753 )
754 value_compare = Field[str](
755 doc="Value of the type_key field for objects that are stars",
756 default="DeltaFunction",
757 )
758 value_is_equal = Field[bool](
759 doc="Whether the value must equal value_compare to be of this class",
760 default=True,
761 )
763 def __call__(self, data: KeyedData, **kwargs) -> Vector:
764 result = super().__call__(data, **kwargs)
765 if self.key_injection_flag:
766 result &= data[self.key_injection_flag.format(band=kwargs["band"])] == False # noqa: E712
767 values = data[self.key_class]
768 result &= (values == self.value_compare) if self.value_is_equal else (values != self.value_compare)
769 if self.plotLabelKey:
770 self._addValueToPlotInfo(f"injected {self.name_class}", **kwargs)
771 return result
773 def getInputSchema(self) -> KeyedDataSchema:
774 yield from super().getInputSchema()
775 yield self.key_class, Vector
776 if self.key_injection_flag:
777 yield self.key_injection_flag, Vector
780class InjectedGalaxySelector(InjectedClassSelector):
781 """A selector for injected galaxies."""
783 def setDefaults(self):
784 self.name_class = "galaxy"
785 # Assumes not star == galaxy - if there are injected AGN or other
786 # object classes, this will need to be updated
787 self.value_is_equal = False
790class InjectedStarSelector(InjectedClassSelector):
791 """A selector for injected stars."""
793 def setDefaults(self):
794 self.name_class = "star"
797class MatchedObjectSelector(RangeSelector):
798 """A selector that selects matched objects with finite distances."""
800 def setDefaults(self):
801 super().setDefaults()
802 self.minimum = 0
803 self.vectorKey = "match_distance"
806class ReferenceGalaxySelector(ThresholdSelector):
807 """A selector that selects galaxies from a catalog with a
808 boolean column identifying unresolved sources.
809 """
811 def __call__(self, data: KeyedData, **kwargs) -> Vector:
812 result = super().__call__(data=data, **kwargs)
813 if self.plotLabelKey:
814 self._addValueToPlotInfo("reference galaxies", **kwargs)
815 return result
817 def setDefaults(self):
818 super().setDefaults()
819 self.op = "eq"
820 self.threshold = 0
821 self.plotLabelKey = "Selection: Galaxies"
822 self.vectorKey = "refcat_is_pointsource"
825class ReferenceObjectSelector(RangeSelector):
826 """A selector that selects all objects from a catalog with a
827 boolean column identifying unresolved sources.
828 """
830 def __call__(self, data: KeyedData, **kwargs) -> Vector:
831 result = super().__call__(data=data, **kwargs)
832 if self.plotLabelKey:
833 self._addValueToPlotInfo("reference objects", **kwargs)
834 return result
836 def setDefaults(self):
837 super().setDefaults()
838 self.minimum = 0
839 self.vectorKey = "refcat_is_pointsource"
842class ReferenceStarSelector(ThresholdSelector):
843 """A selector that selects stars from a catalog with a
844 boolean column identifying unresolved sources.
845 """
847 def __call__(self, data: KeyedData, **kwargs) -> Vector:
848 result = super().__call__(data=data, **kwargs)
849 if self.plotLabelKey:
850 self._addValueToPlotInfo("reference stars", **kwargs)
851 return result
853 def setDefaults(self):
854 super().setDefaults()
855 self.op = "eq"
856 self.plotLabelKey = "Selection: Stars"
857 self.threshold = 1
858 self.vectorKey = "refcat_is_pointsource"