Coverage for python / lsst / analysis / tools / actions / vector / selectors.py: 29%
389 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 09:36 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 09:36 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = (
24 "SelectorBase",
25 "FlagSelector",
26 "CoaddPlotFlagSelector",
27 "RangeSelector",
28 "SetSelector",
29 "SnSelector",
30 "ExtendednessSelector",
31 "SkyObjectSelector",
32 "SkySourceSelector",
33 "GoodDiaSourceSelector",
34 "StarSelector",
35 "GalaxySelector",
36 "UnknownSelector",
37 "VectorSelector",
38 "FiniteSelector",
39 "VisitPlotFlagSelector",
40 "ThresholdSelector",
41 "BandSelector",
42 "MatchingFlagSelector",
43 "MagSelector",
44 "InjectedClassSelector",
45 "InjectedGalaxySelector",
46 "InjectedObjectSelector",
47 "InjectedStarSelector",
48 "MatchedObjectSelector",
49 "ReferenceGalaxySelector",
50 "ReferenceObjectSelector",
51 "ReferenceStarSelector",
52)
54import operator
55from typing import cast
57import numpy as np
59from lsst.pex.config import Field
60from lsst.pex.config.listField import ListField
62from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction
63from ...math import divide, fluxToMag
66class SelectorBase(VectorAction):
67 plotLabelKey = Field[str](
68 doc="Key to use when populating plot info, ignored if empty string", optional=True, default=""
69 )
71 def _addValueToPlotInfo(self, value, plotLabelKey=None, **kwargs):
72 if "plotInfo" in kwargs:
73 if plotLabelKey is not None:
74 kwargs["plotInfo"][plotLabelKey] = value
75 elif self.plotLabelKey:
76 kwargs["plotInfo"][self.plotLabelKey] = value
77 else:
78 raise RuntimeError(f"No plotLabelKey provided for value {value}, so can't add to plotInfo")
81class FlagSelector(VectorAction):
82 """The base flag selector to use to select valid sources for QA."""
84 selectWhenFalse = ListField[str](
85 doc="Names of the flag columns to select on when False", optional=False, default=[]
86 )
88 selectWhenTrue = ListField[str](
89 doc="Names of the flag columns to select on when True", optional=False, default=[]
90 )
92 def getInputSchema(self) -> KeyedDataSchema:
93 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue)
94 return ((col, Vector) for col in allCols)
96 def __call__(self, data: KeyedData, **kwargs) -> Vector:
97 """Select on the given flags
99 Parameters
100 ----------
101 data : `KeyedData`
103 Returns
104 -------
105 result : `Vector`
106 A mask of the objects that satisfy the given
107 flag cuts.
109 Notes
110 -----
111 Uses the columns in selectWhenFalse and
112 selectWhenTrue to decide which columns to
113 select on in each circumstance.
114 """
116 if not self.selectWhenFalse and not self.selectWhenTrue:
117 raise RuntimeError("No column keys specified")
118 results: Vector | None = None
120 for flag in self.selectWhenFalse: # type: ignore
121 temp = np.array(data[flag.format(**kwargs)] == 0)
122 if results is not None:
123 results &= temp # type: ignore
124 else:
125 results = temp
127 for flag in self.selectWhenTrue:
128 temp = np.array(data[flag.format(**kwargs)] == 1)
129 if results is not None:
130 results &= temp # type: ignore
131 else:
132 results = temp
133 # The test at the beginning assures this can never be None
134 return cast(Vector, results)
137class CoaddPlotFlagSelector(FlagSelector):
138 """This default setting makes it take the band from
139 the kwargs.
140 """
142 bands = ListField[str](
143 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
144 default=[],
145 )
147 def getInputSchema(self) -> KeyedDataSchema:
148 yield from super().getInputSchema()
150 def refMatchContext(self):
151 self.selectWhenFalse = [
152 "{band}_psfFlux_flag_target",
153 "{band}_pixelFlags_saturatedCenter_target",
154 "{band}_extendedness_flag_target",
155 "coord_flag_target",
156 ]
157 self.selectWhenTrue = ["detect_isPatchInner_target", "detect_isDeblendedSource_target"]
159 def __call__(self, data: KeyedData, **kwargs) -> Vector:
160 result: Vector | None = None
161 bands: tuple[str, ...]
162 match kwargs:
163 case {"band": band} if not self.bands and self.bands == []:
164 bands = (band,)
165 case {"bands": bands} if not self.bands and self.bands == []:
166 bands = bands
167 case _ if self.bands:
168 bands = tuple(self.bands)
169 case _:
170 bands = ("",)
171 for band in bands:
172 temp = super().__call__(data, **(kwargs | dict(band=band)))
173 if result is not None:
174 result &= temp # type: ignore
175 else:
176 result = temp
177 return cast(Vector, result)
179 def setDefaults(self):
180 self.selectWhenFalse = [
181 "{band}_psfFlux_flag",
182 "{band}_pixelFlags_saturatedCenter",
183 "{band}_extendedness_flag",
184 "coord_flag",
185 "sky_object",
186 ]
187 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"]
190class MatchingFlagSelector(CoaddPlotFlagSelector):
191 """The default flag selector to apply pre matching.
192 The sources are cut down to remove duplicates but
193 not on quality.
194 """
196 def setDefaults(self):
197 self.selectWhenFalse = []
198 self.selectWhenTrue = ["detect_isPrimary"]
201class VisitPlotFlagSelector(FlagSelector):
202 """Select on a set of flags appropriate for making visit-level plots
203 (i.e., using sourceTable_visit catalogs).
204 """
206 catalogSuffix = Field[str](doc="The suffix to apply to all the keys.", default="")
208 def getInputSchema(self) -> KeyedDataSchema:
209 yield from super().getInputSchema()
211 def refMatchContext(self):
212 self.selectWhenFalse = [
213 "psfFlux_flag_target",
214 "pixelFlags_saturatedCenter_target",
215 "extendedness_flag_target",
216 "centroid_flag_target",
217 ]
219 def __call__(self, data: KeyedData, **kwargs) -> Vector:
220 result: Vector | None = None
221 temp = super().__call__(data, **kwargs)
222 if result is not None:
223 result &= temp # type: ignore
224 else:
225 result = temp
227 return result
229 def setDefaults(self):
230 self.selectWhenFalse = [
231 "psfFlux_flag",
232 "pixelFlags_saturatedCenter",
233 "extendedness_flag",
234 "centroid_flag",
235 "sky_source",
236 ]
239class RangeSelector(SelectorBase):
240 """Selects rows within a range, inclusive of min/exclusive of max."""
242 vectorKey = Field[str](doc="Key to select from data")
243 maximum = Field[float](doc="The maximum value (exclusive)", default=np.inf)
244 minimum = Field[float](doc="The minimum value (inclusive)", default=np.nextafter(-np.inf, 0.0))
246 def getInputSchema(self) -> KeyedDataSchema:
247 yield self.vectorKey, Vector
249 def __call__(self, data: KeyedData, **kwargs) -> Vector:
250 """Return a mask of rows with values within the specified range.
252 Parameters
253 ----------
254 data : `KeyedData`
256 Returns
257 -------
258 result : `Vector`
259 A mask of the rows with values within the specified range.
260 """
261 values = cast(Vector, data[self.vectorKey])
262 mask = (values >= self.minimum) & (values < self.maximum)
264 return cast(Vector, mask)
267class SetSelector(SelectorBase):
268 """Selects rows with any number of column values within a given set.
270 For example, given a set of patches (1, 2, 3), and a set of columns
271 (index_1, index_2), return all rows with either index_1 or index_2
272 in the set (1, 2, 3).
274 Notes
275 -----
276 The values are given as floats for flexibility. Integers above
277 the floating point limit (2^53 + 1 = 9,007,199,254,740,993 for 64 bits)
278 will not compare exactly with their float representations.
279 """
281 vectorKeys = ListField[str](
282 doc="Keys to select from data",
283 default=[],
284 listCheck=lambda x: (len(x) > 0) & (len(x) == len(set(x))),
285 )
286 values = ListField[float](
287 doc="The set of acceptable values",
288 default=[],
289 listCheck=lambda x: (len(x) > 0) & (len(x) == len(set(x))),
290 )
292 def getInputSchema(self) -> KeyedDataSchema:
293 yield from ((key, Vector) for key in self.vectorKeys)
295 def __call__(self, data: KeyedData, **kwargs) -> Vector:
296 """Return a mask of rows with values in the specified set.
298 Parameters
299 ----------
300 data : `KeyedData`
302 Returns
303 -------
304 result : `Vector`
305 A mask of the rows with values in the specified set.
306 """
307 mask = np.zeros_like(data[self.vectorKeys[0]], dtype=bool)
308 for key in self.vectorKeys:
309 values = cast(Vector, data[key])
310 for compare in self.values:
311 mask |= values == compare
313 return cast(Vector, mask)
316class PatchSelector(SetSelector):
317 """Select rows within a set of patches."""
319 def setDefaults(self):
320 super().setDefaults()
321 self.vectorKeys = ["patch"]
324class SnSelector(SelectorBase):
325 """Selects points that have S/N > threshold in the given flux type."""
327 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux")
328 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0)
329 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6)
330 uncertaintySuffix = Field[str](
331 doc="Suffix to add to fluxType to specify uncertainty column", default="Err"
332 )
333 bands = ListField[str](
334 doc="The bands to apply the signal to noise cut in. Takes precedence if bands passed to call",
335 default=[],
336 )
338 def getInputSchema(self) -> KeyedDataSchema:
339 fluxCol = self.fluxType
340 fluxInd = fluxCol.find("lux") + len("lux")
341 errCol = f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix}" + f"{fluxCol}"[fluxInd:]
342 yield fluxCol, Vector
343 yield errCol, Vector
345 def __call__(self, data: KeyedData, **kwargs) -> Vector:
346 """Makes a mask of objects that have S/N greater than
347 self.threshold in self.fluxType
349 Parameters
350 ----------
351 data : `KeyedData`
352 The data to perform the selection on.
354 Returns
355 -------
356 result : `Vector`
357 A mask of the objects that satisfy the given
358 S/N cut.
359 """
360 mask: Vector | None = None
361 bands: tuple[str, ...]
362 match kwargs:
363 case {"band": band} if not self.bands and self.bands == []:
364 bands = (band,)
365 case {"bands": bands} if not self.bands and self.bands == []:
366 bands = bands
367 case _ if self.bands:
368 bands = tuple(self.bands)
369 case _:
370 bands = ("",)
371 bandStr = ",".join(bands)
372 for band in bands:
373 fluxCol = self.fluxType.format(**(kwargs | dict(band=band)))
374 fluxInd = fluxCol.find("lux") + len("lux")
375 errCol = (
376 f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix.format(**kwargs)}" + f"{fluxCol}"[fluxInd:]
377 )
378 vec = divide(cast(Vector, data[fluxCol]), cast(Vector, data[errCol]))
379 temp = (vec > self.threshold) & (vec < self.maxSN)
380 if mask is not None:
381 mask &= temp # type: ignore
382 else:
383 mask = temp
385 plotLabelStr = f"({bandStr}) > {self.threshold:.1f}"
386 if self.maxSN < 1e5:
387 plotLabelStr += f" & < {self.maxSN:.1f}"
389 if self.plotLabelKey == "" or self.plotLabelKey is None:
390 self._addValueToPlotInfo(plotLabelStr, plotLabelKey="S/N", **kwargs)
391 else:
392 self._addValueToPlotInfo(plotLabelStr, **kwargs)
394 # It should not be possible for mask to be a None now
395 return np.array(cast(Vector, mask))
398class SkyObjectSelector(FlagSelector):
399 """Selects sky objects in the given band(s)."""
401 bands = ListField[str](
402 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
403 default=[],
404 )
406 def getInputSchema(self) -> KeyedDataSchema:
407 yield from super().getInputSchema()
409 def __call__(self, data: KeyedData, **kwargs) -> Vector:
410 result: Vector | None = None
411 bands: tuple[str, ...]
412 match kwargs:
413 case {"band": band} if not self.bands and self.bands == []:
414 bands = (band,)
415 case {"bands": bands} if not self.bands and self.bands == []:
416 bands = bands
417 case _ if self.bands:
418 bands = tuple(self.bands)
419 case _:
420 bands = ("",)
421 for band in bands:
422 temp = super().__call__(data, **(kwargs | dict(band=band)))
423 if result is not None:
424 result &= temp # type: ignore
425 else:
426 result = temp
427 return cast(Vector, result)
429 def setDefaults(self):
430 super().setDefaults()
431 self.selectWhenFalse = [
432 "{band}_pixelFlags_edge",
433 "{band}_pixelFlags_nodata",
434 ]
435 self.selectWhenTrue = ["sky_object"]
438class SkySourceSelector(FlagSelector):
439 """Selects sky sources from sourceTables."""
441 def getInputSchema(self) -> KeyedDataSchema:
442 yield from super().getInputSchema()
444 def __call__(self, data: KeyedData, **kwargs) -> Vector:
445 result: Vector | None = None
446 temp = super().__call__(data, **(kwargs))
447 if result is not None:
448 result &= temp # type: ignore
449 else:
450 result = temp
451 return result
453 def setDefaults(self):
454 super().setDefaults()
455 self.selectWhenFalse = [
456 "pixelFlags_edge",
457 "pixelFlags_nodata",
458 ]
459 self.selectWhenTrue = ["sky_source"]
462class GoodDiaSourceSelector(FlagSelector):
463 """Selects good DIA sources from diaSourceTables."""
465 def getInputSchema(self) -> KeyedDataSchema:
466 yield from super().getInputSchema()
468 def __call__(self, data: KeyedData, **kwargs) -> Vector:
469 result: Vector | None = None
470 temp = super().__call__(data, **(kwargs))
471 if result is not None:
472 result &= temp # type: ignore
473 else:
474 result = temp
475 return result
477 def setDefaults(self):
478 super().setDefaults()
479 # These default flag names are correct for AP data products
480 self.selectWhenFalse = [
481 "pixelFlags_bad",
482 "pixelFlags_saturatedCenter",
483 "pixelFlags_interpolatedCenter",
484 "pixelFlags_edge",
485 "pixelFlags_nodata",
486 ]
489class ExtendednessSelector(SelectorBase):
490 """A selector that picks between extended and point sources."""
492 vectorKey = Field[str](
493 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness"
494 )
496 def getInputSchema(self) -> KeyedDataSchema:
497 return ((self.vectorKey, Vector),)
499 def __call__(self, data: KeyedData, **kwargs) -> Vector:
500 key = self.vectorKey.format(**kwargs)
501 return cast(Vector, data[key])
504class StarSelector(ExtendednessSelector):
505 """A selector that picks out stars based off of their
506 extendedness values.
507 """
509 extendedness_maximum = Field[float](
510 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float
511 )
513 def __call__(self, data: KeyedData, **kwargs) -> Vector:
514 extendedness = super().__call__(data, **kwargs)
515 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum)))
518class GalaxySelector(ExtendednessSelector):
519 """A selector that picks out galaxies based off of their
520 extendedness values.
521 """
523 extendedness_minimum = Field[float](
524 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5
525 )
527 def __call__(self, data: KeyedData, **kwargs) -> Vector:
528 extendedness = super().__call__(data, **kwargs)
529 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore
532class UnknownSelector(ExtendednessSelector):
533 """A selector that picks out unclassified objects based off of their
534 extendedness values.
535 """
537 def __call__(self, data: KeyedData, **kwargs) -> Vector:
538 extendedness = super().__call__(data, **kwargs)
539 return extendedness == 9
542class FiniteSelector(VectorAction):
543 """Return a mask of finite values for a vector key"""
545 vectorKey = Field[str](doc="Key to make a mask of finite values for.")
547 def getInputSchema(self) -> KeyedDataSchema:
548 return ((self.vectorKey, Vector),)
550 def __call__(self, data: KeyedData, **kwargs) -> Vector:
551 return cast(Vector, np.isfinite(data[self.vectorKey.format(**kwargs)]))
554class VectorSelector(VectorAction):
555 """Load a boolean vector from KeyedData and return it for use as a
556 selector.
557 """
559 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask")
561 def getInputSchema(self) -> KeyedDataSchema:
562 return ((self.vectorKey, Vector),)
564 def __call__(self, data: KeyedData, **kwargs) -> Vector:
565 return cast(Vector, data[self.vectorKey.format(**kwargs)])
568class ThresholdSelector(SelectorBase):
569 """Return a mask corresponding to an applied threshold."""
571 op = Field[str](doc="Operator name.")
572 threshold = Field[float](doc="Threshold to apply.")
573 vectorKey = Field[str](doc="Name of column")
575 def getInputSchema(self) -> KeyedDataSchema:
576 return ((self.vectorKey, Vector),)
578 def __call__(self, data: KeyedData, **kwargs) -> Vector:
579 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold)
580 return cast(Vector, mask)
583class BandSelector(VectorAction):
584 """Makes a mask for sources observed in a specified set of bands."""
586 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band")
587 bands = ListField[str](
588 doc="The bands to select. `None` indicates no band selection applied.",
589 default=[],
590 )
592 def getInputSchema(self) -> KeyedDataSchema:
593 return ((self.vectorKey, Vector),)
595 def __call__(self, data: KeyedData, **kwargs) -> Vector:
596 bands: tuple[str, ...] | None
597 match kwargs:
598 case {"band": band} if not self.bands and self.bands == []:
599 bands = (band,)
600 case {"bands": bands} if not self.bands and self.bands == []:
601 bands = bands
602 case _ if self.bands:
603 bands = tuple(self.bands)
604 case _:
605 bands = None
606 if bands:
607 mask = np.isin(data[self.vectorKey], bands).flatten()
608 else:
609 # No band selection is applied, i.e., select all rows
610 mask = np.full(len(data[self.vectorKey]), True) # type: ignore
611 return cast(Vector, mask)
614class ParentObjectSelector(FlagSelector):
615 """Select only parent objects that are not sky objects."""
617 def setDefaults(self):
618 # This selects all of the parents
619 # parentObjectId excludes subParents.
620 # This works because FlagSelector identifies False as 0.
621 self.selectWhenFalse = [
622 "sky_object",
623 "parentObjectId",
624 ]
627class ChildObjectSelector(RangeSelector):
628 """Select only children from deblended parents"""
630 vectorKey = Field[str](doc="Key to select from data", default="parentSourceId")
632 def getInputSchema(self) -> KeyedDataSchema:
633 yield self.vectorKey, Vector
635 def __call__(self, data: KeyedData, **kwargs) -> Vector:
636 """Return a mask of rows with values within the specified range.
638 Parameters
639 ----------
640 data : `KeyedData`
642 Returns
643 -------
644 result : `Vector`
645 A mask of the rows with values within the specified range.
646 """
647 values = cast(Vector, data[self.vectorKey])
648 mask = values > 0
650 return cast(Vector, mask)
653class MagSelector(SelectorBase):
654 """Selects points that have minMag < mag (AB) < maxMag.
656 The magnitude is based on the given fluxType.
657 """
659 fluxType = Field[str](doc="Flux type to calculate the magnitude in.", default="{band}_psfFlux")
660 minMag = Field[float](doc="Minimum mag to include in the sample.", default=-1e6)
661 maxMag = Field[float](doc="Maximum mag to include in the sample.", default=1e6)
662 fluxUnit = Field[str](doc="Astropy unit of flux vector", default="nJy")
663 returnMillimags = Field[bool](doc="Use millimags or not?", default=False)
664 bands = ListField[str](
665 doc="The band(s) to apply the magnitude cut in. Takes precedence if bands passed to call.",
666 default=[],
667 )
669 def getInputSchema(self) -> KeyedDataSchema:
670 fluxCol = self.fluxType
671 yield fluxCol, Vector
673 def __call__(self, data: KeyedData, **kwargs) -> Vector:
674 """Make a mask of that satisfies self.minMag < mag < self.maxMag.
676 The magnitude is based on the flux in self.fluxType.
678 Parameters
679 ----------
680 data : `KeyedData`
681 The data to perform the magnitude selection on.
683 Returns
684 -------
685 result : `Vector`
686 A mask of the objects that satisfy the given magnitude cut.
687 """
688 mask: Vector | None = None
689 bands: tuple[str, ...]
690 match kwargs:
691 case {"band": band} if not self.bands and self.bands == []:
692 bands = (band,)
693 case {"bands": bands} if not self.bands and self.bands == []:
694 bands = bands
695 case _ if self.bands:
696 bands = tuple(self.bands)
697 case _:
698 bands = ("",)
699 bandStr = ",".join(bands)
700 for band in bands:
701 fluxCol = self.fluxType.format(**(kwargs | dict(band=band)))
702 vec = fluxToMag(
703 cast(Vector, data[fluxCol]),
704 flux_unit=self.fluxUnit,
705 return_millimags=self.returnMillimags,
706 )
707 temp = (vec > self.minMag) & (vec < self.maxMag)
708 if mask is not None:
709 mask &= temp # type: ignore
710 else:
711 mask = temp
713 plotLabelStr = ""
714 if self.maxMag < 100:
715 plotLabelStr += f"({bandStr}) < {self.maxMag:.1f}"
716 if self.minMag > -100:
717 if bandStr in plotLabelStr:
718 plotLabelStr += f" & < {self.minMag:.1f}"
719 else:
720 plotLabelStr += f"({bandStr}) < {self.minMag:.1f}"
721 if self.plotLabelKey == "" or self.plotLabelKey is None:
722 self._addValueToPlotInfo(plotLabelStr, plotLabelKey="Mag", **kwargs)
723 else:
724 self._addValueToPlotInfo(plotLabelStr, **kwargs)
726 # It should not be possible for mask to be a None now
727 return np.array(cast(Vector, mask))
730class InjectedObjectSelector(SelectorBase):
731 """A selector for injected objects."""
733 vectorKey = Field[str](doc="Key to select from data", default="ref_injected_isPrimary")
735 def __call__(self, data: KeyedData, **kwargs) -> Vector:
736 key = self.vectorKey.format(**kwargs)
737 result = cast(Vector, data[key] == 1)
738 return result
740 def getInputSchema(self) -> KeyedDataSchema:
741 yield self.vectorKey, Vector
744class InjectedClassSelector(InjectedObjectSelector):
745 """A selector for injected objects of a given class."""
747 key_class = Field[str](
748 doc="Key for the field indicating the class of the object",
749 default="ref_source_type",
750 )
751 key_injection_flag = Field[str](
752 doc="Key for the field indicating that the object was not injected (per band)",
753 default="ref_{band}_injection_flag",
754 )
755 name_class = Field[str](
756 doc="Name of the class of objects",
757 )
758 value_compare = Field[str](
759 doc="Value of the type_key field for objects that are stars",
760 default="DeltaFunction",
761 )
762 value_is_equal = Field[bool](
763 doc="Whether the value must equal value_compare to be of this class",
764 default=True,
765 )
767 def __call__(self, data: KeyedData, **kwargs) -> Vector:
768 result = super().__call__(data, **kwargs)
769 if self.key_injection_flag:
770 result &= data[self.key_injection_flag.format(band=kwargs["band"])] == False # noqa: E712
771 values = data[self.key_class]
772 result &= (values == self.value_compare) if self.value_is_equal else (values != self.value_compare)
773 if self.plotLabelKey:
774 self._addValueToPlotInfo(f"injected {self.name_class}", **kwargs)
775 return result
777 def getInputSchema(self) -> KeyedDataSchema:
778 yield from super().getInputSchema()
779 yield self.key_class, Vector
780 if self.key_injection_flag:
781 yield self.key_injection_flag, Vector
784class InjectedGalaxySelector(InjectedClassSelector):
785 """A selector for injected galaxies."""
787 def setDefaults(self):
788 self.name_class = "galaxy"
789 # Assumes not star == galaxy - if there are injected AGN or other
790 # object classes, this will need to be updated
791 self.value_is_equal = False
794class InjectedStarSelector(InjectedClassSelector):
795 """A selector for injected stars."""
797 def setDefaults(self):
798 self.name_class = "star"
801class MatchedObjectSelector(RangeSelector):
802 """A selector that selects matched objects with finite distances."""
804 def setDefaults(self):
805 super().setDefaults()
806 self.minimum = 0
807 self.vectorKey = "match_distance"
810class ReferenceGalaxySelector(ThresholdSelector):
811 """A selector that selects galaxies from a catalog with a
812 boolean column identifying unresolved sources.
813 """
815 def __call__(self, data: KeyedData, **kwargs) -> Vector:
816 result = super().__call__(data=data, **kwargs)
817 if self.plotLabelKey:
818 self._addValueToPlotInfo("reference galaxies", **kwargs)
819 return result
821 def setDefaults(self):
822 super().setDefaults()
823 self.op = "eq"
824 self.threshold = 0
825 self.plotLabelKey = "Selection: Galaxies"
826 self.vectorKey = "refcat_is_pointsource"
829class ReferenceObjectSelector(RangeSelector):
830 """A selector that selects all objects from a catalog with a
831 boolean column identifying unresolved sources.
832 """
834 def __call__(self, data: KeyedData, **kwargs) -> Vector:
835 result = super().__call__(data=data, **kwargs)
836 if self.plotLabelKey:
837 self._addValueToPlotInfo("reference objects", **kwargs)
838 return result
840 def setDefaults(self):
841 super().setDefaults()
842 self.minimum = 0
843 self.vectorKey = "refcat_is_pointsource"
846class ReferenceStarSelector(ThresholdSelector):
847 """A selector that selects stars from a catalog with a
848 boolean column identifying unresolved sources.
849 """
851 def __call__(self, data: KeyedData, **kwargs) -> Vector:
852 result = super().__call__(data=data, **kwargs)
853 if self.plotLabelKey:
854 self._addValueToPlotInfo("reference stars", **kwargs)
855 return result
857 def setDefaults(self):
858 super().setDefaults()
859 self.op = "eq"
860 self.plotLabelKey = "Selection: Stars"
861 self.threshold = 1
862 self.vectorKey = "refcat_is_pointsource"