Coverage for python / lsst / analysis / tools / actions / vector / selectors.py: 29%

394 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-25 08:55 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ( 

24 "SelectorBase", 

25 "FlagSelector", 

26 "CoaddPlotFlagSelector", 

27 "RangeSelector", 

28 "SetSelector", 

29 "SnSelector", 

30 "ExtendednessSelector", 

31 "SkyObjectSelector", 

32 "SkySourceSelector", 

33 "GoodDiaSourceSelector", 

34 "StarSelector", 

35 "GalaxySelector", 

36 "UnknownSelector", 

37 "VectorSelector", 

38 "FiniteSelector", 

39 "VisitPlotFlagSelector", 

40 "ThresholdSelector", 

41 "BandSelector", 

42 "MatchingFlagSelector", 

43 "MagSelector", 

44 "InjectedClassSelector", 

45 "InjectedGalaxySelector", 

46 "InjectedObjectSelector", 

47 "InjectedStarSelector", 

48 "MatchedObjectSelector", 

49 "ReferenceGalaxySelector", 

50 "ReferenceObjectSelector", 

51 "ReferenceStarSelector", 

52) 

53 

54import operator 

55from typing import cast 

56 

57import numpy as np 

58 

59from lsst.pex.config import Field 

60from lsst.pex.config.listField import ListField 

61 

62from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction 

63from ...math import divide, fluxToMag 

64 

65 

66class SelectorBase(VectorAction): 

67 plotLabelKey = Field[str]( 

68 doc="Key to use when populating plot info, ignored if empty string", optional=True, default="" 

69 ) 

70 

71 def _addValueToPlotInfo(self, value, plotLabelKey=None, **kwargs): 

72 if "plotInfo" in kwargs: 

73 if plotLabelKey is not None: 

74 kwargs["plotInfo"][plotLabelKey] = value 

75 elif self.plotLabelKey: 

76 kwargs["plotInfo"][self.plotLabelKey] = value 

77 else: 

78 raise RuntimeError(f"No plotLabelKey provided for value {value}, so can't add to plotInfo") 

79 

80 

81class FlagSelector(SelectorBase): 

82 """The base flag selector to use to select valid sources for QA.""" 

83 

84 selectWhenFalse = ListField[str]( 

85 doc="Names of the flag columns to select on when False", optional=False, default=[] 

86 ) 

87 

88 selectWhenTrue = ListField[str]( 

89 doc="Names of the flag columns to select on when True", optional=False, default=[] 

90 ) 

91 

92 def getInputSchema(self) -> KeyedDataSchema: 

93 yield from ((key, Vector) for key in self.selectWhenFalse) 

94 yield from ((key, Vector) for key in self.selectWhenTrue) 

95 

96 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

97 """Select on the given flags 

98 

99 Parameters 

100 ---------- 

101 data : `KeyedData` 

102 

103 Returns 

104 ------- 

105 result : `Vector` 

106 A mask of the objects that satisfy the given 

107 flag cuts. 

108 

109 Notes 

110 ----- 

111 Uses the columns in selectWhenFalse and 

112 selectWhenTrue to decide which columns to 

113 select on in each circumstance. 

114 """ 

115 

116 if not self.selectWhenFalse and not self.selectWhenTrue: 

117 raise RuntimeError("No column keys specified") 

118 results: Vector | None = None 

119 

120 for flag in self.selectWhenFalse: # type: ignore 

121 temp = np.array(data[flag.format(**kwargs)] == 0) 

122 if results is not None: 

123 results &= temp # type: ignore 

124 else: 

125 results = temp 

126 

127 for flag in self.selectWhenTrue: 

128 temp = np.array(data[flag.format(**kwargs)] == 1) 

129 if results is not None: 

130 results &= temp # type: ignore 

131 else: 

132 results = temp 

133 # The test at the beginning assures this can never be None 

134 return cast(Vector, results) 

135 

136 

137class CoaddPlotFlagSelector(FlagSelector): 

138 """This default setting makes it take the band from 

139 the kwargs. 

140 """ 

141 

142 bands = ListField[str]( 

143 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

144 default=[], 

145 ) 

