Coverage for python/lsst/analysis/tools/actions/vector/selectors.py: 29%

196 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-10-18 02:10 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ( 

24 "FlagSelector", 

25 "CoaddPlotFlagSelector", 

26 "RangeSelector", 

27 "SnSelector", 

28 "ExtendednessSelector", 

29 "SkyObjectSelector", 

30 "StarSelector", 

31 "GalaxySelector", 

32 "UnknownSelector", 

33 "VectorSelector", 

34 "VisitPlotFlagSelector", 

35 "ThresholdSelector", 

36 "BandSelector", 

37) 

38 

39import operator 

40from typing import Optional, cast 

41 

42import numpy as np 

43from lsst.pex.config import Field 

44from lsst.pex.config.listField import ListField 

45 

46from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction 

47 

48 

49class FlagSelector(VectorAction): 

50 """The base flag selector to use to select valid sources for QA""" 

51 

52 selectWhenFalse = ListField[str]( 

53 doc="Names of the flag columns to select on when False", optional=False, default=[] 

54 ) 

55 

56 selectWhenTrue = ListField[str]( 

57 doc="Names of the flag columns to select on when True", optional=False, default=[] 

58 ) 

59 

60 def getInputSchema(self) -> KeyedDataSchema: 

61 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue) 

62 return ((col, Vector) for col in allCols) 

63 

64 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

65 """Select on the given flags 

66 Parameters 

67 ---------- 

68 table : `Tabular` 

69 Returns 

70 ------- 

71 result : `Vector` 

72 A mask of the objects that satisfy the given 

73 flag cuts. 

74 Notes 

75 ----- 

76 Uses the columns in selectWhenFalse and 

77 selectWhenTrue to decide which columns to 

78 select on in each circumstance. 

79 """ 

80 if not self.selectWhenFalse and not self.selectWhenTrue: 

81 raise RuntimeError("No column keys specified") 

82 results: Optional[Vector] = None 

83 

84 for flag in self.selectWhenFalse: # type: ignore 

85 temp = np.array(data[flag.format(**kwargs)] == 0) 

86 if results is not None: 

87 results &= temp # type: ignore 

88 else: 

89 results = temp 

90 

91 for flag in self.selectWhenTrue: 

92 temp = np.array(data[flag.format(**kwargs)] == 1) 

93 if results is not None: 

94 results &= temp # type: ignore 

95 else: 

96 results = temp 

97 # The test at the beginning assures this can never be None 

98 return cast(Vector, results) 

99 

100 

101class CoaddPlotFlagSelector(FlagSelector): 

102 bands = ListField[str]( 

103 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

104 default=["i"], 

105 ) 

106 

107 def getInputSchema(self) -> KeyedDataSchema: 

108 yield from super().getInputSchema() 

109 

110 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

111 result: Optional[Vector] = None 

112 bands: tuple[str, ...] 

113 match kwargs: 

114 case {"band": band}: 

115 bands = (band,) 

116 case {"bands": bands} if not self.bands: 

117 bands = bands 

118 case _ if self.bands: 

119 bands = tuple(self.bands) 

120 case _: 

121 bands = ("",) 

122 for band in bands: 

123 temp = super().__call__(data, **(kwargs | dict(band=band))) 

124 if result is not None: 

125 result &= temp # type: ignore 

126 else: 

127 result = temp 

128 return cast(Vector, result) 

129 

130 def setDefaults(self): 

131 self.selectWhenFalse = [ 

132 "{band}_psfFlux_flag", 

133 "{band}_pixelFlags_saturatedCenter", 

134 "{band}_extendedness_flag", 

135 "xy_flag", 

136 ] 

137 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"] 

138 

139 

140class VisitPlotFlagSelector(FlagSelector): 

141 """Select on a set of flags appropriate for making visit-level plots 

142 (i.e., using sourceTable_visit catalogs). 

143 """ 

144 

145 def getInputSchema(self) -> KeyedDataSchema: 

146 yield from super().getInputSchema() 

147 

148 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

149 result: Optional[Vector] = None 

