Coverage for python / lsst / analysis / tools / actions / vector / selectors.py: 29%

389 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-15 00:23 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ( 

24 "SelectorBase", 

25 "FlagSelector", 

26 "CoaddPlotFlagSelector", 

27 "RangeSelector", 

28 "SetSelector", 

29 "SnSelector", 

30 "ExtendednessSelector", 

31 "SkyObjectSelector", 

32 "SkySourceSelector", 

33 "GoodDiaSourceSelector", 

34 "StarSelector", 

35 "GalaxySelector", 

36 "UnknownSelector", 

37 "VectorSelector", 

38 "FiniteSelector", 

39 "VisitPlotFlagSelector", 

40 "ThresholdSelector", 

41 "BandSelector", 

42 "MatchingFlagSelector", 

43 "MagSelector", 

44 "InjectedClassSelector", 

45 "InjectedGalaxySelector", 

46 "InjectedObjectSelector", 

47 "InjectedStarSelector", 

48 "MatchedObjectSelector", 

49 "ReferenceGalaxySelector", 

50 "ReferenceObjectSelector", 

51 "ReferenceStarSelector", 

52) 

53 

54import operator 

55from typing import Optional, cast 

56 

57import numpy as np 

58from lsst.pex.config import Field 

59from lsst.pex.config.listField import ListField 

60 

61from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction 

62from ...math import divide, fluxToMag 

63 

64 

65class SelectorBase(VectorAction): 

66 plotLabelKey = Field[str]( 

67 doc="Key to use when populating plot info, ignored if empty string", optional=True, default="" 

68 ) 

69 

70 def _addValueToPlotInfo(self, value, plotLabelKey=None, **kwargs): 

71 if "plotInfo" in kwargs: 

72 if plotLabelKey is not None: 

73 kwargs["plotInfo"][plotLabelKey] = value 

74 elif self.plotLabelKey: 

75 kwargs["plotInfo"][self.plotLabelKey] = value 

76 else: 

77 raise RuntimeError(f"No plotLabelKey provided for value {value}, so can't add to plotInfo") 

78 

79 

80class FlagSelector(VectorAction): 

81 """The base flag selector to use to select valid sources for QA.""" 

82 

83 selectWhenFalse = ListField[str]( 

84 doc="Names of the flag columns to select on when False", optional=False, default=[] 

85 ) 

86 

87 selectWhenTrue = ListField[str]( 

88 doc="Names of the flag columns to select on when True", optional=False, default=[] 

89 ) 

90 

91 def getInputSchema(self) -> KeyedDataSchema: 

92 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue) 

93 return ((col, Vector) for col in allCols) 

94 

95 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

96 """Select on the given flags 

97 

98 Parameters 

99 ---------- 

100 data : `KeyedData` 

101 

102 Returns 

103 ------- 

104 result : `Vector` 

105 A mask of the objects that satisfy the given 

106 flag cuts. 

107 

108 Notes 

109 ----- 

110 Uses the columns in selectWhenFalse and 

111 selectWhenTrue to decide which columns to 

112 select on in each circumstance. 

113 """ 

114 

115 if not self.selectWhenFalse and not self.selectWhenTrue: 

116 raise RuntimeError("No column keys specified") 

117 results: Optional[Vector] = None 

118 

119 for flag in self.selectWhenFalse: # type: ignore 

120 temp = np.array(data[flag.format(**kwargs)] == 0) 

121 if results is not None: 

122 results &= temp # type: ignore 

123 else: 

124 results = temp 

125 

126 for flag in self.selectWhenTrue: 

127 temp = np.array(data[flag.format(**kwargs)] == 1) 

128 if results is not None: 

129 results &= temp # type: ignore 

130 else: 

131 results = temp 

132 # The test at the beginning assures this can never be None 

133 return cast(Vector, results) 

134 

135 

136class CoaddPlotFlagSelector(FlagSelector): 

137 """This default setting makes it take the band from 

138 the kwargs. 

139 """ 

140 

141 bands = ListField[str]( 

142 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

143 default=[], 

144 ) 

145 

