Coverage for python/lsst/analysis/tools/actions/vector/selectors.py: 35%

186 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-08-06 10:00 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ( 

24 "FlagSelector", 

25 "CoaddPlotFlagSelector", 

26 "SnSelector", 

27 "ExtendednessSelector", 

28 "SkyObjectSelector", 

29 "StarSelector", 

30 "GalaxySelector", 

31 "UnknownSelector", 

32 "VectorSelector", 

33 "VisitPlotFlagSelector", 

34 "ThresholdSelector", 

35 "BandSelector", 

36) 

37 

38import operator 

39from typing import Optional, cast 

40 

41import numpy as np 

42from lsst.pex.config import Field 

43from lsst.pex.config.listField import ListField 

44 

45from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction 

46 

47 

48class FlagSelector(VectorAction): 

49 """The base flag selector to use to select valid sources for QA""" 

50 

51 selectWhenFalse = ListField[str]( 

52 doc="Names of the flag columns to select on when False", optional=False, default=[] 

53 ) 

54 

55 selectWhenTrue = ListField[str]( 

56 doc="Names of the flag columns to select on when True", optional=False, default=[] 

57 ) 

58 

59 def getInputSchema(self) -> KeyedDataSchema: 

60 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue) 

61 return ((col, Vector) for col in allCols) 

62 

63 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

64 """Select on the given flags 

65 Parameters 

66 ---------- 

67 table : `Tabular` 

68 Returns 

69 ------- 

70 result : `Vector` 

71 A mask of the objects that satisfy the given 

72 flag cuts. 

73 Notes 

74 ----- 

75 Uses the columns in selectWhenFalse and 

76 selectWhenTrue to decide which columns to 

77 select on in each circumstance. 

78 """ 

79 if not self.selectWhenFalse and not self.selectWhenTrue: 

80 raise RuntimeError("No column keys specified") 

81 results: Optional[Vector] = None 

82 

83 for flag in self.selectWhenFalse: # type: ignore 

84 temp = np.array(data[flag.format(**kwargs)] == 0) 

85 if results is not None: 

86 results &= temp # type: ignore 

87 else: 

88 results = temp 

89 

90 for flag in self.selectWhenTrue: 

91 temp = np.array(data[flag.format(**kwargs)] == 1) 

92 if results is not None: 

93 results &= temp # type: ignore 

94 else: 

95 results = temp 

96 # The test at the beginning assures this can never be None 

97 return cast(Vector, results) 

98 

99 

100class CoaddPlotFlagSelector(FlagSelector): 

101 bands = ListField[str]( 

102 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

103 default=["i"], 

104 ) 

105 

106 def getInputSchema(self) -> KeyedDataSchema: 

107 yield from super().getInputSchema() 

108 

109 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

110 result: Optional[Vector] = None 

111 bands: tuple[str, ...] 

112 match kwargs: 

113 case {"band": band}: 

114 bands = (band,) 

115 case {"bands": bands} if not self.bands: 

116 bands = bands 

117 case _ if self.bands: 

118 bands = tuple(self.bands) 

119 case _: 

120 bands = ("",) 

121 for band in bands: 

122 temp = super().__call__(data, **(kwargs | dict(band=band))) 

123 if result is not None: 

124 result &= temp # type: ignore 

125 else: 

126 result = temp 

127 return cast(Vector, result) 

128 

129 def setDefaults(self): 

130 self.selectWhenFalse = [ 

131 "{band}_psfFlux_flag", 

132 "{band}_pixelFlags_saturatedCenter", 

133 "{band}_extendedness_flag", 

134 "xy_flag", 

135 ] 

136 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"] 

137 

138 

139class VisitPlotFlagSelector(FlagSelector): 

140 """Select on a set of flags appropriate for making visit-level plots 

141 (i.e., using sourceTable_visit catalogs). 

142 """ 

143 

144 def getInputSchema(self) -> KeyedDataSchema: 

145 yield from super().getInputSchema() 

146 

147 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

148 result: Optional[Vector] = None 

149 temp = super().__call__(data, **kwargs) 

150 if result is not None: 

151 result &= temp # type: ignore 

152 else: 

153 result = temp 

154 

155 return result 

156 

157 def setDefaults(self): 

158 self.selectWhenFalse = [ 

159 "psfFlux_flag", 

160 "pixelFlags_saturatedCenter", 

161 "extendedness_flag", 

162 "centroid_flag", 

163 ] 

164 

165 

166class SnSelector(VectorAction): 

167 """Selects points that have S/N > threshold in the given flux type""" 

168 

169 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux") 

170 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0) 

171 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6) 

172 uncertaintySuffix = Field[str]( 

173 doc="Suffix to add to fluxType to specify uncertainty column", default="Err" 

174 ) 

175 bands = ListField[str]( 

176 doc="The bands to apply the signal to noise cut in." "Takes precedence if bands passed to call", 

177 default=[], 

178 ) 

179 

180 def getInputSchema(self) -> KeyedDataSchema: 

181 yield (fluxCol := self.fluxType), Vector 

182 yield f"{fluxCol}{self.uncertaintySuffix}", Vector 

183 

184 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

185 """Makes a mask of objects that have S/N greater than 

186 self.threshold in self.fluxType 

187 Parameters 

188 ---------- 

189 df : `Tabular` 

190 Returns 

191 ------- 

192 result : `Vector` 

193 A mask of the objects that satisfy the given 

194 S/N cut. 

195 """ 

196 mask: Optional[Vector] = None 

197 bands: tuple[str, ...] 

198 match kwargs: 

