Coverage for python/lsst/pipe/tasks/isolatedStarAssociation.py: 14%

213 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-01 01:33 -0800

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ['IsolatedStarAssociationConnections', 

23 'IsolatedStarAssociationConfig', 

24 'IsolatedStarAssociationTask'] 

25 

26import numpy as np 

27import pandas as pd 

28from smatch.matcher import Matcher 

29 

30import lsst.pex.config as pexConfig 

31import lsst.pipe.base as pipeBase 

32from lsst.skymap import BaseSkyMap 

33from lsst.meas.algorithms.sourceSelector import sourceSelectorRegistry 

34 

35 

36class IsolatedStarAssociationConnections(pipeBase.PipelineTaskConnections, 

37 dimensions=('instrument', 'tract', 'skymap',), 

38 defaultTemplates={}): 

39 source_table_visit = pipeBase.connectionTypes.Input( 

40 doc='Source table in parquet format, per visit', 

41 name='sourceTable_visit', 

42 storageClass='DataFrame', 

43 dimensions=('instrument', 'visit'), 

44 deferLoad=True, 

45 multiple=True, 

46 ) 

47 skymap = pipeBase.connectionTypes.Input( 

48 doc="Input definition of geometry/bbox and projection/wcs for warped exposures", 

49 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

50 storageClass='SkyMap', 

51 dimensions=('skymap',), 

52 ) 

53 isolated_star_sources = pipeBase.connectionTypes.Output( 

54 doc='Catalog of individual sources for the isolated stars', 

55 name='isolated_star_sources', 

56 storageClass='DataFrame', 

57 dimensions=('instrument', 'tract', 'skymap'), 

58 ) 

59 isolated_star_cat = pipeBase.connectionTypes.Output( 

60 doc='Catalog of isolated star positions', 

61 name='isolated_star_cat', 

62 storageClass='DataFrame', 

63 dimensions=('instrument', 'tract', 'skymap'), 

64 ) 

65 

66 

67class IsolatedStarAssociationConfig(pipeBase.PipelineTaskConfig, 

68 pipelineConnections=IsolatedStarAssociationConnections): 

69 """Configuration for IsolatedStarAssociationTask.""" 

70 

71 inst_flux_field = pexConfig.Field( 

72 doc=('Full name of instFlux field to use for s/n selection and persistence. ' 

73 'The associated flag will be implicity included in bad_flags. ' 

74 'Note that this is expected to end in ``instFlux``.'), 

75 dtype=str, 

76 default='apFlux_12_0_instFlux', 

77 ) 

78 match_radius = pexConfig.Field( 

79 doc='Match radius (arcseconds)', 

80 dtype=float, 

81 default=1.0, 

82 ) 

83 isolation_radius = pexConfig.Field( 

84 doc=('Isolation radius (arcseconds). Any stars with average centroids ' 

85 'within this radius of another star will be rejected from the final ' 

86 'catalog. This radius should be at least 2x match_radius.'), 

87 dtype=float, 

88 default=2.0, 

89 ) 

90 band_order = pexConfig.ListField( 

91 doc=(('Ordered list of bands to use for matching/storage. ' 

92 'Any bands not listed will not be matched.')), 

93 dtype=str, 

94 default=['i', 'z', 'r', 'g', 'y', 'u'], 

95 ) 

96 id_column = pexConfig.Field( 

97 doc='Name of column with source id.', 

98 dtype=str, 

99 default='sourceId', 

100 ) 

101 ra_column = pexConfig.Field( 

102 doc='Name of column with right ascension.', 

103 dtype=str, 

104 default='ra', 

105 ) 

106 dec_column = pexConfig.Field( 

107 doc='Name of column with declination.', 

108 dtype=str, 

109 default='decl', 

110 ) 

111 physical_filter_column = pexConfig.Field( 

112 doc='Name of column with physical filter name', 

113 dtype=str, 

114 default='physical_filter', 

115 ) 

