Coverage for python/lsst/analysis/tools/actions/vector/selectors.py: 30%

224 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-30 14:27 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ( 

24 "FlagSelector", 

25 "CoaddPlotFlagSelector", 

26 "RangeSelector", 

27 "SnSelector", 

28 "ExtendednessSelector", 

29 "SkyObjectSelector", 

30 "SkySourceSelector", 

31 "GoodDiaSourceSelector", 

32 "StarSelector", 

33 "GalaxySelector", 

34 "UnknownSelector", 

35 "VectorSelector", 

36 "VisitPlotFlagSelector", 

37 "ThresholdSelector", 

38 "BandSelector", 

39) 

40 

41import operator 

42from typing import Optional, cast 

43 

44import numpy as np 

45from lsst.pex.config import Field 

46from lsst.pex.config.listField import ListField 

47 

48from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction 

49 

50 

51class SelectorBase(VectorAction): 

52 plotLabelKey = Field[str]( 

53 doc="Key to use when populating plot info, ignored if empty string", optional=True, default="" 

54 ) 

55 

56 def _addValueToPlotInfo(self, value, **kwargs): 

57 if "plotInfo" in kwargs and self.plotLabelKey: 

58 kwargs["plotInfo"][self.plotLabelKey] = value 

59 

60 

61class FlagSelector(VectorAction): 

62 """The base flag selector to use to select valid sources for QA.""" 

63 

64 selectWhenFalse = ListField[str]( 

65 doc="Names of the flag columns to select on when False", optional=False, default=[] 

66 ) 

67 

68 selectWhenTrue = ListField[str]( 

69 doc="Names of the flag columns to select on when True", optional=False, default=[] 

70 ) 

71 

72 def getInputSchema(self) -> KeyedDataSchema: 

73 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue) 

74 return ((col, Vector) for col in allCols) 

75 

76 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

77 """Select on the given flags 

78 

79 Parameters 

80 ---------- 

81 data : `KeyedData` 

82 

83 Returns 

84 ------- 

85 result : `Vector` 

86 A mask of the objects that satisfy the given 

87 flag cuts. 

88 

89 Notes 

90 ----- 

91 Uses the columns in selectWhenFalse and 

92 selectWhenTrue to decide which columns to 

93 select on in each circumstance. 

94 """ 

95 

96 if not self.selectWhenFalse and not self.selectWhenTrue: 

97 raise RuntimeError("No column keys specified") 

98 results: Optional[Vector] = None 

99 

100 for flag in self.selectWhenFalse: # type: ignore 

101 temp = np.array(data[flag.format(**kwargs)] == 0) 

102 if results is not None: 

103 results &= temp # type: ignore 

104 else: 

105 results = temp 

106 

107 for flag in self.selectWhenTrue: 

108 temp = np.array(data[flag.format(**kwargs)] == 1) 

109 if results is not None: 

110 results &= temp # type: ignore 

111 else: 

112 results = temp 

113 # The test at the beginning assures this can never be None 

114 return cast(Vector, results) 

115 

116 

117class CoaddPlotFlagSelector(FlagSelector): 

118 """This default setting makes it take the band from 

119 the kwargs. 

120 """ 

121 

122 bands = ListField[str]( 

123 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

124 default=[], 

125 ) 

126 

127 def getInputSchema(self) -> KeyedDataSchema: 

128 yield from super().getInputSchema() 

129 

130 def refMatchContext(self): 

131 self.selectWhenFalse = [ 

132 "{band}_psfFlux_flag_target", 

133 "{band}_pixelFlags_saturatedCenter_target", 

134 "{band}_extendedness_flag_target", 

135 "xy_flag_target", 

136 ] 

137 self.selectWhenTrue = ["detect_isPatchInner_target", "detect_isDeblendedSource_target"] 

138 

139 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

140 result: Optional[Vector] = None 

141 bands: tuple[str, ...] 

142 match kwargs: 

143 case {"band": band} if not self.bands and self.bands == []: 

144 bands = (band,) 

145 case {"bands": bands} if not self.bands and self.bands == []: 

146 bands = bands 

147 case _ if self.bands: 

148 bands = tuple(self.bands) 

149 case _: 

150 bands = ("",) 

151 for band in bands: 

152 temp = super().__call__(data, **(kwargs | dict(band=band))) 

153 if result is not None: 

154 result &= temp # type: ignore 

155 else: 

156 result = temp 

157 return cast(Vector, result) 

158 

159 def setDefaults(self): 

160 self.selectWhenFalse = [ 

161 "{band}_psfFlux_flag", 

162 "{band}_pixelFlags_saturatedCenter", 

163 "{band}_extendedness_flag", 

164 "xy_flag", 

165 ] 

166 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"] 

167 

168 

169class VisitPlotFlagSelector(FlagSelector): 

170 """Select on a set of flags appropriate for making visit-level plots 

171 (i.e., using sourceTable_visit catalogs). 

172 """ 

173 

174 catalogSuffix = Field[str](doc="The suffix to apply to all the keys.", default="") 

175 

176 def getInputSchema(self) -> KeyedDataSchema: 

