Coverage for python/lsst/analysis/tools/actions/vector/selectors.py: 29%

208 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-01-07 03:15 -0800

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ( 

24 "FlagSelector", 

25 "CoaddPlotFlagSelector", 

26 "RangeSelector", 

27 "SnSelector", 

28 "ExtendednessSelector", 

29 "SkyObjectSelector", 

30 "SkySourceSelector", 

31 "GoodDiaSourceSelector", 

32 "StarSelector", 

33 "GalaxySelector", 

34 "UnknownSelector", 

35 "VectorSelector", 

36 "VisitPlotFlagSelector", 

37 "ThresholdSelector", 

38 "BandSelector", 

39) 

40 

41import operator 

42from typing import Optional, cast 

43 

44import numpy as np 

45from lsst.pex.config import Field 

46from lsst.pex.config.listField import ListField 

47 

48from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction 

49 

50 

51class FlagSelector(VectorAction): 

52 """The base flag selector to use to select valid sources for QA""" 

53 

54 selectWhenFalse = ListField[str]( 

55 doc="Names of the flag columns to select on when False", optional=False, default=[] 

56 ) 

57 

58 selectWhenTrue = ListField[str]( 

59 doc="Names of the flag columns to select on when True", optional=False, default=[] 

60 ) 

61 

62 def getInputSchema(self) -> KeyedDataSchema: 

63 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue) 

64 return ((col, Vector) for col in allCols) 

65 

66 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

67 """Select on the given flags 

68 Parameters 

69 ---------- 

70 table : `Tabular` 

71 Returns 

72 ------- 

73 result : `Vector` 

74 A mask of the objects that satisfy the given 

75 flag cuts. 

76 Notes 

77 ----- 

78 Uses the columns in selectWhenFalse and 

79 selectWhenTrue to decide which columns to 

80 select on in each circumstance. 

81 """ 

82 if not self.selectWhenFalse and not self.selectWhenTrue: 

83 raise RuntimeError("No column keys specified") 

84 results: Optional[Vector] = None 

85 

86 for flag in self.selectWhenFalse: # type: ignore 

87 temp = np.array(data[flag.format(**kwargs)] == 0) 

88 if results is not None: 

89 results &= temp # type: ignore 

90 else: 

91 results = temp 

92 

93 for flag in self.selectWhenTrue: 

94 temp = np.array(data[flag.format(**kwargs)] == 1) 

95 if results is not None: 

96 results &= temp # type: ignore 

97 else: 

98 results = temp 

99 # The test at the beginning assures this can never be None 

100 return cast(Vector, results) 

101 

102 

103class CoaddPlotFlagSelector(FlagSelector): 

104 """This default setting makes it take the band from 

105 the kwargs.""" 

106 

107 bands = ListField[str]( 

108 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

109 default=[], 

110 ) 

111 

112 def getInputSchema(self) -> KeyedDataSchema: 

113 yield from super().getInputSchema() 

114 

115 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

116 result: Optional[Vector] = None 

117 bands: tuple[str, ...] 

118 match kwargs: 

119 case {"band": band} if not self.bands and self.bands == []: 

120 bands = (band,) 

121 case {"bands": bands} if not self.bands and self.bands == []: 

122 bands = bands 

123 case _ if self.bands: 

124 bands = tuple(self.bands) 

125 case _: 

126 bands = ("",) 

127 for band in bands: 

128 temp = super().__call__(data, **(kwargs | dict(band=band))) 

129 if result is not None: 

130 result &= temp # type: ignore 

131 else: 

132 result = temp 

133 return cast(Vector, result) 

134 

135 def setDefaults(self): 

136 self.selectWhenFalse = [ 

137 "{band}_psfFlux_flag", 

138 "{band}_pixelFlags_saturatedCenter", 

139 "{band}_extendedness_flag", 

140 "xy_flag", 

141 ] 

142 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"] 

143 

144 

145class VisitPlotFlagSelector(FlagSelector): 

146 """Select on a set of flags appropriate for making visit-level plots 

147 (i.e., using sourceTable_visit catalogs). 

148 """ 

149 

150 def getInputSchema(self) -> KeyedDataSchema: 

151 yield from super().getInputSchema() 

152 

153 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

154 result: Optional[Vector] = None 

155 temp = super().__call__(data, **kwargs) 

156 if result is not None: 

157 result &= temp # type: ignore 

158 else: 

159 result = temp 

160 

161 return result 

162 

163 def setDefaults(self): 

