Coverage for python/lsst/analysis/tools/actions/vector/selectors.py: 31%

233 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-22 11:06 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ( 

24 "FlagSelector", 

25 "CoaddPlotFlagSelector", 

26 "RangeSelector", 

27 "SnSelector", 

28 "ExtendednessSelector", 

29 "SkyObjectSelector", 

30 "SkySourceSelector", 

31 "GoodDiaSourceSelector", 

32 "StarSelector", 

33 "GalaxySelector", 

34 "UnknownSelector", 

35 "VectorSelector", 

36 "VisitPlotFlagSelector", 

37 "ThresholdSelector", 

38 "BandSelector", 

39 "MatchingFlagSelector", 

40) 

41 

42import operator 

43from typing import Optional, cast 

44 

45import numpy as np 

46from lsst.pex.config import Field 

47from lsst.pex.config.listField import ListField 

48 

49from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction 

50from ...math import divide 

51 

52 

53class SelectorBase(VectorAction): 

54 plotLabelKey = Field[str]( 

55 doc="Key to use when populating plot info, ignored if empty string", optional=True, default="" 

56 ) 

57 

58 def _addValueToPlotInfo(self, value, **kwargs): 

59 if "plotInfo" in kwargs and self.plotLabelKey: 

60 kwargs["plotInfo"][self.plotLabelKey] = value 

61 

62 

63class FlagSelector(VectorAction): 

64 """The base flag selector to use to select valid sources for QA.""" 

65 

66 selectWhenFalse = ListField[str]( 

67 doc="Names of the flag columns to select on when False", optional=False, default=[] 

68 ) 

69 

70 selectWhenTrue = ListField[str]( 

71 doc="Names of the flag columns to select on when True", optional=False, default=[] 

72 ) 

73 

74 def getInputSchema(self) -> KeyedDataSchema: 

75 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue) 

76 return ((col, Vector) for col in allCols) 

77 

78 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

79 """Select on the given flags 

80 

81 Parameters 

82 ---------- 

83 data : `KeyedData` 

84 

85 Returns 

86 ------- 

87 result : `Vector` 

88 A mask of the objects that satisfy the given 

89 flag cuts. 

90 

91 Notes 

92 ----- 

93 Uses the columns in selectWhenFalse and 

94 selectWhenTrue to decide which columns to 

95 select on in each circumstance. 

96 """ 

97 

98 if not self.selectWhenFalse and not self.selectWhenTrue: 

99 raise RuntimeError("No column keys specified") 

100 results: Optional[Vector] = None 

101 

102 for flag in self.selectWhenFalse: # type: ignore 

103 temp = np.array(data[flag.format(**kwargs)] == 0) 

104 if results is not None: 

105 results &= temp # type: ignore 

106 else: 

107 results = temp 

108 

109 for flag in self.selectWhenTrue: 

110 temp = np.array(data[flag.format(**kwargs)] == 1) 

111 if results is not None: 

112 results &= temp # type: ignore 

113 else: 

114 results = temp 

115 # The test at the beginning assures this can never be None 

116 return cast(Vector, results) 

117 

118 

119class CoaddPlotFlagSelector(FlagSelector): 

120 """This default setting makes it take the band from 

121 the kwargs. 

122 """ 

123 

124 bands = ListField[str]( 

125 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

126 default=[], 

127 ) 

128 

129 def getInputSchema(self) -> KeyedDataSchema: 

130 yield from super().getInputSchema() 

131 

132 def refMatchContext(self): 

133 self.selectWhenFalse = [ 

134 "{band}_psfFlux_flag_target", 

135 "{band}_pixelFlags_saturatedCenter_target", 

136 "{band}_extendedness_flag_target", 

137 "xy_flag_target", 

138 ] 

139 self.selectWhenTrue = ["detect_isPatchInner_target", "detect_isDeblendedSource_target"] 

140 

141 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

142 result: Optional[Vector] = None 

143 bands: tuple[str, ...] 

144 match kwargs: 

145 case {"band": band} if not self.bands and self.bands == []: 

146 bands = (band,) 

