Coverage for python/lsst/analysis/tools/actions/vector/selectors.py: 34%

157 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-08-04 03:18 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ( 

24 "FlagSelector", 

25 "CoaddPlotFlagSelector", 

26 "SnSelector", 

27 "ExtendednessSelector", 

28 "SkyObjectSelector", 

29 "StarSelector", 

30 "GalaxySelector", 

31 "UnknownSelector", 

32 "VectorSelector", 

33 "VisitPlotFlagSelector", 

34) 

35 

36from typing import Optional, cast 

37 

38import numpy as np 

39from lsst.pex.config import Field 

40from lsst.pex.config.listField import ListField 

41 

42from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction 

43 

44 

45class FlagSelector(VectorAction): 

46 """The base flag selector to use to select valid sources for QA""" 

47 

48 selectWhenFalse = ListField[str]( 

49 doc="Names of the flag columns to select on when False", optional=False, default=[] 

50 ) 

51 

52 selectWhenTrue = ListField[str]( 

53 doc="Names of the flag columns to select on when True", optional=False, default=[] 

54 ) 

55 

56 def getInputSchema(self) -> KeyedDataSchema: 

57 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue) 

58 return ((col, Vector) for col in allCols) 

59 

60 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

61 """Select on the given flags 

62 Parameters 

63 ---------- 

64 table : `Tabular` 

65 Returns 

66 ------- 

67 result : `Vector` 

68 A mask of the objects that satisfy the given 

69 flag cuts. 

70 Notes 

71 ----- 

72 Uses the columns in selectWhenFalse and 

73 selectWhenTrue to decide which columns to 

74 select on in each circumstance. 

75 """ 

76 if not self.selectWhenFalse and not self.selectWhenTrue: 

77 raise RuntimeError("No column keys specified") 

78 results: Optional[Vector] = None 

79 

80 for flag in self.selectWhenFalse: # type: ignore 

81 temp = np.array(data[flag.format(**kwargs)] == 0) 

82 if results is not None: 

83 results &= temp # type: ignore 

84 else: 

85 results = temp 

86 

87 for flag in self.selectWhenTrue: 

88 temp = np.array(data[flag.format(**kwargs)] == 1) 

89 if results is not None: 

90 results &= temp # type: ignore 

91 else: 

92 results = temp 

93 # The test at the beginning assures this can never be None 

94 return cast(Vector, results) 

95 

96 

97class CoaddPlotFlagSelector(FlagSelector): 

98 bands = ListField[str]( 

99 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

100 default=["i"], 

101 ) 

102 

103 def getInputSchema(self) -> KeyedDataSchema: 

104 yield from super().getInputSchema() 

105 

106 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

107 result: Optional[Vector] = None 

108 bands: tuple[str, ...] 

109 match kwargs: 

110 case {"band": band}: 

111 bands = (band,) 

112 case {"bands": bands} if not self.bands: 

113 bands = bands 

114 case _ if self.bands: 

115 bands = tuple(self.bands) 

116 case _: 

117 bands = ("",) 

118 for band in bands: 

119 temp = super().__call__(data, **(kwargs | dict(band=band))) 

120 if result is not None: 

121 result &= temp # type: ignore 

122 else: 

123 result = temp 

124 return cast(Vector, result) 

125 

126 def setDefaults(self): 

127 self.selectWhenFalse = [ 

128 "{band}_psfFlux_flag", 

129 "{band}_pixelFlags_saturatedCenter", 

130 "{band}_extendedness_flag", 

131 "xy_flag", 

132 ] 

133 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"] 

134 

135 

136class VisitPlotFlagSelector(FlagSelector): 

137 """Select on a set of flags appropriate for making visit-level plots 

138 (i.e., using sourceTable_visit catalogs). 

139 """ 

140 

141 def getInputSchema(self) -> KeyedDataSchema: 

142 yield from super().getInputSchema() 

143 

144 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

145 result: Optional[Vector] = None 

146 temp = super().__call__(data, **kwargs) 

147 if result is not None: 

148 result &= temp # type: ignore 

149 else: 

150 result = temp 

151 

152 return result 

153 

154 def setDefaults(self): 

155 self.selectWhenFalse = [ 

156 "psfFlux_flag", 

157 "pixelFlags_saturatedCenter", 

158 "extendedness_flag", 

159 "centroid_flag", 

160 ] 

161 

162 

163class SnSelector(VectorAction): 

164 """Selects points that have S/N > threshold in the given flux type""" 

165 

166 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux") 

167 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0) 

168 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6) 

169 uncertaintySuffix = Field[str]( 

170 doc="Suffix to add to fluxType to specify uncertainty column", default="Err" 

171 ) 

172 bands = ListField[str]( 

173 doc="The bands to apply the signal to noise cut in." "Takes precedence if bands passed to call", 

174 default=[], 

175 ) 