177 yield from super().getInputSchema() 

178 

179 def refMatchContext(self): 

180 self.selectWhenFalse = [ 

181 "psfFlux_flag_target", 

182 "pixelFlags_saturatedCenter_target", 

183 "extendedness_flag_target", 

184 "centroid_flag_target", 

185 ] 

186 

187 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

188 result: Optional[Vector] = None 

189 temp = super().__call__(data, **kwargs) 

190 if result is not None: 

191 result &= temp # type: ignore 

192 else: 

193 result = temp 

194 

195 return result 

196 

197 def setDefaults(self): 

198 self.selectWhenFalse = [ 

199 "psfFlux_flag", 

200 "pixelFlags_saturatedCenter", 

201 "extendedness_flag", 

202 "centroid_flag", 

203 ] 

204 

205 

206class RangeSelector(VectorAction): 

207 """Selects rows within a range, inclusive of min/exclusive of max.""" 

208 

209 vectorKey = Field[str](doc="Key to select from data") 

210 maximum = Field[float](doc="The maximum value", default=np.Inf) 

211 minimum = Field[float](doc="The minimum value", default=np.nextafter(-np.Inf, 0.0)) 

212 

213 def getInputSchema(self) -> KeyedDataSchema: 

214 yield self.vectorKey, Vector 

215 

216 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

217 """Return a mask of rows with values within the specified range. 

218 

219 Parameters 

220 ---------- 

221 data : `KeyedData` 

222 

223 Returns 

224 ------- 

225 result : `Vector` 

226 A mask of the rows with values within the specified range. 

227 """ 

228 values = cast(Vector, data[self.vectorKey]) 

229 mask = (values >= self.minimum) & (values < self.maximum) 

230 

231 return cast(Vector, mask) 

232 

233 

234class SnSelector(SelectorBase): 

235 """Selects points that have S/N > threshold in the given flux type.""" 

236 

237 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux") 

238 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0) 

239 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6) 

240 uncertaintySuffix = Field[str]( 

241 doc="Suffix to add to fluxType to specify uncertainty column", default="Err" 

242 ) 

243 bands = ListField[str]( 

244 doc="The bands to apply the signal to noise cut in." "Takes precedence if bands passed to call", 

245 default=[], 

246 ) 

247 

248 def getInputSchema(self) -> KeyedDataSchema: 

249 fluxCol = self.fluxType 

250 fluxInd = fluxCol.find("lux") + len("lux") 

251 errCol = f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix}" + f"{fluxCol}"[fluxInd:] 

252 yield fluxCol, Vector 

253 yield errCol, Vector 

254 

255 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

256 """Makes a mask of objects that have S/N greater than 

257 self.threshold in self.fluxType 

258 

259 Parameters 

260 ---------- 

261 data : `KeyedData` 

262 The data to perform the selection on. 

263 

264 Returns 

265 ------- 

266 result : `Vector` 

267 A mask of the objects that satisfy the given 

268 S/N cut. 

269 """ 

270 

271 self._addValueToPlotInfo(self.threshold, **kwargs) 

272 mask: Optional[Vector] = None 

273 bands: tuple[str, ...] 

274 match kwargs: 

275 case {"band": band} if not self.bands and self.bands == []: 

276 bands = (band,) 

277 case {"bands": bands} if not self.bands and self.bands == []: 

278 bands = bands 

279 case _ if self.bands: 

280 bands = tuple(self.bands) 

281 case _: 

282 bands = ("",) 

283 for band in bands: 

284 fluxCol = self.fluxType.format(**(kwargs | dict(band=band))) 

285 fluxInd = fluxCol.find("lux") + len("lux") 

286 errCol = ( 

287 f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix.format(**kwargs)}" + f"{fluxCol}"[fluxInd:] 

288 ) 

289 vec = cast(Vector, data[fluxCol]) / data[errCol] 

290 temp = (vec > self.threshold) & (vec < self.maxSN) 

291 if mask is not None: 

292 mask &= temp # type: ignore 

293 else: 

294 mask = temp 

295 

296 # It should not be possible for mask to be a None now 

297 return np.array(cast(Vector, mask)) 

298 

299 

300class SkyObjectSelector(FlagSelector): 

301 """Selects sky objects in the given band(s).""" 

302 

303 bands = ListField[str]( 

304 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

305 default=["i"], 

306 ) 

307 

308 def getInputSchema(self) -> KeyedDataSchema: 

309 yield from super().getInputSchema() 

310 

311 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

312 result: Optional[Vector] = None 

313 bands: tuple[str, ...] 

314 match kwargs: 

315 case {"band": band} if not self.bands and self.bands == []: 

316 bands = (band,) 

317 case {"bands": bands} if not self.bands and self.bands == []: 

318 bands = bands 

319 case _ if self.bands: 

320 bands = tuple(self.bands) 

321 case _: 

322 bands = ("",) 

323 for band in bands: 

324 temp = super().__call__(data, **(kwargs | dict(band=band))) 

325 if result is not None: 

326 result &= temp # type: ignore 

327 else: 

328 result = temp 

329 return cast(Vector, result) 

330 