147 case {"bands": bands} if not self.bands and self.bands == []: 

148 bands = bands 

149 case _ if self.bands: 

150 bands = tuple(self.bands) 

151 case _: 

152 bands = ("",) 

153 for band in bands: 

154 temp = super().__call__(data, **(kwargs | dict(band=band))) 

155 if result is not None: 

156 result &= temp # type: ignore 

157 else: 

158 result = temp 

159 return cast(Vector, result) 

160 

161 def setDefaults(self): 

162 self.selectWhenFalse = [ 

163 "{band}_psfFlux_flag", 

164 "{band}_pixelFlags_saturatedCenter", 

165 "{band}_extendedness_flag", 

166 "xy_flag", 

167 ] 

168 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"] 

169 

170 

171class MatchingFlagSelector(CoaddPlotFlagSelector): 

172 """The default flag selector to apply pre matching. 

173 The sources are cut down to remove duplicates but 

174 not on quality. 

175 """ 

176 

177 def setDefaults(self): 

178 self.selectWhenFalse = [] 

179 self.selectWhenTrue = ["detect_isPrimary"] 

180 

181 

182class VisitPlotFlagSelector(FlagSelector): 

183 """Select on a set of flags appropriate for making visit-level plots 

184 (i.e., using sourceTable_visit catalogs). 

185 """ 

186 

187 catalogSuffix = Field[str](doc="The suffix to apply to all the keys.", default="") 

188 

189 def getInputSchema(self) -> KeyedDataSchema: 

190 yield from super().getInputSchema() 

191 

192 def refMatchContext(self): 

193 self.selectWhenFalse = [ 

194 "psfFlux_flag_target", 

195 "pixelFlags_saturatedCenter_target", 

196 "extendedness_flag_target", 

197 "centroid_flag_target", 

198 ] 

199 

200 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

201 result: Optional[Vector] = None 

202 temp = super().__call__(data, **kwargs) 

203 if result is not None: 

204 result &= temp # type: ignore 

205 else: 

206 result = temp 

207 

208 return result 

209 

210 def setDefaults(self): 

211 self.selectWhenFalse = [ 

212 "psfFlux_flag", 

213 "pixelFlags_saturatedCenter", 

214 "extendedness_flag", 

215 "centroid_flag", 

216 ] 

217 

218 

219class RangeSelector(VectorAction): 

220 """Selects rows within a range, inclusive of min/exclusive of max.""" 

221 

222 vectorKey = Field[str](doc="Key to select from data") 

223 maximum = Field[float](doc="The maximum value", default=np.Inf) 

224 minimum = Field[float](doc="The minimum value", default=np.nextafter(-np.Inf, 0.0)) 

225 

226 def getInputSchema(self) -> KeyedDataSchema: 

227 yield self.vectorKey, Vector 

228 

229 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

230 """Return a mask of rows with values within the specified range. 

231 

232 Parameters 

233 ---------- 

234 data : `KeyedData` 

235 

236 Returns 

237 ------- 

238 result : `Vector` 

239 A mask of the rows with values within the specified range. 

240 """ 

241 values = cast(Vector, data[self.vectorKey]) 

242 mask = (values >= self.minimum) & (values < self.maximum) 

243 

244 return cast(Vector, mask) 

245 

246 

247class SnSelector(SelectorBase): 

248 """Selects points that have S/N > threshold in the given flux type.""" 

249 

250 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux") 

251 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0) 

252 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6) 

253 uncertaintySuffix = Field[str]( 

254 doc="Suffix to add to fluxType to specify uncertainty column", default="Err" 

255 ) 

256 bands = ListField[str]( 

257 doc="The bands to apply the signal to noise cut in." "Takes precedence if bands passed to call", 

258 default=[], 

259 ) 

260 

261 def getInputSchema(self) -> KeyedDataSchema: 

262 fluxCol = self.fluxType 

263 fluxInd = fluxCol.find("lux") + len("lux") 

264 errCol = f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix}" + f"{fluxCol}"[fluxInd:] 

265 yield fluxCol, Vector 