150 temp = super().__call__(data, **kwargs) 

151 if result is not None: 

152 result &= temp # type: ignore 

153 else: 

154 result = temp 

155 

156 return result 

157 

158 def setDefaults(self): 

159 self.selectWhenFalse = [ 

160 "psfFlux_flag", 

161 "pixelFlags_saturatedCenter", 

162 "extendedness_flag", 

163 "centroid_flag", 

164 ] 

165 

166 

167class RangeSelector(VectorAction): 

168 """Selects rows within a range, inclusive of min/exclusive of max.""" 

169 

170 column = Field[str](doc="Column to select from") 

171 maximum = Field[float](doc="The maximum value", default=np.Inf) 

172 minimum = Field[float](doc="The minimum value", default=np.nextafter(-np.Inf, 0.0)) 

173 

174 def getInputSchema(self) -> KeyedDataSchema: 

175 yield self.column, Vector 

176 

177 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

178 """Return a mask of rows with values within the specified range. 

179 

180 Parameters 

181 ---------- 

182 data : `KeyedData` 

183 

184 Returns 

185 ------- 

186 result : `Vector` 

187 A mask of the rows with values within the specified range. 

188 """ 

189 values = cast(Vector, data[self.column]) 

190 mask = (values >= self.minimum) & (values < self.maximum) 

191 

192 return np.array(mask) 

193 

194 

195class SnSelector(VectorAction): 

196 """Selects points that have S/N > threshold in the given flux type""" 

197 

198 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux") 

199 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0) 

200 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6) 

201 uncertaintySuffix = Field[str]( 

202 doc="Suffix to add to fluxType to specify uncertainty column", default="Err" 

203 ) 

204 bands = ListField[str]( 

205 doc="The bands to apply the signal to noise cut in." "Takes precedence if bands passed to call", 

206 default=[], 

207 ) 

208 

209 def getInputSchema(self) -> KeyedDataSchema: 

210 yield (fluxCol := self.fluxType), Vector 

211 yield f"{fluxCol}{self.uncertaintySuffix}", Vector 

212 

213 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

214 """Makes a mask of objects that have S/N greater than 

215 self.threshold in self.fluxType 

216 Parameters 

217 ---------- 

218 data : `KeyedData` 

219 Returns 

220 ------- 

221 result : `Vector` 

222 A mask of the objects that satisfy the given 

223 S/N cut. 

224 """ 

225 mask: Optional[Vector] = None 

226 bands: tuple[str, ...] 

227 match kwargs: 

228 case {"band": band}: 

229 bands = (band,) 

230 case {"bands": bands} if not self.bands: 

231 bands = bands 

232 case _ if self.bands: 

233 bands = tuple(self.bands) 

234 case _: 

235 bands = ("",) 

236 for band in bands: 

237 fluxCol = self.fluxType.format(**(kwargs | dict(band=band))) 

238 errCol = f"{fluxCol}{self.uncertaintySuffix.format(**kwargs)}" 

239 vec = cast(Vector, data[fluxCol]) / data[errCol] 

240 temp = (vec > self.threshold) & (vec < self.maxSN) 

241 if mask is not None: 

242 mask &= temp # type: ignore 

243 else: 

244 mask = temp 

245 

246 # It should not be possible for mask to be a None now 

247 return np.array(cast(Vector, mask)) 

248 

249 

250class SkyObjectSelector(FlagSelector): 

251 """Selects sky objects in the given band(s)""" 

252 

253 bands = ListField[str]( 

254 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

255 default=["i"], 

256 ) 

257 

258 def getInputSchema(self) -> KeyedDataSchema: 

259 yield from super().getInputSchema() 

260 

261 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

262 result: Optional[Vector] = None 

263 bands: tuple[str, ...] 

264 match kwargs: 

265 case {"band": band}: 

266 bands = (band,) 

267 case {"bands": bands} if not self.bands: 

268 bands = bands 

269 case _ if self.bands: 

270 bands = tuple(self.bands) 

271 case _: 

272 bands = ("",) 

