Coverage for python/lsst/pipe/tasks/isolatedStarAssociation.py: 16%

213 statements  

« prev     ^ index     » next       coverage.py v6.4.1, created at 2022-06-30 03:26 -0700

1# 

2# LSST Data Management System 

3# Copyright 2008-2022 AURA/LSST. 

4# 

5# This product includes software developed by the 

6# LSST Project (http://www.lsst.org/). 

7# 

8# This program is free software: you can redistribute it and/or modify 

9# it under the terms of the GNU General Public License as published by 

10# the Free Software Foundation, either version 3 of the License, or 

11# (at your option) any later version. 

12# 

13# This program is distributed in the hope that it will be useful, 

14# but WITHOUT ANY WARRANTY; without even the implied warranty of 

15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

16# GNU General Public License for more details. 

17# 

18# You should have received a copy of the LSST License Statement and 

19# the GNU General Public License along with this program. If not, 

20# see <http://www.lsstcorp.org/LegalNotices/>. 

21# 

22import numpy as np 

23import pandas as pd 

24from smatch.matcher import Matcher 

25 

26import lsst.pex.config as pexConfig 

27import lsst.pipe.base as pipeBase 

28from lsst.skymap import BaseSkyMap 

29from lsst.meas.algorithms.sourceSelector import sourceSelectorRegistry 

30 

31 

32__all__ = ['IsolatedStarAssociationConnections', 

33 'IsolatedStarAssociationConfig', 

34 'IsolatedStarAssociationTask'] 

35 

36 

37class IsolatedStarAssociationConnections(pipeBase.PipelineTaskConnections, 

38 dimensions=('instrument', 'tract', 'skymap',), 

39 defaultTemplates={}): 

40 source_table_visit = pipeBase.connectionTypes.Input( 

41 doc='Source table in parquet format, per visit', 

42 name='sourceTable_visit', 

43 storageClass='DataFrame', 

44 dimensions=('instrument', 'visit'), 

45 deferLoad=True, 

46 multiple=True, 

47 ) 

48 skymap = pipeBase.connectionTypes.Input( 

49 doc="Input definition of geometry/bbox and projection/wcs for warped exposures", 

50 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

51 storageClass='SkyMap', 

52 dimensions=('skymap',), 

53 ) 

54 isolated_star_sources = pipeBase.connectionTypes.Output( 

55 doc='Catalog of individual sources for the isolated stars', 

56 name='isolated_star_sources', 

57 storageClass='DataFrame', 

58 dimensions=('instrument', 'tract', 'skymap'), 

59 ) 

60 isolated_star_cat = pipeBase.connectionTypes.Output( 

61 doc='Catalog of isolated star positions', 

62 name='isolated_star_cat', 

63 storageClass='DataFrame', 

64 dimensions=('instrument', 'tract', 'skymap'), 

65 ) 

66 

67 

68class IsolatedStarAssociationConfig(pipeBase.PipelineTaskConfig, 

69 pipelineConnections=IsolatedStarAssociationConnections): 

70 """Configuration for IsolatedStarAssociationTask.""" 

71 

72 inst_flux_field = pexConfig.Field( 

73 doc=('Full name of instFlux field to use for s/n selection and persistence. ' 

74 'The associated flag will be implicity included in bad_flags. ' 

75 'Note that this is expected to end in ``instFlux``.'), 

76 dtype=str, 

77 default='apFlux_12_0_instFlux', 

78 ) 

79 match_radius = pexConfig.Field( 

80 doc='Match radius (arcseconds)', 

81 dtype=float, 

82 default=1.0, 

83 ) 

84 isolation_radius = pexConfig.Field( 

85 doc=('Isolation radius (arcseconds). Any stars with average centroids ' 

86 'within this radius of another star will be rejected from the final ' 

87 'catalog. This radius should be at least 2x match_radius.'), 

88 dtype=float, 

89 default=2.0, 

90 ) 

91 band_order = pexConfig.ListField( 

92 doc=(('Ordered list of bands to use for matching/storage. ' 

93 'Any bands not listed will not be matched.')), 

94 dtype=str, 

95 default=['i', 'z', 'r', 'g', 'y', 'u'], 

96 ) 