146 

147 def getInputSchema(self) -> KeyedDataSchema: 

148 if self.bands: 

149 for key, dtype in super().getInputSchema(): 

150 if "{band}" in key: 

151 yield from ((key.format(band=band), dtype) for band in self.bands) 

152 else: 

153 yield key, dtype 

154 else: 

155 yield from super().getInputSchema() 

156 

157 def refMatchContext(self): 

158 self.selectWhenFalse = [ 

159 "{band}_psfFlux_flag_target", 

160 "{band}_pixelFlags_saturatedCenter_target", 

161 "{band}_extendedness_flag_target", 

162 "coord_flag_target", 

163 ] 

164 self.selectWhenTrue = ["detect_isPatchInner_target", "detect_isDeblendedSource_target"] 

165 

166 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

167 result: Vector | None = None 

168 bands: tuple[str, ...] 

169 match kwargs: 

170 case {"band": band} if not self.bands and self.bands == []: 

171 bands = (band,) 

172 case {"bands": bands} if not self.bands and self.bands == []: 

173 bands = bands 

174 case _ if self.bands: 

175 bands = tuple(self.bands) 

176 case _: 

177 bands = ("",) 

178 for band in bands: 

179 temp = super().__call__(data, **(kwargs | dict(band=band))) 

180 if result is not None: 

181 result &= temp # type: ignore 

182 else: 

183 result = temp 

184 return cast(Vector, result) 

185 

186 def setDefaults(self): 

187 self.selectWhenFalse = [ 

188 "{band}_psfFlux_flag", 

189 "{band}_pixelFlags_saturatedCenter", 

190 "{band}_extendedness_flag", 

191 "coord_flag", 

192 "sky_object", 

193 ] 

194 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"] 

195 

196 

197class MatchingFlagSelector(CoaddPlotFlagSelector): 

198 """The default flag selector to apply pre matching. 

199 The sources are cut down to remove duplicates but 

200 not on quality. 

201 """ 

202 

203 def setDefaults(self): 

204 self.selectWhenFalse = [] 

205 self.selectWhenTrue = ["detect_isPrimary"] 

206 

207 

208class VisitPlotFlagSelector(FlagSelector): 

209 """Select on a set of flags appropriate for making visit-level plots 

210 (i.e., using sourceTable_visit catalogs). 

211 """ 

212 

213 catalogSuffix = Field[str](doc="The suffix to apply to all the keys.", default="") 

214 

215 def getInputSchema(self) -> KeyedDataSchema: 

216 yield from super().getInputSchema() 

217 

218 def refMatchContext(self): 

219 self.selectWhenFalse = [ 

220 "psfFlux_flag_target", 

221 "pixelFlags_saturatedCenter_target", 

222 "extendedness_flag_target", 

223 "centroid_flag_target", 

224 ] 

225 

226 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

227 result: Vector | None = None 

228 temp = super().__call__(data, **kwargs) 

229 if result is not None: 

230 result &= temp # type: ignore 

231 else: 

232 result = temp 

233 

234 return result 

235 

236 def setDefaults(self): 

237 self.selectWhenFalse = [ 

238 "psfFlux_flag", 

239 "pixelFlags_saturatedCenter", 

240 "extendedness_flag", 

241 "centroid_flag", 

242 "sky_source", 

243 ] 

244 

245 

246class RangeSelector(SelectorBase): 

247 """Selects rows within a range, inclusive of min/exclusive of max.""" 

248 

249 vectorKey = Field[str](doc="Key to select from data") 

250 maximum = Field[float](doc="The maximum value (exclusive)", default=np.inf) 

251 minimum = Field[float](doc="The minimum value (inclusive)", default=np.nextafter(-np.inf, 0.0)) 

252 

253 def getInputSchema(self) -> KeyedDataSchema: 

254 yield self.vectorKey, Vector 

255 

256 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

257 """Return a mask of rows with values within the specified range. 

258 

259 Parameters 

260 ---------- 

261 data : `KeyedData` 

262 

263 Returns 

264 ------- 

265 result : `Vector` 

266 A mask of the rows with values within the specified range. 

267 """ 