266 yield errCol, Vector 

267 

268 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

269 """Makes a mask of objects that have S/N greater than 

270 self.threshold in self.fluxType 

271 

272 Parameters 

273 ---------- 

274 data : `KeyedData` 

275 The data to perform the selection on. 

276 

277 Returns 

278 ------- 

279 result : `Vector` 

280 A mask of the objects that satisfy the given 

281 S/N cut. 

282 """ 

283 

284 self._addValueToPlotInfo(self.threshold, **kwargs) 

285 mask: Optional[Vector] = None 

286 bands: tuple[str, ...] 

287 match kwargs: 

288 case {"band": band} if not self.bands and self.bands == []: 

289 bands = (band,) 

290 case {"bands": bands} if not self.bands and self.bands == []: 

291 bands = bands 

292 case _ if self.bands: 

293 bands = tuple(self.bands) 

294 case _: 

295 bands = ("",) 

296 for band in bands: 

297 fluxCol = self.fluxType.format(**(kwargs | dict(band=band))) 

298 fluxInd = fluxCol.find("lux") + len("lux") 

299 errCol = ( 

300 f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix.format(**kwargs)}" + f"{fluxCol}"[fluxInd:] 

301 ) 

302 vec = divide(cast(Vector, data[fluxCol]), cast(Vector, data[errCol])) 

303 temp = (vec > self.threshold) & (vec < self.maxSN) 

304 if mask is not None: 

305 mask &= temp # type: ignore 

306 else: 

307 mask = temp 

308 

309 # It should not be possible for mask to be a None now 

310 return np.array(cast(Vector, mask)) 

311 

312 

313class SkyObjectSelector(FlagSelector): 

314 """Selects sky objects in the given band(s).""" 

315 

316 bands = ListField[str]( 

317 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

318 default=["i"], 

319 ) 

320 

321 def getInputSchema(self) -> KeyedDataSchema: 

322 yield from super().getInputSchema() 

323 

324 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

325 result: Optional[Vector] = None 

326 bands: tuple[str, ...] 

327 match kwargs: 

328 case {"band": band} if not self.bands and self.bands == []: 

329 bands = (band,) 

330 case {"bands": bands} if not self.bands and self.bands == []: 

331 bands = bands 

332 case _ if self.bands: 

333 bands = tuple(self.bands) 

334 case _: 

335 bands = ("",) 

336 for band in bands: 

337 temp = super().__call__(data, **(kwargs | dict(band=band))) 

338 if result is not None: 

339 result &= temp # type: ignore 

340 else: 

341 result = temp 

342 return cast(Vector, result) 

343 

344 def setDefaults(self): 

345 self.selectWhenFalse = [ 

346 "{band}_pixelFlags_edge", 

347 ] 

348 self.selectWhenTrue = ["sky_object"] 

349 

350 

351class SkySourceSelector(FlagSelector): 

352 """Selects sky sources from sourceTables.""" 

353 

354 def getInputSchema(self) -> KeyedDataSchema: 

355 yield from super().getInputSchema() 

356 

357 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

358 result: Optional[Vector] = None 

359 temp = super().__call__(data, **(kwargs)) 

360 if result is not None: 

361 result &= temp # type: ignore 

362 else: 

363 result = temp 

364 return result 

365 

366 def setDefaults(self): 

367 self.selectWhenFalse = [ 

368 "pixelFlags_edge", 

369 ] 

370 self.selectWhenTrue = ["sky_source"] 

371 

372 

373class GoodDiaSourceSelector(FlagSelector): 

374 """Selects good DIA sources from diaSourceTables.""" 

375 

376 def getInputSchema(self) -> KeyedDataSchema: 

377 yield from super().getInputSchema() 

378 

379 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

380 result: Optional[Vector] = None 

381 temp = super().__call__(data, **(kwargs)) 

382 if result is not None: 

383 result &= temp # type: ignore 

384 else: 

385 result = temp 

386 return result 

387 

388 def setDefaults(self): 