146 def getInputSchema(self) -> KeyedDataSchema: 

147 yield from super().getInputSchema() 

148 

149 def refMatchContext(self): 

150 self.selectWhenFalse = [ 

151 "{band}_psfFlux_flag_target", 

152 "{band}_pixelFlags_saturatedCenter_target", 

153 "{band}_extendedness_flag_target", 

154 "coord_flag_target", 

155 ] 

156 self.selectWhenTrue = ["detect_isPatchInner_target", "detect_isDeblendedSource_target"] 

157 

158 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

159 result: Optional[Vector] = None 

160 bands: tuple[str, ...] 

161 match kwargs: 

162 case {"band": band} if not self.bands and self.bands == []: 

163 bands = (band,) 

164 case {"bands": bands} if not self.bands and self.bands == []: 

165 bands = bands 

166 case _ if self.bands: 

167 bands = tuple(self.bands) 

168 case _: 

169 bands = ("",) 

170 for band in bands: 

171 temp = super().__call__(data, **(kwargs | dict(band=band))) 

172 if result is not None: 

173 result &= temp # type: ignore 

174 else: 

175 result = temp 

176 return cast(Vector, result) 

177 

178 def setDefaults(self): 

179 self.selectWhenFalse = [ 

180 "{band}_psfFlux_flag", 

181 "{band}_pixelFlags_saturatedCenter", 

182 "{band}_extendedness_flag", 

183 "coord_flag", 

184 "sky_object", 

185 ] 

186 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"] 

187 

188 

189class MatchingFlagSelector(CoaddPlotFlagSelector): 

190 """The default flag selector to apply pre matching. 

191 The sources are cut down to remove duplicates but 

192 not on quality. 

193 """ 

194 

195 def setDefaults(self): 

196 self.selectWhenFalse = [] 

197 self.selectWhenTrue = ["detect_isPrimary"] 

198 

199 

200class VisitPlotFlagSelector(FlagSelector): 

201 """Select on a set of flags appropriate for making visit-level plots 

202 (i.e., using sourceTable_visit catalogs). 

203 """ 

204 

205 catalogSuffix = Field[str](doc="The suffix to apply to all the keys.", default="") 

206 

207 def getInputSchema(self) -> KeyedDataSchema: 

208 yield from super().getInputSchema() 

209 

210 def refMatchContext(self): 

211 self.selectWhenFalse = [ 

212 "psfFlux_flag_target", 

213 "pixelFlags_saturatedCenter_target", 

214 "extendedness_flag_target", 

215 "centroid_flag_target", 

216 ] 

217 

218 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

219 result: Optional[Vector] = None 

220 temp = super().__call__(data, **kwargs) 

221 if result is not None: 

222 result &= temp # type: ignore 

223 else: 

224 result = temp 

225 

226 return result 

227 

228 def setDefaults(self): 

229 self.selectWhenFalse = [ 

230 "psfFlux_flag", 

231 "pixelFlags_saturatedCenter", 

232 "extendedness_flag", 

233 "centroid_flag", 

234 "sky_source", 

235 ] 

236 

237 

238class RangeSelector(SelectorBase): 

239 """Selects rows within a range, inclusive of min/exclusive of max.""" 

240 

241 vectorKey = Field[str](doc="Key to select from data") 

242 maximum = Field[float](doc="The maximum value (exclusive)", default=np.inf) 

243 minimum = Field[float](doc="The minimum value (inclusive)", default=np.nextafter(-np.inf, 0.0)) 

244 

245 def getInputSchema(self) -> KeyedDataSchema: 

246 yield self.vectorKey, Vector 

247 

248 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

249 """Return a mask of rows with values within the specified range. 

250 

251 Parameters 

252 ---------- 

253 data : `KeyedData` 

254 

255 Returns 

256 ------- 

257 result : `Vector` 

258 A mask of the rows with values within the specified range. 

259 """ 

260 values = cast(Vector, data[self.vectorKey]) 

261 mask = (values >= self.minimum) & (values < self.maximum) 

262 

263 return cast(Vector, mask) 

264 

