Coverage for python/lsst/pipe/tasks/isolatedStarAssociation.py: 13%

219 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-16 08:24 +0000

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ['IsolatedStarAssociationConnections', 

23 'IsolatedStarAssociationConfig', 

24 'IsolatedStarAssociationTask'] 

25 

26import numpy as np 

27import pandas as pd 

28from smatch.matcher import Matcher 

29 

30import lsst.pex.config as pexConfig 

31import lsst.pipe.base as pipeBase 

32from lsst.skymap import BaseSkyMap 

33from lsst.meas.algorithms.sourceSelector import sourceSelectorRegistry 

34 

35 

36class IsolatedStarAssociationConnections(pipeBase.PipelineTaskConnections, 

37 dimensions=('instrument', 'tract', 'skymap',), 

38 defaultTemplates={}): 

39 source_table_visit = pipeBase.connectionTypes.Input( 

40 doc='Source table in parquet format, per visit', 

41 name='sourceTable_visit', 

42 storageClass='DataFrame', 

43 dimensions=('instrument', 'visit'), 

44 deferLoad=True, 

45 multiple=True, 

46 ) 

47 skymap = pipeBase.connectionTypes.Input( 

48 doc="Input definition of geometry/bbox and projection/wcs for warped exposures", 

49 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

50 storageClass='SkyMap', 

51 dimensions=('skymap',), 

52 ) 

53 isolated_star_sources = pipeBase.connectionTypes.Output( 

54 doc='Catalog of individual sources for the isolated stars', 

55 name='isolated_star_sources', 

56 storageClass='DataFrame', 

57 dimensions=('instrument', 'tract', 'skymap'), 

58 ) 

59 isolated_star_cat = pipeBase.connectionTypes.Output( 

60 doc='Catalog of isolated star positions', 

61 name='isolated_star_cat', 

62 storageClass='DataFrame', 

63 dimensions=('instrument', 'tract', 'skymap'), 

64 ) 

65 

66 

67class IsolatedStarAssociationConfig(pipeBase.PipelineTaskConfig, 

68 pipelineConnections=IsolatedStarAssociationConnections): 

69 """Configuration for IsolatedStarAssociationTask.""" 

70 

71 inst_flux_field = pexConfig.Field( 

72 doc=('Full name of instFlux field to use for s/n selection and persistence. ' 

73 'The associated flag will be implicity included in bad_flags. ' 

74 'Note that this is expected to end in ``instFlux``.'), 

75 dtype=str, 

76 default='apFlux_12_0_instFlux', 

77 ) 

78 match_radius = pexConfig.Field( 

79 doc='Match radius (arcseconds)', 

80 dtype=float, 

81 default=1.0, 

82 ) 

83 isolation_radius = pexConfig.Field( 

84 doc=('Isolation radius (arcseconds). Any stars with average centroids ' 

85 'within this radius of another star will be rejected from the final ' 

86 'catalog. This radius should be at least 2x match_radius.'), 

87 dtype=float, 

88 default=2.0, 

89 ) 

90 band_order = pexConfig.ListField( 

91 doc=(('Ordered list of bands to use for matching/storage. ' 

92 'Any bands not listed will not be matched.')), 

93 dtype=str, 

94 default=['i', 'z', 'r', 'g', 'y', 'u'], 

95 ) 

96 id_column = pexConfig.Field( 

97 doc='Name of column with source id.', 

98 dtype=str, 

99 default='sourceId', 

100 ) 

101 ra_column = pexConfig.Field( 

102 doc='Name of column with right ascension.', 

103 dtype=str, 

104 default='ra', 

105 ) 

106 dec_column = pexConfig.Field( 

107 doc='Name of column with declination.', 

108 dtype=str, 

109 default='dec', 

110 ) 

111 physical_filter_column = pexConfig.Field( 

112 doc='Name of column with physical filter name', 

113 dtype=str, 

114 default='physical_filter', 

115 ) 

116 band_column = pexConfig.Field( 

117 doc='Name of column with band name', 

118 dtype=str, 

119 default='band', 

120 ) 

