Coverage for python/lsst/analysis/tools/actions/vector/selectors.py: 30%

229 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-05 14:05 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ( 

24 "FlagSelector", 

25 "CoaddPlotFlagSelector", 

26 "RangeSelector", 

27 "SnSelector", 

28 "ExtendednessSelector", 

29 "SkyObjectSelector", 

30 "SkySourceSelector", 

31 "GoodDiaSourceSelector", 

32 "StarSelector", 

33 "GalaxySelector", 

34 "UnknownSelector", 

35 "VectorSelector", 

36 "VisitPlotFlagSelector", 

37 "ThresholdSelector", 

38 "BandSelector", 

39) 

40 

41import operator 

42from typing import Optional, cast 

43 

44import numpy as np 

45from lsst.pex.config import Field 

46from lsst.pex.config.listField import ListField 

47 

48from ...interfaces import KeyedData, KeyedDataSchema, Vector, VectorAction 

49from ...math import divide 

50 

51 

52class SelectorBase(VectorAction): 

53 plotLabelKey = Field[str]( 

54 doc="Key to use when populating plot info, ignored if empty string", optional=True, default="" 

55 ) 

56 

57 def _addValueToPlotInfo(self, value, **kwargs): 

58 if "plotInfo" in kwargs and self.plotLabelKey: 

59 kwargs["plotInfo"][self.plotLabelKey] = value 

60 

61 

62class FlagSelector(VectorAction): 

63 """The base flag selector to use to select valid sources for QA.""" 

64 

65 selectWhenFalse = ListField[str]( 

66 doc="Names of the flag columns to select on when False", optional=False, default=[] 

67 ) 

68 

69 selectWhenTrue = ListField[str]( 

70 doc="Names of the flag columns to select on when True", optional=False, default=[] 

71 ) 

72 

73 def getInputSchema(self) -> KeyedDataSchema: 

74 allCols = list(self.selectWhenFalse) + list(self.selectWhenTrue) 

75 return ((col, Vector) for col in allCols) 

76 

77 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

78 """Select on the given flags 

79 

80 Parameters 

81 ---------- 

82 data : `KeyedData` 

83 

84 Returns 

85 ------- 

86 result : `Vector` 

87 A mask of the objects that satisfy the given 

88 flag cuts. 

89 

90 Notes 

91 ----- 

92 Uses the columns in selectWhenFalse and 

93 selectWhenTrue to decide which columns to 

94 select on in each circumstance. 

95 """ 

96 

97 if not self.selectWhenFalse and not self.selectWhenTrue: 

98 raise RuntimeError("No column keys specified") 

99 results: Optional[Vector] = None 

100 

101 for flag in self.selectWhenFalse: # type: ignore 

102 temp = np.array(data[flag.format(**kwargs)] == 0) 

103 if results is not None: 

104 results &= temp # type: ignore 

105 else: 

106 results = temp 

107 

108 for flag in self.selectWhenTrue: 

109 temp = np.array(data[flag.format(**kwargs)] == 1) 

110 if results is not None: 

111 results &= temp # type: ignore 

112 else: 

113 results = temp 

114 # The test at the beginning assures this can never be None 

115 return cast(Vector, results) 

116 

117 

118class CoaddPlotFlagSelector(FlagSelector): 

119 """This default setting makes it take the band from 

120 the kwargs. 

121 """ 

122 

123 bands = ListField[str]( 

124 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

125 default=[], 

126 ) 

127 

128 def getInputSchema(self) -> KeyedDataSchema: 

129 yield from super().getInputSchema() 

130 

131 def refMatchContext(self): 

132 self.selectWhenFalse = [ 

133 "{band}_psfFlux_flag_target", 

134 "{band}_pixelFlags_saturatedCenter_target", 

135 "{band}_extendedness_flag_target", 

136 "xy_flag_target", 

137 ] 

138 self.selectWhenTrue = ["detect_isPatchInner_target", "detect_isDeblendedSource_target"] 

139 

140 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

141 result: Optional[Vector] = None 

142 bands: tuple[str, ...] 

143 match kwargs: 

144 case {"band": band} if not self.bands and self.bands == []: 

145 bands = (band,) 

146 case {"bands": bands} if not self.bands and self.bands == []: 

147 bands = bands 

148 case _ if self.bands: 

