Coverage for python / lsst / analysis / tools / actions / vector / selectors.py: 29%

389 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-18 09:19 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ( 

24 "SelectorBase", 

25 "FlagSelector", 

26 "CoaddPlotFlagSelector", 

27 "RangeSelector", 

28 "SetSelector", 

29 "SnSelector", 

30 "ExtendednessSelector", 

31 "SkyObjectSelector", 

32 "SkySourceSelector", 

33 "GoodDiaSourceSelector", 

34 "StarSelector", 

35 "GalaxySelector", 

36 "UnknownSelector", 

37 "VectorSelector", 

38 "FiniteSelector", 

39 "VisitPlotFlagSelector", 

40 "ThresholdSelector", 

41 "BandSelector", 

42 "MatchingFlagSelector", 

43 "MagSelector", 

44 "InjectedClassSelector", 

45 "InjectedGalaxySelector", 

46 "InjectedObjectSelector", 

47 "InjectedStarSelector", 

48 "MatchedObjectSelector", 

49 "ReferenceGalaxySelector", 

50 "ReferenceObjectSelector", 

51 "ReferenceStarSelector", 

52) 

53 

54import operator 

55from typing import cast 

56 

57import numpy as np 

58 

59from lsst.pex.config import Field 

60from lsst.pex.config.listField import ListField 

61 

62from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction 

63from ...math import divide, fluxToMag 

64 

65 

66class SelectorBase(VectorAction): 

67 plotLabelKey = Field[str]( 

68 doc="Key to use when populating plot info, ignored if empty string", optional=True, default="" 

69 ) 

70 

71 def _addValueToPlotInfo(self, value, plotLabelKey=None, **kwargs): 

72 if "plotInfo" in kwargs: 

73 if plotLabelKey is not None: 

74 kwargs["plotInfo"][plotLabelKey] = value 

75 elif self.plotLabelKey: 

76 kwargs["plotInfo"][self.plotLabelKey] = value 

77 else: 

78 raise RuntimeError(f"No plotLabelKey provided for value {value}, so can't add to plotInfo") 

79 

80 

81class FlagSelector(VectorAction): 

82 """The base flag selector to use to select valid sources for QA.""" 

83 

84 selectWhenFalse = ListField[str]( 

85 doc="Names of the flag columns to select on when False", optional=False, default=[] 

86 ) 

87 

88 selectWhenTrue = ListField[str]( 

89 doc="Names of the flag columns to select on when True", optional=False, default=[] 

90 ) 

91 

92 def getInputSchema(self) -> KeyedDataSchema: 

93 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue) 

94 return ((col, Vector) for col in allCols) 

95 

96 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

97 """Select on the given flags 

98 

99 Parameters 

100 ---------- 

101 data : `KeyedData` 

102 

103 Returns 

104 ------- 

105 result : `Vector` 

106 A mask of the objects that satisfy the given 

107 flag cuts. 

108 

109 Notes 

110 ----- 

111 Uses the columns in selectWhenFalse and 

112 selectWhenTrue to decide which columns to 

113 select on in each circumstance. 

114 """ 

115 

116 if not self.selectWhenFalse and not self.selectWhenTrue: 

117 raise RuntimeError("No column keys specified") 

118 results: Vector | None = None 

119 

120 for flag in self.selectWhenFalse: # type: ignore 

121 temp = np.array(data[flag.format(**kwargs)] == 0) 

122 if results is not None: 

123 results &= temp # type: ignore 

124 else: 

125 results = temp 

126 

127 for flag in self.selectWhenTrue: 

128 temp = np.array(data[flag.format(**kwargs)] == 1) 

129 if results is not None: 

130 results &= temp # type: ignore 

131 else: 

132 results = temp 

133 # The test at the beginning assures this can never be None 

134 return cast(Vector, results) 

135 

136 

137class CoaddPlotFlagSelector(FlagSelector): 

138 """This default setting makes it take the band from 

139 the kwargs. 

140 """ 

141 

142 bands = ListField[str]( 

143 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

144 default=[], 

145 ) 

146 