265 

266class SetSelector(SelectorBase): 

267 """Selects rows with any number of column values within a given set. 

268 

269 For example, given a set of patches (1, 2, 3), and a set of columns 

270 (index_1, index_2), return all rows with either index_1 or index_2 

271 in the set (1, 2, 3). 

272 

273 Notes 

274 ----- 

275 The values are given as floats for flexibility. Integers above 

276 the floating point limit (2^53 + 1 = 9,007,199,254,740,993 for 64 bits) 

277 will not compare exactly with their float representations. 

278 """ 

279 

280 vectorKeys = ListField[str]( 

281 doc="Keys to select from data", 

282 default=[], 

283 listCheck=lambda x: (len(x) > 0) & (len(x) == len(set(x))), 

284 ) 

285 values = ListField[float]( 

286 doc="The set of acceptable values", 

287 default=[], 

288 listCheck=lambda x: (len(x) > 0) & (len(x) == len(set(x))), 

289 ) 

290 

291 def getInputSchema(self) -> KeyedDataSchema: 

292 yield from ((key, Vector) for key in self.vectorKeys) 

293 

294 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

295 """Return a mask of rows with values in the specified set. 

296 

297 Parameters 

298 ---------- 

299 data : `KeyedData` 

300 

301 Returns 

302 ------- 

303 result : `Vector` 

304 A mask of the rows with values in the specified set. 

305 """ 

306 mask = np.zeros_like(data[self.vectorKeys[0]], dtype=bool) 

307 for key in self.vectorKeys: 

308 values = cast(Vector, data[key]) 

309 for compare in self.values: 

310 mask |= values == compare 

311 

312 return cast(Vector, mask) 

313 

314 

315class PatchSelector(SetSelector): 

316 """Select rows within a set of patches.""" 

317 

318 def setDefaults(self): 

319 super().setDefaults() 

320 self.vectorKeys = ["patch"] 

321 

322 

323class SnSelector(SelectorBase): 

324 """Selects points that have S/N > threshold in the given flux type.""" 

325 

326 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux") 

327 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0) 

328 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6) 

329 uncertaintySuffix = Field[str]( 

330 doc="Suffix to add to fluxType to specify uncertainty column", default="Err" 

331 ) 

332 bands = ListField[str]( 

333 doc="The bands to apply the signal to noise cut in. Takes precedence if bands passed to call", 

334 default=[], 

335 ) 

336 

337 def getInputSchema(self) -> KeyedDataSchema: 

338 fluxCol = self.fluxType 

339 fluxInd = fluxCol.find("lux") + len("lux") 

340 errCol = f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix}" + f"{fluxCol}"[fluxInd:] 

341 yield fluxCol, Vector 

342 yield errCol, Vector 

343 

344 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

345 """Makes a mask of objects that have S/N greater than 

346 self.threshold in self.fluxType 

347 

348 Parameters 

349 ---------- 

350 data : `KeyedData` 

351 The data to perform the selection on. 

352 

353 Returns 

354 ------- 

355 result : `Vector` 

356 A mask of the objects that satisfy the given 

357 S/N cut. 

358 """ 

359 mask: Optional[Vector] = None 

360 bands: tuple[str, ...] 

361 match kwargs: 

362 case {"band": band} if not self.bands and self.bands == []: 

363 bands = (band,) 

364 case {"bands": bands} if not self.bands and self.bands == []: 

365 bands = bands 

366 case _ if self.bands: 

367 bands = tuple(self.bands) 

368 case _: 

369 bands = ("",) 

370 bandStr = ",".join(bands) 

371 for band in bands: 

372 fluxCol = self.fluxType.format(**(kwargs | dict(band=band))) 

373 fluxInd = fluxCol.find("lux") + len("lux") 

374 errCol = ( 

375 f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix.format(**kwargs)}" + f"{fluxCol}"[fluxInd:] 

376 ) 

377 vec = divide(cast(Vector, data[fluxCol]), cast(Vector, data[errCol])) 

378 temp = (vec > self.threshold) & (vec < self.maxSN) 

