Coverage for python/lsst/pipe/tasks/isolatedStarAssociation.py: 14%

216 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-04-14 10:26 +0000

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ['IsolatedStarAssociationConnections', 

23 'IsolatedStarAssociationConfig', 

24 'IsolatedStarAssociationTask'] 

25 

26import numpy as np 

27import pandas as pd 

28from smatch.matcher import Matcher 

29 

30import lsst.pex.config as pexConfig 

31import lsst.pipe.base as pipeBase 

32from lsst.skymap import BaseSkyMap 

33from lsst.meas.algorithms.sourceSelector import sourceSelectorRegistry 

34 

35 

36class IsolatedStarAssociationConnections(pipeBase.PipelineTaskConnections, 

37 dimensions=('instrument', 'tract', 'skymap',), 

38 defaultTemplates={}): 

39 source_table_visit = pipeBase.connectionTypes.Input( 

40 doc='Source table in parquet format, per visit', 

41 name='sourceTable_visit', 

42 storageClass='DataFrame', 

43 dimensions=('instrument', 'visit'), 

44 deferLoad=True, 

45 multiple=True, 

46 ) 

47 skymap = pipeBase.connectionTypes.Input( 

48 doc="Input definition of geometry/bbox and projection/wcs for warped exposures", 

49 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

50 storageClass='SkyMap', 

51 dimensions=('skymap',), 

52 ) 

53 isolated_star_sources = pipeBase.connectionTypes.Output( 

54 doc='Catalog of individual sources for the isolated stars', 

55 name='isolated_star_sources', 

56 storageClass='DataFrame', 

57 dimensions=('instrument', 'tract', 'skymap'), 

58 ) 

59 isolated_star_cat = pipeBase.connectionTypes.Output( 

60 doc='Catalog of isolated star positions', 

61 name='isolated_star_cat', 

62 storageClass='DataFrame', 

63 dimensions=('instrument', 'tract', 'skymap'), 

64 ) 

65 

66 

67class IsolatedStarAssociationConfig(pipeBase.PipelineTaskConfig, 

68 pipelineConnections=IsolatedStarAssociationConnections): 

69 """Configuration for IsolatedStarAssociationTask.""" 

70 

71 inst_flux_field = pexConfig.Field( 

72 doc=('Full name of instFlux field to use for s/n selection and persistence. ' 

73 'The associated flag will be implicity included in bad_flags. ' 

74 'Note that this is expected to end in ``instFlux``.'), 

75 dtype=str, 

76 default='apFlux_12_0_instFlux', 

77 ) 

78 match_radius = pexConfig.Field( 

79 doc='Match radius (arcseconds)', 

80 dtype=float, 

81 default=1.0, 

82 ) 

83 isolation_radius = pexConfig.Field( 

84 doc=('Isolation radius (arcseconds). Any stars with average centroids ' 

85 'within this radius of another star will be rejected from the final ' 

86 'catalog. This radius should be at least 2x match_radius.'), 

87 dtype=float, 

88 default=2.0, 

89 ) 

90 band_order = pexConfig.ListField( 

91 doc=(('Ordered list of bands to use for matching/storage. ' 

92 'Any bands not listed will not be matched.')), 

93 dtype=str, 

94 default=['i', 'z', 'r', 'g', 'y', 'u'], 

95 ) 

96 id_column = pexConfig.Field( 

97 doc='Name of column with source id.', 

98 dtype=str, 

99 default='sourceId', 

100 ) 

101 ra_column = pexConfig.Field( 

102 doc='Name of column with right ascension.', 

103 dtype=str, 

104 default='ra', 

105 ) 

106 dec_column = pexConfig.Field( 

107 doc='Name of column with declination.', 

108 dtype=str, 

109 default='decl', 

110 ) 

111 physical_filter_column = pexConfig.Field( 

112 doc='Name of column with physical filter name', 

113 dtype=str, 

114 default='physical_filter', 

115 ) 

116 band_column = pexConfig.Field( 

117 doc='Name of column with band name', 

118 dtype=str, 

119 default='band', 

120 ) 