268 values = cast(Vector, data[self.vectorKey]) 

269 mask = (values >= self.minimum) & (values < self.maximum) 

270 

271 return cast(Vector, mask) 

272 

273 

274class SetSelector(SelectorBase): 

275 """Selects rows with any number of column values within a given set. 

276 

277 For example, given a set of patches (1, 2, 3), and a set of columns 

278 (index_1, index_2), return all rows with either index_1 or index_2 

279 in the set (1, 2, 3). 

280 

281 Notes 

282 ----- 

283 The values are given as floats for flexibility. Integers above 

284 the floating point limit (2^53 + 1 = 9,007,199,254,740,993 for 64 bits) 

285 will not compare exactly with their float representations. 

286 """ 

287 

288 vectorKeys = ListField[str]( 

289 doc="Keys to select from data", 

290 default=[], 

291 listCheck=lambda x: (len(x) > 0) & (len(x) == len(set(x))), 

292 ) 

293 values = ListField[float]( 

294 doc="The set of acceptable values", 

295 default=[], 

296 listCheck=lambda x: (len(x) > 0) & (len(x) == len(set(x))), 

297 ) 

298 

299 def getInputSchema(self) -> KeyedDataSchema: 

300 yield from ((key, Vector) for key in self.vectorKeys) 

301 

302 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

303 """Return a mask of rows with values in the specified set. 

304 

305 Parameters 

306 ---------- 

307 data : `KeyedData` 

308 

309 Returns 

310 ------- 

311 result : `Vector` 

312 A mask of the rows with values in the specified set. 

313 """ 

314 mask = np.zeros_like(data[self.vectorKeys[0]], dtype=bool) 

315 for key in self.vectorKeys: 

316 values = cast(Vector, data[key]) 

317 for compare in self.values: 

318 mask |= values == compare 

319 

320 return cast(Vector, mask) 

321 

322 

323class PatchSelector(SetSelector): 

324 """Select rows within a set of patches.""" 

325 

326 def setDefaults(self): 

327 super().setDefaults() 

328 self.vectorKeys = ["patch"] 

329 

330 

331class SnSelector(SelectorBase): 

332 """Selects points that have S/N > threshold in the given flux type.""" 

333 

334 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux") 

335 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0) 

336 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6) 

337 uncertaintySuffix = Field[str]( 

338 doc="Suffix to add to fluxType to specify uncertainty column", default="Err" 

339 ) 

340 bands = ListField[str]( 

341 doc="The bands to apply the signal to noise cut in. Takes precedence if bands passed to call", 

342 default=[], 

343 ) 

344 

345 def getInputSchema(self) -> KeyedDataSchema: 

346 fluxCol = self.fluxType 

347 fluxInd = fluxCol.find("lux") + len("lux") 

348 errCol = f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix}" + f"{fluxCol}"[fluxInd:] 

349 yield fluxCol, Vector 

350 yield errCol, Vector 

351 

352 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

353 """Makes a mask of objects that have S/N greater than 

354 self.threshold in self.fluxType 

355 

356 Parameters 

357 ---------- 

358 data : `KeyedData` 

359 The data to perform the selection on. 

360 

361 Returns 

362 ------- 

363 result : `Vector` 

364 A mask of the objects that satisfy the given 

365 S/N cut. 

366 """ 

367 mask: Vector | None = None 

368 bands: tuple[str, ...] 

369 match kwargs: 

370 case {"band": band} if not self.bands and self.bands == []: 

371 bands = (band,) 

372 case {"bands": bands} if not self.bands and self.bands == []: 

373 bands = bands 

374 case _ if self.bands: 

375 bands = tuple(self.bands) 

376 case _: 

377 bands = ("",) 

378 bandStr = ",".join(bands) 

379 for band in bands: 

380 fluxCol = self.fluxType.format(**(kwargs | dict(band=band))) 

381 fluxInd = fluxCol.find("lux") + len("lux") 

382 errCol = ( 

383 f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix.format(**kwargs)}" + f"{fluxCol}"[fluxInd:] 

384 ) 

