Coverage for python/lsst/analysis/tools/actions/vector/selectors.py: 30%

214 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-28 05:16 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ( 

24 "FlagSelector", 

25 "CoaddPlotFlagSelector", 

26 "RangeSelector", 

27 "SnSelector", 

28 "ExtendednessSelector", 

29 "SkyObjectSelector", 

30 "SkySourceSelector", 

31 "GoodDiaSourceSelector", 

32 "StarSelector", 

33 "GalaxySelector", 

34 "UnknownSelector", 

35 "VectorSelector", 

36 "VisitPlotFlagSelector", 

37 "ThresholdSelector", 

38 "BandSelector", 

39) 

40 

41import operator 

42from typing import Optional, cast 

43 

44import numpy as np 

45from lsst.pex.config import Field 

46from lsst.pex.config.listField import ListField 

47 

48from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction 

49 

50 

51class SelectorBase(VectorAction): 

52 plotLabelKey = Field[str]( 

53 doc="Key to use when populating plot info, ignored if empty string", optional=True, default="" 

54 ) 

55 

56 def _addValueToPlotInfo(self, value, **kwargs): 

57 if "plotInfo" in kwargs and self.plotLabelKey: 

58 kwargs["plotInfo"][self.plotLabelKey] = value 

59 

60 

61class FlagSelector(VectorAction): 

62 """The base flag selector to use to select valid sources for QA""" 

63 

64 selectWhenFalse = ListField[str]( 

65 doc="Names of the flag columns to select on when False", optional=False, default=[] 

66 ) 

67 

68 selectWhenTrue = ListField[str]( 

69 doc="Names of the flag columns to select on when True", optional=False, default=[] 

70 ) 

71 

72 def getInputSchema(self) -> KeyedDataSchema: 

73 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue) 

74 return ((col, Vector) for col in allCols) 

75 

76 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

77 """Select on the given flags 

78 Parameters 

79 ---------- 

80 table : `Tabular` 

81 Returns 

82 ------- 

83 result : `Vector` 

84 A mask of the objects that satisfy the given 

85 flag cuts. 

86 Notes 

87 ----- 

88 Uses the columns in selectWhenFalse and 

89 selectWhenTrue to decide which columns to 

90 select on in each circumstance. 

91 """ 

92 if not self.selectWhenFalse and not self.selectWhenTrue: 

93 raise RuntimeError("No column keys specified") 

94 results: Optional[Vector] = None 

95 

96 for flag in self.selectWhenFalse: # type: ignore 

97 temp = np.array(data[flag.format(**kwargs)] == 0) 

98 if results is not None: 

99 results &= temp # type: ignore 

100 else: 

101 results = temp 

102 

103 for flag in self.selectWhenTrue: 

104 temp = np.array(data[flag.format(**kwargs)] == 1) 

105 if results is not None: 

106 results &= temp # type: ignore 

107 else: 

108 results = temp 

109 # The test at the beginning assures this can never be None 

110 return cast(Vector, results) 

111 

112 

113class CoaddPlotFlagSelector(FlagSelector): 

114 """This default setting makes it take the band from 

115 the kwargs.""" 

116 

117 bands = ListField[str]( 

118 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

119 default=[], 

120 ) 

121 

122 def getInputSchema(self) -> KeyedDataSchema: 

123 yield from super().getInputSchema() 

124 

125 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

126 result: Optional[Vector] = None 

127 bands: tuple[str, ...] 

128 match kwargs: 

129 case {"band": band} if not self.bands and self.bands == []: 

130 bands = (band,) 

131 case {"bands": bands} if not self.bands and self.bands == []: 

132 bands = bands 

133 case _ if self.bands: 

134 bands = tuple(self.bands) 

135 case _: 

136 bands = ("",) 

137 for band in bands: 

138 temp = super().__call__(data, **(kwargs | dict(band=band))) 

139 if result is not None: 

140 result &= temp # type: ignore 

141 else: 

142 result = temp 

143 return cast(Vector, result) 

144 

145 def setDefaults(self): 

146 self.selectWhenFalse = [ 

147 "{band}_psfFlux_flag", 

148 "{band}_pixelFlags_saturatedCenter", 

149 "{band}_extendedness_flag", 

150 "xy_flag", 

151 ] 

152 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"] 

153 

154 

155class VisitPlotFlagSelector(FlagSelector): 

156 """Select on a set of flags appropriate for making visit-level plots 

157 (i.e., using sourceTable_visit catalogs). 

158 """ 

159 

160 def getInputSchema(self) -> KeyedDataSchema: 

161 yield from super().getInputSchema() 

162 

163 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

