Coverage for python / lsst / analysis / tools / actions / vector / selectors.py: 29%
394 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 09:27 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 09:27 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = (
24 "SelectorBase",
25 "FlagSelector",
26 "CoaddPlotFlagSelector",
27 "RangeSelector",
28 "SetSelector",
29 "SnSelector",
30 "ExtendednessSelector",
31 "SkyObjectSelector",
32 "SkySourceSelector",
33 "GoodDiaSourceSelector",
34 "StarSelector",
35 "GalaxySelector",
36 "UnknownSelector",
37 "VectorSelector",
38 "FiniteSelector",
39 "VisitPlotFlagSelector",
40 "ThresholdSelector",
41 "BandSelector",
42 "MatchingFlagSelector",
43 "MagSelector",
44 "InjectedClassSelector",
45 "InjectedGalaxySelector",
46 "InjectedObjectSelector",
47 "InjectedStarSelector",
48 "MatchedObjectSelector",
49 "ReferenceGalaxySelector",
50 "ReferenceObjectSelector",
51 "ReferenceStarSelector",
52)
54import operator
55from typing import cast
57import numpy as np
59from lsst.pex.config import Field
60from lsst.pex.config.listField import ListField
62from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction
63from ...math import divide, fluxToMag
66class SelectorBase(VectorAction):
67 plotLabelKey = Field[str](
68 doc="Key to use when populating plot info, ignored if empty string", optional=True, default=""
69 )
71 def _addValueToPlotInfo(self, value, plotLabelKey=None, **kwargs):
72 if "plotInfo" in kwargs:
73 if plotLabelKey is not None:
74 kwargs["plotInfo"][plotLabelKey] = value
75 elif self.plotLabelKey:
76 kwargs["plotInfo"][self.plotLabelKey] = value
77 else:
78 raise RuntimeError(f"No plotLabelKey provided for value {value}, so can't add to plotInfo")
81class FlagSelector(SelectorBase):
82 """The base flag selector to use to select valid sources for QA."""
84 selectWhenFalse = ListField[str](
85 doc="Names of the flag columns to select on when False", optional=False, default=[]
86 )
88 selectWhenTrue = ListField[str](
89 doc="Names of the flag columns to select on when True", optional=False, default=[]
90 )
92 def getInputSchema(self) -> KeyedDataSchema:
93 yield from ((key, Vector) for key in self.selectWhenFalse)
94 yield from ((key, Vector) for key in self.selectWhenTrue)
96 def __call__(self, data: KeyedData, **kwargs) -> Vector:
97 """Select on the given flags
99 Parameters
100 ----------
101 data : `KeyedData`
103 Returns
104 -------
105 result : `Vector`
106 A mask of the objects that satisfy the given
107 flag cuts.
109 Notes
110 -----
111 Uses the columns in selectWhenFalse and
112 selectWhenTrue to decide which columns to
113 select on in each circumstance.
114 """
116 if not self.selectWhenFalse and not self.selectWhenTrue:
117 raise RuntimeError("No column keys specified")
118 results: Vector | None = None
120 for flag in self.selectWhenFalse: # type: ignore
121 temp = np.array(data[flag.format(**kwargs)] == 0)
122 if results is not None:
123 results &= temp # type: ignore
124 else:
125 results = temp
127 for flag in self.selectWhenTrue:
128 temp = np.array(data[flag.format(**kwargs)] == 1)
129 if results is not None:
130 results &= temp # type: ignore
131 else:
132 results = temp
133 # The test at the beginning assures this can never be None
134 return cast(Vector, results)
137class CoaddPlotFlagSelector(FlagSelector):
138 """This default setting makes it take the band from
139 the kwargs.
140 """
142 bands = ListField[str](
143 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
144 default=[],
145 )
147 def getInputSchema(self) -> KeyedDataSchema:
148 if self.bands:
149 for key, dtype in super().getInputSchema():
150 if "{band}" in key:
151 yield from ((key.format(band=band), dtype) for band in self.bands)
152 else:
153 yield key, dtype
154 else:
155 yield from super().getInputSchema()
157 def refMatchContext(self):
158 self.selectWhenFalse = [
159 "{band}_psfFlux_flag_target",
160 "{band}_pixelFlags_saturatedCenter_target",
161 "{band}_extendedness_flag_target",
162 "coord_flag_target",
163 ]
164 self.selectWhenTrue = ["detect_isPatchInner_target", "detect_isDeblendedSource_target"]
166 def __call__(self, data: KeyedData, **kwargs) -> Vector:
167 result: Vector | None = None
168 bands: tuple[str, ...]
169 match kwargs:
170 case {"band": band} if not self.bands and self.bands == []:
171 bands = (band,)
172 case {"bands": bands} if not self.bands and self.bands == []:
173 bands = bands
174 case _ if self.bands:
175 bands = tuple(self.bands)
176 case _:
177 bands = ("",)
178 for band in bands:
179 temp = super().__call__(data, **(kwargs | dict(band=band)))
180 if result is not None:
181 result &= temp # type: ignore
182 else:
183 result = temp
184 return cast(Vector, result)
186 def setDefaults(self):
187 self.selectWhenFalse = [
188 "{band}_psfFlux_flag",
189 "{band}_pixelFlags_saturatedCenter",
190 "{band}_extendedness_flag",
191 "coord_flag",
192 "sky_object",
193 ]
194 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"]
197class MatchingFlagSelector(CoaddPlotFlagSelector):
198 """The default flag selector to apply pre matching.
199 The sources are cut down to remove duplicates but
200 not on quality.
201 """
203 def setDefaults(self):
204 self.selectWhenFalse = []
205 self.selectWhenTrue = ["detect_isPrimary"]
208class VisitPlotFlagSelector(FlagSelector):
209 """Select on a set of flags appropriate for making visit-level plots
210 (i.e., using sourceTable_visit catalogs).
211 """
213 catalogSuffix = Field[str](doc="The suffix to apply to all the keys.", default="")
215 def getInputSchema(self) -> KeyedDataSchema:
216 yield from super().getInputSchema()
218 def refMatchContext(self):
219 self.selectWhenFalse = [
220 "psfFlux_flag_target",
221 "pixelFlags_saturatedCenter_target",
222 "extendedness_flag_target",
223 "centroid_flag_target",
224 ]
226 def __call__(self, data: KeyedData, **kwargs) -> Vector:
227 result: Vector | None = None
228 temp = super().__call__(data, **kwargs)
229 if result is not None:
230 result &= temp # type: ignore
231 else:
232 result = temp
234 return result
236 def setDefaults(self):
237 self.selectWhenFalse = [
238 "psfFlux_flag",
239 "pixelFlags_saturatedCenter",
240 "extendedness_flag",
241 "centroid_flag",
242 "sky_source",
243 ]
246class RangeSelector(SelectorBase):
247 """Selects rows within a range, inclusive of min/exclusive of max."""
249 vectorKey = Field[str](doc="Key to select from data")
250 maximum = Field[float](doc="The maximum value (exclusive)", default=np.inf)
251 minimum = Field[float](doc="The minimum value (inclusive)", default=np.nextafter(-np.inf, 0.0))
253 def getInputSchema(self) -> KeyedDataSchema:
254 yield self.vectorKey, Vector
256 def __call__(self, data: KeyedData, **kwargs) -> Vector:
257 """Return a mask of rows with values within the specified range.
259 Parameters
260 ----------
261 data : `KeyedData`
263 Returns
264 -------
265 result : `Vector`
266 A mask of the rows with values within the specified range.
267 """
268 values = cast(Vector, data[self.vectorKey])
269 mask = (values >= self.minimum) & (values < self.maximum)
271 return cast(Vector, mask)
274class SetSelector(SelectorBase):
275 """Selects rows with any number of column values within a given set.
277 For example, given a set of patches (1, 2, 3), and a set of columns
278 (index_1, index_2), return all rows with either index_1 or index_2
279 in the set (1, 2, 3).
281 Notes
282 -----
283 The values are given as floats for flexibility. Integers above
284 the floating point limit (2^53 + 1 = 9,007,199,254,740,993 for 64 bits)
285 will not compare exactly with their float representations.
286 """
288 vectorKeys = ListField[str](
289 doc="Keys to select from data",
290 default=[],
291 listCheck=lambda x: (len(x) > 0) & (len(x) == len(set(x))),
292 )
293 values = ListField[float](
294 doc="The set of acceptable values",
295 default=[],
296 listCheck=lambda x: (len(x) > 0) & (len(x) == len(set(x))),
297 )
299 def getInputSchema(self) -> KeyedDataSchema:
300 yield from ((key, Vector) for key in self.vectorKeys)
302 def __call__(self, data: KeyedData, **kwargs) -> Vector:
303 """Return a mask of rows with values in the specified set.
305 Parameters
306 ----------
307 data : `KeyedData`
309 Returns
310 -------
311 result : `Vector`
312 A mask of the rows with values in the specified set.
313 """
314 mask = np.zeros_like(data[self.vectorKeys[0]], dtype=bool)
315 for key in self.vectorKeys:
316 values = cast(Vector, data[key])
317 for compare in self.values:
318 mask |= values == compare
320 return cast(Vector, mask)
323class PatchSelector(SetSelector):
324 """Select rows within a set of patches."""
326 def setDefaults(self):
327 super().setDefaults()
328 self.vectorKeys = ["patch"]
331class SnSelector(SelectorBase):
332 """Selects points that have S/N > threshold in the given flux type."""
334 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux")
335 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0)
336 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6)
337 uncertaintySuffix = Field[str](
338 doc="Suffix to add to fluxType to specify uncertainty column", default="Err"
339 )
340 bands = ListField[str](
341 doc="The bands to apply the signal to noise cut in. Takes precedence if bands passed to call",
342 default=[],
343 )
345 def getInputSchema(self) -> KeyedDataSchema:
346 fluxCol = self.fluxType
347 fluxInd = fluxCol.find("lux") + len("lux")
348 errCol = f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix}" + f"{fluxCol}"[fluxInd:]
349 yield fluxCol, Vector
350 yield errCol, Vector
352 def __call__(self, data: KeyedData, **kwargs) -> Vector:
353 """Makes a mask of objects that have S/N greater than
354 self.threshold in self.fluxType
356 Parameters
357 ----------
358 data : `KeyedData`
359 The data to perform the selection on.
361 Returns
362 -------
363 result : `Vector`
364 A mask of the objects that satisfy the given
365 S/N cut.
366 """
367 mask: Vector | None = None
368 bands: tuple[str, ...]
369 match kwargs:
370 case {"band": band} if not self.bands and self.bands == []:
371 bands = (band,)
372 case {"bands": bands} if not self.bands and self.bands == []:
373 bands = bands
374 case _ if self.bands:
375 bands = tuple(self.bands)
376 case _:
377 bands = ("",)
378 bandStr = ",".join(bands)
379 for band in bands:
380 fluxCol = self.fluxType.format(**(kwargs | dict(band=band)))
381 fluxInd = fluxCol.find("lux") + len("lux")
382 errCol = (
383 f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix.format(**kwargs)}" + f"{fluxCol}"[fluxInd:]
384 )
385 vec = divide(cast(Vector, data[fluxCol]), cast(Vector, data[errCol]))
386 temp = (vec > self.threshold) & (vec < self.maxSN)
387 if mask is not None:
388 mask &= temp # type: ignore
389 else:
390 mask = temp
392 plotLabelStr = f"({bandStr}) > {self.threshold:.1f}"
393 if self.maxSN < 1e5:
394 plotLabelStr += f" & < {self.maxSN:.1f}"
396 if self.plotLabelKey == "" or self.plotLabelKey is None:
397 self._addValueToPlotInfo(plotLabelStr, plotLabelKey="S/N", **kwargs)
398 else:
399 self._addValueToPlotInfo(plotLabelStr, **kwargs)
401 # It should not be possible for mask to be a None now
402 return np.array(cast(Vector, mask))
405class SkyObjectSelector(FlagSelector):
406 """Selects sky objects in the given band(s)."""
408 bands = ListField[str](
409 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
410 default=[],
411 )
413 def getInputSchema(self) -> KeyedDataSchema:
414 yield from super().getInputSchema()
416 def __call__(self, data: KeyedData, **kwargs) -> Vector:
417 result: Vector | None = None
418 bands: tuple[str, ...]
419 match kwargs:
420 case {"band": band} if not self.bands and self.bands == []:
421 bands = (band,)
422 case {"bands": bands} if not self.bands and self.bands == []:
423 bands = bands
424 case _ if self.bands:
425 bands = tuple(self.bands)
426 case _:
427 bands = ("",)
428 for band in bands:
429 temp = super().__call__(data, **(kwargs | dict(band=band)))
430 if result is not None:
431 result &= temp # type: ignore
432 else:
433 result = temp
434 return cast(Vector, result)
436 def setDefaults(self):
437 super().setDefaults()
438 self.selectWhenFalse = [
439 "{band}_pixelFlags_edge",
440 "{band}_pixelFlags_nodata",
441 ]
442 self.selectWhenTrue = ["sky_object"]
445class SkySourceSelector(FlagSelector):
446 """Selects sky sources from sourceTables."""
448 def getInputSchema(self) -> KeyedDataSchema:
449 yield from super().getInputSchema()
451 def __call__(self, data: KeyedData, **kwargs) -> Vector:
452 result: Vector | None = None
453 temp = super().__call__(data, **(kwargs))
454 if result is not None:
455 result &= temp # type: ignore
456 else:
457 result = temp
458 return result
460 def setDefaults(self):
461 super().setDefaults()
462 self.selectWhenFalse = [
463 "pixelFlags_edge",
464 "pixelFlags_nodata",
465 ]
466 self.selectWhenTrue = ["sky_source"]
469class GoodDiaSourceSelector(FlagSelector):
470 """Selects good DIA sources from diaSourceTables."""
472 def getInputSchema(self) -> KeyedDataSchema:
473 yield from super().getInputSchema()
475 def __call__(self, data: KeyedData, **kwargs) -> Vector:
476 result: Vector | None = None
477 temp = super().__call__(data, **(kwargs))
478 if result is not None:
479 result &= temp # type: ignore
480 else:
481 result = temp
482 return result
484 def setDefaults(self):
485 super().setDefaults()
486 # These default flag names are correct for AP data products
487 self.selectWhenFalse = [
488 "pixelFlags_bad",
489 "pixelFlags_saturatedCenter",
490 "pixelFlags_interpolatedCenter",
491 "pixelFlags_edge",
492 "pixelFlags_nodata",
493 ]
496class ExtendednessSelector(SelectorBase):
497 """A selector that picks between extended and point sources."""
499 vectorKey = Field[str](
500 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness"
501 )
503 def getInputSchema(self) -> KeyedDataSchema:
504 return ((self.vectorKey, Vector),)
506 def __call__(self, data: KeyedData, **kwargs) -> Vector:
507 key = self.vectorKey.format(**kwargs)
508 return cast(Vector, data[key])
511class StarSelector(ExtendednessSelector):
512 """A selector that picks out stars based off of their
513 extendedness values.
514 """
516 extendedness_maximum = Field[float](
517 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float
518 )
520 def __call__(self, data: KeyedData, **kwargs) -> Vector:
521 extendedness = super().__call__(data, **kwargs)
522 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum)))
525class GalaxySelector(ExtendednessSelector):
526 """A selector that picks out galaxies based off of their
527 extendedness values.
528 """
530 extendedness_minimum = Field[float](
531 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5
532 )
534 def __call__(self, data: KeyedData, **kwargs) -> Vector:
535 extendedness = super().__call__(data, **kwargs)
536 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore
539class UnknownSelector(ExtendednessSelector):
540 """A selector that picks out unclassified objects based off of their
541 extendedness values.
542 """
544 def __call__(self, data: KeyedData, **kwargs) -> Vector:
545 extendedness = super().__call__(data, **kwargs)
546 return extendedness == 9
549class FiniteSelector(VectorAction):
550 """Return a mask of finite values for a vector key"""
552 vectorKey = Field[str](doc="Key to make a mask of finite values for.")
554 def getInputSchema(self) -> KeyedDataSchema:
555 return ((self.vectorKey, Vector),)
557 def __call__(self, data: KeyedData, **kwargs) -> Vector:
558 return cast(Vector, np.isfinite(data[self.vectorKey.format(**kwargs)]))
561class VectorSelector(VectorAction):
562 """Load a boolean vector from KeyedData and return it for use as a
563 selector.
564 """
566 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask")
568 def getInputSchema(self) -> KeyedDataSchema:
569 return ((self.vectorKey, Vector),)
571 def __call__(self, data: KeyedData, **kwargs) -> Vector:
572 return cast(Vector, data[self.vectorKey.format(**kwargs)])
575class ThresholdSelector(SelectorBase):
576 """Return a mask corresponding to an applied threshold."""
578 op = Field[str](doc="Operator name.")
579 threshold = Field[float](doc="Threshold to apply.")
580 vectorKey = Field[str](doc="Name of column")
582 def getInputSchema(self) -> KeyedDataSchema:
583 return ((self.vectorKey, Vector),)
585 def __call__(self, data: KeyedData, **kwargs) -> Vector:
586 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold)
587 return cast(Vector, mask)
590class BandSelector(VectorAction):
591 """Makes a mask for sources observed in a specified set of bands."""
593 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band")
594 bands = ListField[str](
595 doc="The bands to select. `None` indicates no band selection applied.",
596 default=[],
597 )
599 def getInputSchema(self) -> KeyedDataSchema:
600 return ((self.vectorKey, Vector),)
602 def __call__(self, data: KeyedData, **kwargs) -> Vector:
603 bands: tuple[str, ...] | None
604 match kwargs:
605 case {"band": band} if not self.bands and self.bands == []:
606 bands = (band,)
607 case {"bands": bands} if not self.bands and self.bands == []:
608 bands = bands
609 case _ if self.bands:
610 bands = tuple(self.bands)
611 case _:
612 bands = None
613 if bands:
614 mask = np.isin(data[self.vectorKey], bands).flatten()
615 else:
616 # No band selection is applied, i.e., select all rows
617 mask = np.full(len(data[self.vectorKey]), True) # type: ignore
618 return cast(Vector, mask)
621class ParentObjectSelector(FlagSelector):
622 """Select only parent objects that are not sky objects."""
624 def setDefaults(self):
625 # This selects all of the parents
626 # parentObjectId excludes subParents.
627 # This works because FlagSelector identifies False as 0.
628 self.selectWhenFalse = [
629 "sky_object",
630 "parentObjectId",
631 ]
634class ChildObjectSelector(RangeSelector):
635 """Select only children from deblended parents"""
637 vectorKey = Field[str](doc="Key to select from data", default="parentSourceId")
639 def getInputSchema(self) -> KeyedDataSchema:
640 yield self.vectorKey, Vector
642 def __call__(self, data: KeyedData, **kwargs) -> Vector:
643 """Return a mask of rows with values within the specified range.
645 Parameters
646 ----------
647 data : `KeyedData`
649 Returns
650 -------
651 result : `Vector`
652 A mask of the rows with values within the specified range.
653 """
654 values = cast(Vector, data[self.vectorKey])
655 mask = values > 0
657 return cast(Vector, mask)
660class MagSelector(SelectorBase):
661 """Selects points that have minMag < mag (AB) < maxMag.
663 The magnitude is based on the given fluxType.
664 """
666 fluxType = Field[str](doc="Flux type to calculate the magnitude in.", default="{band}_psfFlux")
667 minMag = Field[float](doc="Minimum mag to include in the sample.", default=-1e6)
668 maxMag = Field[float](doc="Maximum mag to include in the sample.", default=1e6)
669 fluxUnit = Field[str](doc="Astropy unit of flux vector", default="nJy")
670 returnMillimags = Field[bool](doc="Use millimags or not?", default=False)
671 bands = ListField[str](
672 doc="The band(s) to apply the magnitude cut in. Takes precedence if bands passed to call.",
673 default=[],
674 )
676 def getInputSchema(self) -> KeyedDataSchema:
677 fluxCol = self.fluxType
678 yield fluxCol, Vector
680 def __call__(self, data: KeyedData, **kwargs) -> Vector:
681 """Make a mask of that satisfies self.minMag < mag < self.maxMag.
683 The magnitude is based on the flux in self.fluxType.
685 Parameters
686 ----------
687 data : `KeyedData`
688 The data to perform the magnitude selection on.
690 Returns
691 -------
692 result : `Vector`
693 A mask of the objects that satisfy the given magnitude cut.
694 """
695 mask: Vector | None = None
696 bands: tuple[str, ...]
697 match kwargs:
698 case {"band": band} if not self.bands and self.bands == []:
699 bands = (band,)
700 case {"bands": bands} if not self.bands and self.bands == []:
701 bands = bands
702 case _ if self.bands:
703 bands = tuple(self.bands)
704 case _:
705 bands = ("",)
706 bandStr = ",".join(bands)
707 for band in bands:
708 fluxCol = self.fluxType.format(**(kwargs | dict(band=band)))
709 vec = fluxToMag(
710 cast(Vector, data[fluxCol]),
711 flux_unit=self.fluxUnit,
712 return_millimags=self.returnMillimags,
713 )
714 temp = (vec > self.minMag) & (vec < self.maxMag)
715 if mask is not None:
716 mask &= temp # type: ignore
717 else:
718 mask = temp
720 plotLabelStr = ""
721 if self.maxMag < 100:
722 plotLabelStr += f"({bandStr}) < {self.maxMag:.1f}"
723 if self.minMag > -100:
724 if bandStr in plotLabelStr:
725 plotLabelStr += f" & < {self.minMag:.1f}"
726 else:
727 plotLabelStr += f"({bandStr}) < {self.minMag:.1f}"
728 if self.plotLabelKey == "" or self.plotLabelKey is None:
729 self._addValueToPlotInfo(plotLabelStr, plotLabelKey="Mag", **kwargs)
730 else:
731 self._addValueToPlotInfo(plotLabelStr, **kwargs)
733 # It should not be possible for mask to be a None now
734 return np.array(cast(Vector, mask))
737class InjectedObjectSelector(SelectorBase):
738 """A selector for injected objects."""
740 vectorKey = Field[str](doc="Key to select from data", default="ref_injected_isPrimary")
742 def __call__(self, data: KeyedData, **kwargs) -> Vector:
743 key = self.vectorKey.format(**kwargs)
744 result = cast(Vector, data[key] == 1)
745 return result
747 def getInputSchema(self) -> KeyedDataSchema:
748 yield self.vectorKey, Vector
751class InjectedClassSelector(InjectedObjectSelector):
752 """A selector for injected objects of a given class."""
754 key_class = Field[str](
755 doc="Key for the field indicating the class of the object",
756 default="ref_source_type",
757 )
758 key_injection_flag = Field[str](
759 doc="Key for the field indicating that the object was not injected (per band)",
760 default="ref_{band}_injection_flag",
761 )
762 name_class = Field[str](
763 doc="Name of the class of objects",
764 )
765 value_compare = Field[str](
766 doc="Value of the type_key field for objects that are stars",
767 default="DeltaFunction",
768 )
769 value_is_equal = Field[bool](
770 doc="Whether the value must equal value_compare to be of this class",
771 default=True,
772 )
774 def __call__(self, data: KeyedData, **kwargs) -> Vector:
775 result = super().__call__(data, **kwargs)
776 if self.key_injection_flag:
777 result &= data[self.key_injection_flag.format(band=kwargs["band"])] == False # noqa: E712
778 values = data[self.key_class]
779 result &= (values == self.value_compare) if self.value_is_equal else (values != self.value_compare)
780 if self.plotLabelKey:
781 self._addValueToPlotInfo(f"injected {self.name_class}", **kwargs)
782 return result
784 def getInputSchema(self) -> KeyedDataSchema:
785 yield from super().getInputSchema()
786 yield self.key_class, Vector
787 if self.key_injection_flag:
788 yield self.key_injection_flag, Vector
791class InjectedGalaxySelector(InjectedClassSelector):
792 """A selector for injected galaxies."""
794 def setDefaults(self):
795 self.name_class = "galaxy"
796 # Assumes not star == galaxy - if there are injected AGN or other
797 # object classes, this will need to be updated
798 self.value_is_equal = False
801class InjectedStarSelector(InjectedClassSelector):
802 """A selector for injected stars."""
804 def setDefaults(self):
805 self.name_class = "star"
808class MatchedObjectSelector(RangeSelector):
809 """A selector that selects matched objects with finite distances."""
811 def setDefaults(self):
812 super().setDefaults()
813 self.minimum = 0
814 self.vectorKey = "match_distance"
817class ReferenceGalaxySelector(ThresholdSelector):
818 """A selector that selects galaxies from a catalog with a
819 boolean column identifying unresolved sources.
820 """
822 def __call__(self, data: KeyedData, **kwargs) -> Vector:
823 result = super().__call__(data=data, **kwargs)
824 if self.plotLabelKey:
825 self._addValueToPlotInfo("reference galaxies", **kwargs)
826 return result
828 def setDefaults(self):
829 super().setDefaults()
830 self.op = "eq"
831 self.threshold = 0
832 self.plotLabelKey = "Selection: Galaxies"
833 self.vectorKey = "refcat_is_pointsource"
836class ReferenceObjectSelector(RangeSelector):
837 """A selector that selects all objects from a catalog with a
838 boolean column identifying unresolved sources.
839 """
841 def __call__(self, data: KeyedData, **kwargs) -> Vector:
842 result = super().__call__(data=data, **kwargs)
843 if self.plotLabelKey:
844 self._addValueToPlotInfo("reference objects", **kwargs)
845 return result
847 def setDefaults(self):
848 super().setDefaults()
849 self.minimum = 0
850 self.vectorKey = "refcat_is_pointsource"
853class ReferenceStarSelector(ThresholdSelector):
854 """A selector that selects stars from a catalog with a
855 boolean column identifying unresolved sources.
856 """
858 def __call__(self, data: KeyedData, **kwargs) -> Vector:
859 result = super().__call__(data=data, **kwargs)
860 if self.plotLabelKey:
861 self._addValueToPlotInfo("reference stars", **kwargs)
862 return result
864 def setDefaults(self):
865 super().setDefaults()
866 self.op = "eq"
867 self.plotLabelKey = "Selection: Stars"
868 self.threshold = 1
869 self.vectorKey = "refcat_is_pointsource"