Coverage for python/lsst/analysis/tools/actions/vector/selectors.py: 28%

294 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-15 03:47 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ( 

24 "FlagSelector", 

25 "CoaddPlotFlagSelector", 

26 "RangeSelector", 

27 "SnSelector", 

28 "ExtendednessSelector", 

29 "SkyObjectSelector", 

30 "SkySourceSelector", 

31 "GoodDiaSourceSelector", 

32 "StarSelector", 

33 "GalaxySelector", 

34 "UnknownSelector", 

35 "VectorSelector", 

36 "VisitPlotFlagSelector", 

37 "ThresholdSelector", 

38 "BandSelector", 

39 "MatchingFlagSelector", 

40 "MagSelector", 

41) 

42 

43import operator 

44from typing import Optional, cast 

45 

46import numpy as np 

47from lsst.pex.config import Field 

48from lsst.pex.config.listField import ListField 

49 

50from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction 

51from ...math import divide, fluxToMag 

52 

53 

54class SelectorBase(VectorAction): 

55 plotLabelKey = Field[str]( 

56 doc="Key to use when populating plot info, ignored if empty string", optional=True, default="" 

57 ) 

58 

59 def _addValueToPlotInfo(self, value, plotLabelKey=None, **kwargs): 

60 if "plotInfo" in kwargs: 

61 if plotLabelKey is not None: 

62 kwargs["plotInfo"][plotLabelKey] = value 

63 elif self.plotLabelKey: 

64 kwargs["plotInfo"][self.plotLabelKey] = value 

65 else: 

66 raise RuntimeError(f"No plotLabelKey provided for value {value}, so can't add to plotInfo") 

67 

68 

69class FlagSelector(VectorAction): 

70 """The base flag selector to use to select valid sources for QA.""" 

71 

72 selectWhenFalse = ListField[str]( 

73 doc="Names of the flag columns to select on when False", optional=False, default=[] 

74 ) 

75 

76 selectWhenTrue = ListField[str]( 

77 doc="Names of the flag columns to select on when True", optional=False, default=[] 

78 ) 

79 

80 def getInputSchema(self) -> KeyedDataSchema: 

81 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue) 

82 return ((col, Vector) for col in allCols) 

83 

84 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

85 """Select on the given flags 

86 

87 Parameters 

88 ---------- 

89 data : `KeyedData` 

90 

91 Returns 

92 ------- 

93 result : `Vector` 

94 A mask of the objects that satisfy the given 

95 flag cuts. 

96 

97 Notes 

98 ----- 

99 Uses the columns in selectWhenFalse and 

100 selectWhenTrue to decide which columns to 

101 select on in each circumstance. 

102 """ 

103 

104 if not self.selectWhenFalse and not self.selectWhenTrue: 

105 raise RuntimeError("No column keys specified") 

106 results: Optional[Vector] = None 

107 

108 for flag in self.selectWhenFalse: # type: ignore 

109 temp = np.array(data[flag.format(**kwargs)] == 0) 

110 if results is not None: 

111 results &= temp # type: ignore 

112 else: 

113 results = temp 

114 

115 for flag in self.selectWhenTrue: 

116 temp = np.array(data[flag.format(**kwargs)] == 1) 

117 if results is not None: 

118 results &= temp # type: ignore 

119 else: 

120 results = temp 

121 # The test at the beginning assures this can never be None 

122 return cast(Vector, results) 

123 

124 

125class CoaddPlotFlagSelector(FlagSelector): 

126 """This default setting makes it take the band from 

127 the kwargs. 

128 """ 

129 

130 bands = ListField[str]( 

131 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

132 default=[], 

133 ) 

134 

135 def getInputSchema(self) -> KeyedDataSchema: 

136 yield from super().getInputSchema() 

137 

138 def refMatchContext(self): 

139 self.selectWhenFalse = [ 

140 "{band}_psfFlux_flag_target", 

141 "{band}_pixelFlags_saturatedCenter_target", 

142 "{band}_extendedness_flag_target", 

143 "xy_flag_target", 

144 ] 

145 self.selectWhenTrue = ["detect_isPatchInner_target", "detect_isDeblendedSource_target"] 

146 

147 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

148 result: Optional[Vector] = None 

149 bands: tuple[str, ...] 