121 extra_columns = pexConfig.ListField( 

122 doc='Extra names of columns to read and persist (beyond instFlux and error).', 

123 dtype=str, 

124 default=['x', 

125 'y', 

126 'apFlux_17_0_instFlux', 

127 'apFlux_17_0_instFluxErr', 

128 'apFlux_17_0_flag', 

129 'localBackground_instFlux', 

130 'localBackground_flag'] 

131 ) 

132 source_selector = sourceSelectorRegistry.makeField( 

133 doc='How to select sources. Under normal usage this should not be changed.', 

134 default='science' 

135 ) 

136 

137 def setDefaults(self): 

138 super().setDefaults() 

139 

140 source_selector = self.source_selector['science'] 

141 source_selector.setDefaults() 

142 

143 source_selector.doFlags = True 

144 source_selector.doUnresolved = True 

145 source_selector.doSignalToNoise = True 

146 source_selector.doIsolated = True 

147 source_selector.doRequireFiniteRaDec = True 

148 

149 source_selector.signalToNoise.minimum = 10.0 

150 source_selector.signalToNoise.maximum = 1000.0 

151 

152 flux_flag_name = self.inst_flux_field.replace("instFlux", "flag") 

153 

154 source_selector.flags.bad = ['pixelFlags_edge', 

155 'pixelFlags_interpolatedCenter', 

156 'pixelFlags_saturatedCenter', 

157 'pixelFlags_crCenter', 

158 'pixelFlags_bad', 

159 'pixelFlags_interpolated', 

160 'pixelFlags_saturated', 

161 'centroid_flag', 

162 flux_flag_name] 

163 

164 source_selector.signalToNoise.fluxField = self.inst_flux_field 

165 source_selector.signalToNoise.errField = self.inst_flux_field + 'Err' 

166 

167 source_selector.isolated.parentName = 'parentSourceId' 

168 source_selector.isolated.nChildName = 'deblend_nChild' 

169 

170 source_selector.unresolved.maximum = 0.5 

171 source_selector.unresolved.name = 'extendedness' 

172 

173 source_selector.requireFiniteRaDec.raColName = self.ra_column 

174 source_selector.requireFiniteRaDec.decColName = self.dec_column 

175 

176 

177class IsolatedStarAssociationTask(pipeBase.PipelineTask): 

178 """Associate sources into isolated star catalogs. 

179 """ 

180 ConfigClass = IsolatedStarAssociationConfig 

181 _DefaultName = 'isolatedStarAssociation' 

182 

183 def __init__(self, **kwargs): 

184 super().__init__(**kwargs) 

185 

186 self.makeSubtask('source_selector') 

187 # Only log warning and fatal errors from the source_selector 

188 self.source_selector.log.setLevel(self.source_selector.log.WARN) 

189 

190 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

191 input_ref_dict = butlerQC.get(inputRefs) 

192 

193 tract = butlerQC.quantum.dataId['tract'] 

194 

195 source_table_refs = input_ref_dict['source_table_visit'] 

196 

197 self.log.info('Running with %d source_table_visit dataRefs', 

198 len(source_table_refs)) 

199 

200 source_table_ref_dict_temp = {source_table_ref.dataId['visit']: source_table_ref for 

201 source_table_ref in source_table_refs} 

202 

203 bands = {source_table_ref.dataId['band'] for source_table_ref in source_table_refs} 

204 for band in bands: 

205 if band not in self.config.band_order: 

206 self.log.warning('Input data has data from band %s but that band is not ' 

207 'configured for matching', band) 

208 

209 # TODO: Sort by visit until DM-31701 is done and we have deterministic 

210 # dataset ordering. 

211 source_table_ref_dict = {visit: source_table_ref_dict_temp[visit] for 

212 visit in sorted(source_table_ref_dict_temp.keys())} 

213 

214 struct = self.run(input_ref_dict['skymap'], tract, source_table_ref_dict) 

215 