116 band_column = pexConfig.Field( 

117 doc='Name of column with band name', 

118 dtype=str, 

119 default='band', 

120 ) 

121 extra_columns = pexConfig.ListField( 

122 doc='Extra names of columns to read and persist (beyond instFlux and error).', 

123 dtype=str, 

124 default=['x', 

125 'y', 

126 'apFlux_17_0_instFlux', 

127 'apFlux_17_0_instFluxErr', 

128 'apFlux_17_0_flag', 

129 'localBackground_instFlux', 

130 'localBackground_flag'] 

131 ) 

132 source_selector = sourceSelectorRegistry.makeField( 

133 doc='How to select sources. Under normal usage this should not be changed.', 

134 default='science' 

135 ) 

136 

137 def setDefaults(self): 

138 super().setDefaults() 

139 

140 source_selector = self.source_selector['science'] 

141 source_selector.setDefaults() 

142 

143 source_selector.doFlags = True 

144 source_selector.doUnresolved = True 

145 source_selector.doSignalToNoise = True 

146 source_selector.doIsolated = True 

147 

148 source_selector.signalToNoise.minimum = 10.0 

149 source_selector.signalToNoise.maximum = 1000.0 

150 

151 flux_flag_name = self.inst_flux_field.replace("instFlux", "flag") 

152 

153 source_selector.flags.bad = ['pixelFlags_edge', 

154 'pixelFlags_interpolatedCenter', 

155 'pixelFlags_saturatedCenter', 

156 'pixelFlags_crCenter', 

157 'pixelFlags_bad', 

158 'pixelFlags_interpolated', 

159 'pixelFlags_saturated', 

160 'centroid_flag', 

161 flux_flag_name] 

162 

163 source_selector.signalToNoise.fluxField = self.inst_flux_field 

164 source_selector.signalToNoise.errField = self.inst_flux_field + 'Err' 

165 

166 source_selector.isolated.parentName = 'parentSourceId' 

167 source_selector.isolated.nChildName = 'deblend_nChild' 

168 

169 source_selector.unresolved.maximum = 0.5 

170 source_selector.unresolved.name = 'extendedness' 

171 

172 

173class IsolatedStarAssociationTask(pipeBase.PipelineTask): 

174 """Associate sources into isolated star catalogs. 

175 """ 

176 ConfigClass = IsolatedStarAssociationConfig 

177 _DefaultName = 'isolatedStarAssociation' 

178 

179 def __init__(self, **kwargs): 

180 super().__init__(**kwargs) 

181 

182 self.makeSubtask('source_selector') 

183 # Only log warning and fatal errors from the source_selector 

184 self.source_selector.log.setLevel(self.source_selector.log.WARN) 

185 

186 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

187 input_ref_dict = butlerQC.get(inputRefs) 

188 

189 tract = butlerQC.quantum.dataId['tract'] 

190 

191 source_table_refs = input_ref_dict['source_table_visit'] 

192 

193 self.log.info('Running with %d source_table_visit dataRefs', 

194 len(source_table_refs)) 

195 

196 source_table_ref_dict_temp = {source_table_ref.dataId['visit']: source_table_ref for 

197 source_table_ref in source_table_refs} 

198 

199 bands = {source_table_ref.dataId['band'] for source_table_ref in source_table_refs} 

200 for band in bands: 

201 if band not in self.config.band_order: 

202 self.log.warning('Input data has data from band %s but that band is not ' 

203 'configured for matching', band) 

204 

205 # TODO: Sort by visit until DM-31701 is done and we have deterministic 

206 # dataset ordering. 

207 source_table_ref_dict = {visit: source_table_ref_dict_temp[visit] for 

208 visit in sorted(source_table_ref_dict_temp.keys())} 

209 

210 struct = self.run(input_ref_dict['skymap'], tract, source_table_ref_dict) 

211 

212 butlerQC.put(pd.DataFrame(struct.star_source_cat), 

213 outputRefs.isolated_star_sources) 

214 butlerQC.put(pd.DataFrame(struct.star_cat), 

215 outputRefs.isolated_star_cat) 

