Coverage for python/lsst/analysis/tools/actions/vector/selectors.py: 30%

214 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-23 09:43 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ( 

24 "FlagSelector", 

25 "CoaddPlotFlagSelector", 

26 "RangeSelector", 

27 "SnSelector", 

28 "ExtendednessSelector", 

29 "SkyObjectSelector", 

30 "SkySourceSelector", 

31 "GoodDiaSourceSelector", 

32 "StarSelector", 

33 "GalaxySelector", 

34 "UnknownSelector", 

35 "VectorSelector", 

36 "VisitPlotFlagSelector", 

37 "ThresholdSelector", 

38 "BandSelector", 

39) 

40 

41import operator 

42from typing import Optional, cast 

43 

44import numpy as np 

45from lsst.pex.config import Field 

46from lsst.pex.config.listField import ListField 

47 

48from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction 

49 

50 

51class SelectorBase(VectorAction): 

52 plotLabelKey = Field[str]( 

53 doc="Key to use when populating plot info, ignored if empty string", optional=True, default="" 

54 ) 

55 

56 def _addValueToPlotInfo(self, value, **kwargs): 

57 if "plotInfo" in kwargs and self.plotLabelKey: 

58 kwargs["plotInfo"][self.plotLabelKey] = value 

59 

60 

61class FlagSelector(VectorAction): 

62 """The base flag selector to use to select valid sources for QA.""" 

63 

64 selectWhenFalse = ListField[str]( 

65 doc="Names of the flag columns to select on when False", optional=False, default=[] 

66 ) 

67 

68 selectWhenTrue = ListField[str]( 

69 doc="Names of the flag columns to select on when True", optional=False, default=[] 

70 ) 

71 

72 def getInputSchema(self) -> KeyedDataSchema: 

73 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue) 

74 return ((col, Vector) for col in allCols) 

75 

76 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

77 """Select on the given flags 

78 

79 Parameters 

80 ---------- 

81 data : `KeyedData` 

82 

83 Returns 

84 ------- 

85 result : `Vector` 

86 A mask of the objects that satisfy the given 

87 flag cuts. 

88 

89 Notes 

90 ----- 

91 Uses the columns in selectWhenFalse and 

92 selectWhenTrue to decide which columns to 

93 select on in each circumstance. 

94 """ 

95 

96 if not self.selectWhenFalse and not self.selectWhenTrue: 

97 raise RuntimeError("No column keys specified") 

98 results: Optional[Vector] = None 

99 

100 for flag in self.selectWhenFalse: # type: ignore 

101 temp = np.array(data[flag.format(**kwargs)] == 0) 

102 if results is not None: 

103 results &= temp # type: ignore 

104 else: 

105 results = temp 

106 

107 for flag in self.selectWhenTrue: 

108 temp = np.array(data[flag.format(**kwargs)] == 1) 

109 if results is not None: 

110 results &= temp # type: ignore 

111 else: 

112 results = temp 

113 # The test at the beginning assures this can never be None 

114 return cast(Vector, results) 

115 

116 

117class CoaddPlotFlagSelector(FlagSelector): 

118 """This default setting makes it take the band from 

119 the kwargs. 

120 """ 

121 

122 bands = ListField[str]( 

123 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

124 default=[], 

125 ) 

126 

127 def getInputSchema(self) -> KeyedDataSchema: 

128 yield from super().getInputSchema() 

129 

130 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

131 result: Optional[Vector] = None 

132 bands: tuple[str, ...] 

133 match kwargs: 

134 case {"band": band} if not self.bands and self.bands == []: 

135 bands = (band,) 

136 case {"bands": bands} if not self.bands and self.bands == []: 

137 bands = bands 

138 case _ if self.bands: 

139 bands = tuple(self.bands) 

140 case _: 

141 bands = ("",) 

142 for band in bands: 

143 temp = super().__call__(data, **(kwargs | dict(band=band))) 

144 if result is not None: 

145 result &= temp # type: ignore 

146 else: 

147 result = temp 

148 return cast(Vector, result) 

149 

150 def setDefaults(self): 

151 self.selectWhenFalse = [ 

152 "{band}_psfFlux_flag", 

153 "{band}_pixelFlags_saturatedCenter", 

154 "{band}_extendedness_flag", 

155 "xy_flag", 

156 ] 

157 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"] 

158 

159 

160class VisitPlotFlagSelector(FlagSelector): 

161 """Select on a set of flags appropriate for making visit-level plots 

162 (i.e., using sourceTable_visit catalogs). 

163 """ 

164 

165 def getInputSchema(self) -> KeyedDataSchema: 

166 yield from super().getInputSchema() 

167 

168 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

169 result: Optional[Vector] = None 

170 temp = super().__call__(data, **kwargs) 

171 if result is not None: 