149 bands = tuple(self.bands) 

150 case _: 

151 bands = ("",) 

152 for band in bands: 

153 temp = super().__call__(data, **(kwargs | dict(band=band))) 

154 if result is not None: 

155 result &= temp # type: ignore 

156 else: 

157 result = temp 

158 return cast(Vector, result) 

159 

160 def setDefaults(self): 

161 self.selectWhenFalse = [ 

162 "{band}_psfFlux_flag", 

163 "{band}_pixelFlags_saturatedCenter", 

164 "{band}_extendedness_flag", 

165 "xy_flag", 

166 ] 

167 self.selectWhenTrue = ["detect_isPatchInner", "detect_isDeblendedSource"] 

168 

169 

170class VisitPlotFlagSelector(FlagSelector): 

171 """Select on a set of flags appropriate for making visit-level plots 

172 (i.e., using sourceTable_visit catalogs). 

173 """ 

174 

175 catalogSuffix = Field[str](doc="The suffix to apply to all the keys.", default="") 

176 

177 def getInputSchema(self) -> KeyedDataSchema: 

178 yield from super().getInputSchema() 

179 

180 def refMatchContext(self): 

181 self.selectWhenFalse = [ 

182 "psfFlux_flag_target", 

183 "pixelFlags_saturatedCenter_target", 

184 "extendedness_flag_target", 

185 "centroid_flag_target", 

186 ] 

187 

188 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

189 result: Optional[Vector] = None 

190 temp = super().__call__(data, **kwargs) 

191 if result is not None: 

192 result &= temp # type: ignore 

193 else: 

194 result = temp 

195 

196 return result 

197 

198 def setDefaults(self): 

199 self.selectWhenFalse = [ 

200 "psfFlux_flag", 

201 "pixelFlags_saturatedCenter", 

202 "extendedness_flag", 

203 "centroid_flag", 

204 ] 

205 

206 

207class RangeSelector(VectorAction): 

208 """Selects rows within a range, inclusive of min/exclusive of max.""" 

209 

210 vectorKey = Field[str](doc="Key to select from data") 

211 maximum = Field[float](doc="The maximum value", default=np.Inf) 

212 minimum = Field[float](doc="The minimum value", default=np.nextafter(-np.Inf, 0.0)) 

213 

214 def getInputSchema(self) -> KeyedDataSchema: 

215 yield self.vectorKey, Vector 

216 

217 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

218 """Return a mask of rows with values within the specified range. 

219 

220 Parameters 

221 ---------- 

222 data : `KeyedData` 

223 

224 Returns 

225 ------- 

226 result : `Vector` 

227 A mask of the rows with values within the specified range. 

228 """ 

229 values = cast(Vector, data[self.vectorKey]) 

230 mask = (values >= self.minimum) & (values < self.maximum) 

231 

232 return cast(Vector, mask) 

233 

234 

235class SnSelector(SelectorBase): 

236 """Selects points that have S/N > threshold in the given flux type.""" 

237 

238 fluxType = Field[str](doc="Flux type to calculate the S/N in.", default="{band}_psfFlux") 

239 threshold = Field[float](doc="The S/N threshold to remove sources with.", default=500.0) 

240 maxSN = Field[float](doc="Maximum S/N to include in the sample (to allow S/N ranges).", default=1e6) 

241 uncertaintySuffix = Field[str]( 

242 doc="Suffix to add to fluxType to specify uncertainty column", default="Err" 

243 ) 

244 bands = ListField[str]( 

245 doc="The bands to apply the signal to noise cut in." "Takes precedence if bands passed to call", 

246 default=[], 

247 ) 

248 

249 def getInputSchema(self) -> KeyedDataSchema: 

250 fluxCol = self.fluxType 

251 fluxInd = fluxCol.find("lux") + len("lux") 

252 errCol = f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix}" + f"{fluxCol}"[fluxInd:] 

253 yield fluxCol, Vector 

254 yield errCol, Vector 

255 

256 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

257 """Makes a mask of objects that have S/N greater than 

258 self.threshold in self.fluxType 

259 

260 Parameters 

261 ---------- 

262 data : `KeyedData` 

263 The data to perform the selection on. 

264 

265 Returns 

266 ------- 

267 result : `Vector` 

268 A mask of the objects that satisfy the given 

269 S/N cut. 

270 """ 