150 match kwargs: 

151 case {"band": band} if not self.bands and self.bands == []: 

152 bands = (band,) 

153 case {"bands": bands} if not self.bands and self.bands == []: 

154 bands = bands 

155 case _ if self.bands: 

156 bands = tuple(self.bands) 

157 case _: 

158 bands = ("",) 

159 for band in bands: 

160 temp = super().__call__(data, **(kwargs | dict(band=band))) 

161 if result is not None: 

162 result &= temp # type: ignore 

163 else: 

164 result = temp 

165 return cast(Vector, result) 

166 

167 def setDefaults(self): 

168 self.selectWhenFalse = [ 

169 "{band}_psfFlux_flag", 

170 "{band}_pixelFlags_saturatedCenter", 

171 "{band}_extendedness_flag", 

172 "xy_flag", 

173 "sky_object", 

174 ] 

175 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"] 

176 

177 

178class MatchingFlagSelector(CoaddPlotFlagSelector): 

179 """The default flag selector to apply pre matching. 

180 The sources are cut down to remove duplicates but 

181 not on quality. 

182 """ 

183 

184 def setDefaults(self): 

185 self.selectWhenFalse = [] 

186 self.selectWhenTrue = ["detect_isPrimary"] 

187 

188 

189class VisitPlotFlagSelector(FlagSelector): 

190 """Select on a set of flags appropriate for making visit-level plots 

191 (i.e., using sourceTable_visit catalogs). 

192 """ 

193 

194 catalogSuffix = Field[str](doc="The suffix to apply to all the keys.", default="") 

195 

196 def getInputSchema(self) -> KeyedDataSchema: 

197 yield from super().getInputSchema() 

198 

199 def refMatchContext(self): 

200 self.selectWhenFalse = [ 

201 "psfFlux_flag_target", 

202 "pixelFlags_saturatedCenter_target", 

203 "extendedness_flag_target", 

204 "centroid_flag_target", 

205 ] 

206 

207 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

208 result: Optional[Vector] = None 

209 temp = super().__call__(data, **kwargs) 

210 if result is not None: 

211 result &= temp # type: ignore 

212 else: 

213 result = temp 

214 

215 return result 

216 

217 def setDefaults(self): 

218 self.selectWhenFalse = [ 

219 "psfFlux_flag", 

220 "pixelFlags_saturatedCenter", 

221 "extendedness_flag", 

222 "centroid_flag", 

223 "sky_source", 

224 ] 

225 

226 

227class RangeSelector(SelectorBase): 

228 """Selects rows within a range, inclusive of min/exclusive of max.""" 

229 

230 vectorKey = Field[str](doc="Key to select from data") 

231 maximum = Field[float](doc="The maximum value", default=np.Inf) 

232 minimum = Field[float](doc="The minimum value", default=np.nextafter(-np.Inf, 0.0)) 

233 

234 def getInputSchema(self) -> KeyedDataSchema: 

235 yield self.vectorKey, Vector 

236 

237 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

238 """Return a mask of rows with values within the specified range. 

239 

240 Parameters 

241 ---------- 

242 data : `KeyedData` 

243 

244 Returns 

245 ------- 

246 result : `Vector` 

247 A mask of the rows with values within the specified range. 

248 """ 

249 values = cast(Vector, data[self.vectorKey]) 

250 mask = (values >= self.minimum) & (values < self.maximum) 

251 

252 return cast(Vector, mask) 

253 

254 

255class SnSelector(SelectorBase): 

256 """Selects points that have S/N > threshold in the given flux type.""" 

257 

258 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux") 

259 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0) 

260 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6) 

261 uncertaintySuffix = Field[str]( 

262 doc="Suffix to add to fluxType to specify uncertainty column", default="Err" 

263 ) 

264 bands = ListField[str]( 

265 doc="The bands to apply the signal to noise cut in." "Takes precedence if bands passed to call", 

266 default=[], 

267 ) 

268 

269 def getInputSchema(self) -> KeyedDataSchema: 

270 fluxCol = self.fluxType 

271 fluxInd = fluxCol.find("lux") + len("lux") 

272 errCol = f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix}" + f"{fluxCol}"[fluxInd:] 

273 yield fluxCol, Vector 

274 yield errCol, Vector 

275 

276 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