97 id_column = pexConfig.Field( 

98 doc='Name of column with source id.', 

99 dtype=str, 

100 default='sourceId', 

101 ) 

102 ra_column = pexConfig.Field( 

103 doc='Name of column with right ascension.', 

104 dtype=str, 

105 default='ra', 

106 ) 

107 dec_column = pexConfig.Field( 

108 doc='Name of column with declination.', 

109 dtype=str, 

110 default='decl', 

111 ) 

112 physical_filter_column = pexConfig.Field( 

113 doc='Name of column with physical filter name', 

114 dtype=str, 

115 default='physical_filter', 

116 ) 

117 band_column = pexConfig.Field( 

118 doc='Name of column with band name', 

119 dtype=str, 

120 default='band', 

121 ) 

122 extra_columns = pexConfig.ListField( 

123 doc='Extra names of columns to read and persist (beyond instFlux and error).', 

124 dtype=str, 

125 default=['x', 

126 'y', 

127 'apFlux_17_0_instFlux', 

128 'apFlux_17_0_instFluxErr', 

129 'apFlux_17_0_flag', 

130 'localBackground_instFlux', 

131 'localBackground_flag'] 

132 ) 

133 source_selector = sourceSelectorRegistry.makeField( 

134 doc='How to select sources. Under normal usage this should not be changed.', 

135 default='science' 

136 ) 

137 

138 def setDefaults(self): 

139 super().setDefaults() 

140 

141 source_selector = self.source_selector['science'] 

142 source_selector.setDefaults() 

143 

144 source_selector.doFlags = True 

145 source_selector.doUnresolved = True 

146 source_selector.doSignalToNoise = True 

147 source_selector.doIsolated = True 

148 

149 source_selector.signalToNoise.minimum = 10.0 

150 source_selector.signalToNoise.maximum = 1000.0 

151 

152 flux_flag_name = self.inst_flux_field.replace("instFlux", "flag") 

153 

154 source_selector.flags.bad = ['pixelFlags_edge', 

155 'pixelFlags_interpolatedCenter', 

156 'pixelFlags_saturatedCenter', 

157 'pixelFlags_crCenter', 

158 'pixelFlags_bad', 

159 'pixelFlags_interpolated', 

160 'pixelFlags_saturated', 

161 'centroid_flag', 

162 flux_flag_name] 

163 

164 source_selector.signalToNoise.fluxField = self.inst_flux_field 

165 source_selector.signalToNoise.errField = self.inst_flux_field + 'Err' 

166 

167 source_selector.isolated.parentName = 'parentSourceId' 

168 source_selector.isolated.nChildName = 'deblend_nChild' 

169 

170 source_selector.unresolved.maximum = 0.5 

171 source_selector.unresolved.name = 'extendedness' 

172 

173 

174class IsolatedStarAssociationTask(pipeBase.PipelineTask): 

175 """Associate sources into isolated star catalogs. 

176 """ 

177 ConfigClass = IsolatedStarAssociationConfig 

178 _DefaultName = 'isolatedStarAssociation' 

179 

180 def __init__(self, **kwargs): 

181 super().__init__(**kwargs) 

182 

183 self.makeSubtask('source_selector') 

184 # Only log warning and fatal errors from the source_selector 

185 self.source_selector.log.setLevel(self.source_selector.log.WARN) 

186 

187 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

188 input_ref_dict = butlerQC.get(inputRefs) 

189 

190 tract = butlerQC.quantum.dataId['tract'] 

191 

192 source_table_refs = input_ref_dict['source_table_visit'] 

193 

194 self.log.info('Running with %d source_table_visit dataRefs', 

195 len(source_table_refs)) 

196 

197 source_table_ref_dict_temp = {source_table_ref.dataId['visit']: source_table_ref for 

198 source_table_ref in source_table_refs} 

199 

200 bands = {source_table_ref.dataId['band'] for source_table_ref in source_table_refs} 

201 for band in bands: 

202 if band not in self.config.band_order: 

203 self.log.warning('Input data has data from band %s but that band is not ' 

204 'configured for matching', band) 