271 

272 self._addValueToPlotInfo(self.threshold, **kwargs) 

273 mask: Optional[Vector] = None 

274 bands: tuple[str, ...] 

275 match kwargs: 

276 case {"band": band} if not self.bands and self.bands == []: 

277 bands = (band,) 

278 case {"bands": bands} if not self.bands and self.bands == []: 

279 bands = bands 

280 case _ if self.bands: 

281 bands = tuple(self.bands) 

282 case _: 

283 bands = ("",) 

284 for band in bands: 

285 fluxCol = self.fluxType.format(**(kwargs | dict(band=band))) 

286 fluxInd = fluxCol.find("lux") + len("lux") 

287 errCol = ( 

288 f"{fluxCol}"[:fluxInd] + f"{self.uncertaintySuffix.format(**kwargs)}" + f"{fluxCol}"[fluxInd:] 

289 ) 

290 vec = divide(cast(Vector, data[fluxCol]), cast(Vector, data[errCol])) 

291 temp = (vec > self.threshold) & (vec < self.maxSN) 

292 if mask is not None: 

293 mask &= temp # type: ignore 

294 else: 

295 mask = temp 

296 

297 # It should not be possible for mask to be a None now 

298 return np.array(cast(Vector, mask)) 

299 

300 

301class SkyObjectSelector(FlagSelector): 

302 """Selects sky objects in the given band(s).""" 

303 

304 bands = ListField[str]( 

305 doc="The bands to apply the flags in, takes precedence if band supplied in kwargs", 

306 default=["i"], 

307 ) 

308 

309 def getInputSchema(self) -> KeyedDataSchema: 

310 yield from super().getInputSchema() 

311 

312 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

313 result: Optional[Vector] = None 

314 bands: tuple[str, ...] 

315 match kwargs: 

316 case {"band": band} if not self.bands and self.bands == []: 

317 bands = (band,) 

318 case {"bands": bands} if not self.bands and self.bands == []: 

319 bands = bands 

320 case _ if self.bands: 

321 bands = tuple(self.bands) 

322 case _: 

323 bands = ("",) 

324 for band in bands: 

325 temp = super().__call__(data, **(kwargs | dict(band=band))) 

326 if result is not None: 

327 result &= temp # type: ignore 

328 else: 

329 result = temp 

330 return cast(Vector, result) 

331 

332 def setDefaults(self): 

333 self.selectWhenFalse = [ 

334 "{band}_pixelFlags_edge", 

335 ] 

336 self.selectWhenTrue = ["sky_object"] 

337 

338 

339class SkySourceSelector(FlagSelector): 

340 """Selects sky sources from sourceTables.""" 

341 

342 def getInputSchema(self) -> KeyedDataSchema: 

343 yield from super().getInputSchema() 

344 

345 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

346 result: Optional[Vector] = None 

347 temp = super().__call__(data, **(kwargs)) 

348 if result is not None: 

349 result &= temp # type: ignore 

350 else: 

351 result = temp 

352 return result 

353 

354 def setDefaults(self): 

355 self.selectWhenFalse = [ 

356 "pixelFlags_edge", 

357 ] 

358 self.selectWhenTrue = ["sky_source"] 

359 

360 

361class GoodDiaSourceSelector(FlagSelector): 

362 """Selects good DIA sources from diaSourceTables.""" 

363 

364 def getInputSchema(self) -> KeyedDataSchema: 

365 yield from super().getInputSchema() 

366 

367 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

368 result: Optional[Vector] = None 

369 temp = super().__call__(data, **(kwargs)) 

370 if result is not None: 

371 result &= temp # type: ignore 

372 else: 

373 result = temp 

374 return result 

375 

376 def setDefaults(self): 

377 self.selectWhenFalse = [ 

378 "base_PixelFlags_flag_bad", 

379 "base_PixelFlags_flag_suspect", 

380 "base_PixelFlags_flag_saturatedCenter", 

381 "base_PixelFlags_flag_interpolated", 

382 "base_PixelFlags_flag_interpolatedCenter", 

383 "base_PixelFlags_flag_edge", 

384 ] 

385 

386 

387class ExtendednessSelector(VectorAction): 

388 """A selector that picks between extended and point sources.""" 

389 