164 self.selectWhenFalse = [ 

165 "psfFlux_flag", 

166 "pixelFlags_saturatedCenter", 

167 "extendedness_flag", 

168 "centroid_flag", 

169 ] 

170 

171 

172class RangeSelector(VectorAction): 

173 """Selects rows within a range, inclusive of min/exclusive of max.""" 

174 

175 key = Field[str](doc="Key to select from data") 

176 maximum = Field[float](doc="The maximum value", default=np.Inf) 

177 minimum = Field[float](doc="The minimum value", default=np.nextafter(-np.Inf, 0.0)) 

178 

179 def getInputSchema(self) -> KeyedDataSchema: 

180 yield self.key, Vector 

181 

182 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

183 """Return a mask of rows with values within the specified range. 

184 

185 Parameters 

186 ---------- 

187 data : `KeyedData` 

188 

189 Returns 

190 ------- 

191 result : `Vector` 

192 A mask of the rows with values within the specified range. 

193 """ 

194 values = cast(Vector, data[self.key]) 

195 mask = (values >= self.minimum) & (values < self.maximum) 

196 

197 return np.array(mask) 

198 

199 

200class SnSelector(VectorAction): 

201 """Selects points that have S/N > threshold in the given flux type""" 

202 

203 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux") 

204 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0) 

205 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6) 

206 uncertaintySuffix = Field[str]( 

207 doc="Suffix to add to fluxType to specify uncertainty column", default="Err" 

208 ) 

209 bands = ListField[str]( 

210 doc="The bands to apply the signal to noise cut in." "Takes precedence if bands passed to call", 

211 default=[], 

212 ) 

213 

214 def getInputSchema(self) -> KeyedDataSchema: 

215 yield (fluxCol := self.fluxType), Vector 

216 yield f"{fluxCol}{self.uncertaintySuffix}", Vector 

217 

218 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

219 """Makes a mask of objects that have S/N greater than 

220 self.threshold in self.fluxType 

221 Parameters 

222 ---------- 

223 data : `KeyedData` 

224 Returns 

225 ------- 

226 result : `Vector` 

227 A mask of the objects that satisfy the given 

228 S/N cut. 

229 """ 

230 mask: Optional[Vector] = None 

231 bands: tuple[str, ...] 

232 match kwargs: 

233 case {"band": band} if not self.bands and self.bands == []: 

234 bands = (band,) 

235 case {"bands": bands} if not self.bands and self.bands == []: 

236 bands = bands 

237 case _ if self.bands: 

238 bands = tuple(self.bands) 

239 case _: 

240 bands = ("",) 

241 for band in bands: 

242 fluxCol = self.fluxType.format(**(kwargs | dict(band=band))) 

243 errCol = f"{fluxCol}{self.uncertaintySuffix.format(**kwargs)}" 

244 vec = cast(Vector, data[fluxCol]) / data[errCol] 

245 temp = (vec > self.threshold) & (vec < self.maxSN) 

246 if mask is not None: 

247 mask &= temp # type: ignore 

248 else: 

249 mask = temp 

250 

251 # It should not be possible for mask to be a None now 

252 return np.array(cast(Vector, mask)) 

253 

254 

255class SkyObjectSelector(FlagSelector): 

256 """Selects sky objects in the given band(s)""" 

257 

258 bands = ListField[str]( 

259 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

260 default=["i"], 

261 ) 

262 

263 def getInputSchema(self) -> KeyedDataSchema: 

264 yield from super().getInputSchema() 

265 

266 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

267 result: Optional[Vector] = None 

268 bands: tuple[str, ...] 

269 match kwargs: 

270 case {"band": band} if not self.bands and self.bands == []: 

271 bands = (band,) 

272 case {"bands": bands} if not self.bands and self.bands == []: 

273 bands = bands 

274 case _ if self.bands: 

275 bands = tuple(self.bands) 

276 case _: 

277 bands = ("",) 

278 for band in bands: 

279 temp = super().__call__(data, **(kwargs | dict(band=band))) 

280 if result is not None: 

281 result &= temp # type: ignore 

282 else: 

283 result = temp 

284 return cast(Vector, result) 

285 

286 def setDefaults(self): 

287 self.selectWhenFalse = [ 

288 "{band}_pixelFlags_edge", 

289 ] 

290 self.selectWhenTrue = ["sky_object"] 

291 

292 

293class SkySourceSelector(FlagSelector): 

294 """Selects sky sources from sourceTables""" 