216 

217 def run(self, skymap, tract, source_table_ref_dict): 

218 """Run the isolated star association task. 

219 

220 Parameters 

221 ---------- 

222 skymap : `lsst.skymap.SkyMap` 

223 Skymap object. 

224 tract : `int` 

225 Tract number. 

226 source_table_ref_dict : `dict` 

227 Dictionary of source_table refs. Key is visit, value is dataref. 

228 

229 Returns 

230 ------- 

231 struct : `lsst.pipe.base.struct` 

232 Struct with outputs for persistence. 

233 """ 

234 star_source_cat = self._make_all_star_sources(skymap[tract], source_table_ref_dict) 

235 

236 primary_bands = self.config.band_order 

237 

238 # Do primary matching 

239 primary_star_cat = self._match_primary_stars(primary_bands, star_source_cat) 

240 

241 if len(primary_star_cat) == 0: 

242 return pipeBase.Struct(star_source_cat=np.zeros(0, star_source_cat.dtype), 

243 star_cat=np.zeros(0, primary_star_cat.dtype)) 

244 

245 # Remove neighbors 

246 primary_star_cat = self._remove_neighbors(primary_star_cat) 

247 

248 if len(primary_star_cat) == 0: 

249 return pipeBase.Struct(star_source_cat=np.zeros(0, star_source_cat.dtype), 

250 star_cat=np.zeros(0, primary_star_cat.dtype)) 

251 

252 # Crop to inner tract region 

253 inner_tract_ids = skymap.findTractIdArray(primary_star_cat[self.config.ra_column], 

254 primary_star_cat[self.config.dec_column], 

255 degrees=True) 

256 use = (inner_tract_ids == tract) 

257 self.log.info('Total of %d isolated stars in inner tract.', use.sum()) 

258 

259 primary_star_cat = primary_star_cat[use] 

260 

261 if len(primary_star_cat) == 0: 

262 return pipeBase.Struct(star_source_cat=np.zeros(0, star_source_cat.dtype), 

263 star_cat=np.zeros(0, primary_star_cat.dtype)) 

264 

265 # Set the unique ids. 

266 primary_star_cat['isolated_star_id'] = self._compute_unique_ids(skymap, 

267 tract, 

268 len(primary_star_cat)) 

269 

270 # Match to sources. 

271 star_source_cat, primary_star_cat = self._match_sources(primary_bands, 

272 star_source_cat, 

273 primary_star_cat) 

274 

275 return pipeBase.Struct(star_source_cat=star_source_cat, 

276 star_cat=primary_star_cat) 

277 

278 def _make_all_star_sources(self, tract_info, source_table_ref_dict): 

279 """Make a catalog of all the star sources. 

280 

281 Parameters 

282 ---------- 

283 tract_info : `lsst.skymap.TractInfo` 

284 Information about the tract. 

285 source_table_ref_dict : `dict` 

286 Dictionary of source_table refs. Key is visit, value is dataref. 

287 

288 Returns 

289 ------- 

290 star_source_cat : `np.ndarray` 

291 Catalog of star sources. 

292 """ 

293 # Internally, we use a numpy recarray, they are by far the fastest 

294 # option in testing for relatively narrow tables. 

295 # (have not tested wide tables) 

296 all_columns, persist_columns = self._get_source_table_visit_column_names() 

297 poly = tract_info.outer_sky_polygon 

298 

299 tables = [] 

300 for visit in source_table_ref_dict: 

301 source_table_ref = source_table_ref_dict[visit] 

302 df = source_table_ref.get(parameters={'columns': all_columns}) 

303 df.reset_index(inplace=True) 

304 

305 goodSrc = self.source_selector.selectSources(df) 

306 

307 table = df[persist_columns][goodSrc.selected].to_records() 

308 

309 # Append columns that include the row in the source table 

310 # and the matched object index (to be filled later). 