216 butlerQC.put(pd.DataFrame(struct.star_source_cat), 

217 outputRefs.isolated_star_sources) 

218 butlerQC.put(pd.DataFrame(struct.star_cat), 

219 outputRefs.isolated_star_cat) 

220 

221 def run(self, skymap, tract, source_table_ref_dict): 

222 """Run the isolated star association task. 

223 

224 Parameters 

225 ---------- 

226 skymap : `lsst.skymap.SkyMap` 

227 Skymap object. 

228 tract : `int` 

229 Tract number. 

230 source_table_ref_dict : `dict` 

231 Dictionary of source_table refs. Key is visit, value is dataref. 

232 

233 Returns 

234 ------- 

235 struct : `lsst.pipe.base.struct` 

236 Struct with outputs for persistence. 

237 """ 

238 star_source_cat = self._make_all_star_sources(skymap[tract], source_table_ref_dict) 

239 

240 primary_bands = self.config.band_order 

241 

242 # Do primary matching 

243 primary_star_cat = self._match_primary_stars(primary_bands, star_source_cat) 

244 

245 if len(primary_star_cat) == 0: 

246 return pipeBase.Struct(star_source_cat=np.zeros(0, star_source_cat.dtype), 

247 star_cat=np.zeros(0, primary_star_cat.dtype)) 

248 

249 # Remove neighbors 

250 primary_star_cat = self._remove_neighbors(primary_star_cat) 

251 

252 if len(primary_star_cat) == 0: 

253 return pipeBase.Struct(star_source_cat=np.zeros(0, star_source_cat.dtype), 

254 star_cat=np.zeros(0, primary_star_cat.dtype)) 

255 

256 # Crop to inner tract region 

257 inner_tract_ids = skymap.findTractIdArray(primary_star_cat[self.config.ra_column], 

258 primary_star_cat[self.config.dec_column], 

259 degrees=True) 

260 use = (inner_tract_ids == tract) 

261 self.log.info('Total of %d isolated stars in inner tract.', use.sum()) 

262 

263 primary_star_cat = primary_star_cat[use] 

264 

265 if len(primary_star_cat) == 0: 

266 return pipeBase.Struct(star_source_cat=np.zeros(0, star_source_cat.dtype), 

267 star_cat=np.zeros(0, primary_star_cat.dtype)) 

268 

269 # Set the unique ids. 

270 primary_star_cat['isolated_star_id'] = self._compute_unique_ids(skymap, 

271 tract, 

272 len(primary_star_cat)) 

273 

274 # Match to sources. 

275 star_source_cat, primary_star_cat = self._match_sources(primary_bands, 

276 star_source_cat, 

277 primary_star_cat) 

278 

279 return pipeBase.Struct(star_source_cat=star_source_cat, 

280 star_cat=primary_star_cat) 

281 

282 def _make_all_star_sources(self, tract_info, source_table_ref_dict): 

283 """Make a catalog of all the star sources. 

284 

285 Parameters 

286 ---------- 

287 tract_info : `lsst.skymap.TractInfo` 

288 Information about the tract. 

289 source_table_ref_dict : `dict` 

290 Dictionary of source_table refs. Key is visit, value is dataref. 

291 

292 Returns 

293 ------- 

294 star_source_cat : `np.ndarray` 

295 Catalog of star sources. 

296 """ 

297 # Internally, we use a numpy recarray, they are by far the fastest 

298 # option in testing for relatively narrow tables. 

299 # (have not tested wide tables) 

300 all_columns, persist_columns = self._get_source_table_visit_column_names() 

301 poly = tract_info.outer_sky_polygon 

302 

303 tables = [] 

304 for visit in source_table_ref_dict: 

305 source_table_ref = source_table_ref_dict[visit] 

306 df = source_table_ref.get(parameters={'columns': all_columns}) 

307 df.reset_index(inplace=True) 

308 

309 goodSrc = self.source_selector.selectSources(df) 

310 

311 table = df[persist_columns][goodSrc.selected].to_records() 

312 

313 # Append columns that include the row in the source table 