295 

296 def getInputSchema(self) -> KeyedDataSchema: 

297 yield from super().getInputSchema() 

298 

299 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

300 result: Optional[Vector] = None 

301 temp = super().__call__(data, **(kwargs)) 

302 if result is not None: 

303 result &= temp # type: ignore 

304 else: 

305 result = temp 

306 return result 

307 

308 def setDefaults(self): 

309 self.selectWhenFalse = [ 

310 "pixelFlags_edge", 

311 ] 

312 self.selectWhenTrue = ["sky_source"] 

313 

314 

315class GoodDiaSourceSelector(FlagSelector): 

316 """Selects good DIA sources from diaSourceTables""" 

317 

318 def getInputSchema(self) -> KeyedDataSchema: 

319 yield from super().getInputSchema() 

320 

321 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

322 result: Optional[Vector] = None 

323 temp = super().__call__(data, **(kwargs)) 

324 if result is not None: 

325 result &= temp # type: ignore 

326 else: 

327 result = temp 

328 return result 

329 

330 def setDefaults(self): 

331 self.selectWhenFalse = [ 

332 "base_PixelFlags_flag_bad", 

333 "base_PixelFlags_flag_suspect", 

334 "base_PixelFlags_flag_saturatedCenter", 

335 "base_PixelFlags_flag_interpolated", 

336 "base_PixelFlags_flag_interpolatedCenter", 

337 "base_PixelFlags_flag_edge", 

338 ] 

339 

340 

341class ExtendednessSelector(VectorAction): 

342 vectorKey = Field[str]( 

343 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness" 

344 ) 

345 

346 def getInputSchema(self) -> KeyedDataSchema: 

347 return ((self.vectorKey, Vector),) 

348 

349 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

350 key = self.vectorKey.format(**kwargs) 

351 return cast(Vector, data[key]) 

352 

353 

354class StarSelector(ExtendednessSelector): 

355 extendedness_maximum = Field[float]( 

356 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float 

357 ) 

358 

359 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

360 extendedness = super().__call__(data, **kwargs) 

361 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum))) 

362 

363 

364class GalaxySelector(ExtendednessSelector): 

365 extendedness_minimum = Field[float]( 

366 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5 

367 ) 

368 

369 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

370 extendedness = super().__call__(data, **kwargs) 

371 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore 

372 

373 

374class UnknownSelector(ExtendednessSelector): 

375 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

376 extendedness = super().__call__(data, **kwargs) 

377 return extendedness == 9 

378 

379 

380class VectorSelector(VectorAction): 

381 """Load a boolean vector from KeyedData and return it for use as a 

382 selector. 

383 """ 

384 

385 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask") 

386 

387 def getInputSchema(self) -> KeyedDataSchema: 

388 return ((self.vectorKey, Vector),) 

389 

390 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

391 return cast(Vector, data[self.vectorKey.format(**kwargs)]) 

392 

393 

394class ThresholdSelector(VectorAction): 

395 """Return a mask corresponding to an applied threshold.""" 

396 

397 op = Field[str](doc="Operator name.") 

398 threshold = Field[float](doc="Threshold to apply.") 

399 vectorKey = Field[str](doc="Name of column") 

400 

401 def getInputSchema(self) -> KeyedDataSchema: 

402 return ((self.vectorKey, Vector),) 

403 

404 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

405 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold) 

406 return cast(Vector, mask) 

407 

408 

409class BandSelector(VectorAction): 

410 """Makes a mask for sources observed in a specified set of bands.""" 

411 

412 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band") 

413 bands = ListField[str]( 

414 doc="The bands to select. `None` indicates no band selection applied.", 

415 default=[], 

416 ) 

417 

418 def getInputSchema(self) -> KeyedDataSchema: 

419 return ((self.vectorKey, Vector),) 

420 

421 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

422 bands: Optional[tuple[str, ...]] 

423 match kwargs: 

424 case {"band": band} if not self.bands and self.bands == []: 

425 bands = (band,) 

426 case {"bands": bands} if not self.bands and self.bands == []: 

427 bands = bands 

428 case _ if self.bands: 

429 bands = tuple(self.bands) 

430 case _: 

431 bands = None 

432 if bands: 

433 mask = np.in1d(data[self.vectorKey], bands) 

434 else: 

435 # No band selection is applied, i.e., select all rows 

436 mask = np.full(len(data[self.vectorKey]), True) # type: ignore 

437 return cast(Vector, mask)