205 

206 # TODO: Sort by visit until DM-31701 is done and we have deterministic 

207 # dataset ordering. 

208 source_table_ref_dict = {visit: source_table_ref_dict_temp[visit] for 

209 visit in sorted(source_table_ref_dict_temp.keys())} 

210 

211 struct = self.run(input_ref_dict['skymap'], tract, source_table_ref_dict) 

212 

213 butlerQC.put(pd.DataFrame(struct.star_source_cat), 

214 outputRefs.isolated_star_sources) 

215 butlerQC.put(pd.DataFrame(struct.star_cat), 

216 outputRefs.isolated_star_cat) 

217 

218 def run(self, skymap, tract, source_table_ref_dict): 

219 """Run the isolated star association task. 

220 

221 Parameters 

222 ---------- 

223 skymap : `lsst.skymap.SkyMap` 

224 Skymap object. 

225 tract : `int` 

226 Tract number. 

227 source_table_ref_dict : `dict` 

228 Dictionary of source_table refs. Key is visit, value is dataref. 

229 

230 Returns 

231 ------- 

232 struct : `lsst.pipe.base.struct` 

233 Struct with outputs for persistence. 

234 """ 

235 star_source_cat = self._make_all_star_sources(skymap[tract], source_table_ref_dict) 

236 

237 primary_bands = self.config.band_order 

238 

239 # Do primary matching 

240 primary_star_cat = self._match_primary_stars(primary_bands, star_source_cat) 

241 

242 if len(primary_star_cat) == 0: 

243 return pipeBase.Struct(star_source_cat=np.zeros(0, star_source_cat.dtype), 

244 star_cat=np.zeros(0, primary_star_cat.dtype)) 

245 

246 # Remove neighbors 

247 primary_star_cat = self._remove_neighbors(primary_star_cat) 

248 

249 if len(primary_star_cat) == 0: 

250 return pipeBase.Struct(star_source_cat=np.zeros(0, star_source_cat.dtype), 

251 star_cat=np.zeros(0, primary_star_cat.dtype)) 

252 

253 # Crop to inner tract region 

254 inner_tract_ids = skymap.findTractIdArray(primary_star_cat[self.config.ra_column], 

255 primary_star_cat[self.config.dec_column], 

256 degrees=True) 

257 use = (inner_tract_ids == tract) 

258 self.log.info('Total of %d isolated stars in inner tract.', use.sum()) 

259 

260 primary_star_cat = primary_star_cat[use] 

261 

262 if len(primary_star_cat) == 0: 

263 return pipeBase.Struct(star_source_cat=np.zeros(0, star_source_cat.dtype), 

264 star_cat=np.zeros(0, primary_star_cat.dtype)) 

265 

266 # Set the unique ids. 

267 primary_star_cat['isolated_star_id'] = self._compute_unique_ids(skymap, 

268 tract, 

269 len(primary_star_cat)) 

270 

271 # Match to sources. 

272 star_source_cat, primary_star_cat = self._match_sources(primary_bands, 

273 star_source_cat, 

274 primary_star_cat) 

275 

276 return pipeBase.Struct(star_source_cat=star_source_cat, 

277 star_cat=primary_star_cat) 

278 

279 def _make_all_star_sources(self, tract_info, source_table_ref_dict): 

280 """Make a catalog of all the star sources. 

281 

282 Parameters 

283 ---------- 

284 tract_info : `lsst.skymap.TractInfo` 

285 Information about the tract. 

286 source_table_ref_dict : `dict` 

287 Dictionary of source_table refs. Key is visit, value is dataref. 

288 

289 Returns 

290 ------- 

291 star_source_cat : `np.ndarray` 

292 Catalog of star sources. 

293 """ 

294 # Internally, we use a numpy recarray, they are by far the fastest 

295 # option in testing for relatively narrow tables. 

296 # (have not tested wide tables) 

297 all_columns, persist_columns = self._get_source_table_visit_column_names() 

298 poly = tract_info.outer_sky_polygon 

299 

300 tables = [] 

301 for visit in source_table_ref_dict: 

302 source_table_ref = source_table_ref_dict[visit] 