314 # and the matched object index (to be filled later). 

315 table = np.lib.recfunctions.append_fields(table, 

316 ['source_row', 

317 'obj_index'], 

318 [np.where(goodSrc.selected)[0], 

319 np.zeros(goodSrc.selected.sum(), dtype=np.int32)], 

320 dtypes=['i4', 'i4'], 

321 usemask=False) 

322 

323 # We cut to the outer tract polygon to ensure consistent matching 

324 # from tract to tract. 

325 tract_use = poly.contains(np.deg2rad(table[self.config.ra_column]), 

326 np.deg2rad(table[self.config.dec_column])) 

327 

328 tables.append(table[tract_use]) 

329 

330 # Combine tables 

331 star_source_cat = np.concatenate(tables) 

332 

333 return star_source_cat 

334 

335 def _get_source_table_visit_column_names(self): 

336 """Get the list of sourceTable_visit columns from the config. 

337 

338 Returns 

339 ------- 

340 all_columns : `list` [`str`] 

341 All columns to read 

342 persist_columns : `list` [`str`] 

343 Columns to persist (excluding selection columns) 

344 """ 

345 columns = [self.config.id_column, 

346 'visit', 'detector', 

347 self.config.ra_column, self.config.dec_column, 

348 self.config.physical_filter_column, self.config.band_column, 

349 self.config.inst_flux_field, self.config.inst_flux_field + 'Err'] 

350 columns.extend(self.config.extra_columns) 

351 

352 all_columns = columns.copy() 

353 if self.source_selector.config.doFlags: 

354 all_columns.extend(self.source_selector.config.flags.bad) 

355 if self.source_selector.config.doUnresolved: 

356 all_columns.append(self.source_selector.config.unresolved.name) 

357 if self.source_selector.config.doIsolated: 

358 all_columns.append(self.source_selector.config.isolated.parentName) 

359 all_columns.append(self.source_selector.config.isolated.nChildName) 

360 

361 return all_columns, columns 

362 

363 def _match_primary_stars(self, primary_bands, star_source_cat): 

364 """Match primary stars. 

365 

366 Parameters 

367 ---------- 

368 primary_bands : `list` [`str`] 

369 Ordered list of primary bands. 

370 star_source_cat : `np.ndarray` 

371 Catalog of star sources. 

372 

373 Returns 

374 ------- 

375 primary_star_cat : `np.ndarray` 

376 Catalog of primary star positions 

377 """ 

378 ra_col = self.config.ra_column 

379 dec_col = self.config.dec_column 

380 

381 dtype = self._get_primary_dtype(primary_bands) 

382 

383 primary_star_cat = None 

384 for primary_band in primary_bands: 

385 use = (star_source_cat['band'] == primary_band) 

386 

387 ra = star_source_cat[ra_col][use] 

388 dec = star_source_cat[dec_col][use] 

389 

390 with Matcher(ra, dec) as matcher: 

391 try: 

392 # New smatch API 

393 idx = matcher.query_groups(self.config.match_radius/3600., min_match=1) 

394 except AttributeError: 

395 # Old smatch API 

396 idx = matcher.query_self(self.config.match_radius/3600., min_match=1) 

397 

398 count = len(idx) 

399 

400 if count == 0: 

401 self.log.info('Found 0 primary stars in %s band.', primary_band) 

402 continue 

403 

404 band_cat = np.zeros(count, dtype=dtype) 

405 band_cat['primary_band'] = primary_band 

406 

407 # If the tract cross ra=0 (that is, it has both low ra and high ra) 

408 # then we need to remap all ra values from [0, 360) to [-180, 180) 

409 # before doing any position averaging. 

410 remapped = False 

411 if ra.min() < 60.0 and ra.max() > 300.0: 

412 ra_temp = (ra + 180.0) % 360. - 180. 

413 remapped = True 

414 else: 

415 ra_temp = ra 

416 

417 # Compute mean position for each primary star 

418 for i, row in enumerate(idx): 

419 row = np.array(row) 