379 if mask is not None: 

380 mask &= temp # type: ignore 

381 else: 

382 mask = temp 

383 

384 plotLabelStr = "({}) > {:.1f}".format(bandStr, self.threshold) 

385 if self.maxSN < 1e5: 

386 plotLabelStr += " & < {:.1f}".format(self.maxSN) 

387 

388 if self.plotLabelKey == "" or self.plotLabelKey is None: 

389 self._addValueToPlotInfo(plotLabelStr, plotLabelKey="S/N", **kwargs) 

390 else: 

391 self._addValueToPlotInfo(plotLabelStr, **kwargs) 

392 

393 # It should not be possible for mask to be a None now 

394 return np.array(cast(Vector, mask)) 

395 

396 

397class SkyObjectSelector(FlagSelector): 

398 """Selects sky objects in the given band(s).""" 

399 

400 bands = ListField[str]( 

401 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

402 default=[], 

403 ) 

404 

405 def getInputSchema(self) -> KeyedDataSchema: 

406 yield from super().getInputSchema() 

407 

408 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

409 result: Optional[Vector] = None 

410 bands: tuple[str, ...] 

411 match kwargs: 

412 case {"band": band} if not self.bands and self.bands == []: 

413 bands = (band,) 

414 case {"bands": bands} if not self.bands and self.bands == []: 

415 bands = bands 

416 case _ if self.bands: 

417 bands = tuple(self.bands) 

418 case _: 

419 bands = ("",) 

420 for band in bands: 

421 temp = super().__call__(data, **(kwargs | dict(band=band))) 

422 if result is not None: 

423 result &= temp # type: ignore 

424 else: 

425 result = temp 

426 return cast(Vector, result) 

427 

428 def setDefaults(self): 

429 super().setDefaults() 

430 self.selectWhenFalse = [ 

431 "{band}_pixelFlags_edge", 

432 "{band}_pixelFlags_nodata", 

433 ] 

434 self.selectWhenTrue = ["sky_object"] 

435 

436 

437class SkySourceSelector(FlagSelector): 

438 """Selects sky sources from sourceTables.""" 

439 

440 def getInputSchema(self) -> KeyedDataSchema: 

441 yield from super().getInputSchema() 

442 

443 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

444 result: Optional[Vector] = None 

445 temp = super().__call__(data, **(kwargs)) 

446 if result is not None: 

447 result &= temp # type: ignore 

448 else: 

449 result = temp 

450 return result 

451 

452 def setDefaults(self): 

453 super().setDefaults() 

454 self.selectWhenFalse = [ 

455 "pixelFlags_edge", 

456 "pixelFlags_nodata", 

457 ] 

458 self.selectWhenTrue = ["sky_source"] 

459 

460 

461class GoodDiaSourceSelector(FlagSelector): 

462 """Selects good DIA sources from diaSourceTables.""" 

463 

464 def getInputSchema(self) -> KeyedDataSchema: 

465 yield from super().getInputSchema() 

466 

467 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

468 result: Optional[Vector] = None 

469 temp = super().__call__(data, **(kwargs)) 

470 if result is not None: 

471 result &= temp # type: ignore 

472 else: 

473 result = temp 

474 return result 

475 

476 def setDefaults(self): 

477 super().setDefaults() 

478 # These default flag names are correct for AP data products 

479 self.selectWhenFalse = [ 

480 "pixelFlags_bad", 

481 "pixelFlags_saturatedCenter", 

482 "pixelFlags_interpolatedCenter", 

483 "pixelFlags_edge", 

484 "pixelFlags_nodata", 

485 ] 

486 

487 

488class ExtendednessSelector(SelectorBase): 

489 """A selector that picks between extended and point sources.""" 

490 

491 vectorKey = Field[str]( 

492 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness" 

493 ) 

494 

495 def getInputSchema(self) -> KeyedDataSchema: 

496 return ((self.vectorKey, Vector),) 

497 

498 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

499 key = self.vectorKey.format(**kwargs) 

500 return cast(Vector, data[key]) 

501 

502 

503class StarSelector(ExtendednessSelector): 