385 vec = divide(cast(Vector, data[fluxCol]), cast(Vector, data[errCol])) 

386 temp = (vec > self.threshold) & (vec < self.maxSN) 

387 if mask is not None: 

388 mask &= temp # type: ignore 

389 else: 

390 mask = temp 

391 

392 plotLabelStr = f"({bandStr}) > {self.threshold:.1f}" 

393 if self.maxSN < 1e5: 

394 plotLabelStr += f" & < {self.maxSN:.1f}" 

395 

396 if self.plotLabelKey == "" or self.plotLabelKey is None: 

397 self._addValueToPlotInfo(plotLabelStr, plotLabelKey="S/N", **kwargs) 

398 else: 

399 self._addValueToPlotInfo(plotLabelStr, **kwargs) 

400 

401 # It should not be possible for mask to be a None now 

402 return np.array(cast(Vector, mask)) 

403 

404 

405class SkyObjectSelector(FlagSelector): 

406 """Selects sky objects in the given band(s).""" 

407 

408 bands = ListField[str]( 

409 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

410 default=[], 

411 ) 

412 

413 def getInputSchema(self) -> KeyedDataSchema: 

414 yield from super().getInputSchema() 

415 

416 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

417 result: Vector | None = None 

418 bands: tuple[str, ...] 

419 match kwargs: 

420 case {"band": band} if not self.bands and self.bands == []: 

421 bands = (band,) 

422 case {"bands": bands} if not self.bands and self.bands == []: 

423 bands = bands 

424 case _ if self.bands: 

425 bands = tuple(self.bands) 

426 case _: 

427 bands = ("",) 

428 for band in bands: 

429 temp = super().__call__(data, **(kwargs | dict(band=band))) 

430 if result is not None: 

431 result &= temp # type: ignore 

432 else: 

433 result = temp 

434 return cast(Vector, result) 

435 

436 def setDefaults(self): 

437 super().setDefaults() 

438 self.selectWhenFalse = [ 

439 "{band}_pixelFlags_edge", 

440 "{band}_pixelFlags_nodata", 

441 ] 

442 self.selectWhenTrue = ["sky_object"] 

443 

444 

445class SkySourceSelector(FlagSelector): 

446 """Selects sky sources from sourceTables.""" 

447 

448 def getInputSchema(self) -> KeyedDataSchema: 

449 yield from super().getInputSchema() 

450 

451 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

452 result: Vector | None = None 

453 temp = super().__call__(data, **(kwargs)) 

454 if result is not None: 

455 result &= temp # type: ignore 

456 else: 

457 result = temp 

458 return result 

459 

460 def setDefaults(self): 

461 super().setDefaults() 

462 self.selectWhenFalse = [ 

463 "pixelFlags_edge", 

464 "pixelFlags_nodata", 

465 ] 

466 self.selectWhenTrue = ["sky_source"] 

467 

468 

469class GoodDiaSourceSelector(FlagSelector): 

470 """Selects good DIA sources from diaSourceTables.""" 

471 

472 def getInputSchema(self) -> KeyedDataSchema: 

473 yield from super().getInputSchema() 

474 

475 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

476 result: Vector | None = None 

477 temp = super().__call__(data, **(kwargs)) 

478 if result is not None: 

479 result &= temp # type: ignore 

480 else: 

481 result = temp 

482 return result 

483 

484 def setDefaults(self): 

485 super().setDefaults() 

486 # These default flag names are correct for AP data products 

487 self.selectWhenFalse = [ 

488 "pixelFlags_bad", 

489 "pixelFlags_saturatedCenter", 

490 "pixelFlags_interpolatedCenter", 

491 "pixelFlags_edge", 

492 "pixelFlags_nodata", 

493 ] 

494 

495 

496class ExtendednessSelector(SelectorBase): 

497 """A selector that picks between extended and point sources.""" 

498 

499 vectorKey = Field[str]( 

500 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness" 

501 ) 

502 

503 def getInputSchema(self) -> KeyedDataSchema: 

504 return ((self.vectorKey, Vector),) 

505 

506 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

507 key = self.vectorKey.format(**kwargs) 

508 return cast(Vector, data[key]) 