331 def setDefaults(self): 

332 self.selectWhenFalse = [ 

333 "{band}_pixelFlags_edge", 

334 ] 

335 self.selectWhenTrue = ["sky_object"] 

336 

337 

338class SkySourceSelector(FlagSelector): 

339 """Selects sky sources from sourceTables.""" 

340 

341 def getInputSchema(self) -> KeyedDataSchema: 

342 yield from super().getInputSchema() 

343 

344 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

345 result: Optional[Vector] = None 

346 temp = super().__call__(data, **(kwargs)) 

347 if result is not None: 

348 result &= temp # type: ignore 

349 else: 

350 result = temp 

351 return result 

352 

353 def setDefaults(self): 

354 self.selectWhenFalse = [ 

355 "pixelFlags_edge", 

356 ] 

357 self.selectWhenTrue = ["sky_source"] 

358 

359 

360class GoodDiaSourceSelector(FlagSelector): 

361 """Selects good DIA sources from diaSourceTables.""" 

362 

363 def getInputSchema(self) -> KeyedDataSchema: 

364 yield from super().getInputSchema() 

365 

366 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

367 result: Optional[Vector] = None 

368 temp = super().__call__(data, **(kwargs)) 

369 if result is not None: 

370 result &= temp # type: ignore 

371 else: 

372 result = temp 

373 return result 

374 

375 def setDefaults(self): 

376 self.selectWhenFalse = [ 

377 "base_PixelFlags_flag_bad", 

378 "base_PixelFlags_flag_suspect", 

379 "base_PixelFlags_flag_saturatedCenter", 

380 "base_PixelFlags_flag_interpolated", 

381 "base_PixelFlags_flag_interpolatedCenter", 

382 "base_PixelFlags_flag_edge", 

383 ] 

384 

385 

386class ExtendednessSelector(VectorAction): 

387 """A selector that picks between extended and point sources.""" 

388 

389 vectorKey = Field[str]( 

390 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness" 

391 ) 

392 

393 def getInputSchema(self) -> KeyedDataSchema: 

394 return ((self.vectorKey, Vector),) 

395 

396 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

397 key = self.vectorKey.format(**kwargs) 

398 return cast(Vector, data[key]) 

399 

400 

401class StarSelector(ExtendednessSelector): 

402 """A selector that picks out stars based off of their 

403 extendedness values. 

404 """ 

405 

406 extendedness_maximum = Field[float]( 

407 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float 

408 ) 

409 

410 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

411 extendedness = super().__call__(data, **kwargs) 

412 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum))) 

413 

414 

415class GalaxySelector(ExtendednessSelector): 

416 """A selector that picks out galaxies based off of their 

417 extendedness values. 

418 """ 

419 

420 extendedness_minimum = Field[float]( 

421 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5 

422 ) 

423 

424 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

425 extendedness = super().__call__(data, **kwargs) 

426 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore 

427 

428 

429class UnknownSelector(ExtendednessSelector): 

430 """A selector that picks out unclassified objects based off of their 

431 extendedness values. 

432 """ 

433 

434 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

435 extendedness = super().__call__(data, **kwargs) 

436 return extendedness == 9 

437 

438 

439class VectorSelector(VectorAction): 

440 """Load a boolean vector from KeyedData and return it for use as a 

441 selector. 

442 """ 

443 

444 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask") 

445 

446 def getInputSchema(self) -> KeyedDataSchema: 

447 return ((self.vectorKey, Vector),) 

448 

449 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

450 return cast(Vector, data[self.vectorKey.format(**kwargs)]) 

451 

452 

453class ThresholdSelector(VectorAction): 

454 """Return a mask corresponding to an applied threshold.""" 

455 

456 op = Field[str](doc="Operator name.") 

457 threshold = Field[float](doc="Threshold to apply.") 

458 vectorKey = Field[str](doc="Name of column") 

459 

460 def getInputSchema(self) -> KeyedDataSchema: 

461 return ((self.vectorKey, Vector),) 

462 

463 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

464 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold) 

465 return cast(Vector, mask) 

466 

467 

468class BandSelector(VectorAction): 

469 """Makes a mask for sources observed in a specified set of bands.""" 

470 

471 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band") 

472 bands = ListField[str]( 

473 doc="The bands to select. `None` indicates no band selection applied.", 

474 default=[], 

475 ) 

476 

477 def getInputSchema(self) -> KeyedDataSchema: 

478 return ((self.vectorKey, Vector),) 

479 

480 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

481 bands: Optional[tuple[str, ...]] 

482 match kwargs: 

483 case {"band": band} if not self.bands and self.bands == []: 

484 bands = (band,) 

485 case {"bands": bands} if not self.bands and self.bands == []: 

486 bands = bands 

487 case _ if self.bands: 

488 bands = tuple(self.bands) 

489 case _: 

490 bands = None 

491 if bands: 

492 mask = np.in1d(data[self.vectorKey], bands) 

493 else: 

494 # No band selection is applied, i.e., select all rows 

495 mask = np.full(len(data[self.vectorKey]), True) # type: ignore 

496 return cast(Vector, mask)