172 result &= temp # type: ignore 

173 else: 

174 result = temp 

175 

176 return result 

177 

178 def setDefaults(self): 

179 self.selectWhenFalse = [ 

180 "psfFlux_flag", 

181 "pixelFlags_saturatedCenter", 

182 "extendedness_flag", 

183 "centroid_flag", 

184 ] 

185 

186 

187class RangeSelector(VectorAction): 

188 """Selects rows within a range, inclusive of min/exclusive of max.""" 

189 

190 key = Field[str](doc="Key to select from data") 

191 maximum = Field[float](doc="The maximum value", default=np.Inf) 

192 minimum = Field[float](doc="The minimum value", default=np.nextafter(-np.Inf, 0.0)) 

193 

194 def getInputSchema(self) -> KeyedDataSchema: 

195 yield self.key, Vector 

196 

197 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

198 """Return a mask of rows with values within the specified range. 

199 

200 Parameters 

201 ---------- 

202 data : `KeyedData` 

203 

204 Returns 

205 ------- 

206 result : `Vector` 

207 A mask of the rows with values within the specified range. 

208 """ 

209 values = cast(Vector, data[self.key]) 

210 mask = (values >= self.minimum) & (values < self.maximum) 

211 

212 return np.array(mask) 

213 

214 

215class SnSelector(SelectorBase): 

216 """Selects points that have S/N > threshold in the given flux type.""" 

217 

218 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux") 

219 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0) 

220 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6) 

221 uncertaintySuffix = Field[str]( 

222 doc="Suffix to add to fluxType to specify uncertainty column", default="Err" 

223 ) 

224 bands = ListField[str]( 

225 doc="The bands to apply the signal to noise cut in." "Takes precedence if bands passed to call", 

226 default=[], 

227 ) 

228 

229 def getInputSchema(self) -> KeyedDataSchema: 

230 yield (fluxCol := self.fluxType), Vector 

231 yield f"{fluxCol}{self.uncertaintySuffix}", Vector 

232 

233 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

234 """Makes a mask of objects that have S/N greater than 

235 self.threshold in self.fluxType 

236 

237 Parameters 

238 ---------- 

239 data : `KeyedData` 

240 The data to perform the selection on. 

241 

242 Returns 

243 ------- 

244 result : `Vector` 

245 A mask of the objects that satisfy the given 

246 S/N cut. 

247 """ 

248 

249 self._addValueToPlotInfo(self.threshold, **kwargs) 

250 mask: Optional[Vector] = None 

251 bands: tuple[str, ...] 

252 match kwargs: 

253 case {"band": band} if not self.bands and self.bands == []: 

254 bands = (band,) 

255 case {"bands": bands} if not self.bands and self.bands == []: 

256 bands = bands 

257 case _ if self.bands: 

258 bands = tuple(self.bands) 

259 case _: 

260 bands = ("",) 

261 for band in bands: 

262 fluxCol = self.fluxType.format(**(kwargs | dict(band=band))) 

263 errCol = f"{fluxCol}{self.uncertaintySuffix.format(**kwargs)}" 

264 vec = cast(Vector, data[fluxCol]) / cast(Vector, data[errCol]) 

265 temp = (vec > self.threshold) & (vec < self.maxSN) 

266 if mask is not None: 

267 mask &= temp # type: ignore 

268 else: 

269 mask = temp 

270 

271 # It should not be possible for mask to be a None now 

272 return np.array(cast(Vector, mask)) 

273 

274 

275class SkyObjectSelector(FlagSelector): 

276 """Selects sky objects in the given band(s).""" 

277 

278 bands = ListField[str]( 

279 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

280 default=["i"], 

281 ) 

282 

283 def getInputSchema(self) -> KeyedDataSchema: 

284 yield from super().getInputSchema() 

285 

286 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

287 result: Optional[Vector] = None 

288 bands: tuple[str, ...] 

289 match kwargs: 

290 case {"band": band} if not self.bands and self.bands == []: 

291 bands = (band,) 

292 case {"bands": bands} if not self.bands and self.bands == []: 

293 bands = bands 

294 case _ if self.bands: 

295 bands = tuple(self.bands) 

296 case _: 

297 bands = ("",) 

298 for band in bands: 

299 temp = super().__call__(data, **(kwargs | dict(band=band))) 

300 if result is not None: 

301 result &= temp # type: ignore 

302 else: 

303 result = temp 

304 return cast(Vector, result) 

305 

306 def setDefaults(self): 

307 self.selectWhenFalse = [ 

308 "{band}_pixelFlags_edge", 

309 ] 

310 self.selectWhenTrue = ["sky_object"] 

311 

312 

313class SkySourceSelector(FlagSelector): 

314 """Selects sky sources from sourceTables.""" 

315 