509 

510 

511class StarSelector(ExtendednessSelector): 

512 """A selector that picks out stars based off of their 

513 extendedness values. 

514 """ 

515 

516 extendedness_maximum = Field[float]( 

517 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float 

518 ) 

519 

520 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

521 extendedness = super().__call__(data, **kwargs) 

522 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum))) 

523 

524 

525class GalaxySelector(ExtendednessSelector): 

526 """A selector that picks out galaxies based off of their 

527 extendedness values. 

528 """ 

529 

530 extendedness_minimum = Field[float]( 

531 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5 

532 ) 

533 

534 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

535 extendedness = super().__call__(data, **kwargs) 

536 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore 

537 

538 

539class UnknownSelector(ExtendednessSelector): 

540 """A selector that picks out unclassified objects based off of their 

541 extendedness values. 

542 """ 

543 

544 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

545 extendedness = super().__call__(data, **kwargs) 

546 return extendedness == 9 

547 

548 

549class FiniteSelector(VectorAction): 

550 """Return a mask of finite values for a vector key""" 

551 

552 vectorKey = Field[str](doc="Key to make a mask of finite values for.") 

553 

554 def getInputSchema(self) -> KeyedDataSchema: 

555 return ((self.vectorKey, Vector),) 

556 

557 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

558 return cast(Vector, np.isfinite(data[self.vectorKey.format(**kwargs)])) 

559 

560 

561class VectorSelector(VectorAction): 

562 """Load a boolean vector from KeyedData and return it for use as a 

563 selector. 

564 """ 

565 

566 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask") 

567 

568 def getInputSchema(self) -> KeyedDataSchema: 

569 return ((self.vectorKey, Vector),) 

570 

571 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

572 return cast(Vector, data[self.vectorKey.format(**kwargs)]) 

573 

574 

575class ThresholdSelector(SelectorBase): 

576 """Return a mask corresponding to an applied threshold.""" 

577 

578 op = Field[str](doc="Operator name.") 

579 threshold = Field[float](doc="Threshold to apply.") 

580 vectorKey = Field[str](doc="Name of column") 

581 

582 def getInputSchema(self) -> KeyedDataSchema: 

583 return ((self.vectorKey, Vector),) 

584 

585 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

586 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold) 

587 return cast(Vector, mask) 

588 

589 

590class BandSelector(VectorAction): 

591 """Makes a mask for sources observed in a specified set of bands.""" 

592 

593 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band") 

594 bands = ListField[str]( 

595 doc="The bands to select. `None` indicates no band selection applied.", 

596 default=[], 

597 ) 

598 

599 def getInputSchema(self) -> KeyedDataSchema: 

600 return ((self.vectorKey, Vector),) 

601 

602 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

603 bands: tuple[str, ...] | None 

604 match kwargs: 

605 case {"band": band} if not self.bands and self.bands == []: 

606 bands = (band,) 

607 case {"bands": bands} if not self.bands and self.bands == []: 

608 bands = bands 

609 case _ if self.bands: 

610 bands = tuple(self.bands) 

611 case _: 

612 bands = None 

613 if bands: 

614 mask = np.isin(data[self.vectorKey], bands).flatten() 

615 else: 

616 # No band selection is applied, i.e., select all rows 

617 mask = np.full(len(data[self.vectorKey]), True) # type: ignore 

618 return cast(Vector, mask) 

619 

620 

621class ParentObjectSelector(FlagSelector): 

622 """Select only parent objects that are not sky objects.""" 

623 

624 def setDefaults(self): 

625 # This selects all of the parents 

626 # parentObjectId excludes subParents. 

627 # This works because FlagSelector identifies False as 0. 

628 self.selectWhenFalse = [ 

629 "sky_object", 

630 "parentObjectId", 

631 ] 

632 

633 

634class ChildObjectSelector(RangeSelector): 

635 """Select only children from deblended parents""" 

636 

637 vectorKey = Field[str](doc="Key to select from data", default="parentSourceId") 

638 

639 def getInputSchema(self) -> KeyedDataSchema: 

640 yield self.vectorKey, Vector 

641 

642 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