277 """Makes a mask of objects that have S/N greater than 

278 self.threshold in self.fluxType 

279 

280 Parameters 

281 ---------- 

282 data : `KeyedData` 

283 The data to perform the selection on. 

284 

285 Returns 

286 ------- 

287 result : `Vector` 

288 A mask of the objects that satisfy the given 

289 S/N cut. 

290 """ 

291 mask: Optional[Vector] = None 

292 bands: tuple[str, ...] 

293 match kwargs: 

294 case {"band": band} if not self.bands and self.bands == []: 

295 bands = (band,) 

296 case {"bands": bands} if not self.bands and self.bands == []: 

297 bands = bands 

298 case _ if self.bands: 

299 bands = tuple(self.bands) 

300 case _: 

301 bands = ("",) 

302 bandStr = ",".join(bands) 

303 for band in bands: 

304 fluxCol = self.fluxType.format(**(kwargs | dict(band=band))) 

305 fluxInd = fluxCol.find("lux") + len("lux") 

306 errCol = ( 

307 f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix.format(**kwargs)}" + f"{fluxCol}"[fluxInd:] 

308 ) 

309 vec = divide(cast(Vector, data[fluxCol]), cast(Vector, data[errCol])) 

310 temp = (vec > self.threshold) & (vec < self.maxSN) 

311 if mask is not None: 

312 mask &= temp # type: ignore 

313 else: 

314 mask = temp 

315 

316 plotLabelStr = "({}) > {:.1f}".format(bandStr, self.threshold) 

317 if self.maxSN < 1e5: 

318 plotLabelStr += " & < {:.1f}".format(self.maxSN) 

319 

320 if self.plotLabelKey == "" or self.plotLabelKey is None: 

321 self._addValueToPlotInfo(plotLabelStr, plotLabelKey="S/N", **kwargs) 

322 else: 

323 self._addValueToPlotInfo(plotLabelStr, **kwargs) 

324 

325 # It should not be possible for mask to be a None now 

326 return np.array(cast(Vector, mask)) 

327 

328 

329class SkyObjectSelector(FlagSelector): 

330 """Selects sky objects in the given band(s).""" 

331 

332 bands = ListField[str]( 

333 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

334 default=["i"], 

335 ) 

336 

337 def getInputSchema(self) -> KeyedDataSchema: 

338 yield from super().getInputSchema() 

339 

340 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

341 result: Optional[Vector] = None 

342 bands: tuple[str, ...] 

343 match kwargs: 

344 case {"band": band} if not self.bands and self.bands == []: 

345 bands = (band,) 

346 case {"bands": bands} if not self.bands and self.bands == []: 

347 bands = bands 

348 case _ if self.bands: 

349 bands = tuple(self.bands) 

350 case _: 

351 bands = ("",) 

352 for band in bands: 

353 temp = super().__call__(data, **(kwargs | dict(band=band))) 

354 if result is not None: 

355 result &= temp # type: ignore 

356 else: 

357 result = temp 

358 return cast(Vector, result) 

359 

360 def setDefaults(self): 

361 self.selectWhenFalse = [ 

362 "{band}_pixelFlags_edge", 

363 ] 

364 self.selectWhenTrue = ["sky_object"] 

365 

366 

367class SkySourceSelector(FlagSelector): 

368 """Selects sky sources from sourceTables.""" 

369 

370 def getInputSchema(self) -> KeyedDataSchema: 

371 yield from super().getInputSchema() 

372 

373 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

374 result: Optional[Vector] = None 

375 temp = super().__call__(data, **(kwargs)) 

376 if result is not None: 

377 result &= temp # type: ignore 

378 else: 

379 result = temp 

380 return result 

381 

382 def setDefaults(self): 

383 self.selectWhenFalse = [ 

384 "pixelFlags_edge", 

385 ] 

386 self.selectWhenTrue = ["sky_source"] 

387 

388 

389class GoodDiaSourceSelector(FlagSelector): 

390 """Selects good DIA sources from diaSourceTables.""" 

391 

392 def getInputSchema(self) -> KeyedDataSchema: 

393 yield from super().getInputSchema() 

394 

395 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

396 result: Optional[Vector] = None 

397 temp = super().__call__(data, **(kwargs)) 

398 if result is not None: 

399 result &= temp # type: ignore 

