Coverage for python/lsst/analysis/tools/actions/vector/selectors.py: 29%

196 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-11-17 02:43 -0800

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ( 

24 "FlagSelector", 

25 "CoaddPlotFlagSelector", 

26 "RangeSelector", 

27 "SnSelector", 

28 "ExtendednessSelector", 

29 "SkyObjectSelector", 

30 "StarSelector", 

31 "GalaxySelector", 

32 "UnknownSelector", 

33 "VectorSelector", 

34 "VisitPlotFlagSelector", 

35 "ThresholdSelector", 

36 "BandSelector", 

37) 

38 

39import operator 

40from typing import Optional, cast 

41 

42import numpy as np 

43from lsst.pex.config import Field 

44from lsst.pex.config.listField import ListField 

45 

46from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction 

47 

48 

49class FlagSelector(VectorAction): 

50 """The base flag selector to use to select valid sources for QA""" 

51 

52 selectWhenFalse = ListField[str]( 

53 doc="Names of the flag columns to select on when False", optional=False, default=[] 

54 ) 

55 

56 selectWhenTrue = ListField[str]( 

57 doc="Names of the flag columns to select on when True", optional=False, default=[] 

58 ) 

59 

60 def getInputSchema(self) -> KeyedDataSchema: 

61 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue) 

62 return ((col, Vector) for col in allCols) 

63 

64 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

65 """Select on the given flags 

66 Parameters 

67 ---------- 

68 table : `Tabular` 

69 Returns 

70 ------- 

71 result : `Vector` 

72 A mask of the objects that satisfy the given 

73 flag cuts. 

74 Notes 

75 ----- 

76 Uses the columns in selectWhenFalse and 

77 selectWhenTrue to decide which columns to 

78 select on in each circumstance. 

79 """ 

80 if not self.selectWhenFalse and not self.selectWhenTrue: 

81 raise RuntimeError("No column keys specified") 

82 results: Optional[Vector] = None 

83 

84 for flag in self.selectWhenFalse: # type: ignore 

85 temp = np.array(data[flag.format(**kwargs)] == 0) 

86 if results is not None: 

87 results &= temp # type: ignore 

88 else: 

89 results = temp 

90 

91 for flag in self.selectWhenTrue: 

92 temp = np.array(data[flag.format(**kwargs)] == 1) 

93 if results is not None: 

94 results &= temp # type: ignore 

95 else: 

96 results = temp 

97 # The test at the beginning assures this can never be None 

98 return cast(Vector, results) 

99 

100 

101class CoaddPlotFlagSelector(FlagSelector): 

102 """This default setting makes it take the band from 

103 the kwargs.""" 

104 

105 bands = ListField[str]( 

106 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

107 default=[], 

108 ) 

109 

110 def getInputSchema(self) -> KeyedDataSchema: 

111 yield from super().getInputSchema() 

112 

113 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

114 result: Optional[Vector] = None 

115 bands: tuple[str, ...] 

116 match kwargs: 

117 case {"band": band} if not self.bands and self.bands == []: 

118 bands = (band,) 

119 case {"bands": bands} if not self.bands and self.bands == []: 

120 bands = bands 

121 case _ if self.bands: 

122 bands = tuple(self.bands) 

123 case _: 

124 bands = ("",) 

125 for band in bands: 

126 temp = super().__call__(data, **(kwargs | dict(band=band))) 

127 if result is not None: 

128 result &= temp # type: ignore 

129 else: 

130 result = temp 

131 return cast(Vector, result) 

132 

133 def setDefaults(self): 

134 self.selectWhenFalse = [ 

135 "{band}_psfFlux_flag", 

136 "{band}_pixelFlags_saturatedCenter", 

137 "{band}_extendedness_flag", 

138 "xy_flag", 

139 ] 

140 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"] 

141 

142 

143class VisitPlotFlagSelector(FlagSelector): 

144 """Select on a set of flags appropriate for making visit-level plots 

145 (i.e., using sourceTable_visit catalogs). 

146 """ 

147 

148 def getInputSchema(self) -> KeyedDataSchema: 

149 yield from super().getInputSchema() 

150 