147 def getInputSchema(self) -> KeyedDataSchema: 

148 yield from super().getInputSchema() 

149 

150 def refMatchContext(self): 

151 self.selectWhenFalse = [ 

152 "{band}_psfFlux_flag_target", 

153 "{band}_pixelFlags_saturatedCenter_target", 

154 "{band}_extendedness_flag_target", 

155 "coord_flag_target", 

156 ] 

157 self.selectWhenTrue = ["detect_isPatchInner_target", "detect_isDeblendedSource_target"] 

158 

159 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

160 result: Vector | None = None 

161 bands: tuple[str, ...] 

162 match kwargs: 

163 case {"band": band} if not self.bands and self.bands == []: 

164 bands = (band,) 

165 case {"bands": bands} if not self.bands and self.bands == []: 

166 bands = bands 

167 case _ if self.bands: 

168 bands = tuple(self.bands) 

169 case _: 

170 bands = ("",) 

171 for band in bands: 

172 temp = super().__call__(data, **(kwargs | dict(band=band))) 

173 if result is not None: 

174 result &= temp # type: ignore 

175 else: 

176 result = temp 

177 return cast(Vector, result) 

178 

179 def setDefaults(self): 

180 self.selectWhenFalse = [ 

181 "{band}_psfFlux_flag", 

182 "{band}_pixelFlags_saturatedCenter", 

183 "{band}_extendedness_flag", 

184 "coord_flag", 

185 "sky_object", 

186 ] 

187 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"] 

188 

189 

190class MatchingFlagSelector(CoaddPlotFlagSelector): 

191 """The default flag selector to apply pre matching. 

192 The sources are cut down to remove duplicates but 

193 not on quality. 

194 """ 

195 

196 def setDefaults(self): 

197 self.selectWhenFalse = [] 

198 self.selectWhenTrue = ["detect_isPrimary"] 

199 

200 

201class VisitPlotFlagSelector(FlagSelector): 

202 """Select on a set of flags appropriate for making visit-level plots 

203 (i.e., using sourceTable_visit catalogs). 

204 """ 

205 

206 catalogSuffix = Field[str](doc="The suffix to apply to all the keys.", default="") 

207 

208 def getInputSchema(self) -> KeyedDataSchema: 

209 yield from super().getInputSchema() 

210 

211 def refMatchContext(self): 

212 self.selectWhenFalse = [ 

213 "psfFlux_flag_target", 

214 "pixelFlags_saturatedCenter_target", 

215 "extendedness_flag_target", 

216 "centroid_flag_target", 

217 ] 

218 

219 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

220 result: Vector | None = None 

221 temp = super().__call__(data, **kwargs) 

222 if result is not None: 

223 result &= temp # type: ignore 

224 else: 

225 result = temp 

226 

227 return result 

228 

229 def setDefaults(self): 

230 self.selectWhenFalse = [ 

231 "psfFlux_flag", 

232 "pixelFlags_saturatedCenter", 

233 "extendedness_flag", 

234 "centroid_flag", 

235 "sky_source", 

236 ] 

237 

238 

239class RangeSelector(SelectorBase): 

240 """Selects rows within a range, inclusive of min/exclusive of max.""" 

241 

242 vectorKey = Field[str](doc="Key to select from data") 

243 maximum = Field[float](doc="The maximum value (exclusive)", default=np.inf) 

244 minimum = Field[float](doc="The minimum value (inclusive)", default=np.nextafter(-np.inf, 0.0)) 

245 

246 def getInputSchema(self) -> KeyedDataSchema: 

247 yield self.vectorKey, Vector 

248 

249 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

250 """Return a mask of rows with values within the specified range. 

251 

252 Parameters 

253 ---------- 

254 data : `KeyedData` 

255 

256 Returns 

257 ------- 

258 result : `Vector` 

259 A mask of the rows with values within the specified range. 

260 """ 

261 values = cast(Vector, data[self.vectorKey]) 

262 mask = (values >= self.minimum) & (values < self.maximum) 

263 

264 return cast(Vector, mask) 

265 

266 

267class SetSelector(SelectorBase): 

