Coverage for python/lsst/analysis/tools/actions/vector/selectors.py: 30%

214 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-05 04:42 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ( 

24 "FlagSelector", 

25 "CoaddPlotFlagSelector", 

26 "RangeSelector", 

27 "SnSelector", 

28 "ExtendednessSelector", 

29 "SkyObjectSelector", 

30 "SkySourceSelector", 

31 "GoodDiaSourceSelector", 

32 "StarSelector", 

33 "GalaxySelector", 

34 "UnknownSelector", 

35 "VectorSelector", 

36 "VisitPlotFlagSelector", 

37 "ThresholdSelector", 

38 "BandSelector", 

39) 

40 

41import operator 

42from typing import Optional, cast 

43 

44import numpy as np 

45from lsst.pex.config import Field 

46from lsst.pex.config.listField import ListField 

47 

48from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction 

49 

50 

51class SelectorBase(VectorAction): 

52 plotLabelKey = Field[str]( 

53 doc="Key to use when populating plot info, ignored if empty string", optional=True, default="" 

54 ) 

55 

56 def _addValueToPlotInfo(self, value, **kwargs): 

57 if "plotInfo" in kwargs and self.plotLabelKey: 

58 kwargs["plotInfo"][self.plotLabelKey] = value 

59 

60 

61class FlagSelector(VectorAction): 

62 """The base flag selector to use to select valid sources for QA.""" 

63 

64 selectWhenFalse = ListField[str]( 

65 doc="Names of the flag columns to select on when False", optional=False, default=[] 

66 ) 

67 

68 selectWhenTrue = ListField[str]( 

69 doc="Names of the flag columns to select on when True", optional=False, default=[] 

70 ) 

71 

72 def getInputSchema(self) -> KeyedDataSchema: 

73 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue) 

74 return ((col, Vector) for col in allCols) 

75 

76 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

77 """Select on the given flags 

78 

79 Parameters 

80 ---------- 

81 table : `Tabular` 

82 

83 Returns 

84 ------- 

85 result : `Vector` 

86 A mask of the objects that satisfy the given 

87 flag cuts. 

88 

89 Notes 

90 ----- 

91 Uses the columns in selectWhenFalse and 

92 selectWhenTrue to decide which columns to 

93 select on in each circumstance. 

94 """ 

95 

96 if not self.selectWhenFalse and not self.selectWhenTrue: 

97 raise RuntimeError("No column keys specified") 

98 results: Optional[Vector] = None 

99 

100 for flag in self.selectWhenFalse: # type: ignore 

101 temp = np.array(data[flag.format(**kwargs)] == 0) 

102 if results is not None: 

103 results &= temp # type: ignore 

104 else: 

105 results = temp 

106 

107 for flag in self.selectWhenTrue: 

108 temp = np.array(data[flag.format(**kwargs)] == 1) 

109 if results is not None: 

110 results &= temp # type: ignore 

111 else: 

112 results = temp 

113 # The test at the beginning assures this can never be None 

114 return cast(Vector, results) 

115 

116 

117class CoaddPlotFlagSelector(FlagSelector): 

118 """This default setting makes it take the band from 

119 the kwargs. 

120 """ 

121 

122 bands = ListField[str]( 

123 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

124 default=[], 

125 ) 

126 

127 def getInputSchema(self) -> KeyedDataSchema: 

128 yield from super().getInputSchema() 

129 

130 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

131 result: Optional[Vector] = None 

132 bands: tuple[str, ...] 

133 match kwargs: 

134 case {"band": band} if not self.bands and self.bands == []: 

135 bands = (band,) 

136 case {"bands": bands} if not self.bands and self.bands == []: 

137 bands = bands 

138 case _ if self.bands: 

139 bands = tuple(self.bands) 

140 case _: 

141 bands = ("",) 

142 for band in bands: 

143 temp = super().__call__(data, **(kwargs | dict(band=band))) 

144 if result is not None: 

145 result &= temp # type: ignore 

146 else: 

147 result = temp 

148 return cast(Vector, result) 

149 

150 def setDefaults(self): 

151 self.selectWhenFalse = [ 

152 "{band}_psfFlux_flag", 

153 "{band}_pixelFlags_saturatedCenter", 

154 "{band}_extendedness_flag", 

155 "xy_flag", 

156 ] 

157 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"] 

158 

159 

160class VisitPlotFlagSelector(FlagSelector): 

161 """Select on a set of flags appropriate for making visit-level plots 

162 (i.e., using sourceTable_visit catalogs). 

163 """ 

164 

165 def getInputSchema(self) -> KeyedDataSchema: 

166 yield from super().getInputSchema() 

167 

168 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

169 result: Optional[Vector] = None 

170 temp = super().__call__(data, **kwargs) 

171 if result is not None: 