303 df = source_table_ref.get(parameters={'columns': all_columns}) 

304 df.reset_index(inplace=True) 

305 

306 goodSrc = self.source_selector.selectSources(df) 

307 

308 table = df[persist_columns][goodSrc.selected].to_records() 

309 

310 # Append columns that include the row in the source table 

311 # and the matched object index (to be filled later). 

312 table = np.lib.recfunctions.append_fields(table, 

313 ['source_row', 

314 'obj_index'], 

315 [np.where(goodSrc.selected)[0], 

316 np.zeros(goodSrc.selected.sum(), dtype=np.int32)], 

317 dtypes=['i4', 'i4'], 

318 usemask=False) 

319 

320 # We cut to the outer tract polygon to ensure consistent matching 

321 # from tract to tract. 

322 tract_use = poly.contains(np.deg2rad(table[self.config.ra_column]), 

323 np.deg2rad(table[self.config.dec_column])) 

324 

325 tables.append(table[tract_use]) 

326 

327 # Combine tables 

328 star_source_cat = np.concatenate(tables) 

329 

330 return star_source_cat 

331 

332 def _get_source_table_visit_column_names(self): 

333 """Get the list of sourceTable_visit columns from the config. 

334 

335 Returns 

336 ------- 

337 all_columns : `list` [`str`] 

338 All columns to read 

339 persist_columns : `list` [`str`] 

340 Columns to persist (excluding selection columns) 

341 """ 

342 columns = [self.config.id_column, 

343 'visit', 'detector', 

344 self.config.ra_column, self.config.dec_column, 

345 self.config.physical_filter_column, self.config.band_column, 

346 self.config.inst_flux_field, self.config.inst_flux_field + 'Err'] 

347 columns.extend(self.config.extra_columns) 

348 

349 all_columns = columns.copy() 

350 if self.source_selector.config.doFlags: 

351 all_columns.extend(self.source_selector.config.flags.bad) 

352 if self.source_selector.config.doUnresolved: 

353 all_columns.append(self.source_selector.config.unresolved.name) 

354 if self.source_selector.config.doIsolated: 

355 all_columns.append(self.source_selector.config.isolated.parentName) 

356 all_columns.append(self.source_selector.config.isolated.nChildName) 

357 

358 return all_columns, columns 

359 

360 def _match_primary_stars(self, primary_bands, star_source_cat): 

361 """Match primary stars. 

362 

363 Parameters 

364 ---------- 

365 primary_bands : `list` [`str`] 

366 Ordered list of primary bands. 

367 star_source_cat : `np.ndarray` 

368 Catalog of star sources. 

369 

370 Returns 

371 ------- 

372 primary_star_cat : `np.ndarray` 

373 Catalog of primary star positions 

374 """ 

375 ra_col = self.config.ra_column 

376 dec_col = self.config.dec_column 

377 

378 dtype = self._get_primary_dtype(primary_bands) 

379 

380 primary_star_cat = None 

381 for primary_band in primary_bands: 

382 use = (star_source_cat['band'] == primary_band) 

383 

384 ra = star_source_cat[ra_col][use] 

385 dec = star_source_cat[dec_col][use] 

386 

387 with Matcher(ra, dec) as matcher: 

388 try: 

389 # New smatch API 

390 idx = matcher.query_groups(self.config.match_radius/3600., min_match=1) 

391 except AttributeError: 

392 # Old smatch API 

393 idx = matcher.query_self(self.config.match_radius/3600., min_match=1) 

394 

395 count = len(idx) 

396 

397 if count == 0: 

398 self.log.info('Found 0 primary stars in %s band.', primary_band) 

399 continue 

400 

401 band_cat = np.zeros(count, dtype=dtype) 

402 band_cat['primary_band'] = primary_band 

403 

404 # If the tract cross ra=0 (that is, it has both low ra and high ra) 

405 # then we need to remap all ra values from [0, 360) to [-180, 180) 

406 # before doing any position averaging. 

407 remapped = False 

408 if ra.min() < 60.0 and ra.max() > 300.0: 

409 ra_temp = (ra + 180.0) % 360. - 180. 