121 extra_columns = pexConfig.ListField( 

122 doc='Extra names of columns to read and persist (beyond instFlux and error).', 

123 dtype=str, 

124 default=['x', 

125 'y', 

126 'apFlux_17_0_instFlux', 

127 'apFlux_17_0_instFluxErr', 

128 'apFlux_17_0_flag', 

129 'localBackground_instFlux', 

130 'localBackground_flag'] 

131 ) 

132 source_selector = sourceSelectorRegistry.makeField( 

133 doc='How to select sources. Under normal usage this should not be changed.', 

134 default='science' 

135 ) 

136 

137 def setDefaults(self): 

138 super().setDefaults() 

139 

140 source_selector = self.source_selector['science'] 

141 source_selector.setDefaults() 

142 

143 source_selector.doFlags = True 

144 source_selector.doUnresolved = True 

145 source_selector.doSignalToNoise = True 

146 source_selector.doIsolated = True 

147 source_selector.doRequireFiniteRaDec = True 

148 source_selector.doRequirePrimary = True 

149 

150 source_selector.signalToNoise.minimum = 10.0 

151 source_selector.signalToNoise.maximum = 1000.0 

152 

153 flux_flag_name = self.inst_flux_field.replace("instFlux", "flag") 

154 

155 source_selector.flags.bad = ['pixelFlags_edge', 

156 'pixelFlags_interpolatedCenter', 

157 'pixelFlags_saturatedCenter', 

158 'pixelFlags_crCenter', 

159 'pixelFlags_bad', 

160 'pixelFlags_interpolated', 

161 'pixelFlags_saturated', 

162 'centroid_flag', 

163 flux_flag_name] 

164 

165 source_selector.signalToNoise.fluxField = self.inst_flux_field 

166 source_selector.signalToNoise.errField = self.inst_flux_field + 'Err' 

167 

168 source_selector.isolated.parentName = 'parentSourceId' 

169 source_selector.isolated.nChildName = 'deblend_nChild' 

170 

171 source_selector.unresolved.maximum = 0.5 

172 source_selector.unresolved.name = 'extendedness' 

173 

174 source_selector.requireFiniteRaDec.raColName = self.ra_column 

175 source_selector.requireFiniteRaDec.decColName = self.dec_column 

176 

177 

178class IsolatedStarAssociationTask(pipeBase.PipelineTask): 

179 """Associate sources into isolated star catalogs. 

180 """ 

181 ConfigClass = IsolatedStarAssociationConfig 

182 _DefaultName = 'isolatedStarAssociation' 

183 

184 def __init__(self, **kwargs): 

185 super().__init__(**kwargs) 

186 

187 self.makeSubtask('source_selector') 

188 # Only log warning and fatal errors from the source_selector 

189 self.source_selector.log.setLevel(self.source_selector.log.WARN) 

190 

191 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

192 input_ref_dict = butlerQC.get(inputRefs) 

193 

194 tract = butlerQC.quantum.dataId['tract'] 

195 

196 source_table_refs = input_ref_dict['source_table_visit'] 

197 

198 self.log.info('Running with %d source_table_visit dataRefs', 

199 len(source_table_refs)) 

200 

201 source_table_ref_dict_temp = {source_table_ref.dataId['visit']: source_table_ref for 

202 source_table_ref in source_table_refs} 

203 

204 bands = {source_table_ref.dataId['band'] for source_table_ref in source_table_refs} 

205 for band in bands: 

206 if band not in self.config.band_order: 

207 self.log.warning('Input data has data from band %s but that band is not ' 

208 'configured for matching', band) 

209 

210 # TODO: Sort by visit until DM-31701 is done and we have deterministic 

211 # dataset ordering. 

212 source_table_ref_dict = {visit: source_table_ref_dict_temp[visit] for 

213 visit in sorted(source_table_ref_dict_temp.keys())} 

214 

215 struct = self.run(input_ref_dict['skymap'], tract, source_table_ref_dict) 

216 

217 butlerQC.put(pd.DataFrame(struct.star_source_cat), 

218 outputRefs.isolated_star_sources) 