504 """A selector that picks out stars based off of their 

505 extendedness values. 

506 """ 

507 

508 extendedness_maximum = Field[float]( 

509 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float 

510 ) 

511 

512 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

513 extendedness = super().__call__(data, **kwargs) 

514 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum))) 

515 

516 

517class GalaxySelector(ExtendednessSelector): 

518 """A selector that picks out galaxies based off of their 

519 extendedness values. 

520 """ 

521 

522 extendedness_minimum = Field[float]( 

523 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5 

524 ) 

525 

526 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

527 extendedness = super().__call__(data, **kwargs) 

528 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore 

529 

530 

531class UnknownSelector(ExtendednessSelector): 

532 """A selector that picks out unclassified objects based off of their 

533 extendedness values. 

534 """ 

535 

536 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

537 extendedness = super().__call__(data, **kwargs) 

538 return extendedness == 9 

539 

540 

541class FiniteSelector(VectorAction): 

542 """Return a mask of finite values for a vector key""" 

543 

544 vectorKey = Field[str](doc="Key to make a mask of finite values for.") 

545 

546 def getInputSchema(self) -> KeyedDataSchema: 

547 return ((self.vectorKey, Vector),) 

548 

549 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

550 return cast(Vector, np.isfinite(data[self.vectorKey.format(**kwargs)])) 

551 

552 

553class VectorSelector(VectorAction): 

554 """Load a boolean vector from KeyedData and return it for use as a 

555 selector. 

556 """ 

557 

558 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask") 

559 

560 def getInputSchema(self) -> KeyedDataSchema: 

561 return ((self.vectorKey, Vector),) 

562 

563 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

564 return cast(Vector, data[self.vectorKey.format(**kwargs)]) 

565 

566 

567class ThresholdSelector(SelectorBase): 

568 """Return a mask corresponding to an applied threshold.""" 

569 

570 op = Field[str](doc="Operator name.") 

571 threshold = Field[float](doc="Threshold to apply.") 

572 vectorKey = Field[str](doc="Name of column") 

573 

574 def getInputSchema(self) -> KeyedDataSchema: 

575 return ((self.vectorKey, Vector),) 

576 

577 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

578 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold) 

579 return cast(Vector, mask) 

580 

581 

582class BandSelector(VectorAction): 

583 """Makes a mask for sources observed in a specified set of bands.""" 

584 

585 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band") 

586 bands = ListField[str]( 

587 doc="The bands to select. `None` indicates no band selection applied.", 

588 default=[], 

589 ) 

590 

591 def getInputSchema(self) -> KeyedDataSchema: 

592 return ((self.vectorKey, Vector),) 

593 

594 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

595 bands: Optional[tuple[str, ...]] 

596 match kwargs: 

597 case {"band": band} if not self.bands and self.bands == []: 

598 bands = (band,) 

599 case {"bands": bands} if not self.bands and self.bands == []: 

600 bands = bands 

601 case _ if self.bands: 

602 bands = tuple(self.bands) 

603 case _: 

604 bands = None 

605 if bands: 

606 mask = np.isin(data[self.vectorKey], bands).flatten() 

607 else: 

608 # No band selection is applied, i.e., select all rows 

609 mask = np.full(len(data[self.vectorKey]), True) # type: ignore 

610 return cast(Vector, mask) 

611 

612 

613class ParentObjectSelector(FlagSelector): 

614 """Select only parent objects that are not sky objects.""" 

615 

616 def setDefaults(self): 

617 # This selects all of the parents 

618 self.selectWhenFalse = [ 

619 "sky_object", 

620 ] 

621 

622 

623class ChildObjectSelector(RangeSelector): 

624 """Select only children from deblended parents""" 

625 

626 vectorKey = Field[str](doc="Key to select from data", default="parentSourceId") 

627 

628 def getInputSchema(self) -> KeyedDataSchema: 

629 yield self.vectorKey, Vector 

630 

631 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