164 result: Optional[Vector] = None 

165 temp = super().__call__(data, **kwargs) 

166 if result is not None: 

167 result &= temp # type: ignore 

168 else: 

169 result = temp 

170 

171 return result 

172 

173 def setDefaults(self): 

174 self.selectWhenFalse = [ 

175 "psfFlux_flag", 

176 "pixelFlags_saturatedCenter", 

177 "extendedness_flag", 

178 "centroid_flag", 

179 ] 

180 

181 

182class RangeSelector(VectorAction): 

183 """Selects rows within a range, inclusive of min/exclusive of max.""" 

184 

185 key = Field[str](doc="Key to select from data") 

186 maximum = Field[float](doc="The maximum value", default=np.Inf) 

187 minimum = Field[float](doc="The minimum value", default=np.nextafter(-np.Inf, 0.0)) 

188 

189 def getInputSchema(self) -> KeyedDataSchema: 

190 yield self.key, Vector 

191 

192 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

193 """Return a mask of rows with values within the specified range. 

194 

195 Parameters 

196 ---------- 

197 data : `KeyedData` 

198 

199 Returns 

200 ------- 

201 result : `Vector` 

202 A mask of the rows with values within the specified range. 

203 """ 

204 values = cast(Vector, data[self.key]) 

205 mask = (values >= self.minimum) & (values < self.maximum) 

206 

207 return np.array(mask) 

208 

209 

210class SnSelector(SelectorBase): 

211 """Selects points that have S/N > threshold in the given flux type""" 

212 

213 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux") 

214 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0) 

215 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6) 

216 uncertaintySuffix = Field[str]( 

217 doc="Suffix to add to fluxType to specify uncertainty column", default="Err" 

218 ) 

219 bands = ListField[str]( 

220 doc="The bands to apply the signal to noise cut in." "Takes precedence if bands passed to call", 

221 default=[], 

222 ) 

223 

224 def getInputSchema(self) -> KeyedDataSchema: 

225 yield (fluxCol := self.fluxType), Vector 

226 yield f"{fluxCol}{self.uncertaintySuffix}", Vector 

227 

228 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

229 """Makes a mask of objects that have S/N greater than 

230 self.threshold in self.fluxType 

231 Parameters 

232 ---------- 

233 data : `KeyedData` 

234 Returns 

235 ------- 

236 result : `Vector` 

237 A mask of the objects that satisfy the given 

238 S/N cut. 

239 """ 

240 self._addValueToPlotInfo(self.threshold, **kwargs) 

241 mask: Optional[Vector] = None 

242 bands: tuple[str, ...] 

243 match kwargs: 

244 case {"band": band} if not self.bands and self.bands == []: 

245 bands = (band,) 

246 case {"bands": bands} if not self.bands and self.bands == []: 

247 bands = bands 

248 case _ if self.bands: 

249 bands = tuple(self.bands) 

250 case _: 

251 bands = ("",) 

252 for band in bands: 

253 fluxCol = self.fluxType.format(**(kwargs | dict(band=band))) 

254 errCol = f"{fluxCol}{self.uncertaintySuffix.format(**kwargs)}" 

255 vec = cast(Vector, data[fluxCol]) / cast(Vector, data[errCol]) 

256 temp = (vec > self.threshold) & (vec < self.maxSN) 

257 if mask is not None: 

258 mask &= temp # type: ignore 

259 else: 

260 mask = temp 

261 

262 # It should not be possible for mask to be a None now 

263 return np.array(cast(Vector, mask)) 

264 

265 

266class SkyObjectSelector(FlagSelector): 

267 """Selects sky objects in the given band(s)""" 

268 

269 bands = ListField[str]( 

270 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

271 default=["i"], 

272 ) 

273 

274 def getInputSchema(self) -> KeyedDataSchema: 

275 yield from super().getInputSchema() 

276 

277 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

278 result: Optional[Vector] = None 

279 bands: tuple[str, ...] 

280 match kwargs: 

281 case {"band": band} if not self.bands and self.bands == []: 

282 bands = (band,) 

283 case {"bands": bands} if not self.bands and self.bands == []: 

284 bands = bands 

285 case _ if self.bands: 

286 bands = tuple(self.bands) 

287 case _: 

288 bands = ("",) 

289 for band in bands: 

290 temp = super().__call__(data, **(kwargs | dict(band=band))) 

291 if result is not None: 

292 result &= temp # type: ignore 

293 else: 

294 result = temp 

295 return cast(Vector, result) 

296 

297 def setDefaults(self): 

298 self.selectWhenFalse = [ 

299 "{band}_pixelFlags_edge", 

300 ] 

