Coverage for python/lsst/analysis/tools/actions/vector/selectors.py: 27%

283 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-02-07 14:40 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ( 

24 "FlagSelector", 

25 "CoaddPlotFlagSelector", 

26 "RangeSelector", 

27 "SnSelector", 

28 "ExtendednessSelector", 

29 "SkyObjectSelector", 

30 "SkySourceSelector", 

31 "GoodDiaSourceSelector", 

32 "StarSelector", 

33 "GalaxySelector", 

34 "UnknownSelector", 

35 "VectorSelector", 

36 "VisitPlotFlagSelector", 

37 "ThresholdSelector", 

38 "BandSelector", 

39 "MatchingFlagSelector", 

40 "MagSelector", 

41) 

42 

43import operator 

44from typing import Optional, cast 

45 

46import numpy as np 

47from lsst.pex.config import Field 

48from lsst.pex.config.listField import ListField 

49 

50from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction 

51from ...math import divide, fluxToMag 

52 

53 

54class SelectorBase(VectorAction): 

55 plotLabelKey = Field[str]( 

56 doc="Key to use when populating plot info, ignored if empty string", optional=True, default="" 

57 ) 

58 

59 def _addValueToPlotInfo(self, value, plotLabelKey=None, **kwargs): 

60 if "plotInfo" in kwargs: 

61 if plotLabelKey is not None: 

62 kwargs["plotInfo"][plotLabelKey] = value 

63 elif self.plotLabelKey: 

64 kwargs["plotInfo"][self.plotLabelKey] = value 

65 else: 

66 raise RuntimeError(f"No plotLabelKey provided for value {value}, so can't add to plotInfo") 

67 

68 

69class FlagSelector(VectorAction): 

70 """The base flag selector to use to select valid sources for QA.""" 

71 

72 selectWhenFalse = ListField[str]( 

73 doc="Names of the flag columns to select on when False", optional=False, default=[] 

74 ) 

75 

76 selectWhenTrue = ListField[str]( 

77 doc="Names of the flag columns to select on when True", optional=False, default=[] 

78 ) 

79 

80 def getInputSchema(self) -> KeyedDataSchema: 

81 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue) 

82 return ((col, Vector) for col in allCols) 

83 

84 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

85 """Select on the given flags 

86 

87 Parameters 

88 ---------- 

89 data : `KeyedData` 

90 

91 Returns 

92 ------- 

93 result : `Vector` 

94 A mask of the objects that satisfy the given 

95 flag cuts. 

96 

97 Notes 

98 ----- 

99 Uses the columns in selectWhenFalse and 

100 selectWhenTrue to decide which columns to 

101 select on in each circumstance. 

102 """ 

103 

104 if not self.selectWhenFalse and not self.selectWhenTrue: 

105 raise RuntimeError("No column keys specified") 

106 results: Optional[Vector] = None 

107 

108 for flag in self.selectWhenFalse: # type: ignore 

109 temp = np.array(data[flag.format(**kwargs)] == 0) 

110 if results is not None: 

111 results &= temp # type: ignore 

112 else: 

113 results = temp 

114 

115 for flag in self.selectWhenTrue: 

116 temp = np.array(data[flag.format(**kwargs)] == 1) 

117 if results is not None: 

118 results &= temp # type: ignore 

119 else: 

120 results = temp 

121 # The test at the beginning assures this can never be None 

122 return cast(Vector, results) 

123 

124 

125class CoaddPlotFlagSelector(FlagSelector): 

126 """This default setting makes it take the band from 

127 the kwargs. 

128 """ 

129 

130 bands = ListField[str]( 

131 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

132 default=[], 

133 ) 

134 

135 def getInputSchema(self) -> KeyedDataSchema: 

136 yield from super().getInputSchema() 

137 

138 def refMatchContext(self): 

139 self.selectWhenFalse = [ 

140 "{band}_psfFlux_flag_target", 

141 "{band}_pixelFlags_saturatedCenter_target", 

142 "{band}_extendedness_flag_target", 

143 "xy_flag_target", 

144 ] 