219 butlerQC.put(pd.DataFrame(struct.star_cat), 

220 outputRefs.isolated_star_cat) 

221 

222 def run(self, skymap, tract, source_table_ref_dict): 

223 """Run the isolated star association task. 

224 

225 Parameters 

226 ---------- 

227 skymap : `lsst.skymap.SkyMap` 

228 Skymap object. 

229 tract : `int` 

230 Tract number. 

231 source_table_ref_dict : `dict` 

232 Dictionary of source_table refs. Key is visit, value is dataref. 

233 

234 Returns 

235 ------- 

236 struct : `lsst.pipe.base.struct` 

237 Struct with outputs for persistence. 

238 """ 

239 star_source_cat = self._make_all_star_sources(skymap[tract], source_table_ref_dict) 

240 

241 primary_bands = self.config.band_order 

242 

243 # Do primary matching 

244 primary_star_cat = self._match_primary_stars(primary_bands, star_source_cat) 

245 

246 if len(primary_star_cat) == 0: 

247 return pipeBase.Struct(star_source_cat=np.zeros(0, star_source_cat.dtype), 

248 star_cat=np.zeros(0, primary_star_cat.dtype)) 

249 

250 # Remove neighbors 

251 primary_star_cat = self._remove_neighbors(primary_star_cat) 

252 

253 if len(primary_star_cat) == 0: 

254 return pipeBase.Struct(star_source_cat=np.zeros(0, star_source_cat.dtype), 

255 star_cat=np.zeros(0, primary_star_cat.dtype)) 

256 

257 # Crop to inner tract region 

258 inner_tract_ids = skymap.findTractIdArray(primary_star_cat[self.config.ra_column], 

259 primary_star_cat[self.config.dec_column], 

260 degrees=True) 

261 use = (inner_tract_ids == tract) 

262 self.log.info('Total of %d isolated stars in inner tract.', use.sum()) 

263 

264 primary_star_cat = primary_star_cat[use] 

265 

266 if len(primary_star_cat) == 0: 

267 return pipeBase.Struct(star_source_cat=np.zeros(0, star_source_cat.dtype), 

268 star_cat=np.zeros(0, primary_star_cat.dtype)) 

269 

270 # Set the unique ids. 

271 primary_star_cat['isolated_star_id'] = self._compute_unique_ids(skymap, 

272 tract, 

273 len(primary_star_cat)) 

274 

275 # Match to sources. 

276 star_source_cat, primary_star_cat = self._match_sources(primary_bands, 

277 star_source_cat, 

278 primary_star_cat) 

279 

280 return pipeBase.Struct(star_source_cat=star_source_cat, 

281 star_cat=primary_star_cat) 

282 

283 def _make_all_star_sources(self, tract_info, source_table_ref_dict): 

284 """Make a catalog of all the star sources. 

285 

286 Parameters 

287 ---------- 

288 tract_info : `lsst.skymap.TractInfo` 

289 Information about the tract. 

290 source_table_ref_dict : `dict` 

291 Dictionary of source_table refs. Key is visit, value is dataref. 

292 

293 Returns 

294 ------- 

295 star_source_cat : `np.ndarray` 

296 Catalog of star sources. 

297 """ 

298 # Internally, we use a numpy recarray, they are by far the fastest 

299 # option in testing for relatively narrow tables. 

300 # (have not tested wide tables) 

301 all_columns, persist_columns = self._get_source_table_visit_column_names() 

302 poly = tract_info.outer_sky_polygon 

303 

304 tables = [] 

305 for visit in source_table_ref_dict: 

306 source_table_ref = source_table_ref_dict[visit] 

307 df = source_table_ref.get(parameters={'columns': all_columns}) 

308 df.reset_index(inplace=True) 

309 

310 goodSrc = self.source_selector.selectSources(df) 

311 

312 table = df[persist_columns][goodSrc.selected].to_records() 

313 

314 # Append columns that include the row in the source table 

315 # and the matched object index (to be filled later). 