268 """Selects rows with any number of column values within a given set. 

269 

270 For example, given a set of patches (1, 2, 3), and a set of columns 

271 (index_1, index_2), return all rows with either index_1 or index_2 

272 in the set (1, 2, 3). 

273 

274 Notes 

275 ----- 

276 The values are given as floats for flexibility. Integers above 

277 the floating point limit (2^53 + 1 = 9,007,199,254,740,993 for 64 bits) 

278 will not compare exactly with their float representations. 

279 """ 

280 

281 vectorKeys = ListField[str]( 

282 doc="Keys to select from data", 

283 default=[], 

284 listCheck=lambda x: (len(x) > 0) & (len(x) == len(set(x))), 

285 ) 

286 values = ListField[float]( 

287 doc="The set of acceptable values", 

288 default=[], 

289 listCheck=lambda x: (len(x) > 0) & (len(x) == len(set(x))), 

290 ) 

291 

292 def getInputSchema(self) -> KeyedDataSchema: 

293 yield from ((key, Vector) for key in self.vectorKeys) 

294 

295 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

296 """Return a mask of rows with values in the specified set. 

297 

298 Parameters 

299 ---------- 

300 data : `KeyedData` 

301 

302 Returns 

303 ------- 

304 result : `Vector` 

305 A mask of the rows with values in the specified set. 

306 """ 

307 mask = np.zeros_like(data[self.vectorKeys[0]], dtype=bool) 

308 for key in self.vectorKeys: 

309 values = cast(Vector, data[key]) 

310 for compare in self.values: 

311 mask |= values == compare 

312 

313 return cast(Vector, mask) 

314 

315 

316class PatchSelector(SetSelector): 

317 """Select rows within a set of patches.""" 

318 

319 def setDefaults(self): 

320 super().setDefaults() 

321 self.vectorKeys = ["patch"] 

322 

323 

324class SnSelector(SelectorBase): 

325 """Selects points that have S/N > threshold in the given flux type.""" 

326 

327 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux") 

328 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0) 

329 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6) 

330 uncertaintySuffix = Field[str]( 

331 doc="Suffix to add to fluxType to specify uncertainty column", default="Err" 

332 ) 

333 bands = ListField[str]( 

334 doc="The bands to apply the signal to noise cut in. Takes precedence if bands passed to call", 

335 default=[], 

336 ) 

337 

338 def getInputSchema(self) -> KeyedDataSchema: 

339 fluxCol = self.fluxType 

340 fluxInd = fluxCol.find("lux") + len("lux") 

341 errCol = f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix}" + f"{fluxCol}"[fluxInd:] 

342 yield fluxCol, Vector 

343 yield errCol, Vector 

344 

345 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

346 """Makes a mask of objects that have S/N greater than 

347 self.threshold in self.fluxType 

348 

349 Parameters 

350 ---------- 

351 data : `KeyedData` 

352 The data to perform the selection on. 

353 

354 Returns 

355 ------- 

356 result : `Vector` 

357 A mask of the objects that satisfy the given 

358 S/N cut. 

359 """ 

360 mask: Vector | None = None 

361 bands: tuple[str, ...] 

362 match kwargs: 

363 case {"band": band} if not self.bands and self.bands == []: 

364 bands = (band,) 

365 case {"bands": bands} if not self.bands and self.bands == []: 

366 bands = bands 

367 case _ if self.bands: 

368 bands = tuple(self.bands) 

369 case _: 

370 bands = ("",) 

371 bandStr = ",".join(bands) 

372 for band in bands: 

373 fluxCol = self.fluxType.format(**(kwargs | dict(band=band))) 

374 fluxInd = fluxCol.find("lux") + len("lux") 

375 errCol = ( 

376 f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix.format(**kwargs)}" + f"{fluxCol}"[fluxInd:] 

377 ) 

378 vec = divide(cast(Vector, data[fluxCol]), cast(Vector, data[errCol])) 

379 temp = (vec > self.threshold) & (vec < self.maxSN) 

380 if mask is not None: 

381 mask &= temp # type: ignore 