145 self.selectWhenTrue = ["detect_isPatchInner_target", "detect_isDeblendedSource_target"] 

146 

147 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

148 result: Optional[Vector] = None 

149 bands: tuple[str, ...] 

150 match kwargs: 

151 case {"band": band} if not self.bands and self.bands == []: 

152 bands = (band,) 

153 case {"bands": bands} if not self.bands and self.bands == []: 

154 bands = bands 

155 case _ if self.bands: 

156 bands = tuple(self.bands) 

157 case _: 

158 bands = ("",) 

159 for band in bands: 

160 temp = super().__call__(data, **(kwargs | dict(band=band))) 

161 if result is not None: 

162 result &= temp # type: ignore 

163 else: 

164 result = temp 

165 return cast(Vector, result) 

166 

167 def setDefaults(self): 

168 self.selectWhenFalse = [ 

169 "{band}_psfFlux_flag", 

170 "{band}_pixelFlags_saturatedCenter", 

171 "{band}_extendedness_flag", 

172 "xy_flag", 

173 "sky_object", 

174 ] 

175 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"] 

176 

177 

178class MatchingFlagSelector(CoaddPlotFlagSelector): 

179 """The default flag selector to apply pre matching. 

180 The sources are cut down to remove duplicates but 

181 not on quality. 

182 """ 

183 

184 def setDefaults(self): 

185 self.selectWhenFalse = [] 

186 self.selectWhenTrue = ["detect_isPrimary"] 

187 

188 

189class VisitPlotFlagSelector(FlagSelector): 

190 """Select on a set of flags appropriate for making visit-level plots 

191 (i.e., using sourceTable_visit catalogs). 

192 """ 

193 

194 catalogSuffix = Field[str](doc="The suffix to apply to all the keys.", default="") 

195 

196 def getInputSchema(self) -> KeyedDataSchema: 

197 yield from super().getInputSchema() 

198 

199 def refMatchContext(self): 

200 self.selectWhenFalse = [ 

201 "psfFlux_flag_target", 

202 "pixelFlags_saturatedCenter_target", 

203 "extendedness_flag_target", 

204 "centroid_flag_target", 

205 ] 

206 

207 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

208 result: Optional[Vector] = None 

209 temp = super().__call__(data, **kwargs) 

210 if result is not None: 

211 result &= temp # type: ignore 

212 else: 

213 result = temp 

214 

215 return result 

216 

217 def setDefaults(self): 

218 self.selectWhenFalse = [ 

219 "psfFlux_flag", 

220 "pixelFlags_saturatedCenter", 

221 "extendedness_flag", 

222 "centroid_flag", 

223 "sky_source", 

224 ] 

225 

226 

227class RangeSelector(VectorAction): 

228 """Selects rows within a range, inclusive of min/exclusive of max.""" 

229 

230 vectorKey = Field[str](doc="Key to select from data") 

231 maximum = Field[float](doc="The maximum value", default=np.Inf) 

232 minimum = Field[float](doc="The minimum value", default=np.nextafter(-np.Inf, 0.0)) 

233 

234 def getInputSchema(self) -> KeyedDataSchema: 

235 yield self.vectorKey, Vector 

236 

237 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

238 """Return a mask of rows with values within the specified range. 

239 

240 Parameters 

241 ---------- 

242 data : `KeyedData` 

243 

244 Returns 

245 ------- 

246 result : `Vector` 

247 A mask of the rows with values within the specified range. 

248 """ 

249 values = cast(Vector, data[self.vectorKey]) 

250 mask = (values >= self.minimum) & (values < self.maximum) 

251 

252 return cast(Vector, mask) 

253 

254 

255class SnSelector(SelectorBase): 

256 """Selects points that have S/N > threshold in the given flux type.""" 

257 

258 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux") 

259 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0) 

260 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6) 

261 uncertaintySuffix = Field[str]( 

262 doc="Suffix to add to fluxType to specify uncertainty column", default="Err" 

263 ) 

264 bands = ListField[str]( 

265 doc="The bands to apply the signal to noise cut in." "Takes precedence if bands passed to call", 

266 default=[], 

267 ) 