410 remapped = True 

411 else: 

412 ra_temp = ra 

413 

414 # Compute mean position for each primary star 

415 for i, row in enumerate(idx): 

416 row = np.array(row) 

417 band_cat[ra_col][i] = np.mean(ra_temp[row]) 

418 band_cat[dec_col][i] = np.mean(dec[row]) 

419 

420 if remapped: 

421 # Remap ra back to [0, 360) 

422 band_cat[ra_col] %= 360.0 

423 

424 # Match to previous band catalog(s), and remove duplicates. 

425 if primary_star_cat is None or len(primary_star_cat) == 0: 

426 primary_star_cat = band_cat 

427 else: 

428 with Matcher(band_cat[ra_col], band_cat[dec_col]) as matcher: 

429 idx = matcher.query_radius(primary_star_cat[ra_col], 

430 primary_star_cat[dec_col], 

431 self.config.match_radius/3600.) 

432 # Any object with a match should be removed. 

433 match_indices = np.array([i for i in range(len(idx)) if len(idx[i]) > 0]) 

434 if len(match_indices) > 0: 

435 band_cat = np.delete(band_cat, match_indices) 

436 

437 primary_star_cat = np.append(primary_star_cat, band_cat) 

438 self.log.info('Found %d primary stars in %s band.', len(band_cat), primary_band) 

439 

440 # If everything was cut, we still want the correct datatype. 

441 if primary_star_cat is None: 

442 primary_star_cat = np.zeros(0, dtype=dtype) 

443 

444 return primary_star_cat 

445 

446 def _remove_neighbors(self, primary_star_cat): 

447 """Remove neighbors from the primary star catalog. 

448 

449 Parameters 

450 ---------- 

451 primary_star_cat : `np.ndarray` 

452 Primary star catalog. 

453 

454 Returns 

455 ------- 

456 primary_star_cat_cut : `np.ndarray` 

457 Primary star cat with neighbors removed. 

458 """ 

459 ra_col = self.config.ra_column 

460 dec_col = self.config.dec_column 

461 

462 with Matcher(primary_star_cat[ra_col], primary_star_cat[dec_col]) as matcher: 

463 # By setting min_match=2 objects that only match to themselves 

464 # will not be recorded. 

465 try: 

466 # New smatch API 

467 idx = matcher.query_groups(self.config.isolation_radius/3600., min_match=2) 

468 except AttributeError: 

469 # Old smatch API 

470 idx = matcher.query_self(self.config.isolation_radius/3600., min_match=2) 

471 

472 try: 

473 neighbor_indices = np.concatenate(idx) 

474 except ValueError: 

475 neighbor_indices = np.zeros(0, dtype=int) 

476 

477 if len(neighbor_indices) > 0: 

478 neighbored = np.unique(neighbor_indices) 

479 self.log.info('Cutting %d objects with close neighbors.', len(neighbored)) 

480 primary_star_cat = np.delete(primary_star_cat, neighbored) 

481 

482 return primary_star_cat 

483 

484 def _match_sources(self, bands, star_source_cat, primary_star_cat): 

485 """Match individual sources to primary stars. 

486 

487 Parameters 

488 ---------- 

489 bands : `list` [`str`] 

490 List of bands. 

491 star_source_cat : `np.ndarray` 

492 Array of star sources. 

493 primary_star_cat : `np.ndarray` 

494 Array of primary stars. 

495 

496 Returns 

497 ------- 

498 star_source_cat_sorted : `np.ndarray` 

499 Sorted and cropped array of star sources. 

500 primary_star_cat : `np.ndarray` 

501 Catalog of isolated stars, with indexes to star_source_cat_cut. 

502 """ 

503 ra_col = self.config.ra_column 

504 dec_col = self.config.dec_column 

505 

506 # We match sources per-band because it allows us to have sorted 

507 # sources for easy retrieval of per-band matches. 

508 n_source_per_band_per_obj = np.zeros((len(bands), 

509 len(primary_star_cat)), 

510 dtype=np.int32) 

511 band_uses = [] 

512 idxs = [] 

513 with Matcher(primary_star_cat[ra_col], primary_star_cat[dec_col]) as matcher: 