382 else: 

383 mask = temp 

384 

385 plotLabelStr = f"({bandStr}) > {self.threshold:.1f}" 

386 if self.maxSN < 1e5: 

387 plotLabelStr += f" & < {self.maxSN:.1f}" 

388 

389 if self.plotLabelKey == "" or self.plotLabelKey is None: 

390 self._addValueToPlotInfo(plotLabelStr, plotLabelKey="S/N", **kwargs) 

391 else: 

392 self._addValueToPlotInfo(plotLabelStr, **kwargs) 

393 

394 # It should not be possible for mask to be a None now 

395 return np.array(cast(Vector, mask)) 

396 

397 

398class SkyObjectSelector(FlagSelector): 

399 """Selects sky objects in the given band(s).""" 

400 

401 bands = ListField[str]( 

402 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

403 default=[], 

404 ) 

405 

406 def getInputSchema(self) -> KeyedDataSchema: 

407 yield from super().getInputSchema() 

408 

409 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

410 result: Vector | None = None 

411 bands: tuple[str, ...] 

412 match kwargs: 

413 case {"band": band} if not self.bands and self.bands == []: 

414 bands = (band,) 

415 case {"bands": bands} if not self.bands and self.bands == []: 

416 bands = bands 

417 case _ if self.bands: 

418 bands = tuple(self.bands) 

419 case _: 

420 bands = ("",) 

421 for band in bands: 

422 temp = super().__call__(data, **(kwargs | dict(band=band))) 

423 if result is not None: 

424 result &= temp # type: ignore 

425 else: 

426 result = temp 

427 return cast(Vector, result) 

428 

429 def setDefaults(self): 

430 super().setDefaults() 

431 self.selectWhenFalse = [ 

432 "{band}_pixelFlags_edge", 

433 "{band}_pixelFlags_nodata", 

434 ] 

435 self.selectWhenTrue = ["sky_object"] 

436 

437 

438class SkySourceSelector(FlagSelector): 

439 """Selects sky sources from sourceTables.""" 

440 

441 def getInputSchema(self) -> KeyedDataSchema: 

442 yield from super().getInputSchema() 

443 

444 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

445 result: Vector | None = None 

446 temp = super().__call__(data, **(kwargs)) 

447 if result is not None: 

448 result &= temp # type: ignore 

449 else: 

450 result = temp 

451 return result 

452 

453 def setDefaults(self): 

454 super().setDefaults() 

455 self.selectWhenFalse = [ 

456 "pixelFlags_edge", 

457 "pixelFlags_nodata", 

458 ] 

459 self.selectWhenTrue = ["sky_source"] 

460 

461 

462class GoodDiaSourceSelector(FlagSelector): 

463 """Selects good DIA sources from diaSourceTables.""" 

464 

465 def getInputSchema(self) -> KeyedDataSchema: 

466 yield from super().getInputSchema() 

467 

468 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

469 result: Vector | None = None 

470 temp = super().__call__(data, **(kwargs)) 

471 if result is not None: 

472 result &= temp # type: ignore 

473 else: 

474 result = temp 

475 return result 

476 

477 def setDefaults(self): 

478 super().setDefaults() 

479 # These default flag names are correct for AP data products 

480 self.selectWhenFalse = [ 

481 "pixelFlags_bad", 

482 "pixelFlags_saturatedCenter", 

483 "pixelFlags_interpolatedCenter", 

484 "pixelFlags_edge", 

485 "pixelFlags_nodata", 

486 ] 

487 

488 

489class ExtendednessSelector(SelectorBase): 

490 """A selector that picks between extended and point sources.""" 

491 

492 vectorKey = Field[str]( 

493 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness" 

494 ) 

495 

496 def getInputSchema(self) -> KeyedDataSchema: 

497 return ((self.vectorKey, Vector),) 

498 

499 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

500 key = self.vectorKey.format(**kwargs) 

501 return cast(Vector, data[key]) 

502 

503 

504class StarSelector(ExtendednessSelector): 

505 """A selector that picks out stars based off of their 

506 extendedness values. 

507 """ 