268 

269 def getInputSchema(self) -> KeyedDataSchema: 

270 fluxCol = self.fluxType 

271 fluxInd = fluxCol.find("lux") + len("lux") 

272 errCol = f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix}" + f"{fluxCol}"[fluxInd:] 

273 yield fluxCol, Vector 

274 yield errCol, Vector 

275 

276 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

277 """Makes a mask of objects that have S/N greater than 

278 self.threshold in self.fluxType 

279 

280 Parameters 

281 ---------- 

282 data : `KeyedData` 

283 The data to perform the selection on. 

284 

285 Returns 

286 ------- 

287 result : `Vector` 

288 A mask of the objects that satisfy the given 

289 S/N cut. 

290 """ 

291 mask: Optional[Vector] = None 

292 bands: tuple[str, ...] 

293 match kwargs: 

294 case {"band": band} if not self.bands and self.bands == []: 

295 bands = (band,) 

296 case {"bands": bands} if not self.bands and self.bands == []: 

297 bands = bands 

298 case _ if self.bands: 

299 bands = tuple(self.bands) 

300 case _: 

301 bands = ("",) 

302 bandStr = ",".join(bands) 

303 for band in bands: 

304 fluxCol = self.fluxType.format(**(kwargs | dict(band=band))) 

305 fluxInd = fluxCol.find("lux") + len("lux") 

306 errCol = ( 

307 f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix.format(**kwargs)}" + f"{fluxCol}"[fluxInd:] 

308 ) 

309 vec = divide(cast(Vector, data[fluxCol]), cast(Vector, data[errCol])) 

310 temp = (vec > self.threshold) & (vec < self.maxSN) 

311 if mask is not None: 

312 mask &= temp # type: ignore 

313 else: 

314 mask = temp 

315 

316 plotLabelStr = "({}) > {:.1f}".format(bandStr, self.threshold) 

317 if self.maxSN < 1e5: 

318 plotLabelStr += " & < {:.1f}".format(self.maxSN) 

319 

320 if self.plotLabelKey == "" or self.plotLabelKey is None: 

321 self._addValueToPlotInfo(plotLabelStr, plotLabelKey="S/N", **kwargs) 

322 else: 

323 self._addValueToPlotInfo(plotLabelStr, **kwargs) 

324 

325 # It should not be possible for mask to be a None now 

326 return np.array(cast(Vector, mask)) 

327 

328 

329class SkyObjectSelector(FlagSelector): 

330 """Selects sky objects in the given band(s).""" 

331 

332 bands = ListField[str]( 

333 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

334 default=["i"], 

335 ) 

336 

337 def getInputSchema(self) -> KeyedDataSchema: 

338 yield from super().getInputSchema() 

339 

340 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

341 result: Optional[Vector] = None 

342 bands: tuple[str, ...] 

343 match kwargs: 

344 case {"band": band} if not self.bands and self.bands == []: 

345 bands = (band,) 

346 case {"bands": bands} if not self.bands and self.bands == []: 

347 bands = bands 

348 case _ if self.bands: 

349 bands = tuple(self.bands) 

350 case _: 

351 bands = ("",) 

352 for band in bands: 

353 temp = super().__call__(data, **(kwargs | dict(band=band))) 

354 if result is not None: 

355 result &= temp # type: ignore 

356 else: 

357 result = temp 

358 return cast(Vector, result) 

359 

360 def setDefaults(self): 

361 self.selectWhenFalse = [ 

362 "{band}_pixelFlags_edge", 

363 ] 

364 self.selectWhenTrue = ["sky_object"] 

365 

366 

367class SkySourceSelector(FlagSelector): 

368 """Selects sky sources from sourceTables.""" 

369 

370 def getInputSchema(self) -> KeyedDataSchema: 

371 yield from super().getInputSchema() 

372 

373 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

374 result: Optional[Vector] = None 

375 temp = super().__call__(data, **(kwargs)) 

376 if result is not None: 

377 result &= temp # type: ignore 