400 else: 

401 result = temp 

402 return result 

403 

404 def setDefaults(self): 

405 # These default flag names are correct for AP data products 

406 self.selectWhenFalse = [ 

407 "base_PixelFlags_flag_bad", 

408 "base_PixelFlags_flag_suspect", 

409 "base_PixelFlags_flag_saturatedCenter", 

410 "base_PixelFlags_flag_interpolated", 

411 "base_PixelFlags_flag_interpolatedCenter", 

412 "base_PixelFlags_flag_edge", 

413 ] 

414 

415 def drpContext(self): 

416 # These flag names are correct for DRP data products 

417 newSelectWhenFalse = [ 

418 flag.replace("base_PixelFlags_flag", "pixelFlags") for flag in self.selectWhenFalse 

419 ] 

420 self.selectWhenFalse = newSelectWhenFalse 

421 

422 

423class ExtendednessSelector(VectorAction): 

424 """A selector that picks between extended and point sources.""" 

425 

426 vectorKey = Field[str]( 

427 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness" 

428 ) 

429 

430 def getInputSchema(self) -> KeyedDataSchema: 

431 return ((self.vectorKey, Vector),) 

432 

433 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

434 key = self.vectorKey.format(**kwargs) 

435 return cast(Vector, data[key]) 

436 

437 

438class StarSelector(ExtendednessSelector): 

439 """A selector that picks out stars based off of their 

440 extendedness values. 

441 """ 

442 

443 extendedness_maximum = Field[float]( 

444 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float 

445 ) 

446 

447 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

448 extendedness = super().__call__(data, **kwargs) 

449 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum))) 

450 

451 

452class GalaxySelector(ExtendednessSelector): 

453 """A selector that picks out galaxies based off of their 

454 extendedness values. 

455 """ 

456 

457 extendedness_minimum = Field[float]( 

458 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5 

459 ) 

460 

461 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

462 extendedness = super().__call__(data, **kwargs) 

463 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore 

464 

465 

466class UnknownSelector(ExtendednessSelector): 

467 """A selector that picks out unclassified objects based off of their 

468 extendedness values. 

469 """ 

470 

471 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

472 extendedness = super().__call__(data, **kwargs) 

473 return extendedness == 9 

474 

475 

476class VectorSelector(VectorAction): 

477 """Load a boolean vector from KeyedData and return it for use as a 

478 selector. 

479 """ 

480 

481 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask") 

482 

483 def getInputSchema(self) -> KeyedDataSchema: 

484 return ((self.vectorKey, Vector),) 

485 

486 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

487 return cast(Vector, data[self.vectorKey.format(**kwargs)]) 

488 

489 

490class ThresholdSelector(SelectorBase): 

491 """Return a mask corresponding to an applied threshold.""" 

492 

493 op = Field[str](doc="Operator name.") 

494 threshold = Field[float](doc="Threshold to apply.") 

495 vectorKey = Field[str](doc="Name of column") 

496 

497 def getInputSchema(self) -> KeyedDataSchema: 

498 return ((self.vectorKey, Vector),) 

499 

500 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

501 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold) 

502 return cast(Vector, mask) 

503 

504 

505class BandSelector(VectorAction): 

506 """Makes a mask for sources observed in a specified set of bands.""" 

507 

508 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band") 

509 bands = ListField[str]( 

510 doc="The bands to select. `None` indicates no band selection applied.", 

511 default=[], 

512 ) 

513 

514 def getInputSchema(self) -> KeyedDataSchema: 

515 return ((self.vectorKey, Vector),) 

516 

517 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

518 bands: Optional[tuple[str, ...]] 

519 match kwargs: 

520 case {"band": band} if not self.bands and self.bands == []: 

521 bands = (band,) 

522 case {"bands": bands} if not self.bands and self.bands == []: 

523 bands = bands 

524 case _ if self.bands: 

525 bands = tuple(self.bands) 

526 case _: 

527 bands = None 

528 if bands: 

529 mask = np.in1d(data[self.vectorKey], bands) 

530 else: 

531 # No band selection is applied, i.e., select all rows 

532 mask = np.full(len(data[self.vectorKey]), True) # type: ignore 

533 return cast(Vector, mask) 

534 

535 

536class ParentObjectSelector(FlagSelector): 