643 """Return a mask of rows with values within the specified range. 

644 

645 Parameters 

646 ---------- 

647 data : `KeyedData` 

648 

649 Returns 

650 ------- 

651 result : `Vector` 

652 A mask of the rows with values within the specified range. 

653 """ 

654 values = cast(Vector, data[self.vectorKey]) 

655 mask = values > 0 

656 

657 return cast(Vector, mask) 

658 

659 

660class MagSelector(SelectorBase): 

661 """Selects points that have minMag < mag (AB) < maxMag. 

662 

663 The magnitude is based on the given fluxType. 

664 """ 

665 

666 fluxType = Field[str](doc="Flux type to calculate the magnitude in.", default="{band}_psfFlux") 

667 minMag = Field[float](doc="Minimum mag to include in the sample.", default=-1e6) 

668 maxMag = Field[float](doc="Maximum mag to include in the sample.", default=1e6) 

669 fluxUnit = Field[str](doc="Astropy unit of flux vector", default="nJy") 

670 returnMillimags = Field[bool](doc="Use millimags or not?", default=False) 

671 bands = ListField[str]( 

672 doc="The band(s) to apply the magnitude cut in. Takes precedence if bands passed to call.", 

673 default=[], 

674 ) 

675 

676 def getInputSchema(self) -> KeyedDataSchema: 

677 fluxCol = self.fluxType 

678 yield fluxCol, Vector 

679 

680 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

681 """Make a mask of that satisfies self.minMag < mag < self.maxMag. 

682 

683 The magnitude is based on the flux in self.fluxType. 

684 

685 Parameters 

686 ---------- 

687 data : `KeyedData` 

688 The data to perform the magnitude selection on. 

689 

690 Returns 

691 ------- 

692 result : `Vector` 

693 A mask of the objects that satisfy the given magnitude cut. 

694 """ 

695 mask: Vector | None = None 

696 bands: tuple[str, ...] 

697 match kwargs: 

698 case {"band": band} if not self.bands and self.bands == []: 

699 bands = (band,) 

700 case {"bands": bands} if not self.bands and self.bands == []: 

701 bands = bands 

702 case _ if self.bands: 

703 bands = tuple(self.bands) 

704 case _: 

705 bands = ("",) 

706 bandStr = ",".join(bands) 

707 for band in bands: 

708 fluxCol = self.fluxType.format(**(kwargs | dict(band=band))) 

709 vec = fluxToMag( 

710 cast(Vector, data[fluxCol]), 

711 flux_unit=self.fluxUnit, 

712 return_millimags=self.returnMillimags, 

713 ) 

714 temp = (vec > self.minMag) & (vec < self.maxMag) 

715 if mask is not None: 

716 mask &= temp # type: ignore 

717 else: 

718 mask = temp 

719 

720 plotLabelStr = "" 

721 if self.maxMag < 100: 

722 plotLabelStr += f"({bandStr}) < {self.maxMag:.1f}" 

723 if self.minMag > -100: 

724 if bandStr in plotLabelStr: 

725 plotLabelStr += f" & < {self.minMag:.1f}" 

726 else: 

727 plotLabelStr += f"({bandStr}) < {self.minMag:.1f}" 

728 if self.plotLabelKey == "" or self.plotLabelKey is None: 

729 self._addValueToPlotInfo(plotLabelStr, plotLabelKey="Mag", **kwargs) 

730 else: 

731 self._addValueToPlotInfo(plotLabelStr, **kwargs) 

732 

733 # It should not be possible for mask to be a None now 

734 return np.array(cast(Vector, mask)) 

735 

736 

737class InjectedObjectSelector(SelectorBase): 

738 """A selector for injected objects.""" 

739 

740 vectorKey = Field[str](doc="Key to select from data", default="ref_injected_isPrimary") 

741 

742 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

743 key = self.vectorKey.format(**kwargs) 

744 result = cast(Vector, data[key] == 1) 

745 return result 

746 

747 def getInputSchema(self) -> KeyedDataSchema: 

748 yield self.vectorKey, Vector 

749 

750 

751class InjectedClassSelector(InjectedObjectSelector): 