172 result &= temp # type: ignore 

173 else: 

174 result = temp 

175 

176 return result 

177 

178 def setDefaults(self): 

179 self.selectWhenFalse = [ 

180 "psfFlux_flag", 

181 "pixelFlags_saturatedCenter", 

182 "extendedness_flag", 

183 "centroid_flag", 

184 ] 

185 

186 

187class RangeSelector(VectorAction): 

188 """Selects rows within a range, inclusive of min/exclusive of max.""" 

189 

190 key = Field[str](doc="Key to select from data") 

191 maximum = Field[float](doc="The maximum value", default=np.Inf) 

192 minimum = Field[float](doc="The minimum value", default=np.nextafter(-np.Inf, 0.0)) 

193 

194 def getInputSchema(self) -> KeyedDataSchema: 

195 yield self.key, Vector 

196 

197 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

198 """Return a mask of rows with values within the specified range. 

199 

200 Parameters 

201 ---------- 

202 data : `KeyedData` 

203 

204 Returns 

205 ------- 

206 result : `Vector` 

207 A mask of the rows with values within the specified range. 

208 """ 

209 values = cast(Vector, data[self.key]) 

210 mask = (values >= self.minimum) & (values < self.maximum) 

211 

212 return np.array(mask) 

213 

214 

215class SnSelector(SelectorBase): 

216 """Selects points that have S/N > threshold in the given flux type.""" 

217 

218 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux") 

219 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0) 

220 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6) 

221 uncertaintySuffix = Field[str]( 

222 doc="Suffix to add to fluxType to specify uncertainty column", default="Err" 

223 ) 

224 bands = ListField[str]( 

225 doc="The bands to apply the signal to noise cut in." "Takes precedence if bands passed to call", 

226 default=[], 

227 ) 

228 

229 def getInputSchema(self) -> KeyedDataSchema: 

230 yield (fluxCol := self.fluxType), Vector 

231 yield f"{fluxCol}{self.uncertaintySuffix}", Vector 

232 

233 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

234 """Makes a mask of objects that have S/N greater than 

235 self.threshold in self.fluxType 

236 

237 Parameters 

238 ---------- 

239 data : `KeyedData` 

240 

241 Returns 

242 ------- 

243 result : `Vector` 

244 A mask of the objects that satisfy the given 

245 S/N cut. 

246 """ 

247 

248 self._addValueToPlotInfo(self.threshold, **kwargs) 

249 mask: Optional[Vector] = None 

250 bands: tuple[str, ...] 

251 match kwargs: 

252 case {"band": band} if not self.bands and self.bands == []: 

253 bands = (band,) 

254 case {"bands": bands} if not self.bands and self.bands == []: 

255 bands = bands 

256 case _ if self.bands: 

257 bands = tuple(self.bands) 

258 case _: 

259 bands = ("",) 

260 for band in bands: 

261 fluxCol = self.fluxType.format(**(kwargs | dict(band=band))) 

262 errCol = f"{fluxCol}{self.uncertaintySuffix.format(**kwargs)}" 

263 vec = cast(Vector, data[fluxCol]) / cast(Vector, data[errCol]) 

264 temp = (vec > self.threshold) & (vec < self.maxSN) 

265 if mask is not None: 

266 mask &= temp # type: ignore 

267 else: 

268 mask = temp 

269 

270 # It should not be possible for mask to be a None now 

271 return np.array(cast(Vector, mask)) 

272 

273 

274class SkyObjectSelector(FlagSelector): 

275 """Selects sky objects in the given band(s).""" 

276 

277 bands = ListField[str]( 

278 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

279 default=["i"], 

280 ) 

281 

282 def getInputSchema(self) -> KeyedDataSchema: 

283 yield from super().getInputSchema() 

284 

285 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

286 result: Optional[Vector] = None 

287 bands: tuple[str, ...] 

288 match kwargs: 

289 case {"band": band} if not self.bands and self.bands == []: 

290 bands = (band,) 

291 case {"bands": bands} if not self.bands and self.bands == []: 

292 bands = bands 

293 case _ if self.bands: 

294 bands = tuple(self.bands) 

295 case _: 

296 bands = ("",) 

297 for band in bands: 

298 temp = super().__call__(data, **(kwargs | dict(band=band))) 

299 if result is not None: 

300 result &= temp # type: ignore 

301 else: 

302 result = temp 

303 return cast(Vector, result) 

304 

305 def setDefaults(self): 

306 self.selectWhenFalse = [ 

307 "{band}_pixelFlags_edge", 

308 ] 

309 self.selectWhenTrue = ["sky_object"] 

310 

311 

312class SkySourceSelector(FlagSelector): 

313 """Selects sky sources from sourceTables.""" 

314 