508 

509 extendedness_maximum = Field[float]( 

510 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float 

511 ) 

512 

513 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

514 extendedness = super().__call__(data, **kwargs) 

515 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum))) 

516 

517 

518class GalaxySelector(ExtendednessSelector): 

519 """A selector that picks out galaxies based off of their 

520 extendedness values. 

521 """ 

522 

523 extendedness_minimum = Field[float]( 

524 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5 

525 ) 

526 

527 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

528 extendedness = super().__call__(data, **kwargs) 

529 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore 

530 

531 

532class UnknownSelector(ExtendednessSelector): 

533 """A selector that picks out unclassified objects based off of their 

534 extendedness values. 

535 """ 

536 

537 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

538 extendedness = super().__call__(data, **kwargs) 

539 return extendedness == 9 

540 

541 

542class FiniteSelector(VectorAction): 

543 """Return a mask of finite values for a vector key""" 

544 

545 vectorKey = Field[str](doc="Key to make a mask of finite values for.") 

546 

547 def getInputSchema(self) -> KeyedDataSchema: 

548 return ((self.vectorKey, Vector),) 

549 

550 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

551 return cast(Vector, np.isfinite(data[self.vectorKey.format(**kwargs)])) 

552 

553 

554class VectorSelector(VectorAction): 

555 """Load a boolean vector from KeyedData and return it for use as a 

556 selector. 

557 """ 

558 

559 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask") 

560 

561 def getInputSchema(self) -> KeyedDataSchema: 

562 return ((self.vectorKey, Vector),) 

563 

564 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

565 return cast(Vector, data[self.vectorKey.format(**kwargs)]) 

566 

567 

568class ThresholdSelector(SelectorBase): 

569 """Return a mask corresponding to an applied threshold.""" 

570 

571 op = Field[str](doc="Operator name.") 

572 threshold = Field[float](doc="Threshold to apply.") 

573 vectorKey = Field[str](doc="Name of column") 

574 

575 def getInputSchema(self) -> KeyedDataSchema: 

576 return ((self.vectorKey, Vector),) 

577 

578 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

579 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold) 

580 return cast(Vector, mask) 

581 

582 

583class BandSelector(VectorAction): 

584 """Makes a mask for sources observed in a specified set of bands.""" 

585 

586 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band") 

587 bands = ListField[str]( 

588 doc="The bands to select. `None` indicates no band selection applied.", 

589 default=[], 

590 ) 

591 

592 def getInputSchema(self) -> KeyedDataSchema: 

593 return ((self.vectorKey, Vector),) 

594 

595 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

596 bands: tuple[str, ...] | None 

597 match kwargs: 

598 case {"band": band} if not self.bands and self.bands == []: 

599 bands = (band,) 

600 case {"bands": bands} if not self.bands and self.bands == []: 

601 bands = bands 

602 case _ if self.bands: 

603 bands = tuple(self.bands) 

604 case _: 

605 bands = None 

606 if bands: 

607 mask = np.isin(data[self.vectorKey], bands).flatten() 

608 else: 

609 # No band selection is applied, i.e., select all rows 

610 mask = np.full(len(data[self.vectorKey]), True) # type: ignore 

611 return cast(Vector, mask) 

612 

613 

614class ParentObjectSelector(FlagSelector): 

615 """Select only parent objects that are not sky objects.""" 

616 

617 def setDefaults(self): 

618 # This selects all of the parents 

619 # parentObjectId excludes subParents. 

620 # This works because FlagSelector identifies False as 0. 

621 self.selectWhenFalse = [ 

622 "sky_object", 

623 "parentObjectId", 

624 ] 

625 

626 

627class ChildObjectSelector(RangeSelector): 

628 """Select only children from deblended parents""" 

629 

630 vectorKey = Field[str](doc="Key to select from data", default="parentSourceId") 

631 

632 def getInputSchema(self) -> KeyedDataSchema: 

633 yield self.vectorKey, Vector 

634 

635 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