199 case {"band": band}: 

200 bands = (band,) 

201 case {"bands": bands} if not self.bands: 

202 bands = bands 

203 case _ if self.bands: 

204 bands = tuple(self.bands) 

205 case _: 

206 bands = ("",) 

207 for band in bands: 

208 fluxCol = self.fluxType.format(**(kwargs | dict(band=band))) 

209 errCol = f"{fluxCol}{self.uncertaintySuffix.format(**kwargs)}" 

210 vec = cast(Vector, data[fluxCol]) / data[errCol] 

211 temp = (vec > self.threshold) & (vec < self.maxSN) 

212 if mask is not None: 

213 mask &= temp # type: ignore 

214 else: 

215 mask = temp 

216 

217 # It should not be possible for mask to be a None now 

218 return np.array(cast(Vector, mask)) 

219 

220 

221class SkyObjectSelector(FlagSelector): 

222 """Selects sky objects in the given band(s)""" 

223 

224 bands = ListField[str]( 

225 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

226 default=["i"], 

227 ) 

228 

229 def getInputSchema(self) -> KeyedDataSchema: 

230 yield from super().getInputSchema() 

231 

232 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

233 result: Optional[Vector] = None 

234 bands: tuple[str, ...] 

235 match kwargs: 

236 case {"band": band}: 

237 bands = (band,) 

238 case {"bands": bands} if not self.bands: 

239 bands = bands 

240 case _ if self.bands: 

241 bands = tuple(self.bands) 

242 case _: 

243 bands = ("",) 

244 for band in bands: 

245 temp = super().__call__(data, **(kwargs | dict(band=band))) 

246 if result is not None: 

247 result &= temp # type: ignore 

248 else: 

249 result = temp 

250 return cast(Vector, result) 

251 

252 def setDefaults(self): 

253 self.selectWhenFalse = [ 

254 "{band}_pixelFlags_edge", 

255 ] 

256 self.selectWhenTrue = ["sky_object"] 

257 

258 

259class SkySourceSelector(FlagSelector): 

260 """Selects sky sources from sourceTables""" 

261 

262 def getInputSchema(self) -> KeyedDataSchema: 

263 yield from super().getInputSchema() 

264 

265 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

266 result: Optional[Vector] = None 

267 temp = super().__call__(data, **(kwargs)) 

268 if result is not None: 

269 result &= temp # type: ignore 

270 else: 

271 result = temp 

272 return result 

273 

274 def setDefaults(self): 

275 self.selectWhenFalse = [ 

276 "pixelFlags_edge", 

277 ] 

278 self.selectWhenTrue = ["sky_source"] 

279 

280 

281class ExtendednessSelector(VectorAction): 

282 vectorKey = Field[str]( 

283 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness" 

284 ) 

285 

286 def getInputSchema(self) -> KeyedDataSchema: 

287 return ((self.vectorKey, Vector),) 

288 

289 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

290 key = self.vectorKey.format(**kwargs) 

291 return cast(Vector, data[key]) 

292 

293 

294class StarSelector(ExtendednessSelector): 

295 extendedness_maximum = Field[float]( 

296 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float 

297 ) 

298 

299 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

300 extendedness = super().__call__(data, **kwargs) 

301 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum))) 

302 

303 

304class GalaxySelector(ExtendednessSelector): 

305 extendedness_minimum = Field[float]( 

306 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5 

307 ) 

308 

309 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

310 extendedness = super().__call__(data, **kwargs) 

311 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore 

312 

313 

314class UnknownSelector(ExtendednessSelector): 

315 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

316 extendedness = super().__call__(data, **kwargs) 

317 return extendedness == 9 

318 

319 

320class VectorSelector(VectorAction): 

321 """Load a boolean vector from KeyedData and return it for use as a 

322 selector. 

323 """ 

324 

325 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask") 

326 

327 def getInputSchema(self) -> KeyedDataSchema: 

328 return ((self.vectorKey, Vector),) 

329 

330 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

331 return cast(Vector, data[self.vectorKey.format(**kwargs)]) 

332 

333 

334class ThresholdSelector(VectorAction): 

335 """Return a mask corresponding to an applied threshold.""" 

336 

337 op = Field[str](doc="Operator name.") 

338 threshold = Field[float](doc="Threshold to apply.") 

339 vectorKey = Field[str](doc="Name of column") 

340 

341 def getInputSchema(self) -> KeyedDataSchema: 

342 return ((self.vectorKey, Vector),) 

343 

344 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

345 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold) 

346 return cast(Vector, mask) 

347 

348 

349class BandSelector(VectorAction): 

350 """Makes a mask for sources observed in a specified set of bands.""" 

351 

352 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band") 

353 bands = ListField[str]( 

354 doc="The bands to select. `None` indicates no band selection applied.", 

355 default=[], 

356 ) 

357 

358 def getInputSchema(self) -> KeyedDataSchema: 

359 return ((self.vectorKey, Vector),) 

360 

361 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

362 bands: Optional[tuple[str, ...]] 

363 match kwargs: 

364 case {"band": band}: 

365 bands = (band,) 

366 case {"bands": bands} if not self.bands: 

367 bands = bands 

368 case _ if self.bands: 

369 bands = tuple(self.bands) 

370 case _: 

371 bands = None 

372 if bands: 

373 mask = np.in1d(data[self.vectorKey], bands) 

374 else: 

375 # No band selection is applied, i.e., select all rows 

376 mask = np.full(len(data[self.vectorKey]), True) # type: ignore 

377 return cast(Vector, mask)