632 """Return a mask of rows with values within the specified range. 

633 

634 Parameters 

635 ---------- 

636 data : `KeyedData` 

637 

638 Returns 

639 ------- 

640 result : `Vector` 

641 A mask of the rows with values within the specified range. 

642 """ 

643 values = cast(Vector, data[self.vectorKey]) 

644 mask = values > 0 

645 

646 return cast(Vector, mask) 

647 

648 

649class MagSelector(SelectorBase): 

650 """Selects points that have minMag < mag (AB) < maxMag. 

651 

652 The magnitude is based on the given fluxType. 

653 """ 

654 

655 fluxType = Field[str](doc="Flux type to calculate the magnitude in.", default="{band}_psfFlux") 

656 minMag = Field[float](doc="Minimum mag to include in the sample.", default=-1e6) 

657 maxMag = Field[float](doc="Maximum mag to include in the sample.", default=1e6) 

658 fluxUnit = Field[str](doc="Astropy unit of flux vector", default="nJy") 

659 returnMillimags = Field[bool](doc="Use millimags or not?", default=False) 

660 bands = ListField[str]( 

661 doc="The band(s) to apply the magnitude cut in. Takes precedence if bands passed to call.", 

662 default=[], 

663 ) 

664 

665 def getInputSchema(self) -> KeyedDataSchema: 

666 fluxCol = self.fluxType 

667 yield fluxCol, Vector 

668 

669 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

670 """Make a mask of that satisfies self.minMag < mag < self.maxMag. 

671 

672 The magnitude is based on the flux in self.fluxType. 

673 

674 Parameters 

675 ---------- 

676 data : `KeyedData` 

677 The data to perform the magnitude selection on. 

678 

679 Returns 

680 ------- 

681 result : `Vector` 

682 A mask of the objects that satisfy the given magnitude cut. 

683 """ 

684 mask: Optional[Vector] = None 

685 bands: tuple[str, ...] 

686 match kwargs: 

687 case {"band": band} if not self.bands and self.bands == []: 

688 bands = (band,) 

689 case {"bands": bands} if not self.bands and self.bands == []: 

690 bands = bands 

691 case _ if self.bands: 

692 bands = tuple(self.bands) 

693 case _: 

694 bands = ("",) 

695 bandStr = ",".join(bands) 

696 for band in bands: 

697 fluxCol = self.fluxType.format(**(kwargs | dict(band=band))) 

698 vec = fluxToMag( 

699 cast(Vector, data[fluxCol]), 

700 flux_unit=self.fluxUnit, 

701 return_millimags=self.returnMillimags, 

702 ) 

703 temp = (vec > self.minMag) & (vec < self.maxMag) 

704 if mask is not None: 

705 mask &= temp # type: ignore 

706 else: 

707 mask = temp 

708 

709 plotLabelStr = "" 

710 if self.maxMag < 100: 

711 plotLabelStr += "({}) < {:.1f}".format(bandStr, self.maxMag) 

712 if self.minMag > -100: 

713 if bandStr in plotLabelStr: 

714 plotLabelStr += " & < {:.1f}".format(self.minMag) 

715 else: 

716 plotLabelStr += "({}) < {:.1f}".format(bandStr, self.minMag) 

717 if self.plotLabelKey == "" or self.plotLabelKey is None: 

718 self._addValueToPlotInfo(plotLabelStr, plotLabelKey="Mag", **kwargs) 

719 else: 

720 self._addValueToPlotInfo(plotLabelStr, **kwargs) 

721 

722 # It should not be possible for mask to be a None now 

723 return np.array(cast(Vector, mask)) 

724 

725 

726class InjectedObjectSelector(SelectorBase): 

727 """A selector for injected objects.""" 

728 

729 vectorKey = Field[str](doc="Key to select from data", default="ref_injected_isPrimary") 

730 

731 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

732 key = self.vectorKey.format(**kwargs) 

733 result = cast(Vector, data[key] == 1) 

734 return result 

735 

736 def getInputSchema(self) -> KeyedDataSchema: 

737 yield self.vectorKey, Vector 

738 

739 

740class InjectedClassSelector(InjectedObjectSelector): 

741 """A selector for injected objects of a given class.""" 