311 table = np.lib.recfunctions.append_fields(table, 

312 ['source_row', 

313 'obj_index'], 

314 [np.where(goodSrc.selected)[0], 

315 np.zeros(goodSrc.selected.sum(), dtype=np.int32)], 

316 dtypes=['i4', 'i4'], 

317 usemask=False) 

318 

319 # We cut to the outer tract polygon to ensure consistent matching 

320 # from tract to tract. 

321 tract_use = poly.contains(np.deg2rad(table[self.config.ra_column]), 

322 np.deg2rad(table[self.config.dec_column])) 

323 

324 tables.append(table[tract_use]) 

325 

326 # Combine tables 

327 star_source_cat = np.concatenate(tables) 

328 

329 return star_source_cat 

330 

331 def _get_source_table_visit_column_names(self): 

332 """Get the list of sourceTable_visit columns from the config. 

333 

334 Returns 

335 ------- 

336 all_columns : `list` [`str`] 

337 All columns to read 

338 persist_columns : `list` [`str`] 

339 Columns to persist (excluding selection columns) 

340 """ 

341 columns = [self.config.id_column, 

342 'visit', 'detector', 

343 self.config.ra_column, self.config.dec_column, 

344 self.config.physical_filter_column, self.config.band_column, 

345 self.config.inst_flux_field, self.config.inst_flux_field + 'Err'] 

346 columns.extend(self.config.extra_columns) 

347 

348 all_columns = columns.copy() 

349 if self.source_selector.config.doFlags: 

350 all_columns.extend(self.source_selector.config.flags.bad) 

351 if self.source_selector.config.doUnresolved: 

352 all_columns.append(self.source_selector.config.unresolved.name) 

353 if self.source_selector.config.doIsolated: 

354 all_columns.append(self.source_selector.config.isolated.parentName) 

355 all_columns.append(self.source_selector.config.isolated.nChildName) 

356 

357 return all_columns, columns 

358 

359 def _match_primary_stars(self, primary_bands, star_source_cat): 

360 """Match primary stars. 

361 

362 Parameters 

363 ---------- 

364 primary_bands : `list` [`str`] 

365 Ordered list of primary bands. 

366 star_source_cat : `np.ndarray` 

367 Catalog of star sources. 

368 

369 Returns 

370 ------- 

371 primary_star_cat : `np.ndarray` 

372 Catalog of primary star positions 

373 """ 

374 ra_col = self.config.ra_column 

375 dec_col = self.config.dec_column 

376 

377 dtype = self._get_primary_dtype(primary_bands) 

378 

379 primary_star_cat = None 

380 for primary_band in primary_bands: 

381 use = (star_source_cat['band'] == primary_band) 

382 

383 ra = star_source_cat[ra_col][use] 

384 dec = star_source_cat[dec_col][use] 

385 

386 with Matcher(ra, dec) as matcher: 

387 try: 

388 # New smatch API 

389 idx = matcher.query_groups(self.config.match_radius/3600., min_match=1) 

390 except AttributeError: 

391 # Old smatch API 

392 idx = matcher.query_self(self.config.match_radius/3600., min_match=1) 

393 

394 count = len(idx) 

395 

396 if count == 0: 

397 self.log.info('Found 0 primary stars in %s band.', primary_band) 

398 continue 

399 

400 band_cat = np.zeros(count, dtype=dtype) 

401 band_cat['primary_band'] = primary_band 

402 

403 # If the tract cross ra=0 (that is, it has both low ra and high ra) 

404 # then we need to remap all ra values from [0, 360) to [-180, 180) 

405 # before doing any position averaging. 

406 remapped = False 

407 if ra.min() < 60.0 and ra.max() > 300.0: 

408 ra_temp = (ra + 180.0) % 360. - 180. 

409 remapped = True 

410 else: 

411 ra_temp = ra 

412 

413 # Compute mean position for each primary star 

414 for i, row in enumerate(idx): 

415 row = np.array(row) 

416 band_cat[ra_col][i] = np.mean(ra_temp[row]) 