151 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

152 result: Optional[Vector] = None 

153 temp = super().__call__(data, **kwargs) 

154 if result is not None: 

155 result &= temp # type: ignore 

156 else: 

157 result = temp 

158 

159 return result 

160 

161 def setDefaults(self): 

162 self.selectWhenFalse = [ 

163 "psfFlux_flag", 

164 "pixelFlags_saturatedCenter", 

165 "extendedness_flag", 

166 "centroid_flag", 

167 ] 

168 

169 

170class RangeSelector(VectorAction): 

171 """Selects rows within a range, inclusive of min/exclusive of max.""" 

172 

173 key = Field[str](doc="Key to select from data") 

174 maximum = Field[float](doc="The maximum value", default=np.Inf) 

175 minimum = Field[float](doc="The minimum value", default=np.nextafter(-np.Inf, 0.0)) 

176 

177 def getInputSchema(self) -> KeyedDataSchema: 

178 yield self.key, Vector 

179 

180 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

181 """Return a mask of rows with values within the specified range. 

182 

183 Parameters 

184 ---------- 

185 data : `KeyedData` 

186 

187 Returns 

188 ------- 

189 result : `Vector` 

190 A mask of the rows with values within the specified range. 

191 """ 

192 values = cast(Vector, data[self.key]) 

193 mask = (values >= self.minimum) & (values < self.maximum) 

194 

195 return np.array(mask) 

196 

197 

198class SnSelector(VectorAction): 

199 """Selects points that have S/N > threshold in the given flux type""" 

200 

201 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux") 

202 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0) 

203 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6) 

204 uncertaintySuffix = Field[str]( 

205 doc="Suffix to add to fluxType to specify uncertainty column", default="Err" 

206 ) 

207 bands = ListField[str]( 

208 doc="The bands to apply the signal to noise cut in." "Takes precedence if bands passed to call", 

209 default=[], 

210 ) 

211 

212 def getInputSchema(self) -> KeyedDataSchema: 

213 yield (fluxCol := self.fluxType), Vector 

214 yield f"{fluxCol}{self.uncertaintySuffix}", Vector 

215 

216 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

217 """Makes a mask of objects that have S/N greater than 

218 self.threshold in self.fluxType 

219 Parameters 

220 ---------- 

221 data : `KeyedData` 

222 Returns 

223 ------- 

224 result : `Vector` 

225 A mask of the objects that satisfy the given 

226 S/N cut. 

227 """ 

228 mask: Optional[Vector] = None 

229 bands: tuple[str, ...] 

230 match kwargs: 

231 case {"band": band} if not self.bands and self.bands == []: 

232 bands = (band,) 

233 case {"bands": bands} if not self.bands and self.bands == []: 

234 bands = bands 

235 case _ if self.bands: 

236 bands = tuple(self.bands) 

237 case _: 

238 bands = ("",) 

239 for band in bands: 

240 fluxCol = self.fluxType.format(**(kwargs | dict(band=band))) 

241 errCol = f"{fluxCol}{self.uncertaintySuffix.format(**kwargs)}" 

242 vec = cast(Vector, data[fluxCol]) / data[errCol] 

243 temp = (vec > self.threshold) & (vec < self.maxSN) 

244 if mask is not None: 

245 mask &= temp # type: ignore 

246 else: 

247 mask = temp 

248 

249 # It should not be possible for mask to be a None now 

250 return np.array(cast(Vector, mask)) 

251 

252 

253class SkyObjectSelector(FlagSelector): 

254 """Selects sky objects in the given band(s)""" 

255 

256 bands = ListField[str]( 

257 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

258 default=["i"], 

259 ) 

260 

261 def getInputSchema(self) -> KeyedDataSchema: 

262 yield from super().getInputSchema() 

263 

264 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

265 result: Optional[Vector] = None 

266 bands: tuple[str, ...] 

267 match kwargs: 

268 case {"band": band} if not self.bands and self.bands == []: 

269 bands = (band,) 

270 case {"bands": bands} if not self.bands and self.bands == []: 

271 bands = bands 

272 case _ if self.bands: 

273 bands = tuple(self.bands) 