316 table = np.lib.recfunctions.append_fields(table, 

317 ['source_row', 

318 'obj_index'], 

319 [np.where(goodSrc.selected)[0], 

320 np.zeros(goodSrc.selected.sum(), dtype=np.int32)], 

321 dtypes=['i4', 'i4'], 

322 usemask=False) 

323 

324 # We cut to the outer tract polygon to ensure consistent matching 

325 # from tract to tract. 

326 tract_use = poly.contains(np.deg2rad(table[self.config.ra_column]), 

327 np.deg2rad(table[self.config.dec_column])) 

328 

329 tables.append(table[tract_use]) 

330 

331 # Combine tables 

332 star_source_cat = np.concatenate(tables) 

333 

334 return star_source_cat 

335 

336 def _get_source_table_visit_column_names(self): 

337 """Get the list of sourceTable_visit columns from the config. 

338 

339 Returns 

340 ------- 

341 all_columns : `list` [`str`] 

342 All columns to read 

343 persist_columns : `list` [`str`] 

344 Columns to persist (excluding selection columns) 

345 """ 

346 columns = [self.config.id_column, 

347 'visit', 'detector', 

348 self.config.ra_column, self.config.dec_column, 

349 self.config.physical_filter_column, self.config.band_column, 

350 self.config.inst_flux_field, self.config.inst_flux_field + 'Err'] 

351 columns.extend(self.config.extra_columns) 

352 

353 all_columns = columns.copy() 

354 if self.source_selector.config.doFlags: 

355 all_columns.extend(self.source_selector.config.flags.bad) 

356 if self.source_selector.config.doUnresolved: 

357 all_columns.append(self.source_selector.config.unresolved.name) 

358 if self.source_selector.config.doIsolated: 

359 all_columns.append(self.source_selector.config.isolated.parentName) 

360 all_columns.append(self.source_selector.config.isolated.nChildName) 

361 if self.source_selector.config.doRequirePrimary: 

362 all_columns.append(self.source_selector.config.requirePrimary.primaryColName) 

363 

364 return all_columns, columns 

365 

366 def _match_primary_stars(self, primary_bands, star_source_cat): 

367 """Match primary stars. 

368 

369 Parameters 

370 ---------- 

371 primary_bands : `list` [`str`] 

372 Ordered list of primary bands. 

373 star_source_cat : `np.ndarray` 

374 Catalog of star sources. 

375 

376 Returns 

377 ------- 

378 primary_star_cat : `np.ndarray` 

379 Catalog of primary star positions 

380 """ 

381 ra_col = self.config.ra_column 

382 dec_col = self.config.dec_column 

383 

384 dtype = self._get_primary_dtype(primary_bands) 

385 

386 primary_star_cat = None 

387 for primary_band in primary_bands: 

388 use = (star_source_cat['band'] == primary_band) 

389 

390 ra = star_source_cat[ra_col][use] 

391 dec = star_source_cat[dec_col][use] 

392 

393 with Matcher(ra, dec) as matcher: 

394 try: 

395 # New smatch API 

396 idx = matcher.query_groups(self.config.match_radius/3600., min_match=1) 

397 except AttributeError: 

398 # Old smatch API 

399 idx = matcher.query_self(self.config.match_radius/3600., min_match=1) 

400 

401 count = len(idx) 

402 

403 if count == 0: 

404 self.log.info('Found 0 primary stars in %s band.', primary_band) 

405 continue 

406 

407 band_cat = np.zeros(count, dtype=dtype) 

408 band_cat['primary_band'] = primary_band 

409 

410 # If the tract cross ra=0 (that is, it has both low ra and high ra) 

411 # then we need to remap all ra values from [0, 360) to [-180, 180) 

412 # before doing any position averaging. 

413 remapped = False 

414 if ra.min() < 60.0 and ra.max() > 300.0: 

415 ra_temp = (ra + 180.0) % 360. - 180. 

416 remapped = True 

417 else: 

418 ra_temp = ra 

419 

420 # Compute mean position for each primary star 

421 for i, row in enumerate(idx): 