417 band_cat[dec_col][i] = np.mean(dec[row]) 

418 

419 if remapped: 

420 # Remap ra back to [0, 360) 

421 band_cat[ra_col] %= 360.0 

422 

423 # Match to previous band catalog(s), and remove duplicates. 

424 if primary_star_cat is None or len(primary_star_cat) == 0: 

425 primary_star_cat = band_cat 

426 else: 

427 with Matcher(band_cat[ra_col], band_cat[dec_col]) as matcher: 

428 idx = matcher.query_radius(primary_star_cat[ra_col], 

429 primary_star_cat[dec_col], 

430 self.config.match_radius/3600.) 

431 # Any object with a match should be removed. 

432 match_indices = np.array([i for i in range(len(idx)) if len(idx[i]) > 0]) 

433 if len(match_indices) > 0: 

434 band_cat = np.delete(band_cat, match_indices) 

435 

436 primary_star_cat = np.append(primary_star_cat, band_cat) 

437 self.log.info('Found %d primary stars in %s band.', len(band_cat), primary_band) 

438 

439 # If everything was cut, we still want the correct datatype. 

440 if primary_star_cat is None: 

441 primary_star_cat = np.zeros(0, dtype=dtype) 

442 

443 return primary_star_cat 

444 

445 def _remove_neighbors(self, primary_star_cat): 

446 """Remove neighbors from the primary star catalog. 

447 

448 Parameters 

449 ---------- 

450 primary_star_cat : `np.ndarray` 

451 Primary star catalog. 

452 

453 Returns 

454 ------- 

455 primary_star_cat_cut : `np.ndarray` 

456 Primary star cat with neighbors removed. 

457 """ 

458 ra_col = self.config.ra_column 

459 dec_col = self.config.dec_column 

460 

461 with Matcher(primary_star_cat[ra_col], primary_star_cat[dec_col]) as matcher: 

462 # By setting min_match=2 objects that only match to themselves 

463 # will not be recorded. 

464 try: 

465 # New smatch API 

466 idx = matcher.query_groups(self.config.isolation_radius/3600., min_match=2) 

467 except AttributeError: 

468 # Old smatch API 

469 idx = matcher.query_self(self.config.isolation_radius/3600., min_match=2) 

470 

471 try: 

472 neighbor_indices = np.concatenate(idx) 

473 except ValueError: 

474 neighbor_indices = np.zeros(0, dtype=int) 

475 

476 if len(neighbor_indices) > 0: 

477 neighbored = np.unique(neighbor_indices) 

478 self.log.info('Cutting %d objects with close neighbors.', len(neighbored)) 

479 primary_star_cat = np.delete(primary_star_cat, neighbored) 

480 

481 return primary_star_cat 

482 

483 def _match_sources(self, bands, star_source_cat, primary_star_cat): 

484 """Match individual sources to primary stars. 

485 

486 Parameters 

487 ---------- 

488 bands : `list` [`str`] 

489 List of bands. 

490 star_source_cat : `np.ndarray` 

491 Array of star sources. 

492 primary_star_cat : `np.ndarray` 

493 Array of primary stars. 

494 

495 Returns 

496 ------- 

497 star_source_cat_sorted : `np.ndarray` 

498 Sorted and cropped array of star sources. 

499 primary_star_cat : `np.ndarray` 

500 Catalog of isolated stars, with indexes to star_source_cat_cut. 

501 """ 

502 ra_col = self.config.ra_column 

503 dec_col = self.config.dec_column 

504 

505 # We match sources per-band because it allows us to have sorted 

506 # sources for easy retrieval of per-band matches. 

507 n_source_per_band_per_obj = np.zeros((len(bands), 

508 len(primary_star_cat)), 

509 dtype=np.int32) 

510 band_uses = [] 

511 idxs = [] 

512 with Matcher(primary_star_cat[ra_col], primary_star_cat[dec_col]) as matcher: 

513 for b, band in enumerate(bands): 

514 band_use, = np.where(star_source_cat['band'] == band) 