274 case _: 

275 bands = ("",) 

276 for band in bands: 

277 temp = super().__call__(data, **(kwargs | dict(band=band))) 

278 if result is not None: 

279 result &= temp # type: ignore 

280 else: 

281 result = temp 

282 return cast(Vector, result) 

283 

284 def setDefaults(self): 

285 self.selectWhenFalse = [ 

286 "{band}_pixelFlags_edge", 

287 ] 

288 self.selectWhenTrue = ["sky_object"] 

289 

290 

291class SkySourceSelector(FlagSelector): 

292 """Selects sky sources from sourceTables""" 

293 

294 def getInputSchema(self) -> KeyedDataSchema: 

295 yield from super().getInputSchema() 

296 

297 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

298 result: Optional[Vector] = None 

299 temp = super().__call__(data, **(kwargs)) 

300 if result is not None: 

301 result &= temp # type: ignore 

302 else: 

303 result = temp 

304 return result 

305 

306 def setDefaults(self): 

307 self.selectWhenFalse = [ 

308 "pixelFlags_edge", 

309 ] 

310 self.selectWhenTrue = ["sky_source"] 

311 

312 

313class ExtendednessSelector(VectorAction): 

314 vectorKey = Field[str]( 

315 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness" 

316 ) 

317 

318 def getInputSchema(self) -> KeyedDataSchema: 

319 return ((self.vectorKey, Vector),) 

320 

321 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

322 key = self.vectorKey.format(**kwargs) 

323 return cast(Vector, data[key]) 

324 

325 

326class StarSelector(ExtendednessSelector): 

327 extendedness_maximum = Field[float]( 

328 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float 

329 ) 

330 

331 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

332 extendedness = super().__call__(data, **kwargs) 

333 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum))) 

334 

335 

336class GalaxySelector(ExtendednessSelector): 

337 extendedness_minimum = Field[float]( 

338 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5 

339 ) 

340 

341 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

342 extendedness = super().__call__(data, **kwargs) 

343 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore 

344 

345 

346class UnknownSelector(ExtendednessSelector): 

347 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

348 extendedness = super().__call__(data, **kwargs) 

349 return extendedness == 9 

350 

351 

352class VectorSelector(VectorAction): 

353 """Load a boolean vector from KeyedData and return it for use as a 

354 selector. 

355 """ 

356 

357 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask") 

358 

359 def getInputSchema(self) -> KeyedDataSchema: 

360 return ((self.vectorKey, Vector),) 

361 

362 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

363 return cast(Vector, data[self.vectorKey.format(**kwargs)]) 

364 

365 

366class ThresholdSelector(VectorAction): 

367 """Return a mask corresponding to an applied threshold.""" 

368 

369 op = Field[str](doc="Operator name.") 

370 threshold = Field[float](doc="Threshold to apply.") 

371 vectorKey = Field[str](doc="Name of column") 

372 

373 def getInputSchema(self) -> KeyedDataSchema: 

374 return ((self.vectorKey, Vector),) 

375 

376 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

377 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold) 

378 return cast(Vector, mask) 

379 

380 

381class BandSelector(VectorAction): 

382 """Makes a mask for sources observed in a specified set of bands.""" 

383 

384 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band") 

385 bands = ListField[str]( 

386 doc="The bands to select. `None` indicates no band selection applied.", 

387 default=[], 

388 ) 

389 

390 def getInputSchema(self) -> KeyedDataSchema: 

391 return ((self.vectorKey, Vector),) 

392 

393 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

394 bands: Optional[tuple[str, ...]] 

395 match kwargs: 

396 case {"band": band} if not self.bands and self.bands == []: 

397 bands = (band,) 

398 case {"bands": bands} if not self.bands and self.bands == []: 

399 bands = bands 

400 case _ if self.bands: 

401 bands = tuple(self.bands) 

402 case _: 

403 bands = None 

404 if bands: 

405 mask = np.in1d(data[self.vectorKey], bands) 

406 else: 

407 # No band selection is applied, i.e., select all rows 

408 mask = np.full(len(data[self.vectorKey]), True) # type: ignore 

409 return cast(Vector, mask)