Coverage for python/lsst/analysis/tools/actions/vector/selectors.py: 30%
228 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 13:17 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 13:17 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = (
24 "FlagSelector",
25 "CoaddPlotFlagSelector",
26 "RangeSelector",
27 "SnSelector",
28 "ExtendednessSelector",
29 "SkyObjectSelector",
30 "SkySourceSelector",
31 "GoodDiaSourceSelector",
32 "StarSelector",
33 "GalaxySelector",
34 "UnknownSelector",
35 "VectorSelector",
36 "VisitPlotFlagSelector",
37 "ThresholdSelector",
38 "BandSelector",
39)
41import operator
42from typing import Optional, cast
44import numpy as np
45from lsst.pex.config import Field
46from lsst.pex.config.listField import ListField
48from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction
51class SelectorBase(VectorAction):
52 plotLabelKey = Field[str](
53 doc="Key to use when populating plot info, ignored if empty string", optional=True, default=""
54 )
56 def _addValueToPlotInfo(self, value, **kwargs):
57 if "plotInfo" in kwargs and self.plotLabelKey:
58 kwargs["plotInfo"][self.plotLabelKey] = value
61class FlagSelector(VectorAction):
62 """The base flag selector to use to select valid sources for QA."""
64 selectWhenFalse = ListField[str](
65 doc="Names of the flag columns to select on when False", optional=False, default=[]
66 )
68 selectWhenTrue = ListField[str](
69 doc="Names of the flag columns to select on when True", optional=False, default=[]
70 )
72 def getInputSchema(self) -> KeyedDataSchema:
73 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue)
74 return ((col, Vector) for col in allCols)
76 def __call__(self, data: KeyedData, **kwargs) -> Vector:
77 """Select on the given flags
79 Parameters
80 ----------
81 data : `KeyedData`
83 Returns
84 -------
85 result : `Vector`
86 A mask of the objects that satisfy the given
87 flag cuts.
89 Notes
90 -----
91 Uses the columns in selectWhenFalse and
92 selectWhenTrue to decide which columns to
93 select on in each circumstance.
94 """
96 if not self.selectWhenFalse and not self.selectWhenTrue:
97 raise RuntimeError("No column keys specified")
98 results: Optional[Vector] = None
100 for flag in self.selectWhenFalse: # type: ignore
101 temp = np.array(data[flag.format(**kwargs)] == 0)
102 if results is not None:
103 results &= temp # type: ignore
104 else:
105 results = temp
107 for flag in self.selectWhenTrue:
108 temp = np.array(data[flag.format(**kwargs)] == 1)
109 if results is not None:
110 results &= temp # type: ignore
111 else:
112 results = temp
113 # The test at the beginning assures this can never be None
114 return cast(Vector, results)
117class CoaddPlotFlagSelector(FlagSelector):
118 """This default setting makes it take the band from
119 the kwargs.
120 """
122 bands = ListField[str](
123 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
124 default=[],
125 )
127 def getInputSchema(self) -> KeyedDataSchema:
128 yield from super().getInputSchema()
130 def refMatchContext(self):
131 self.selectWhenFalse = [
132 "{band}_psfFlux_flag_target",
133 "{band}_pixelFlags_saturatedCenter_target",
134 "{band}_extendedness_flag_target",
135 "xy_flag_target",
136 ]
137 self.selectWhenTrue = ["detect_isPatchInner_target", "detect_isDeblendedSource_target"]
139 def __call__(self, data: KeyedData, **kwargs) -> Vector:
140 result: Optional[Vector] = None
141 bands: tuple[str, ...]
142 match kwargs:
143 case {"band": band} if not self.bands and self.bands == []:
144 bands = (band,)
145 case {"bands": bands} if not self.bands and self.bands == []:
146 bands = bands
147 case _ if self.bands:
148 bands = tuple(self.bands)
149 case _:
150 bands = ("",)
151 for band in bands:
152 temp = super().__call__(data, **(kwargs | dict(band=band)))
153 if result is not None:
154 result &= temp # type: ignore
155 else:
156 result = temp
157 return cast(Vector, result)
159 def setDefaults(self):
160 self.selectWhenFalse = [
161 "{band}_psfFlux_flag",
162 "{band}_pixelFlags_saturatedCenter",
163 "{band}_extendedness_flag",
164 "xy_flag",
165 ]
166 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"]
169class VisitPlotFlagSelector(FlagSelector):
170 """Select on a set of flags appropriate for making visit-level plots
171 (i.e., using sourceTable_visit catalogs).
172 """
174 catalogSuffix = Field[str](doc="The suffix to apply to all the keys.", default="")
176 def getInputSchema(self) -> KeyedDataSchema:
177 yield from super().getInputSchema()
179 def refMatchContext(self):
180 self.selectWhenFalse = [
181 "psfFlux_flag_target",
182 "pixelFlags_saturatedCenter_target",
183 "extendedness_flag_target",
184 "centroid_flag_target",
185 ]
187 def __call__(self, data: KeyedData, **kwargs) -> Vector:
188 result: Optional[Vector] = None
189 temp = super().__call__(data, **kwargs)
190 if result is not None:
191 result &= temp # type: ignore
192 else:
193 result = temp
195 return result
197 def setDefaults(self):
198 self.selectWhenFalse = [
199 "psfFlux_flag",
200 "pixelFlags_saturatedCenter",
201 "extendedness_flag",
202 "centroid_flag",
203 ]
206class RangeSelector(VectorAction):
207 """Selects rows within a range, inclusive of min/exclusive of max."""
209 vectorKey = Field[str](doc="Key to select from data")
210 maximum = Field[float](doc="The maximum value", default=np.Inf)
211 minimum = Field[float](doc="The minimum value", default=np.nextafter(-np.Inf, 0.0))
213 def getInputSchema(self) -> KeyedDataSchema:
214 yield self.vectorKey, Vector
216 def __call__(self, data: KeyedData, **kwargs) -> Vector:
217 """Return a mask of rows with values within the specified range.
219 Parameters
220 ----------
221 data : `KeyedData`
223 Returns
224 -------
225 result : `Vector`
226 A mask of the rows with values within the specified range.
227 """
228 values = cast(Vector, data[self.vectorKey])
229 mask = (values >= self.minimum) & (values < self.maximum)
231 return cast(Vector, mask)
234class SnSelector(SelectorBase):
235 """Selects points that have S/N > threshold in the given flux type."""
237 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux")
238 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0)
239 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6)
240 uncertaintySuffix = Field[str](
241 doc="Suffix to add to fluxType to specify uncertainty column", default="Err"
242 )
243 bands = ListField[str](
244 doc="The bands to apply the signal to noise cut in." "Takes precedence if bands passed to call",
245 default=[],
246 )
248 def getInputSchema(self) -> KeyedDataSchema:
249 fluxCol = self.fluxType
250 fluxInd = fluxCol.find("lux") + len("lux")
251 errCol = f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix}" + f"{fluxCol}"[fluxInd:]
252 yield fluxCol, Vector
253 yield errCol, Vector
255 def __call__(self, data: KeyedData, **kwargs) -> Vector:
256 """Makes a mask of objects that have S/N greater than
257 self.threshold in self.fluxType
259 Parameters
260 ----------
261 data : `KeyedData`
262 The data to perform the selection on.
264 Returns
265 -------
266 result : `Vector`
267 A mask of the objects that satisfy the given
268 S/N cut.
269 """
271 self._addValueToPlotInfo(self.threshold, **kwargs)
272 mask: Optional[Vector] = None
273 bands: tuple[str, ...]
274 match kwargs:
275 case {"band": band} if not self.bands and self.bands == []:
276 bands = (band,)
277 case {"bands": bands} if not self.bands and self.bands == []:
278 bands = bands
279 case _ if self.bands:
280 bands = tuple(self.bands)
281 case _:
282 bands = ("",)
283 for band in bands:
284 fluxCol = self.fluxType.format(**(kwargs | dict(band=band)))
285 fluxInd = fluxCol.find("lux") + len("lux")
286 errCol = (
287 f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix.format(**kwargs)}" + f"{fluxCol}"[fluxInd:]
288 )
289 vec = cast(Vector, data[fluxCol]) / data[errCol]
290 temp = (vec > self.threshold) & (vec < self.maxSN)
291 if mask is not None:
292 mask &= temp # type: ignore
293 else:
294 mask = temp
296 # It should not be possible for mask to be a None now
297 return np.array(cast(Vector, mask))
300class SkyObjectSelector(FlagSelector):
301 """Selects sky objects in the given band(s)."""
303 bands = ListField[str](
304 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs",
305 default=["i"],
306 )
308 def getInputSchema(self) -> KeyedDataSchema:
309 yield from super().getInputSchema()
311 def __call__(self, data: KeyedData, **kwargs) -> Vector:
312 result: Optional[Vector] = None
313 bands: tuple[str, ...]
314 match kwargs:
315 case {"band": band} if not self.bands and self.bands == []:
316 bands = (band,)
317 case {"bands": bands} if not self.bands and self.bands == []:
318 bands = bands
319 case _ if self.bands:
320 bands = tuple(self.bands)
321 case _:
322 bands = ("",)
323 for band in bands:
324 temp = super().__call__(data, **(kwargs | dict(band=band)))
325 if result is not None:
326 result &= temp # type: ignore
327 else:
328 result = temp
329 return cast(Vector, result)
331 def setDefaults(self):
332 self.selectWhenFalse = [
333 "{band}_pixelFlags_edge",
334 ]
335 self.selectWhenTrue = ["sky_object"]
338class SkySourceSelector(FlagSelector):
339 """Selects sky sources from sourceTables."""
341 def getInputSchema(self) -> KeyedDataSchema:
342 yield from super().getInputSchema()
344 def __call__(self, data: KeyedData, **kwargs) -> Vector:
345 result: Optional[Vector] = None
346 temp = super().__call__(data, **(kwargs))
347 if result is not None:
348 result &= temp # type: ignore
349 else:
350 result = temp
351 return result
353 def setDefaults(self):
354 self.selectWhenFalse = [
355 "pixelFlags_edge",
356 ]
357 self.selectWhenTrue = ["sky_source"]
360class GoodDiaSourceSelector(FlagSelector):
361 """Selects good DIA sources from diaSourceTables."""
363 def getInputSchema(self) -> KeyedDataSchema:
364 yield from super().getInputSchema()
366 def __call__(self, data: KeyedData, **kwargs) -> Vector:
367 result: Optional[Vector] = None
368 temp = super().__call__(data, **(kwargs))
369 if result is not None:
370 result &= temp # type: ignore
371 else:
372 result = temp
373 return result
375 def setDefaults(self):
376 self.selectWhenFalse = [
377 "base_PixelFlags_flag_bad",
378 "base_PixelFlags_flag_suspect",
379 "base_PixelFlags_flag_saturatedCenter",
380 "base_PixelFlags_flag_interpolated",
381 "base_PixelFlags_flag_interpolatedCenter",
382 "base_PixelFlags_flag_edge",
383 ]
386class ExtendednessSelector(VectorAction):
387 """A selector that picks between extended and point sources."""
389 vectorKey = Field[str](
390 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness"
391 )
393 def getInputSchema(self) -> KeyedDataSchema:
394 return ((self.vectorKey, Vector),)
396 def __call__(self, data: KeyedData, **kwargs) -> Vector:
397 key = self.vectorKey.format(**kwargs)
398 return cast(Vector, data[key])
401class StarSelector(ExtendednessSelector):
402 """A selector that picks out stars based off of their
403 extendedness values.
404 """
406 extendedness_maximum = Field[float](
407 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float
408 )
410 def __call__(self, data: KeyedData, **kwargs) -> Vector:
411 extendedness = super().__call__(data, **kwargs)
412 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum)))
415class GalaxySelector(ExtendednessSelector):
416 """A selector that picks out galaxies based off of their
417 extendedness values.
418 """
420 extendedness_minimum = Field[float](
421 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5
422 )
424 def __call__(self, data: KeyedData, **kwargs) -> Vector:
425 extendedness = super().__call__(data, **kwargs)
426 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore
429class UnknownSelector(ExtendednessSelector):
430 """A selector that picks out unclassified objects based off of their
431 extendedness values.
432 """
434 def __call__(self, data: KeyedData, **kwargs) -> Vector:
435 extendedness = super().__call__(data, **kwargs)
436 return extendedness == 9
439class VectorSelector(VectorAction):
440 """Load a boolean vector from KeyedData and return it for use as a
441 selector.
442 """
444 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask")
446 def getInputSchema(self) -> KeyedDataSchema:
447 return ((self.vectorKey, Vector),)
449 def __call__(self, data: KeyedData, **kwargs) -> Vector:
450 return cast(Vector, data[self.vectorKey.format(**kwargs)])
453class ThresholdSelector(VectorAction):
454 """Return a mask corresponding to an applied threshold."""
456 op = Field[str](doc="Operator name.")
457 threshold = Field[float](doc="Threshold to apply.")
458 vectorKey = Field[str](doc="Name of column")
460 def getInputSchema(self) -> KeyedDataSchema:
461 return ((self.vectorKey, Vector),)
463 def __call__(self, data: KeyedData, **kwargs) -> Vector:
464 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold)
465 return cast(Vector, mask)
468class BandSelector(VectorAction):
469 """Makes a mask for sources observed in a specified set of bands."""
471 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band")
472 bands = ListField[str](
473 doc="The bands to select. `None` indicates no band selection applied.",
474 default=[],
475 )
477 def getInputSchema(self) -> KeyedDataSchema:
478 return ((self.vectorKey, Vector),)
480 def __call__(self, data: KeyedData, **kwargs) -> Vector:
481 bands: Optional[tuple[str, ...]]
482 match kwargs:
483 case {"band": band} if not self.bands and self.bands == []:
484 bands = (band,)
485 case {"bands": bands} if not self.bands and self.bands == []:
486 bands = bands
487 case _ if self.bands:
488 bands = tuple(self.bands)
489 case _:
490 bands = None
491 if bands:
492 mask = np.in1d(data[self.vectorKey], bands)
493 else:
494 # No band selection is applied, i.e., select all rows
495 mask = np.full(len(data[self.vectorKey]), True) # type: ignore
496 return cast(Vector, mask)
499class ParentObjectSelector(FlagSelector):
500 """Select only parent objects that are not sky objects."""
502 def setDefaults(self):
503 # This selects all of the parents
504 self.selectWhenFalse = [
505 "detect_isDeblendedModelSource",
506 "sky_object",
507 ]
508 self.selectWhenTrue = ["detect_isPatchInner"]