742 

743 key_class = Field[str]( 

744 doc="Key for the field indicating the class of the object", 

745 default="ref_source_type", 

746 ) 

747 key_injection_flag = Field[str]( 

748 doc="Key for the field indicating that the object was not injected (per band)", 

749 default="ref_{band}_injection_flag", 

750 ) 

751 name_class = Field[str]( 

752 doc="Name of the class of objects", 

753 ) 

754 value_compare = Field[str]( 

755 doc="Value of the type_key field for objects that are stars", 

756 default="DeltaFunction", 

757 ) 

758 value_is_equal = Field[bool]( 

759 doc="Whether the value must equal value_compare to be of this class", 

760 default=True, 

761 ) 

762 

763 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

764 result = super().__call__(data, **kwargs) 

765 if self.key_injection_flag: 

766 result &= data[self.key_injection_flag.format(band=kwargs["band"])] == False # noqa: E712 

767 values = data[self.key_class] 

768 result &= (values == self.value_compare) if self.value_is_equal else (values != self.value_compare) 

769 if self.plotLabelKey: 

770 self._addValueToPlotInfo(f"injected {self.name_class}", **kwargs) 

771 return result 

772 

773 def getInputSchema(self) -> KeyedDataSchema: 

774 yield from super().getInputSchema() 

775 yield self.key_class, Vector 

776 if self.key_injection_flag: 

777 yield self.key_injection_flag, Vector 

778 

779 

780class InjectedGalaxySelector(InjectedClassSelector): 

781 """A selector for injected galaxies.""" 

782 

783 def setDefaults(self): 

784 self.name_class = "galaxy" 

785 # Assumes not star == galaxy - if there are injected AGN or other 

786 # object classes, this will need to be updated 

787 self.value_is_equal = False 

788 

789 

790class InjectedStarSelector(InjectedClassSelector): 

791 """A selector for injected stars.""" 

792 

793 def setDefaults(self): 

794 self.name_class = "star" 

795 

796 

797class MatchedObjectSelector(RangeSelector): 

798 """A selector that selects matched objects with finite distances.""" 

799 

800 def setDefaults(self): 

801 super().setDefaults() 

802 self.minimum = 0 

803 self.vectorKey = "match_distance" 

804 

805 

806class ReferenceGalaxySelector(ThresholdSelector): 

807 """A selector that selects galaxies from a catalog with a 

808 boolean column identifying unresolved sources. 

809 """ 

810 

811 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

812 result = super().__call__(data=data, **kwargs) 

813 if self.plotLabelKey: 

814 self._addValueToPlotInfo("reference galaxies", **kwargs) 

815 return result 

816 

817 def setDefaults(self): 

818 super().setDefaults() 

819 self.op = "eq" 

820 self.threshold = 0 

821 self.plotLabelKey = "Selection: Galaxies" 

822 self.vectorKey = "refcat_is_pointsource" 

823 

824 

825class ReferenceObjectSelector(RangeSelector): 

826 """A selector that selects all objects from a catalog with a 

827 boolean column identifying unresolved sources. 

828 """ 

829 

830 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

831 result = super().__call__(data=data, **kwargs) 

832 if self.plotLabelKey: 

833 self._addValueToPlotInfo("reference objects", **kwargs) 

834 return result 

835 

836 def setDefaults(self): 

837 super().setDefaults() 

838 self.minimum = 0 

839 self.vectorKey = "refcat_is_pointsource" 

840 

841 

842class ReferenceStarSelector(ThresholdSelector): 

843 """A selector that selects stars from a catalog with a 

844 boolean column identifying unresolved sources. 

845 """ 

846 

847 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

848 result = super().__call__(data=data, **kwargs) 

849 if self.plotLabelKey: 

850 self._addValueToPlotInfo("reference stars", **kwargs) 

851 return result 

852 

853 def setDefaults(self): 

854 super().setDefaults() 

855 self.op = "eq" 

856 self.plotLabelKey = "Selection: Stars" 

857 self.threshold = 1 

858 self.vectorKey = "refcat_is_pointsource"