422 row = np.array(row) 

423 band_cat[ra_col][i] = np.mean(ra_temp[row]) 

424 band_cat[dec_col][i] = np.mean(dec[row]) 

425 

426 if remapped: 

427 # Remap ra back to [0, 360) 

428 band_cat[ra_col] %= 360.0 

429 

430 # Match to previous band catalog(s), and remove duplicates. 

431 if primary_star_cat is None or len(primary_star_cat) == 0: 

432 primary_star_cat = band_cat 

433 else: 

434 with Matcher(band_cat[ra_col], band_cat[dec_col]) as matcher: 

435 idx = matcher.query_radius(primary_star_cat[ra_col], 

436 primary_star_cat[dec_col], 

437 self.config.match_radius/3600.) 

438 # Any object with a match should be removed. 

439 match_indices = np.array([i for i in range(len(idx)) if len(idx[i]) > 0]) 

440 if len(match_indices) > 0: 

441 band_cat = np.delete(band_cat, match_indices) 

442 

443 primary_star_cat = np.append(primary_star_cat, band_cat) 

444 self.log.info('Found %d primary stars in %s band.', len(band_cat), primary_band) 

445 

446 # If everything was cut, we still want the correct datatype. 

447 if primary_star_cat is None: 

448 primary_star_cat = np.zeros(0, dtype=dtype) 

449 

450 return primary_star_cat 

451 

452 def _remove_neighbors(self, primary_star_cat): 

453 """Remove neighbors from the primary star catalog. 

454 

455 Parameters 

456 ---------- 

457 primary_star_cat : `np.ndarray` 

458 Primary star catalog. 

459 

460 Returns 

461 ------- 

462 primary_star_cat_cut : `np.ndarray` 

463 Primary star cat with neighbors removed. 

464 """ 

465 ra_col = self.config.ra_column 

466 dec_col = self.config.dec_column 

467 

468 with Matcher(primary_star_cat[ra_col], primary_star_cat[dec_col]) as matcher: 

469 # By setting min_match=2 objects that only match to themselves 

470 # will not be recorded. 

471 try: 

472 # New smatch API 

473 idx = matcher.query_groups(self.config.isolation_radius/3600., min_match=2) 

474 except AttributeError: 

475 # Old smatch API 

476 idx = matcher.query_self(self.config.isolation_radius/3600., min_match=2) 

477 

478 try: 

479 neighbor_indices = np.concatenate(idx) 

480 except ValueError: 

481 neighbor_indices = np.zeros(0, dtype=int) 

482 

483 if len(neighbor_indices) > 0: 

484 neighbored = np.unique(neighbor_indices) 

485 self.log.info('Cutting %d objects with close neighbors.', len(neighbored)) 

486 primary_star_cat = np.delete(primary_star_cat, neighbored) 

487 

488 return primary_star_cat 

489 

490 def _match_sources(self, bands, star_source_cat, primary_star_cat): 

491 """Match individual sources to primary stars. 

492 

493 Parameters 

494 ---------- 

495 bands : `list` [`str`] 

496 List of bands. 

497 star_source_cat : `np.ndarray` 

498 Array of star sources. 

499 primary_star_cat : `np.ndarray` 

500 Array of primary stars. 

501 

502 Returns 

503 ------- 

504 star_source_cat_sorted : `np.ndarray` 

505 Sorted and cropped array of star sources. 

506 primary_star_cat : `np.ndarray` 

507 Catalog of isolated stars, with indexes to star_source_cat_cut. 

508 """ 

509 ra_col = self.config.ra_column 

510 dec_col = self.config.dec_column 

511 

512 # We match sources per-band because it allows us to have sorted 

513 # sources for easy retrieval of per-band matches. 

514 n_source_per_band_per_obj = np.zeros((len(bands), 

515 len(primary_star_cat)), 

516 dtype=np.int32) 

517 band_uses = [] 

518 idxs = [] 

519 with Matcher(primary_star_cat[ra_col], primary_star_cat[dec_col]) as matcher: 

520 for b, band in enumerate(bands): 