315 def getInputSchema(self) -> KeyedDataSchema: 

316 yield from super().getInputSchema() 

317 

318 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

319 result: Optional[Vector] = None 

320 temp = super().__call__(data, **(kwargs)) 

321 if result is not None: 

322 result &= temp # type: ignore 

323 else: 

324 result = temp 

325 return result 

326 

327 def setDefaults(self): 

328 self.selectWhenFalse = [ 

329 "pixelFlags_edge", 

330 ] 

331 self.selectWhenTrue = ["sky_source"] 

332 

333 

334class GoodDiaSourceSelector(FlagSelector): 

335 """Selects good DIA sources from diaSourceTables.""" 

336 

337 def getInputSchema(self) -> KeyedDataSchema: 

338 yield from super().getInputSchema() 

339 

340 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

341 result: Optional[Vector] = None 

342 temp = super().__call__(data, **(kwargs)) 

343 if result is not None: 

344 result &= temp # type: ignore 

345 else: 

346 result = temp 

347 return result 

348 

349 def setDefaults(self): 

350 self.selectWhenFalse = [ 

351 "base_PixelFlags_flag_bad", 

352 "base_PixelFlags_flag_suspect", 

353 "base_PixelFlags_flag_saturatedCenter", 

354 "base_PixelFlags_flag_interpolated", 

355 "base_PixelFlags_flag_interpolatedCenter", 

356 "base_PixelFlags_flag_edge", 

357 ] 

358 

359 

360class ExtendednessSelector(VectorAction): 

361 """A selector that picks between extended and point sources.""" 

362 

363 vectorKey = Field[str]( 

364 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness" 

365 ) 

366 

367 def getInputSchema(self) -> KeyedDataSchema: 

368 return ((self.vectorKey, Vector),) 

369 

370 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

371 key = self.vectorKey.format(**kwargs) 

372 return cast(Vector, data[key]) 

373 

374 

375class StarSelector(ExtendednessSelector): 

376 """A selector that picks out stars based off of their 

377 extendedness values. 

378 """ 

379 

380 extendedness_maximum = Field[float]( 

381 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float 

382 ) 

383 

384 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

385 extendedness = super().__call__(data, **kwargs) 

386 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum))) 

387 

388 

389class GalaxySelector(ExtendednessSelector): 

390 """A selector that picks out galaxies based off of their 

391 extendedness values. 

392 """ 

393 

394 extendedness_minimum = Field[float]( 

395 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5 

396 ) 

397 

398 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

399 extendedness = super().__call__(data, **kwargs) 

400 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore 

401 

402 

403class UnknownSelector(ExtendednessSelector): 

404 """A selector that picks out unclassified objects based off of their 

405 extendedness values. 

406 """ 

407 

408 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

409 extendedness = super().__call__(data, **kwargs) 

410 return extendedness == 9 

411 

412 

413class VectorSelector(VectorAction): 

414 """Load a boolean vector from KeyedData and return it for use as a 

415 selector. 

416 """ 

417 

418 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask") 

419 

420 def getInputSchema(self) -> KeyedDataSchema: 

421 return ((self.vectorKey, Vector),) 

422 

423 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

424 return cast(Vector, data[self.vectorKey.format(**kwargs)]) 

425 

426 

427class ThresholdSelector(VectorAction): 

428 """Return a mask corresponding to an applied threshold.""" 

429 

430 op = Field[str](doc="Operator name.") 

431 threshold = Field[float](doc="Threshold to apply.") 

432 vectorKey = Field[str](doc="Name of column") 

433 

434 def getInputSchema(self) -> KeyedDataSchema: 

435 return ((self.vectorKey, Vector),) 

436 

437 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

438 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold) 

439 return cast(Vector, mask) 

440 

441 

442class BandSelector(VectorAction): 

443 """Makes a mask for sources observed in a specified set of bands.""" 

444 

445 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band") 

446 bands = ListField[str]( 

447 doc="The bands to select. `None` indicates no band selection applied.", 

448 default=[], 

449 ) 

450 

451 def getInputSchema(self) -> KeyedDataSchema: 

452 return ((self.vectorKey, Vector),) 

453 

454 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

455 bands: Optional[tuple[str, ...]] 

456 match kwargs: 

457 case {"band": band} if not self.bands and self.bands == []: 

458 bands = (band,) 

459 case {"bands": bands} if not self.bands and self.bands == []: 

460 bands = bands 

461 case _ if self.bands: 

462 bands = tuple(self.bands) 

463 case _: 

464 bands = None 

465 if bands: 

466 mask = np.in1d(data[self.vectorKey], bands) 

467 else: 

468 # No band selection is applied, i.e., select all rows 

469 mask = np.full(len(data[self.vectorKey]), True) # type: ignore 

470 return cast(Vector, mask)