378 else: 

379 result = temp 

380 return result 

381 

382 def setDefaults(self): 

383 self.selectWhenFalse = [ 

384 "pixelFlags_edge", 

385 ] 

386 self.selectWhenTrue = ["sky_source"] 

387 

388 

389class GoodDiaSourceSelector(FlagSelector): 

390 """Selects good DIA sources from diaSourceTables.""" 

391 

392 def getInputSchema(self) -> KeyedDataSchema: 

393 yield from super().getInputSchema() 

394 

395 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

396 result: Optional[Vector] = None 

397 temp = super().__call__(data, **(kwargs)) 

398 if result is not None: 

399 result &= temp # type: ignore 

400 else: 

401 result = temp 

402 return result 

403 

404 def setDefaults(self): 

405 self.selectWhenFalse = [ 

406 "base_PixelFlags_flag_bad", 

407 "base_PixelFlags_flag_suspect", 

408 "base_PixelFlags_flag_saturatedCenter", 

409 "base_PixelFlags_flag_interpolated", 

410 "base_PixelFlags_flag_interpolatedCenter", 

411 "base_PixelFlags_flag_edge", 

412 ] 

413 

414 

415class ExtendednessSelector(VectorAction): 

416 """A selector that picks between extended and point sources.""" 

417 

418 vectorKey = Field[str]( 

419 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness" 

420 ) 

421 

422 def getInputSchema(self) -> KeyedDataSchema: 

423 return ((self.vectorKey, Vector),) 

424 

425 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

426 key = self.vectorKey.format(**kwargs) 

427 return cast(Vector, data[key]) 

428 

429 

430class StarSelector(ExtendednessSelector): 

431 """A selector that picks out stars based off of their 

432 extendedness values. 

433 """ 

434 

435 extendedness_maximum = Field[float]( 

436 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float 

437 ) 

438 

439 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

440 extendedness = super().__call__(data, **kwargs) 

441 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum))) 

442 

443 

444class GalaxySelector(ExtendednessSelector): 

445 """A selector that picks out galaxies based off of their 

446 extendedness values. 

447 """ 

448 

449 extendedness_minimum = Field[float]( 

450 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5 

451 ) 

452 

453 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

454 extendedness = super().__call__(data, **kwargs) 

455 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore 

456 

457 

458class UnknownSelector(ExtendednessSelector): 

459 """A selector that picks out unclassified objects based off of their 

460 extendedness values. 

461 """ 

462 

463 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

464 extendedness = super().__call__(data, **kwargs) 

465 return extendedness == 9 

466 

467 

468class VectorSelector(VectorAction): 

469 """Load a boolean vector from KeyedData and return it for use as a 

470 selector. 

471 """ 

472 

473 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask") 

474 

475 def getInputSchema(self) -> KeyedDataSchema: 

476 return ((self.vectorKey, Vector),) 

477 

478 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

479 return cast(Vector, data[self.vectorKey.format(**kwargs)]) 

480 

481 

482class ThresholdSelector(VectorAction): 

483 """Return a mask corresponding to an applied threshold.""" 

484 

485 op = Field[str](doc="Operator name.") 

486 threshold = Field[float](doc="Threshold to apply.") 

487 vectorKey = Field[str](doc="Name of column") 

488 

489 def getInputSchema(self) -> KeyedDataSchema: 

490 return ((self.vectorKey, Vector),) 

491 

492 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

493 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold) 

494 return cast(Vector, mask) 

495 

496 

497class BandSelector(VectorAction): 

498 """Makes a mask for sources observed in a specified set of bands.""" 

499 

500 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band") 

501 bands = ListField[str]( 

502 doc="The bands to select. `None` indicates no band selection applied.", 

503 default=[], 

504 ) 

505 

506 def getInputSchema(self) -> KeyedDataSchema: 

507 return ((self.vectorKey, Vector),) 

508 

509 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

510 bands: Optional[tuple[str, ...]] 

511 match kwargs: 

512 case {"band": band} if not self.bands and self.bands == []: 

513 bands = (band,) 