316 def getInputSchema(self) -> KeyedDataSchema: 

317 yield from super().getInputSchema() 

318 

319 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

320 result: Optional[Vector] = None 

321 temp = super().__call__(data, **(kwargs)) 

322 if result is not None: 

323 result &= temp # type: ignore 

324 else: 

325 result = temp 

326 return result 

327 

328 def setDefaults(self): 

329 self.selectWhenFalse = [ 

330 "pixelFlags_edge", 

331 ] 

332 self.selectWhenTrue = ["sky_source"] 

333 

334 

335class GoodDiaSourceSelector(FlagSelector): 

336 """Selects good DIA sources from diaSourceTables.""" 

337 

338 def getInputSchema(self) -> KeyedDataSchema: 

339 yield from super().getInputSchema() 

340 

341 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

342 result: Optional[Vector] = None 

343 temp = super().__call__(data, **(kwargs)) 

344 if result is not None: 

345 result &= temp # type: ignore 

346 else: 

347 result = temp 

348 return result 

349 

350 def setDefaults(self): 

351 self.selectWhenFalse = [ 

352 "base_PixelFlags_flag_bad", 

353 "base_PixelFlags_flag_suspect", 

354 "base_PixelFlags_flag_saturatedCenter", 

355 "base_PixelFlags_flag_interpolated", 

356 "base_PixelFlags_flag_interpolatedCenter", 

357 "base_PixelFlags_flag_edge", 

358 ] 

359 

360 

361class ExtendednessSelector(VectorAction): 

362 """A selector that picks between extended and point sources.""" 

363 

364 vectorKey = Field[str]( 

365 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness" 

366 ) 

367 

368 def getInputSchema(self) -> KeyedDataSchema: 

369 return ((self.vectorKey, Vector),) 

370 

371 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

372 key = self.vectorKey.format(**kwargs) 

373 return cast(Vector, data[key]) 

374 

375 

376class StarSelector(ExtendednessSelector): 

377 """A selector that picks out stars based off of their 

378 extendedness values. 

379 """ 

380 

381 extendedness_maximum = Field[float]( 

382 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float 

383 ) 

384 

385 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

386 extendedness = super().__call__(data, **kwargs) 

387 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum))) 

388 

389 

390class GalaxySelector(ExtendednessSelector): 

391 """A selector that picks out galaxies based off of their 

392 extendedness values. 

393 """ 

394 

395 extendedness_minimum = Field[float]( 

396 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5 

397 ) 

398 

399 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

400 extendedness = super().__call__(data, **kwargs) 

401 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore 

402 

403 

404class UnknownSelector(ExtendednessSelector): 

405 """A selector that picks out unclassified objects based off of their 

406 extendedness values. 

407 """ 

408 

409 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

410 extendedness = super().__call__(data, **kwargs) 

411 return extendedness == 9 

412 

413 

414class VectorSelector(VectorAction): 

415 """Load a boolean vector from KeyedData and return it for use as a 

416 selector. 

417 """ 

418 

419 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask") 

420 

421 def getInputSchema(self) -> KeyedDataSchema: 

422 return ((self.vectorKey, Vector),) 

423 

424 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

425 return cast(Vector, data[self.vectorKey.format(**kwargs)]) 

426 

427 

428class ThresholdSelector(VectorAction): 

429 """Return a mask corresponding to an applied threshold.""" 

430 

431 op = Field[str](doc="Operator name.") 

432 threshold = Field[float](doc="Threshold to apply.") 

433 vectorKey = Field[str](doc="Name of column") 

434 

435 def getInputSchema(self) -> KeyedDataSchema: 

436 return ((self.vectorKey, Vector),) 

437 

438 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

439 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold) 

440 return cast(Vector, mask) 

441 

442 

443class BandSelector(VectorAction): 

444 """Makes a mask for sources observed in a specified set of bands.""" 

445 

446 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band") 

447 bands = ListField[str]( 

448 doc="The bands to select. `None` indicates no band selection applied.", 

449 default=[], 

450 ) 

451 

452 def getInputSchema(self) -> KeyedDataSchema: 

453 return ((self.vectorKey, Vector),) 

454 

455 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

456 bands: Optional[tuple[str, ...]] 

457 match kwargs: 

458 case {"band": band} if not self.bands and self.bands == []: 

459 bands = (band,) 

460 case {"bands": bands} if not self.bands and self.bands == []: 

461 bands = bands 

462 case _ if self.bands: 

463 bands = tuple(self.bands) 

464 case _: 

465 bands = None 

466 if bands: 

467 mask = np.in1d(data[self.vectorKey], bands) 

468 else: 

469 # No band selection is applied, i.e., select all rows 

470 mask = np.full(len(data[self.vectorKey]), True) # type: ignore 

471 return cast(Vector, mask)