389 self.selectWhenFalse = [ 

390 "base_PixelFlags_flag_bad", 

391 "base_PixelFlags_flag_suspect", 

392 "base_PixelFlags_flag_saturatedCenter", 

393 "base_PixelFlags_flag_interpolated", 

394 "base_PixelFlags_flag_interpolatedCenter", 

395 "base_PixelFlags_flag_edge", 

396 ] 

397 

398 

399class ExtendednessSelector(VectorAction): 

400 """A selector that picks between extended and point sources.""" 

401 

402 vectorKey = Field[str]( 

403 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness" 

404 ) 

405 

406 def getInputSchema(self) -> KeyedDataSchema: 

407 return ((self.vectorKey, Vector),) 

408 

409 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

410 key = self.vectorKey.format(**kwargs) 

411 return cast(Vector, data[key]) 

412 

413 

414class StarSelector(ExtendednessSelector): 

415 """A selector that picks out stars based off of their 

416 extendedness values. 

417 """ 

418 

419 extendedness_maximum = Field[float]( 

420 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float 

421 ) 

422 

423 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

424 extendedness = super().__call__(data, **kwargs) 

425 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum))) 

426 

427 

428class GalaxySelector(ExtendednessSelector): 

429 """A selector that picks out galaxies based off of their 

430 extendedness values. 

431 """ 

432 

433 extendedness_minimum = Field[float]( 

434 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5 

435 ) 

436 

437 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

438 extendedness = super().__call__(data, **kwargs) 

439 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore 

440 

441 

442class UnknownSelector(ExtendednessSelector): 

443 """A selector that picks out unclassified objects based off of their 

444 extendedness values. 

445 """ 

446 

447 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

448 extendedness = super().__call__(data, **kwargs) 

449 return extendedness == 9 

450 

451 

452class VectorSelector(VectorAction): 

453 """Load a boolean vector from KeyedData and return it for use as a 

454 selector. 

455 """ 

456 

457 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask") 

458 

459 def getInputSchema(self) -> KeyedDataSchema: 

460 return ((self.vectorKey, Vector),) 

461 

462 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

463 return cast(Vector, data[self.vectorKey.format(**kwargs)]) 

464 

465 

466class ThresholdSelector(VectorAction): 

467 """Return a mask corresponding to an applied threshold.""" 

468 

469 op = Field[str](doc="Operator name.") 

470 threshold = Field[float](doc="Threshold to apply.") 

471 vectorKey = Field[str](doc="Name of column") 

472 

473 def getInputSchema(self) -> KeyedDataSchema: 

474 return ((self.vectorKey, Vector),) 

475 

476 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

477 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold) 

478 return cast(Vector, mask) 

479 

480 

481class BandSelector(VectorAction): 

482 """Makes a mask for sources observed in a specified set of bands.""" 

483 

484 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band") 

485 bands = ListField[str]( 

486 doc="The bands to select. `None` indicates no band selection applied.", 

487 default=[], 

488 ) 

489 

490 def getInputSchema(self) -> KeyedDataSchema: 

491 return ((self.vectorKey, Vector),) 

492 

493 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

494 bands: Optional[tuple[str, ...]] 

495 match kwargs: 

496 case {"band": band} if not self.bands and self.bands == []: 

497 bands = (band,) 

498 case {"bands": bands} if not self.bands and self.bands == []: 

499 bands = bands 

500 case _ if self.bands: 

501 bands = tuple(self.bands) 

502 case _: 

503 bands = None 

504 if bands: 

505 mask = np.in1d(data[self.vectorKey], bands) 

506 else: 

507 # No band selection is applied, i.e., select all rows 

508 mask = np.full(len(data[self.vectorKey]), True) # type: ignore 

509 return cast(Vector, mask) 

510 

511 

512class ParentObjectSelector(FlagSelector): 

513 """Select only parent objects that are not sky objects.""" 

514 

515 def setDefaults(self): 

516 # This selects all of the parents 

517 self.selectWhenFalse = [ 

518 "detect_isDeblendedModelSource", 

519 "sky_object", 

520 ] 

521 self.selectWhenTrue = ["detect_isPatchInner"]