420 band_cat[ra_col][i] = np.mean(ra_temp[row]) 

421 band_cat[dec_col][i] = np.mean(dec[row]) 

422 

423 if remapped: 

424 # Remap ra back to [0, 360) 

425 band_cat[ra_col] %= 360.0 

426 

427 # Match to previous band catalog(s), and remove duplicates. 

428 if primary_star_cat is None or len(primary_star_cat) == 0: 

429 primary_star_cat = band_cat 

430 else: 

431 with Matcher(band_cat[ra_col], band_cat[dec_col]) as matcher: 

432 idx = matcher.query_radius(primary_star_cat[ra_col], 

433 primary_star_cat[dec_col], 

434 self.config.match_radius/3600.) 

435 # Any object with a match should be removed. 

436 match_indices = np.array([i for i in range(len(idx)) if len(idx[i]) > 0]) 

437 if len(match_indices) > 0: 

438 band_cat = np.delete(band_cat, match_indices) 

439 

440 primary_star_cat = np.append(primary_star_cat, band_cat) 

441 self.log.info('Found %d primary stars in %s band.', len(band_cat), primary_band) 

442 

443 # If everything was cut, we still want the correct datatype. 

444 if primary_star_cat is None: 

445 primary_star_cat = np.zeros(0, dtype=dtype) 

446 

447 return primary_star_cat 

448 

449 def _remove_neighbors(self, primary_star_cat): 

450 """Remove neighbors from the primary star catalog. 

451 

452 Parameters 

453 ---------- 

454 primary_star_cat : `np.ndarray` 

455 Primary star catalog. 

456 

457 Returns 

458 ------- 

459 primary_star_cat_cut : `np.ndarray` 

460 Primary star cat with neighbors removed. 

461 """ 

462 ra_col = self.config.ra_column 

463 dec_col = self.config.dec_column 

464 

465 with Matcher(primary_star_cat[ra_col], primary_star_cat[dec_col]) as matcher: 

466 # By setting min_match=2 objects that only match to themselves 

467 # will not be recorded. 

468 try: 

469 # New smatch API 

470 idx = matcher.query_groups(self.config.isolation_radius/3600., min_match=2) 

471 except AttributeError: 

472 # Old smatch API 

473 idx = matcher.query_self(self.config.isolation_radius/3600., min_match=2) 

474 

475 try: 

476 neighbor_indices = np.concatenate(idx) 

477 except ValueError: 

478 neighbor_indices = np.zeros(0, dtype=int) 

479 

480 if len(neighbor_indices) > 0: 

481 neighbored = np.unique(neighbor_indices) 

482 self.log.info('Cutting %d objects with close neighbors.', len(neighbored)) 

483 primary_star_cat = np.delete(primary_star_cat, neighbored) 

484 

485 return primary_star_cat 

486 

487 def _match_sources(self, bands, star_source_cat, primary_star_cat): 

488 """Match individual sources to primary stars. 

489 

490 Parameters 

491 ---------- 

492 bands : `list` [`str`] 

493 List of bands. 

494 star_source_cat : `np.ndarray` 

495 Array of star sources. 

496 primary_star_cat : `np.ndarray` 

497 Array of primary stars. 

498 

499 Returns 

500 ------- 

501 star_source_cat_sorted : `np.ndarray` 

502 Sorted and cropped array of star sources. 

503 primary_star_cat : `np.ndarray` 

504 Catalog of isolated stars, with indexes to star_source_cat_cut. 

505 """ 

506 ra_col = self.config.ra_column 

507 dec_col = self.config.dec_column 

508 

509 # We match sources per-band because it allows us to have sorted 

510 # sources for easy retrieval of per-band matches. 

511 n_source_per_band_per_obj = np.zeros((len(bands), 

512 len(primary_star_cat)), 

513 dtype=np.int32) 

514 band_uses = [] 

515 idxs = [] 

516 with Matcher(primary_star_cat[ra_col], primary_star_cat[dec_col]) as matcher: 

517 for b, band in enumerate(bands): 