636 """Return a mask of rows with values within the specified range. 

637 

638 Parameters 

639 ---------- 

640 data : `KeyedData` 

641 

642 Returns 

643 ------- 

644 result : `Vector` 

645 A mask of the rows with values within the specified range. 

646 """ 

647 values = cast(Vector, data[self.vectorKey]) 

648 mask = values > 0 

649 

650 return cast(Vector, mask) 

651 

652 

653class MagSelector(SelectorBase): 

654 """Selects points that have minMag < mag (AB) < maxMag. 

655 

656 The magnitude is based on the given fluxType. 

657 """ 

658 

659 fluxType = Field[str](doc="Flux type to calculate the magnitude in.", default="{band}_psfFlux") 

660 minMag = Field[float](doc="Minimum mag to include in the sample.", default=-1e6) 

661 maxMag = Field[float](doc="Maximum mag to include in the sample.", default=1e6) 

662 fluxUnit = Field[str](doc="Astropy unit of flux vector", default="nJy") 

663 returnMillimags = Field[bool](doc="Use millimags or not?", default=False) 

664 bands = ListField[str]( 

665 doc="The band(s) to apply the magnitude cut in. Takes precedence if bands passed to call.", 

666 default=[], 

667 ) 

668 

669 def getInputSchema(self) -> KeyedDataSchema: 

670 fluxCol = self.fluxType 

671 yield fluxCol, Vector 

672 

673 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

674 """Make a mask of that satisfies self.minMag < mag < self.maxMag. 

675 

676 The magnitude is based on the flux in self.fluxType. 

677 

678 Parameters 

679 ---------- 

680 data : `KeyedData` 

681 The data to perform the magnitude selection on. 

682 

683 Returns 

684 ------- 

685 result : `Vector` 

686 A mask of the objects that satisfy the given magnitude cut. 

687 """ 

688 mask: Vector | None = None 

689 bands: tuple[str, ...] 

690 match kwargs: 

691 case {"band": band} if not self.bands and self.bands == []: 

692 bands = (band,) 

693 case {"bands": bands} if not self.bands and self.bands == []: 

694 bands = bands 

695 case _ if self.bands: 

696 bands = tuple(self.bands) 

697 case _: 

698 bands = ("",) 

699 bandStr = ",".join(bands) 

700 for band in bands: 

701 fluxCol = self.fluxType.format(**(kwargs | dict(band=band))) 

702 vec = fluxToMag( 

703 cast(Vector, data[fluxCol]), 

704 flux_unit=self.fluxUnit, 

705 return_millimags=self.returnMillimags, 

706 ) 

707 temp = (vec > self.minMag) & (vec < self.maxMag) 

708 if mask is not None: 

709 mask &= temp # type: ignore 

710 else: 

711 mask = temp 

712 

713 plotLabelStr = "" 

714 if self.maxMag < 100: 

715 plotLabelStr += f"({bandStr}) < {self.maxMag:.1f}" 

716 if self.minMag > -100: 

717 if bandStr in plotLabelStr: 

718 plotLabelStr += f" & < {self.minMag:.1f}" 

719 else: 

720 plotLabelStr += f"({bandStr}) < {self.minMag:.1f}" 

721 if self.plotLabelKey == "" or self.plotLabelKey is None: 

722 self._addValueToPlotInfo(plotLabelStr, plotLabelKey="Mag", **kwargs) 

723 else: 

724 self._addValueToPlotInfo(plotLabelStr, **kwargs) 

725 

726 # It should not be possible for mask to be a None now 

727 return np.array(cast(Vector, mask)) 

728 

729 

730class InjectedObjectSelector(SelectorBase): 

731 """A selector for injected objects.""" 

732 

733 vectorKey = Field[str](doc="Key to select from data", default="ref_injected_isPrimary") 

734 

735 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

736 key = self.vectorKey.format(**kwargs) 

737 result = cast(Vector, data[key] == 1) 

738 return result 

739 

740 def getInputSchema(self) -> KeyedDataSchema: 

741 yield self.vectorKey, Vector 

742 

743 

744class InjectedClassSelector(InjectedObjectSelector): 

745 """A selector for injected objects of a given class.""" 