390 vectorKey = Field[str]( 

391 doc="Key of the Vector which defines extendedness metric", default="{band}_extendedness" 

392 ) 

393 

394 def getInputSchema(self) -> KeyedDataSchema: 

395 return ((self.vectorKey, Vector),) 

396 

397 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

398 key = self.vectorKey.format(**kwargs) 

399 return cast(Vector, data[key]) 

400 

401 

402class StarSelector(ExtendednessSelector): 

403 """A selector that picks out stars based off of their 

404 extendedness values. 

405 """ 

406 

407 extendedness_maximum = Field[float]( 

408 doc="Maximum extendedness to qualify as unresolved, inclusive.", default=0.5, dtype=float 

409 ) 

410 

411 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

412 extendedness = super().__call__(data, **kwargs) 

413 return np.array(cast(Vector, (extendedness >= 0) & (extendedness < self.extendedness_maximum))) 

414 

415 

416class GalaxySelector(ExtendednessSelector): 

417 """A selector that picks out galaxies based off of their 

418 extendedness values. 

419 """ 

420 

421 extendedness_minimum = Field[float]( 

422 doc="Minimum extendedness to qualify as resolved, not inclusive.", default=0.5 

423 ) 

424 

425 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

426 extendedness = super().__call__(data, **kwargs) 

427 return cast(Vector, extendedness > self.extendedness_minimum) # type: ignore 

428 

429 

430class UnknownSelector(ExtendednessSelector): 

431 """A selector that picks out unclassified objects based off of their 

432 extendedness values. 

433 """ 

434 

435 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

436 extendedness = super().__call__(data, **kwargs) 

437 return extendedness == 9 

438 

439 

440class VectorSelector(VectorAction): 

441 """Load a boolean vector from KeyedData and return it for use as a 

442 selector. 

443 """ 

444 

445 vectorKey = Field[str](doc="Key corresponding to boolean vector to use as a selection mask") 

446 

447 def getInputSchema(self) -> KeyedDataSchema: 

448 return ((self.vectorKey, Vector),) 

449 

450 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

451 return cast(Vector, data[self.vectorKey.format(**kwargs)]) 

452 

453 

454class ThresholdSelector(VectorAction): 

455 """Return a mask corresponding to an applied threshold.""" 

456 

457 op = Field[str](doc="Operator name.") 

458 threshold = Field[float](doc="Threshold to apply.") 

459 vectorKey = Field[str](doc="Name of column") 

460 

461 def getInputSchema(self) -> KeyedDataSchema: 

462 return ((self.vectorKey, Vector),) 

463 

464 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

465 mask = getattr(operator, self.op)(data[self.vectorKey], self.threshold) 

466 return cast(Vector, mask) 

467 

468 

469class BandSelector(VectorAction): 

470 """Makes a mask for sources observed in a specified set of bands.""" 

471 

472 vectorKey = Field[str](doc="Key of the Vector which defines the band", default="band") 

473 bands = ListField[str]( 

474 doc="The bands to select. `None` indicates no band selection applied.", 

475 default=[], 

476 ) 

477 

478 def getInputSchema(self) -> KeyedDataSchema: 

479 return ((self.vectorKey, Vector),) 

480 

481 def __call__(self, data: KeyedData, **kwargs) -> Vector: 

482 bands: Optional[tuple[str, ...]] 

483 match kwargs: 

484 case {"band": band} if not self.bands and self.bands == []: 

485 bands = (band,) 

486 case {"bands": bands} if not self.bands and self.bands == []: 

487 bands = bands 

488 case _ if self.bands: 

489 bands = tuple(self.bands) 

490 case _: 

491 bands = None 

492 if bands: 

493 mask = np.in1d(data[self.vectorKey], bands) 

494 else: 

495 # No band selection is applied, i.e., select all rows 

496 mask = np.full(len(data[self.vectorKey]), True) # type: ignore 

497 return cast(Vector, mask) 

498 

499 

500class ParentObjectSelector(FlagSelector): 

501 """Select only parent objects that are not sky objects.""" 

502 

503 def setDefaults(self): 

504 # This selects all of the parents 

505 self.selectWhenFalse = [ 

506 "detect_isDeblendedModelSource", 

507 "sky_object", 

508 ] 

509 self.selectWhenTrue = ["detect_isPatchInner"]