537 """Select only parent objects that are not sky objects.""" 

538 

539 def setDefaults(self): 

540 # This selects all of the parents 

541 self.selectWhenFalse = [ 

542 "detect_isDeblendedModelSource", 

543 "sky_object", 

544 ] 

545 self.selectWhenTrue = ["detect_isPatchInner"] 

546 

547 

548class ChildObjectSelector(RangeSelector): 

549 """Select only children from deblended parents""" 

550 

551 vectorKey = Field[str](doc="Key to select from data", default="parentSourceId") 

552 

553 def getInputSchema(self) -> KeyedDataSchema: 

554 yield self.vectorKey, Vector 

555 

556 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

557 """Return a mask of rows with values within the specified range. 

558 

559 Parameters 

560 ---------- 

561 data : `KeyedData` 

562 

563 Returns 

564 ------- 

565 result : `Vector` 

566 A mask of the rows with values within the specified range. 

567 """ 

568 values = cast(Vector, data[self.vectorKey]) 

569 mask = values > 0 

570 

571 return cast(Vector, mask) 

572 

573 

574class MagSelector(SelectorBase): 

575 """Selects points that have minMag < mag (AB) < maxMag. 

576 

577 The magnitude is based on the given fluxType. 

578 """ 

579 

580 fluxType = Field[str](doc="Flux type to calculate the magnitude in.", default="{band}_psfFlux") 

581 minMag = Field[float](doc="Minimum mag to include in the sample.", default=-1e6) 

582 maxMag = Field[float](doc="Maximum mag to include in the sample.", default=1e6) 

583 fluxUnit = Field[str](doc="Astropy unit of flux vector", default="nJy") 

584 returnMillimags = Field[bool](doc="Use millimags or not?", default=False) 

585 bands = ListField[str]( 

586 doc="The band(s) to apply the magnitude cut in. Takes precedence if bands passed to call.", 

587 default=[], 

588 ) 

589 

590 def getInputSchema(self) -> KeyedDataSchema: 

591 fluxCol = self.fluxType 

592 yield fluxCol, Vector 

593 

594 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

595 """Make a mask of that satisfies self.minMag < mag < self.maxMag. 

596 

597 The magnitude is based on the flux in self.fluxType. 

598 

599 Parameters 

600 ---------- 

601 data : `KeyedData` 

602 The data to perform the magnitude selection on. 

603 

604 Returns 

605 ------- 

606 result : `Vector` 

607 A mask of the objects that satisfy the given magnitude cut. 

608 """ 

609 mask: Optional[Vector] = None 

610 bands: tuple[str, ...] 

611 match kwargs: 

612 case {"band": band} if not self.bands and self.bands == []: 

613 bands = (band,) 

614 case {"bands": bands} if not self.bands and self.bands == []: 

615 bands = bands 

616 case _ if self.bands: 

617 bands = tuple(self.bands) 

618 case _: 

619 bands = ("",) 

620 bandStr = ",".join(bands) 

621 for band in bands: 

622 fluxCol = self.fluxType.format(**(kwargs | dict(band=band))) 

623 vec = fluxToMag( 

624 cast(Vector, data[fluxCol]), 

625 flux_unit=self.fluxUnit, 

626 return_millimags=self.returnMillimags, 

627 ) 

628 temp = (vec > self.minMag) & (vec < self.maxMag) 

629 if mask is not None: 

630 mask &= temp # type: ignore 

631 else: 

632 mask = temp 

633 

634 plotLabelStr = "" 

635 if self.maxMag < 100: 

636 plotLabelStr += "({}) < {:.1f}".format(bandStr, self.maxMag) 

637 if self.minMag > -100: 

638 if bandStr in plotLabelStr: 

639 plotLabelStr += " & < {:.1f}".format(self.minMag) 

640 else: 

641 plotLabelStr += "({}) < {:.1f}".format(bandStr, self.minMag) 

642 if self.plotLabelKey == "" or self.plotLabelKey is None: 

643 self._addValueToPlotInfo(plotLabelStr, plotLabelKey="Mag", **kwargs) 

644 else: 

645 self._addValueToPlotInfo(plotLabelStr, **kwargs) 

646 

647 # It should not be possible for mask to be a None now 

648 return np.array(cast(Vector, mask))