515 

516 idx = matcher.query_radius(star_source_cat[ra_col][band_use], 

517 star_source_cat[dec_col][band_use], 

518 self.config.match_radius/3600.) 

519 n_source_per_band_per_obj[b, :] = np.array([len(row) for row in idx]) 

520 idxs.append(idx) 

521 band_uses.append(band_use) 

522 

523 n_source_per_obj = np.sum(n_source_per_band_per_obj, axis=0) 

524 

525 primary_star_cat['nsource'] = n_source_per_obj 

526 primary_star_cat['source_cat_index'][1:] = np.cumsum(n_source_per_obj)[:-1] 

527 

528 n_tot_source = primary_star_cat['source_cat_index'][-1] + primary_star_cat['nsource'][-1] 

529 

530 # Temporary arrays until we crop/sort the source catalog 

531 source_index = np.zeros(n_tot_source, dtype=np.int32) 

532 obj_index = np.zeros(n_tot_source, dtype=np.int32) 

533 

534 ctr = 0 

535 for i in range(len(primary_star_cat)): 

536 obj_index[ctr: ctr + n_source_per_obj[i]] = i 

537 for b in range(len(bands)): 

538 source_index[ctr: ctr + n_source_per_band_per_obj[b, i]] = band_uses[b][idxs[b][i]] 

539 ctr += n_source_per_band_per_obj[b, i] 

540 

541 source_cat_index_band_offset = np.cumsum(n_source_per_band_per_obj, axis=0) 

542 

543 for b, band in enumerate(bands): 

544 primary_star_cat[f'nsource_{band}'] = n_source_per_band_per_obj[b, :] 

545 if b == 0: 

546 # The first band listed is the same as the overall star 

547 primary_star_cat[f'source_cat_index_{band}'] = primary_star_cat['source_cat_index'] 

548 else: 

549 # Other band indices are offset from the previous band 

550 primary_star_cat[f'source_cat_index_{band}'] = (primary_star_cat['source_cat_index'] 

551 + source_cat_index_band_offset[b - 1, :]) 

552 

553 star_source_cat = star_source_cat[source_index] 

554 star_source_cat['obj_index'] = obj_index 

555 

556 return star_source_cat, primary_star_cat 

557 

558 def _compute_unique_ids(self, skymap, tract, nstar): 

559 """Compute unique star ids. 

560 

561 This is a simple hash of the tract and star to provide an 

562 id that is unique for a given processing. 

563 

564 Parameters 

565 ---------- 

566 skymap : `lsst.skymap.Skymap` 

567 Skymap object. 

568 tract : `int` 

569 Tract id number. 

570 nstar : `int` 

571 Number of stars. 

572 

573 Returns 

574 ------- 

575 ids : `np.ndarray` 

576 Array of unique star ids. 

577 """ 

578 # The end of the id will be big enough to hold the tract number 

579 mult = 10**(int(np.log10(len(skymap))) + 1) 

580 

581 return (np.arange(nstar) + 1)*mult + tract 

582 

583 def _get_primary_dtype(self, primary_bands): 

584 """Get the numpy datatype for the primary star catalog. 

585 

586 Parameters 

587 ---------- 

588 primary_bands : `list` [`str`] 

589 List of primary bands. 

590 

591 Returns 

592 ------- 

593 dtype : `numpy.dtype` 

594 Datatype of the primary catalog. 

595 """ 

596 max_len = max([len(primary_band) for primary_band in primary_bands]) 

597 

598 dtype = [('isolated_star_id', 'i8'), 

599 (self.config.ra_column, 'f8'), 

600 (self.config.dec_column, 'f8'), 

601 ('primary_band', f'U{max_len}'), 

602 ('source_cat_index', 'i4'), 

603 ('nsource', 'i4')] 

604 

605 for band in primary_bands: 

606 dtype.append((f'source_cat_index_{band}', 'i4')) 

607 dtype.append((f'nsource_{band}', 'i4')) 

608 

609 return dtype