514 for b, band in enumerate(bands): 

515 band_use, = np.where(star_source_cat['band'] == band) 

516 

517 idx = matcher.query_radius(star_source_cat[ra_col][band_use], 

518 star_source_cat[dec_col][band_use], 

519 self.config.match_radius/3600.) 

520 n_source_per_band_per_obj[b, :] = np.array([len(row) for row in idx]) 

521 idxs.append(idx) 

522 band_uses.append(band_use) 

523 

524 n_source_per_obj = np.sum(n_source_per_band_per_obj, axis=0) 

525 

526 primary_star_cat['nsource'] = n_source_per_obj 

527 primary_star_cat['source_cat_index'][1:] = np.cumsum(n_source_per_obj)[:-1] 

528 

529 n_tot_source = primary_star_cat['source_cat_index'][-1] + primary_star_cat['nsource'][-1] 

530 

531 # Temporary arrays until we crop/sort the source catalog 

532 source_index = np.zeros(n_tot_source, dtype=np.int32) 

533 obj_index = np.zeros(n_tot_source, dtype=np.int32) 

534 

535 ctr = 0 

536 for i in range(len(primary_star_cat)): 

537 obj_index[ctr: ctr + n_source_per_obj[i]] = i 

538 for b in range(len(bands)): 

539 source_index[ctr: ctr + n_source_per_band_per_obj[b, i]] = band_uses[b][idxs[b][i]] 

540 ctr += n_source_per_band_per_obj[b, i] 

541 

542 source_cat_index_band_offset = np.cumsum(n_source_per_band_per_obj, axis=0) 

543 

544 for b, band in enumerate(bands): 

545 primary_star_cat[f'nsource_{band}'] = n_source_per_band_per_obj[b, :] 

546 if b == 0: 

547 # The first band listed is the same as the overall star 

548 primary_star_cat[f'source_cat_index_{band}'] = primary_star_cat['source_cat_index'] 

549 else: 

550 # Other band indices are offset from the previous band 

551 primary_star_cat[f'source_cat_index_{band}'] = (primary_star_cat['source_cat_index'] 

552 + source_cat_index_band_offset[b - 1, :]) 

553 

554 star_source_cat = star_source_cat[source_index] 

555 star_source_cat['obj_index'] = obj_index 

556 

557 return star_source_cat, primary_star_cat 

558 

559 def _compute_unique_ids(self, skymap, tract, nstar): 

560 """Compute unique star ids. 

561 

562 This is a simple hash of the tract and star to provide an 

563 id that is unique for a given processing. 

564 

565 Parameters 

566 ---------- 

567 skymap : `lsst.skymap.Skymap` 

568 Skymap object. 

569 tract : `int` 

570 Tract id number. 

571 nstar : `int` 

572 Number of stars. 

573 

574 Returns 

575 ------- 

576 ids : `np.ndarray` 

577 Array of unique star ids. 

578 """ 

579 # The end of the id will be big enough to hold the tract number 

580 mult = 10**(int(np.log10(len(skymap))) + 1) 

581 

582 return (np.arange(nstar) + 1)*mult + tract 

583 

584 def _get_primary_dtype(self, primary_bands): 

585 """Get the numpy datatype for the primary star catalog. 

586 

587 Parameters 

588 ---------- 

589 primary_bands : `list` [`str`] 

590 List of primary bands. 

591 

592 Returns 

593 ------- 

594 dtype : `numpy.dtype` 

595 Datatype of the primary catalog. 

596 """ 

597 max_len = max([len(primary_band) for primary_band in primary_bands]) 

598 

599 dtype = [('isolated_star_id', 'i8'), 

600 (self.config.ra_column, 'f8'), 

601 (self.config.dec_column, 'f8'), 

602 ('primary_band', f'U{max_len}'), 

603 ('source_cat_index', 'i4'), 

604 ('nsource', 'i4')] 

605 

606 for band in primary_bands: 

607 dtype.append((f'source_cat_index_{band}', 'i4')) 

608 dtype.append((f'nsource_{band}', 'i4')) 

609 

610 return dtype