176 

177 def getInputSchema(self) -> KeyedDataSchema: 

178 yield (fluxCol := self.fluxType), Vector 

179 yield f"{fluxCol}{self.uncertaintySuffix}", Vector 

180 

181 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

182 """Makes a mask of objects that have S/N greater than 

183 self.threshold in self.fluxType 

184 Parameters 

185 ---------- 

186 df : `Tabular` 

187 Returns 

188 ------- 

189 result : `Vector` 

190 A mask of the objects that satisfy the given 

191 S/N cut. 

192 """ 

193 mask: Optional[Vector] = None 

194 bands: tuple[str, ...] 

195 match kwargs: 

196 case {"band": band}: 

197 bands = (band,) 

198 case {"bands": bands} if not self.bands: 

199 bands = bands 

200 case _ if self.bands: 

201 bands = tuple(self.bands) 

202 case _: 

203 bands = ("",) 

204 for band in bands: 

205 fluxCol = self.fluxType.format(**(kwargs | dict(band=band))) 

206 errCol = f"{fluxCol}{self.uncertaintySuffix.format(**kwargs)}" 

207 vec = cast(Vector, data[fluxCol]) / data[errCol] 

208 temp = (vec > self.threshold) & (vec < self.maxSN) 

209 if mask is not None: 

210 mask &= temp # type: ignore 

211 else: 

212 mask = temp 

213 

214 # It should not be possible for mask to be a None now 

215 return np.array(cast(Vector, mask)) 

216 

217 

218class SkyObjectSelector(FlagSelector): 

219 """Selects sky objects in the given band(s)""" 

220 

221 bands = ListField[str]( 

222 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

223 default=["i"], 

224 ) 

225 

226 def getInputSchema(self) -> KeyedDataSchema: 

227 yield from super().getInputSchema() 

228 

229 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

230 result: Optional[Vector] = None 

231 bands: tuple[str, ...] 

232 match kwargs: 

233 case {"band": band}: 

234 bands = (band,) 

235 case {"bands": bands} if not self.bands: 

236 bands = bands 

237 case _ if self.bands: 

238 bands = tuple(self.bands) 

239 case _: 

240 bands = ("",) 

241 for band in bands: 

242 temp = super().__call__(data, **(kwargs | dict(band=band))) 

243 if result is not None: 

244 result &= temp # type: ignore 

245 else: 

246 result = temp 

247 return cast(Vector, result) 

248 

249 def setDefaults(self): 

250 self.selectWhenFalse = [ 

251 "{band}_pixelFlags_edge", 

252 ] 

253 self.selectWhenTrue = ["sky_object"] 

254 

255 

256class SkySourceSelector(FlagSelector): 

257 """Selects sky sources from sourceTables""" 

258 

259 def getInputSchema(self) -> KeyedDataSchema: 

260 yield from super().getInputSchema() 

261 

262 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

263 result: Optional[Vector] = None 

264 temp = super().__call__(data, **(kwargs)) 

265 if result is not None: 

266 result &= temp # type: ignore 

267 else: 

268 result = temp 

269 return result 

270 

271 def setDefaults(self): 

272 self.selectWhenFalse = [ 

273 "pixelFlags_edge", 

274 ] 

275 self.selectWhenTrue = ["sky_source"] 

276 

277 

278class ExtendednessSelector(VectorAction): 

279 vectorKey = Field[str]( 

280 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness" 

281 ) 

282 

283 def getInputSchema(self) -> KeyedDataSchema: 

284 return ((self.vectorKey, Vector),) 

285 

286 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

287 key = self.vectorKey.format(**kwargs) 

288 return cast(Vector, data[key]) 

289 

290 

291class StarSelector(ExtendednessSelector): 

292 extendedness_maximum = Field[float]( 

293 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float 

294 ) 

295 

296 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

297 extendedness = super().__call__(data, **kwargs) 

298 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum))) 

299 

300 

301class GalaxySelector(ExtendednessSelector): 

302 extendedness_minimum = Field[float]( 

303 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5 

304 ) 

305 

306 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

307 extendedness = super().__call__(data, **kwargs) 

308 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore 

309 

310 

311class UnknownSelector(ExtendednessSelector): 

312 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

313 extendedness = super().__call__(data, **kwargs) 

314 return extendedness == 9 

315 

316 

317class VectorSelector(VectorAction): 

318 """Load a boolean vector from KeyedData and return it for use as a 

319 selector. 

320 """ 

321 

322 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask") 

323 

324 def getInputSchema(self) -> KeyedDataSchema: 

325 return ((self.vectorKey, Vector),) 

326 

327 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

328 return cast(Vector, data[self.vectorKey.format(**kwargs)])