518 band_use, = np.where(star_source_cat['band'] == band) 

519 

520 idx = matcher.query_radius(star_source_cat[ra_col][band_use], 

521 star_source_cat[dec_col][band_use], 

522 self.config.match_radius/3600.) 

523 n_source_per_band_per_obj[b, :] = np.array([len(row) for row in idx]) 

524 idxs.append(idx) 

525 band_uses.append(band_use) 

526 

527 n_source_per_obj = np.sum(n_source_per_band_per_obj, axis=0) 

528 

529 primary_star_cat['nsource'] = n_source_per_obj 

530 primary_star_cat['source_cat_index'][1:] = np.cumsum(n_source_per_obj)[:-1] 

531 

532 n_tot_source = primary_star_cat['source_cat_index'][-1] + primary_star_cat['nsource'][-1] 

533 

534 # Temporary arrays until we crop/sort the source catalog 

535 source_index = np.zeros(n_tot_source, dtype=np.int32) 

536 obj_index = np.zeros(n_tot_source, dtype=np.int32) 

537 

538 ctr = 0 

539 for i in range(len(primary_star_cat)): 

540 obj_index[ctr: ctr + n_source_per_obj[i]] = i 

541 for b in range(len(bands)): 

542 source_index[ctr: ctr + n_source_per_band_per_obj[b, i]] = band_uses[b][idxs[b][i]] 

543 ctr += n_source_per_band_per_obj[b, i] 

544 

545 source_cat_index_band_offset = np.cumsum(n_source_per_band_per_obj, axis=0) 

546 

547 for b, band in enumerate(bands): 

548 primary_star_cat[f'nsource_{band}'] = n_source_per_band_per_obj[b, :] 

549 if b == 0: 

550 # The first band listed is the same as the overall star 

551 primary_star_cat[f'source_cat_index_{band}'] = primary_star_cat['source_cat_index'] 

552 else: 

553 # Other band indices are offset from the previous band 

554 primary_star_cat[f'source_cat_index_{band}'] = (primary_star_cat['source_cat_index'] 

555 + source_cat_index_band_offset[b - 1, :]) 

556 

557 star_source_cat = star_source_cat[source_index] 

558 star_source_cat['obj_index'] = obj_index 

559 

560 return star_source_cat, primary_star_cat 

561 

562 def _compute_unique_ids(self, skymap, tract, nstar): 

563 """Compute unique star ids. 

564 

565 This is a simple hash of the tract and star to provide an 

566 id that is unique for a given processing. 

567 

568 Parameters 

569 ---------- 

570 skymap : `lsst.skymap.Skymap` 

571 Skymap object. 

572 tract : `int` 

573 Tract id number. 

574 nstar : `int` 

575 Number of stars. 

576 

577 Returns 

578 ------- 

579 ids : `np.ndarray` 

580 Array of unique star ids. 

581 """ 

582 # The end of the id will be big enough to hold the tract number 

583 mult = 10**(int(np.log10(len(skymap))) + 1) 

584 

585 return (np.arange(nstar) + 1)*mult + tract 

586 

587 def _get_primary_dtype(self, primary_bands): 

588 """Get the numpy datatype for the primary star catalog. 

589 

590 Parameters 

591 ---------- 

592 primary_bands : `list` [`str`] 

593 List of primary bands. 

594 

595 Returns 

596 ------- 

597 dtype : `numpy.dtype` 

598 Datatype of the primary catalog. 

599 """ 

600 max_len = max([len(primary_band) for primary_band in primary_bands]) 

601 

602 dtype = [('isolated_star_id', 'i8'), 

603 (self.config.ra_column, 'f8'), 

604 (self.config.dec_column, 'f8'), 

605 ('primary_band', f'U{max_len}'), 

606 ('source_cat_index', 'i4'), 

607 ('nsource', 'i4')] 

608 

609 for band in primary_bands: 

610 dtype.append((f'source_cat_index_{band}', 'i4')) 

611 dtype.append((f'nsource_{band}', 'i4')) 

612 

613 return dtype