521 band_use, = np.where(star_source_cat['band'] == band) 

522 

523 idx = matcher.query_radius(star_source_cat[ra_col][band_use], 

524 star_source_cat[dec_col][band_use], 

525 self.config.match_radius/3600.) 

526 n_source_per_band_per_obj[b, :] = np.array([len(row) for row in idx]) 

527 idxs.append(idx) 

528 band_uses.append(band_use) 

529 

530 n_source_per_obj = np.sum(n_source_per_band_per_obj, axis=0) 

531 

532 primary_star_cat['nsource'] = n_source_per_obj 

533 primary_star_cat['source_cat_index'][1:] = np.cumsum(n_source_per_obj)[:-1] 

534 

535 n_tot_source = primary_star_cat['source_cat_index'][-1] + primary_star_cat['nsource'][-1] 

536 

537 # Temporary arrays until we crop/sort the source catalog 

538 source_index = np.zeros(n_tot_source, dtype=np.int32) 

539 obj_index = np.zeros(n_tot_source, dtype=np.int32) 

540 

541 ctr = 0 

542 for i in range(len(primary_star_cat)): 

543 obj_index[ctr: ctr + n_source_per_obj[i]] = i 

544 for b in range(len(bands)): 

545 source_index[ctr: ctr + n_source_per_band_per_obj[b, i]] = band_uses[b][idxs[b][i]] 

546 ctr += n_source_per_band_per_obj[b, i] 

547 

548 source_cat_index_band_offset = np.cumsum(n_source_per_band_per_obj, axis=0) 

549 

550 for b, band in enumerate(bands): 

551 primary_star_cat[f'nsource_{band}'] = n_source_per_band_per_obj[b, :] 

552 if b == 0: 

553 # The first band listed is the same as the overall star 

554 primary_star_cat[f'source_cat_index_{band}'] = primary_star_cat['source_cat_index'] 

555 else: 

556 # Other band indices are offset from the previous band 

557 primary_star_cat[f'source_cat_index_{band}'] = (primary_star_cat['source_cat_index'] 

558 + source_cat_index_band_offset[b - 1, :]) 

559 

560 star_source_cat = star_source_cat[source_index] 

561 star_source_cat['obj_index'] = obj_index 

562 

563 return star_source_cat, primary_star_cat 

564 

565 def _compute_unique_ids(self, skymap, tract, nstar): 

566 """Compute unique star ids. 

567 

568 This is a simple hash of the tract and star to provide an 

569 id that is unique for a given processing. 

570 

571 Parameters 

572 ---------- 

573 skymap : `lsst.skymap.Skymap` 

574 Skymap object. 

575 tract : `int` 

576 Tract id number. 

577 nstar : `int` 

578 Number of stars. 

579 

580 Returns 

581 ------- 

582 ids : `np.ndarray` 

583 Array of unique star ids. 

584 """ 

585 # The end of the id will be big enough to hold the tract number 

586 mult = 10**(int(np.log10(len(skymap))) + 1) 

587 

588 return (np.arange(nstar) + 1)*mult + tract 

589 

590 def _get_primary_dtype(self, primary_bands): 

591 """Get the numpy datatype for the primary star catalog. 

592 

593 Parameters 

594 ---------- 

595 primary_bands : `list` [`str`] 

596 List of primary bands. 

597 

598 Returns 

599 ------- 

600 dtype : `numpy.dtype` 

601 Datatype of the primary catalog. 

602 """ 

603 max_len = max([len(primary_band) for primary_band in primary_bands]) 

604 

605 dtype = [('isolated_star_id', 'i8'), 

606 (self.config.ra_column, 'f8'), 

607 (self.config.dec_column, 'f8'), 

608 ('primary_band', f'U{max_len}'), 

609 ('source_cat_index', 'i4'), 

610 ('nsource', 'i4')] 

611 

612 for band in primary_bands: 

613 dtype.append((f'source_cat_index_{band}', 'i4')) 

614 dtype.append((f'nsource_{band}', 'i4')) 

615 

616 return dtype