746 

747 key_class = Field[str]( 

748 doc="Key for the field indicating the class of the object", 

749 default="ref_source_type", 

750 ) 

751 key_injection_flag = Field[str]( 

752 doc="Key for the field indicating that the object was not injected (per band)", 

753 default="ref_{band}_injection_flag", 

754 ) 

755 name_class = Field[str]( 

756 doc="Name of the class of objects", 

757 ) 

758 value_compare = Field[str]( 

759 doc="Value of the type_key field for objects that are stars", 

760 default="DeltaFunction", 

761 ) 

762 value_is_equal = Field[bool]( 

763 doc="Whether the value must equal value_compare to be of this class", 

764 default=True, 

765 ) 

766 

767 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

768 result = super().__call__(data, **kwargs) 

769 if self.key_injection_flag: 

770 result &= data[self.key_injection_flag.format(band=kwargs["band"])] == False # noqa: E712 

771 values = data[self.key_class] 

772 result &= (values == self.value_compare) if self.value_is_equal else (values != self.value_compare) 

773 if self.plotLabelKey: 

774 self._addValueToPlotInfo(f"injected {self.name_class}", **kwargs) 

775 return result 

776 

777 def getInputSchema(self) -> KeyedDataSchema: 

778 yield from super().getInputSchema() 

779 yield self.key_class, Vector 

780 if self.key_injection_flag: 

781 yield self.key_injection_flag, Vector 

782 

783 

784class InjectedGalaxySelector(InjectedClassSelector): 

785 """A selector for injected galaxies.""" 

786 

787 def setDefaults(self): 

788 self.name_class = "galaxy" 

789 # Assumes not star == galaxy - if there are injected AGN or other 

790 # object classes, this will need to be updated 

791 self.value_is_equal = False 

792 

793 

794class InjectedStarSelector(InjectedClassSelector): 

795 """A selector for injected stars.""" 

796 

797 def setDefaults(self): 

798 self.name_class = "star" 

799 

800 

801class MatchedObjectSelector(RangeSelector): 

802 """A selector that selects matched objects with finite distances.""" 

803 

804 def setDefaults(self): 

805 super().setDefaults() 

806 self.minimum = 0 

807 self.vectorKey = "match_distance" 

808 

809 

810class ReferenceGalaxySelector(ThresholdSelector): 

811 """A selector that selects galaxies from a catalog with a 

812 boolean column identifying unresolved sources. 

813 """ 

814 

815 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

816 result = super().__call__(data=data, **kwargs) 

817 if self.plotLabelKey: 

818 self._addValueToPlotInfo("reference galaxies", **kwargs) 

819 return result 

820 

821 def setDefaults(self): 

822 super().setDefaults() 

823 self.op = "eq" 

824 self.threshold = 0 

825 self.plotLabelKey = "Selection: Galaxies" 

826 self.vectorKey = "refcat_is_pointsource" 

827 

828 

829class ReferenceObjectSelector(RangeSelector): 

830 """A selector that selects all objects from a catalog with a 

831 boolean column identifying unresolved sources. 

832 """ 

833 

834 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

835 result = super().__call__(data=data, **kwargs) 

836 if self.plotLabelKey: 

837 self._addValueToPlotInfo("reference objects", **kwargs) 

838 return result 

839 

840 def setDefaults(self): 

841 super().setDefaults() 

842 self.minimum = 0 

843 self.vectorKey = "refcat_is_pointsource" 

844 

845 

846class ReferenceStarSelector(ThresholdSelector): 

847 """A selector that selects stars from a catalog with a 

848 boolean column identifying unresolved sources. 

849 """ 

850 

851 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

852 result = super().__call__(data=data, **kwargs) 

853 if self.plotLabelKey: 

854 self._addValueToPlotInfo("reference stars", **kwargs) 

855 return result 

856 

857 def setDefaults(self): 

858 super().setDefaults() 

859 self.op = "eq" 

860 self.plotLabelKey = "Selection: Stars" 

861 self.threshold = 1 

862 self.vectorKey = "refcat_is_pointsource"