514 case {"bands": bands} if not self.bands and self.bands == []: 

515 bands = bands 

516 case _ if self.bands: 

517 bands = tuple(self.bands) 

518 case _: 

519 bands = None 

520 if bands: 

521 mask = np.in1d(data[self.vectorKey], bands) 

522 else: 

523 # No band selection is applied, i.e., select all rows 

524 mask = np.full(len(data[self.vectorKey]), True) # type: ignore 

525 return cast(Vector, mask) 

526 

527 

528class ParentObjectSelector(FlagSelector): 

529 """Select only parent objects that are not sky objects.""" 

530 

531 def setDefaults(self): 

532 # This selects all of the parents 

533 self.selectWhenFalse = [ 

534 "detect_isDeblendedModelSource", 

535 "sky_object", 

536 ] 

537 self.selectWhenTrue = ["detect_isPatchInner"] 

538 

539 

540class MagSelector(SelectorBase): 

541 """Selects points that have minMag < mag (AB) < maxMag. 

542 

543 The magnitude is based on the given fluxType. 

544 """ 

545 

546 fluxType = Field[str](doc="Flux type to calculate the magnitude in.", default="{band}_psfFlux") 

547 minMag = Field[float](doc="Minimum mag to include in the sample.", default=-1e6) 

548 maxMag = Field[float](doc="Maximum mag to include in the sample.", default=1e6) 

549 fluxUnit = Field[str](doc="Astropy unit of flux vector", default="nJy") 

550 returnMillimags = Field[bool](doc="Use millimags or not?", default=False) 

551 bands = ListField[str]( 

552 doc="The band(s) to apply the magnitude cut in. Takes precedence if bands passed to call.", 

553 default=[], 

554 ) 

555 

556 def getInputSchema(self) -> KeyedDataSchema: 

557 fluxCol = self.fluxType 

558 yield fluxCol, Vector 

559 

560 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

561 """Make a mask of that satisfies self.minMag < mag < self.maxMag. 

562 

563 The magnitude is based on the flux in self.fluxType. 

564 

565 Parameters 

566 ---------- 

567 data : `KeyedData` 

568 The data to perform the magnitude selection on. 

569 

570 Returns 

571 ------- 

572 result : `Vector` 

573 A mask of the objects that satisfy the given magnitude cut. 

574 """ 

575 mask: Optional[Vector] = None 

576 bands: tuple[str, ...] 

577 match kwargs: 

578 case {"band": band} if not self.bands and self.bands == []: 

579 bands = (band,) 

580 case {"bands": bands} if not self.bands and self.bands == []: 

581 bands = bands 

582 case _ if self.bands: 

583 bands = tuple(self.bands) 

584 case _: 

585 bands = ("",) 

586 bandStr = ",".join(bands) 

587 for band in bands: 

588 fluxCol = self.fluxType.format(**(kwargs | dict(band=band))) 

589 vec = fluxToMag( 

590 cast(Vector, data[fluxCol]), 

591 flux_unit=self.fluxUnit, 

592 return_millimags=self.returnMillimags, 

593 ) 

594 temp = (vec > self.minMag) & (vec < self.maxMag) 

595 if mask is not None: 

596 mask &= temp # type: ignore 

597 else: 

598 mask = temp 

599 

600 plotLabelStr = "" 

601 if self.maxMag < 100: 

602 plotLabelStr += "({}) < {:.1f}".format(bandStr, self.maxMag) 

603 if self.minMag > -100: 

604 if bandStr in plotLabelStr: 

605 plotLabelStr += " & < {:.1f}".format(self.minMag) 

606 else: 

607 plotLabelStr += "({}) < {:.1f}".format(bandStr, self.minMag) 

608 if self.plotLabelKey == "" or self.plotLabelKey is None: 

609 self._addValueToPlotInfo(plotLabelStr, plotLabelKey="Mag", **kwargs) 

610 else: 

611 self._addValueToPlotInfo(plotLabelStr, **kwargs) 

612 

613 # It should not be possible for mask to be a None now 

614 return np.array(cast(Vector, mask))