752 """A selector for injected objects of a given class.""" 

753 

754 key_class = Field[str]( 

755 doc="Key for the field indicating the class of the object", 

756 default="ref_source_type", 

757 ) 

758 key_injection_flag = Field[str]( 

759 doc="Key for the field indicating that the object was not injected (per band)", 

760 default="ref_{band}_injection_flag", 

761 ) 

762 name_class = Field[str]( 

763 doc="Name of the class of objects", 

764 ) 

765 value_compare = Field[str]( 

766 doc="Value of the type_key field for objects that are stars", 

767 default="DeltaFunction", 

768 ) 

769 value_is_equal = Field[bool]( 

770 doc="Whether the value must equal value_compare to be of this class", 

771 default=True, 

772 ) 

773 

774 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

775 result = super().__call__(data, **kwargs) 

776 if self.key_injection_flag: 

777 result &= data[self.key_injection_flag.format(band=kwargs["band"])] == False # noqa: E712 

778 values = data[self.key_class] 

779 result &= (values == self.value_compare) if self.value_is_equal else (values != self.value_compare) 

780 if self.plotLabelKey: 

781 self._addValueToPlotInfo(f"injected {self.name_class}", **kwargs) 

782 return result 

783 

784 def getInputSchema(self) -> KeyedDataSchema: 

785 yield from super().getInputSchema() 

786 yield self.key_class, Vector 

787 if self.key_injection_flag: 

788 yield self.key_injection_flag, Vector 

789 

790 

791class InjectedGalaxySelector(InjectedClassSelector): 

792 """A selector for injected galaxies.""" 

793 

794 def setDefaults(self): 

795 self.name_class = "galaxy" 

796 # Assumes not star == galaxy - if there are injected AGN or other 

797 # object classes, this will need to be updated 

798 self.value_is_equal = False 

799 

800 

801class InjectedStarSelector(InjectedClassSelector): 

802 """A selector for injected stars.""" 

803 

804 def setDefaults(self): 

805 self.name_class = "star" 

806 

807 

808class MatchedObjectSelector(RangeSelector): 

809 """A selector that selects matched objects with finite distances.""" 

810 

811 def setDefaults(self): 

812 super().setDefaults() 

813 self.minimum = 0 

814 self.vectorKey = "match_distance" 

815 

816 

817class ReferenceGalaxySelector(ThresholdSelector): 

818 """A selector that selects galaxies from a catalog with a 

819 boolean column identifying unresolved sources. 

820 """ 

821 

822 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

823 result = super().__call__(data=data, **kwargs) 

824 if self.plotLabelKey: 

825 self._addValueToPlotInfo("reference galaxies", **kwargs) 

826 return result 

827 

828 def setDefaults(self): 

829 super().setDefaults() 

830 self.op = "eq" 

831 self.threshold = 0 

832 self.plotLabelKey = "Selection: Galaxies" 

833 self.vectorKey = "refcat_is_pointsource" 

834 

835 

836class ReferenceObjectSelector(RangeSelector): 

837 """A selector that selects all objects from a catalog with a 

838 boolean column identifying unresolved sources. 

839 """ 

840 

841 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

842 result = super().__call__(data=data, **kwargs) 

843 if self.plotLabelKey: 

844 self._addValueToPlotInfo("reference objects", **kwargs) 

845 return result 

846 

847 def setDefaults(self): 

848 super().setDefaults() 

849 self.minimum = 0 

850 self.vectorKey = "refcat_is_pointsource" 

851 

852 

853class ReferenceStarSelector(ThresholdSelector): 

854 """A selector that selects stars from a catalog with a 

855 boolean column identifying unresolved sources. 

856 """ 

857 

858 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

859 result = super().__call__(data=data, **kwargs) 

860 if self.plotLabelKey: 

861 self._addValueToPlotInfo("reference stars", **kwargs) 

862 return result 

863 

864 def setDefaults(self): 

865 super().setDefaults() 

866 self.op = "eq" 

867 self.plotLabelKey = "Selection: Stars" 

868 self.threshold = 1 

869 self.vectorKey = "refcat_is_pointsource"