273 for band in bands: 

274 temp = super().__call__(data, **(kwargs | dict(band=band))) 

275 if result is not None: 

276 result &= temp # type: ignore 

277 else: 

278 result = temp 

279 return cast(Vector, result) 

280 

281 def setDefaults(self): 

282 self.selectWhenFalse = [ 

283 "{band}_pixelFlags_edge", 

284 ] 

285 self.selectWhenTrue = ["sky_object"] 

286 

287 

288class SkySourceSelector(FlagSelector): 

289 """Selects sky sources from sourceTables""" 

290 

291 def getInputSchema(self) -> KeyedDataSchema: 

292 yield from super().getInputSchema() 

293 

294 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

295 result: Optional[Vector] = None 

296 temp = super().__call__(data, **(kwargs)) 

297 if result is not None: 

298 result &= temp # type: ignore 

299 else: 

300 result = temp 

301 return result 

302 

303 def setDefaults(self): 

304 self.selectWhenFalse = [ 

305 "pixelFlags_edge", 

306 ] 

307 self.selectWhenTrue = ["sky_source"] 

308 

309 

310class ExtendednessSelector(VectorAction): 

311 vectorKey = Field[str]( 

312 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness" 

313 ) 

314 

315 def getInputSchema(self) -> KeyedDataSchema: 

316 return ((self.vectorKey, Vector),) 

317 

318 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

319 key = self.vectorKey.format(**kwargs) 

320 return cast(Vector, data[key]) 

321 

322 

323class StarSelector(ExtendednessSelector): 

324 extendedness_maximum = Field[float]( 

325 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float 

326 ) 

327 

328 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

329 extendedness = super().__call__(data, **kwargs) 

330 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum))) 

331 

332 

333class GalaxySelector(ExtendednessSelector): 

334 extendedness_minimum = Field[float]( 

335 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5 

336 ) 

337 

338 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

339 extendedness = super().__call__(data, **kwargs) 

340 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore 

341 

342 

343class UnknownSelector(ExtendednessSelector): 

344 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

345 extendedness = super().__call__(data, **kwargs) 

346 return extendedness == 9 

347 

348 

349class VectorSelector(VectorAction): 

350 """Load a boolean vector from KeyedData and return it for use as a 

351 selector. 

352 """ 

353 

354 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask") 

355 

356 def getInputSchema(self) -> KeyedDataSchema: 

357 return ((self.vectorKey, Vector),) 

358 

359 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

360 return cast(Vector, data[self.vectorKey.format(**kwargs)]) 

361 

362 

363class ThresholdSelector(VectorAction): 

364 """Return a mask corresponding to an applied threshold.""" 

365 

366 op = Field[str](doc="Operator name.") 

367 threshold = Field[float](doc="Threshold to apply.") 

368 vectorKey = Field[str](doc="Name of column") 

369 

370 def getInputSchema(self) -> KeyedDataSchema: 

371 return ((self.vectorKey, Vector),) 

372 

373 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

374 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold) 

375 return cast(Vector, mask) 

376 

377 

378class BandSelector(VectorAction): 

379 """Makes a mask for sources observed in a specified set of bands.""" 

380 

381 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band") 

382 bands = ListField[str]( 

383 doc="The bands to select. `None` indicates no band selection applied.", 

384 default=[], 

385 ) 

386 

387 def getInputSchema(self) -> KeyedDataSchema: 

388 return ((self.vectorKey, Vector),) 

389 

390 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

391 bands: Optional[tuple[str, ...]] 

392 match kwargs: 

393 case {"band": band}: 

394 bands = (band,) 

395 case {"bands": bands} if not self.bands: 

396 bands = bands 

397 case _ if self.bands: 

398 bands = tuple(self.bands) 

399 case _: 

400 bands = None 

401 if bands: 

402 mask = np.in1d(data[self.vectorKey], bands) 

403 else: 

404 # No band selection is applied, i.e., select all rows 

405 mask = np.full(len(data[self.vectorKey]), True) # type: ignore 

406 return cast(Vector, mask)