301 self.selectWhenTrue = ["sky_object"] 

302 

303 

304class SkySourceSelector(FlagSelector): 

305 """Selects sky sources from sourceTables""" 

306 

307 def getInputSchema(self) -> KeyedDataSchema: 

308 yield from super().getInputSchema() 

309 

310 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

311 result: Optional[Vector] = None 

312 temp = super().__call__(data, **(kwargs)) 

313 if result is not None: 

314 result &= temp # type: ignore 

315 else: 

316 result = temp 

317 return result 

318 

319 def setDefaults(self): 

320 self.selectWhenFalse = [ 

321 "pixelFlags_edge", 

322 ] 

323 self.selectWhenTrue = ["sky_source"] 

324 

325 

326class GoodDiaSourceSelector(FlagSelector): 

327 """Selects good DIA sources from diaSourceTables""" 

328 

329 def getInputSchema(self) -> KeyedDataSchema: 

330 yield from super().getInputSchema() 

331 

332 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

333 result: Optional[Vector] = None 

334 temp = super().__call__(data, **(kwargs)) 

335 if result is not None: 

336 result &= temp # type: ignore 

337 else: 

338 result = temp 

339 return result 

340 

341 def setDefaults(self): 

342 self.selectWhenFalse = [ 

343 "base_PixelFlags_flag_bad", 

344 "base_PixelFlags_flag_suspect", 

345 "base_PixelFlags_flag_saturatedCenter", 

346 "base_PixelFlags_flag_interpolated", 

347 "base_PixelFlags_flag_interpolatedCenter", 

348 "base_PixelFlags_flag_edge", 

349 ] 

350 

351 

352class ExtendednessSelector(VectorAction): 

353 vectorKey = Field[str]( 

354 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness" 

355 ) 

356 

357 def getInputSchema(self) -> KeyedDataSchema: 

358 return ((self.vectorKey, Vector),) 

359 

360 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

361 key = self.vectorKey.format(**kwargs) 

362 return cast(Vector, data[key]) 

363 

364 

365class StarSelector(ExtendednessSelector): 

366 extendedness_maximum = Field[float]( 

367 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float 

368 ) 

369 

370 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

371 extendedness = super().__call__(data, **kwargs) 

372 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum))) 

373 

374 

375class GalaxySelector(ExtendednessSelector): 

376 extendedness_minimum = Field[float]( 

377 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5 

378 ) 

379 

380 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

381 extendedness = super().__call__(data, **kwargs) 

382 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore 

383 

384 

385class UnknownSelector(ExtendednessSelector): 

386 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

387 extendedness = super().__call__(data, **kwargs) 

388 return extendedness == 9 

389 

390 

391class VectorSelector(VectorAction): 

392 """Load a boolean vector from KeyedData and return it for use as a 

393 selector. 

394 """ 

395 

396 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask") 

397 

398 def getInputSchema(self) -> KeyedDataSchema: 

399 return ((self.vectorKey, Vector),) 

400 

401 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

402 return cast(Vector, data[self.vectorKey.format(**kwargs)]) 

403 

404 

405class ThresholdSelector(VectorAction): 

406 """Return a mask corresponding to an applied threshold.""" 

407 

408 op = Field[str](doc="Operator name.") 

409 threshold = Field[float](doc="Threshold to apply.") 

410 vectorKey = Field[str](doc="Name of column") 

411 

412 def getInputSchema(self) -> KeyedDataSchema: 

413 return ((self.vectorKey, Vector),) 

414 

415 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

416 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold) 

417 return cast(Vector, mask) 

418 

419 

420class BandSelector(VectorAction): 

421 """Makes a mask for sources observed in a specified set of bands.""" 

422 

423 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band") 

424 bands = ListField[str]( 

425 doc="The bands to select. `None` indicates no band selection applied.", 

426 default=[], 

427 ) 

428 

429 def getInputSchema(self) -> KeyedDataSchema: 

430 return ((self.vectorKey, Vector),) 

431 

432 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

433 bands: Optional[tuple[str, ...]] 

434 match kwargs: 

435 case {"band": band} if not self.bands and self.bands == []: 

436 bands = (band,) 

437 case {"bands": bands} if not self.bands and self.bands == []: 

438 bands = bands 

439 case _ if self.bands: 

440 bands = tuple(self.bands) 

441 case _: 

442 bands = None 

443 if bands: 

444 mask = np.in1d(data[self.vectorKey], bands) 

445 else: 

446 # No band selection is applied, i.e., select all rows 

447 mask = np.full(len(data[self.vectorKey]), True) # type: ignore 

448 return cast(Vector, mask)