Coverage for python/lsst/pipe/tasks/isolatedStarAssociation.py: 14%

218 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-19 05:10 -0700

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ['IsolatedStarAssociationConnections', 

23 'IsolatedStarAssociationConfig', 

24 'IsolatedStarAssociationTask'] 

25 

26import numpy as np 

27import pandas as pd 

28from smatch.matcher import Matcher 

29 

30import lsst.pex.config as pexConfig 

31import lsst.pipe.base as pipeBase 

32from lsst.skymap import BaseSkyMap 

33from lsst.meas.algorithms.sourceSelector import sourceSelectorRegistry 

34 

35 

36class IsolatedStarAssociationConnections(pipeBase.PipelineTaskConnections, 

37 dimensions=('instrument', 'tract', 'skymap',), 

38 defaultTemplates={}): 

39 source_table_visit = pipeBase.connectionTypes.Input( 

40 doc='Source table in parquet format, per visit', 

41 name='sourceTable_visit', 

42 storageClass='DataFrame', 

43 dimensions=('instrument', 'visit'), 

44 deferLoad=True, 

45 multiple=True, 

46 ) 

47 skymap = pipeBase.connectionTypes.Input( 

48 doc="Input definition of geometry/bbox and projection/wcs for warped exposures", 

49 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

50 storageClass='SkyMap', 

51 dimensions=('skymap',), 

52 ) 

53 isolated_star_sources = pipeBase.connectionTypes.Output( 

54 doc='Catalog of individual sources for the isolated stars', 

55 name='isolated_star_sources', 

56 storageClass='DataFrame', 

57 dimensions=('instrument', 'tract', 'skymap'), 

58 ) 

59 isolated_star_cat = pipeBase.connectionTypes.Output( 

60 doc='Catalog of isolated star positions', 

61 name='isolated_star_cat', 

62 storageClass='DataFrame', 

63 dimensions=('instrument', 'tract', 'skymap'), 

64 ) 

65 

66 

67class IsolatedStarAssociationConfig(pipeBase.PipelineTaskConfig, 

68 pipelineConnections=IsolatedStarAssociationConnections): 

69 """Configuration for IsolatedStarAssociationTask.""" 

70 

71 inst_flux_field = pexConfig.Field( 

72 doc=('Full name of instFlux field to use for s/n selection and persistence. ' 

73 'The associated flag will be implicity included in bad_flags. ' 

74 'Note that this is expected to end in ``instFlux``.'), 

75 dtype=str, 

76 default='apFlux_12_0_instFlux', 

77 ) 

78 match_radius = pexConfig.Field( 

79 doc='Match radius (arcseconds)', 

80 dtype=float, 

81 default=1.0, 

82 ) 

83 isolation_radius = pexConfig.Field( 

84 doc=('Isolation radius (arcseconds). Any stars with average centroids ' 

85 'within this radius of another star will be rejected from the final ' 

86 'catalog. This radius should be at least 2x match_radius.'), 

87 dtype=float, 

88 default=2.0, 

89 ) 

90 band_order = pexConfig.ListField( 

91 doc=(('Ordered list of bands to use for matching/storage. ' 

92 'Any bands not listed will not be matched.')), 

93 dtype=str, 

94 default=['i', 'z', 'r', 'g', 'y', 'u'], 

95 ) 

96 id_column = pexConfig.Field( 

97 doc='Name of column with source id.', 

98 dtype=str, 

99 default='sourceId', 

100 ) 

101 ra_column = pexConfig.Field( 

102 doc='Name of column with right ascension.', 

103 dtype=str, 

104 default='ra', 

105 ) 

106 dec_column = pexConfig.Field( 

107 doc='Name of column with declination.', 

108 dtype=str, 

109 default='dec', 

110 ) 

111 physical_filter_column = pexConfig.Field( 

112 doc='Name of column with physical filter name', 

113 dtype=str, 

114 default='physical_filter', 

115 ) 

116 band_column = pexConfig.Field( 

117 doc='Name of column with band name', 

118 dtype=str, 

119 default='band', 

120 ) 

121 extra_columns = pexConfig.ListField( 

122 doc='Extra names of columns to read and persist (beyond instFlux and error).', 

123 dtype=str, 

124 default=['x', 

125 'y', 

126 'xErr', 

127 'yErr', 

128 'apFlux_17_0_instFlux', 

129 'apFlux_17_0_instFluxErr', 

130 'apFlux_17_0_flag', 

131 'localBackground_instFlux', 

132 'localBackground_flag', 

133 'ixx', 

134 'iyy', 

135 'ixy',] 

136 ) 

137 source_selector = sourceSelectorRegistry.makeField( 

138 doc='How to select sources. Under normal usage this should not be changed.', 

139 default='science' 

140 ) 

141 

142 def setDefaults(self): 

143 super().setDefaults() 

144 

145 source_selector = self.source_selector['science'] 

146 source_selector.setDefaults() 

147 

148 source_selector.doFlags = True 

149 source_selector.doUnresolved = True 

150 source_selector.doSignalToNoise = True 

151 source_selector.doIsolated = True 

152 source_selector.doRequireFiniteRaDec = True 

153 source_selector.doRequirePrimary = True 

154 

155 source_selector.signalToNoise.minimum = 10.0 

156 source_selector.signalToNoise.maximum = 1000.0 

157 

158 flux_flag_name = self.inst_flux_field.replace("instFlux", "flag") 

159 

160 source_selector.flags.bad = ['pixelFlags_edge', 

161 'pixelFlags_interpolatedCenter', 

162 'pixelFlags_saturatedCenter', 

163 'pixelFlags_crCenter', 

164 'pixelFlags_bad', 

165 'pixelFlags_interpolated', 

166 'pixelFlags_saturated', 

167 'centroid_flag', 

168 flux_flag_name] 

169 

170 source_selector.signalToNoise.fluxField = self.inst_flux_field 

171 source_selector.signalToNoise.errField = self.inst_flux_field + 'Err' 

172 

173 source_selector.isolated.parentName = 'parentSourceId' 

174 source_selector.isolated.nChildName = 'deblend_nChild' 

175 

176 source_selector.unresolved.name = 'sizeExtendedness' 

177 

178 source_selector.requireFiniteRaDec.raColName = self.ra_column 

179 source_selector.requireFiniteRaDec.decColName = self.dec_column 

180 

181 

182class IsolatedStarAssociationTask(pipeBase.PipelineTask): 

183 """Associate sources into isolated star catalogs. 

184 """ 

185 ConfigClass = IsolatedStarAssociationConfig 

186 _DefaultName = 'isolatedStarAssociation' 

187 

188 def __init__(self, **kwargs): 

189 super().__init__(**kwargs) 

190 

191 self.makeSubtask('source_selector') 

192 # Only log warning and fatal errors from the source_selector 

193 self.source_selector.log.setLevel(self.source_selector.log.WARN) 

194 

195 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

196 input_ref_dict = butlerQC.get(inputRefs) 

197 

198 tract = butlerQC.quantum.dataId['tract'] 

199 

200 source_table_refs = input_ref_dict['source_table_visit'] 

201 

202 self.log.info('Running with %d source_table_visit dataRefs', 

203 len(source_table_refs)) 

204 

205 source_table_ref_dict_temp = {source_table_ref.dataId['visit']: source_table_ref for 

206 source_table_ref in source_table_refs} 

207 

208 bands = {source_table_ref.dataId['band'] for source_table_ref in source_table_refs} 

209 for band in bands: 

210 if band not in self.config.band_order: 

211 self.log.warning('Input data has data from band %s but that band is not ' 

212 'configured for matching', band) 

213 

214 # TODO: Sort by visit until DM-31701 is done and we have deterministic 

215 # dataset ordering. 

216 source_table_ref_dict = {visit: source_table_ref_dict_temp[visit] for 

217 visit in sorted(source_table_ref_dict_temp.keys())} 

218 

219 struct = self.run(input_ref_dict['skymap'], tract, source_table_ref_dict) 

220 

221 butlerQC.put(pd.DataFrame(struct.star_source_cat), 

222 outputRefs.isolated_star_sources) 

223 butlerQC.put(pd.DataFrame(struct.star_cat), 

224 outputRefs.isolated_star_cat) 

225 

226 def run(self, skymap, tract, source_table_ref_dict): 

227 """Run the isolated star association task. 

228 

229 Parameters 

230 ---------- 

231 skymap : `lsst.skymap.SkyMap` 

232 Skymap object. 

233 tract : `int` 

234 Tract number. 

235 source_table_ref_dict : `dict` 

236 Dictionary of source_table refs. Key is visit, value is dataref. 

237 

238 Returns 

239 ------- 

240 struct : `lsst.pipe.base.struct` 

241 Struct with outputs for persistence. 

242 """ 

243 star_source_cat = self._make_all_star_sources(skymap[tract], source_table_ref_dict) 

244 

245 primary_bands = self.config.band_order 

246 

247 # Do primary matching 

248 primary_star_cat = self._match_primary_stars(primary_bands, star_source_cat) 

249 

250 if len(primary_star_cat) == 0: 

251 return pipeBase.Struct(star_source_cat=np.zeros(0, star_source_cat.dtype), 

252 star_cat=np.zeros(0, primary_star_cat.dtype)) 

253 

254 # Remove neighbors 

255 primary_star_cat = self._remove_neighbors(primary_star_cat) 

256 

257 if len(primary_star_cat) == 0: 

258 return pipeBase.Struct(star_source_cat=np.zeros(0, star_source_cat.dtype), 

259 star_cat=np.zeros(0, primary_star_cat.dtype)) 

260 

261 # Crop to inner tract region 

262 inner_tract_ids = skymap.findTractIdArray(primary_star_cat[self.config.ra_column], 

263 primary_star_cat[self.config.dec_column], 

264 degrees=True) 

265 use = (inner_tract_ids == tract) 

266 self.log.info('Total of %d isolated stars in inner tract.', use.sum()) 

267 

268 primary_star_cat = primary_star_cat[use] 

269 

270 if len(primary_star_cat) == 0: 

271 return pipeBase.Struct(star_source_cat=np.zeros(0, star_source_cat.dtype), 

272 star_cat=np.zeros(0, primary_star_cat.dtype)) 

273 

274 # Set the unique ids. 

275 primary_star_cat['isolated_star_id'] = self._compute_unique_ids(skymap, 

276 tract, 

277 len(primary_star_cat)) 

278 

279 # Match to sources. 

280 star_source_cat, primary_star_cat = self._match_sources(primary_bands, 

281 star_source_cat, 

282 primary_star_cat) 

283 

284 return pipeBase.Struct(star_source_cat=star_source_cat, 

285 star_cat=primary_star_cat) 

286 

287 def _make_all_star_sources(self, tract_info, source_table_ref_dict): 

288 """Make a catalog of all the star sources. 

289 

290 Parameters 

291 ---------- 

292 tract_info : `lsst.skymap.TractInfo` 

293 Information about the tract. 

294 source_table_ref_dict : `dict` 

295 Dictionary of source_table refs. Key is visit, value is dataref. 

296 

297 Returns 

298 ------- 

299 star_source_cat : `np.ndarray` 

300 Catalog of star sources. 

301 """ 

302 # Internally, we use a numpy recarray, they are by far the fastest 

303 # option in testing for relatively narrow tables. 

304 # (have not tested wide tables) 

305 all_columns, persist_columns = self._get_source_table_visit_column_names() 

306 poly = tract_info.outer_sky_polygon 

307 

308 tables = [] 

309 for visit in source_table_ref_dict: 

310 source_table_ref = source_table_ref_dict[visit] 

311 df = source_table_ref.get(parameters={'columns': all_columns}) 

312 df.reset_index(inplace=True) 

313 

314 goodSrc = self.source_selector.selectSources(df) 

315 

316 table = df[persist_columns][goodSrc.selected].to_records() 

317 

318 # Append columns that include the row in the source table 

319 # and the matched object index (to be filled later). 

320 table = np.lib.recfunctions.append_fields(table, 

321 ['source_row', 

322 'obj_index'], 

323 [np.where(goodSrc.selected)[0], 

324 np.zeros(goodSrc.selected.sum(), dtype=np.int32)], 

325 dtypes=['i4', 'i4'], 

326 usemask=False) 

327 

328 # We cut to the outer tract polygon to ensure consistent matching 

329 # from tract to tract. 

330 tract_use = poly.contains(np.deg2rad(table[self.config.ra_column]), 

331 np.deg2rad(table[self.config.dec_column])) 

332 

333 tables.append(table[tract_use]) 

334 

335 # Combine tables 

336 star_source_cat = np.concatenate(tables) 

337 

338 return star_source_cat 

339 

340 def _get_source_table_visit_column_names(self): 

341 """Get the list of sourceTable_visit columns from the config. 

342 

343 Returns 

344 ------- 

345 all_columns : `list` [`str`] 

346 All columns to read 

347 persist_columns : `list` [`str`] 

348 Columns to persist (excluding selection columns) 

349 """ 

350 columns = [self.config.id_column, 

351 'visit', 'detector', 

352 self.config.ra_column, self.config.dec_column, 

353 self.config.physical_filter_column, self.config.band_column, 

354 self.config.inst_flux_field, self.config.inst_flux_field + 'Err'] 

355 columns.extend(self.config.extra_columns) 

356 

357 all_columns = columns.copy() 

358 if self.source_selector.config.doFlags: 

359 all_columns.extend(self.source_selector.config.flags.bad) 

360 if self.source_selector.config.doUnresolved: 

361 all_columns.append(self.source_selector.config.unresolved.name) 

362 if self.source_selector.config.doIsolated: 

363 all_columns.append(self.source_selector.config.isolated.parentName) 

364 all_columns.append(self.source_selector.config.isolated.nChildName) 

365 if self.source_selector.config.doRequirePrimary: 

366 all_columns.append(self.source_selector.config.requirePrimary.primaryColName) 

367 

368 return all_columns, columns 

369 

370 def _match_primary_stars(self, primary_bands, star_source_cat): 

371 """Match primary stars. 

372 

373 Parameters 

374 ---------- 

375 primary_bands : `list` [`str`] 

376 Ordered list of primary bands. 

377 star_source_cat : `np.ndarray` 

378 Catalog of star sources. 

379 

380 Returns 

381 ------- 

382 primary_star_cat : `np.ndarray` 

383 Catalog of primary star positions 

384 """ 

385 ra_col = self.config.ra_column 

386 dec_col = self.config.dec_column 

387 

388 dtype = self._get_primary_dtype(primary_bands) 

389 

390 primary_star_cat = None 

391 for primary_band in primary_bands: 

392 use = (star_source_cat['band'] == primary_band) 

393 

394 ra = star_source_cat[ra_col][use] 

395 dec = star_source_cat[dec_col][use] 

396 

397 with Matcher(ra, dec) as matcher: 

398 try: 

399 # New smatch API 

400 idx = matcher.query_groups(self.config.match_radius/3600., min_match=1) 

401 except AttributeError: 

402 # Old smatch API 

403 idx = matcher.query_self(self.config.match_radius/3600., min_match=1) 

404 

405 count = len(idx) 

406 

407 if count == 0: 

408 self.log.info('Found 0 primary stars in %s band.', primary_band) 

409 continue 

410 

411 band_cat = np.zeros(count, dtype=dtype) 

412 band_cat['primary_band'] = primary_band 

413 

414 # If the tract cross ra=0 (that is, it has both low ra and high ra) 

415 # then we need to remap all ra values from [0, 360) to [-180, 180) 

416 # before doing any position averaging. 

417 remapped = False 

418 if ra.min() < 60.0 and ra.max() > 300.0: 

419 ra_temp = (ra + 180.0) % 360. - 180. 

420 remapped = True 

421 else: 

422 ra_temp = ra 

423 

424 # Compute mean position for each primary star 

425 for i, row in enumerate(idx): 

426 row = np.array(row) 

427 band_cat[ra_col][i] = np.mean(ra_temp[row]) 

428 band_cat[dec_col][i] = np.mean(dec[row]) 

429 

430 if remapped: 

431 # Remap ra back to [0, 360) 

432 band_cat[ra_col] %= 360.0 

433 

434 # Match to previous band catalog(s), and remove duplicates. 

435 if primary_star_cat is None or len(primary_star_cat) == 0: 

436 primary_star_cat = band_cat 

437 else: 

438 with Matcher(band_cat[ra_col], band_cat[dec_col]) as matcher: 

439 idx = matcher.query_radius(primary_star_cat[ra_col], 

440 primary_star_cat[dec_col], 

441 self.config.match_radius/3600.) 

442 # Any object with a match should be removed. 

443 match_indices = np.array([i for i in range(len(idx)) if len(idx[i]) > 0]) 

444 if len(match_indices) > 0: 

445 band_cat = np.delete(band_cat, match_indices) 

446 

447 primary_star_cat = np.append(primary_star_cat, band_cat) 

448 self.log.info('Found %d primary stars in %s band.', len(band_cat), primary_band) 

449 

450 # If everything was cut, we still want the correct datatype. 

451 if primary_star_cat is None: 

452 primary_star_cat = np.zeros(0, dtype=dtype) 

453 

454 return primary_star_cat 

455 

456 def _remove_neighbors(self, primary_star_cat): 

457 """Remove neighbors from the primary star catalog. 

458 

459 Parameters 

460 ---------- 

461 primary_star_cat : `np.ndarray` 

462 Primary star catalog. 

463 

464 Returns 

465 ------- 

466 primary_star_cat_cut : `np.ndarray` 

467 Primary star cat with neighbors removed. 

468 """ 

469 ra_col = self.config.ra_column 

470 dec_col = self.config.dec_column 

471 

472 with Matcher(primary_star_cat[ra_col], primary_star_cat[dec_col]) as matcher: 

473 # By setting min_match=2 objects that only match to themselves 

474 # will not be recorded. 

475 try: 

476 # New smatch API 

477 idx = matcher.query_groups(self.config.isolation_radius/3600., min_match=2) 

478 except AttributeError: 

479 # Old smatch API 

480 idx = matcher.query_self(self.config.isolation_radius/3600., min_match=2) 

481 

482 try: 

483 neighbor_indices = np.concatenate(idx) 

484 except ValueError: 

485 neighbor_indices = np.zeros(0, dtype=int) 

486 

487 if len(neighbor_indices) > 0: 

488 neighbored = np.unique(neighbor_indices) 

489 self.log.info('Cutting %d objects with close neighbors.', len(neighbored)) 

490 primary_star_cat = np.delete(primary_star_cat, neighbored) 

491 

492 return primary_star_cat 

493 

494 def _match_sources(self, bands, star_source_cat, primary_star_cat): 

495 """Match individual sources to primary stars. 

496 

497 Parameters 

498 ---------- 

499 bands : `list` [`str`] 

500 List of bands. 

501 star_source_cat : `np.ndarray` 

502 Array of star sources. 

503 primary_star_cat : `np.ndarray` 

504 Array of primary stars. 

505 

506 Returns 

507 ------- 

508 star_source_cat_sorted : `np.ndarray` 

509 Sorted and cropped array of star sources. 

510 primary_star_cat : `np.ndarray` 

511 Catalog of isolated stars, with indexes to star_source_cat_cut. 

512 """ 

513 ra_col = self.config.ra_column 

514 dec_col = self.config.dec_column 

515 

516 # We match sources per-band because it allows us to have sorted 

517 # sources for easy retrieval of per-band matches. 

518 n_source_per_band_per_obj = np.zeros((len(bands), 

519 len(primary_star_cat)), 

520 dtype=np.int32) 

521 band_uses = [] 

522 idxs = [] 

523 with Matcher(primary_star_cat[ra_col], primary_star_cat[dec_col]) as matcher: 

524 for b, band in enumerate(bands): 

525 band_use, = np.where(star_source_cat['band'] == band) 

526 

527 idx = matcher.query_radius(star_source_cat[ra_col][band_use], 

528 star_source_cat[dec_col][band_use], 

529 self.config.match_radius/3600.) 

530 n_source_per_band_per_obj[b, :] = np.array([len(row) for row in idx]) 

531 idxs.append(idx) 

532 band_uses.append(band_use) 

533 

534 n_source_per_obj = np.sum(n_source_per_band_per_obj, axis=0) 

535 

536 primary_star_cat['nsource'] = n_source_per_obj 

537 primary_star_cat['source_cat_index'][1:] = np.cumsum(n_source_per_obj)[:-1] 

538 

539 n_tot_source = primary_star_cat['source_cat_index'][-1] + primary_star_cat['nsource'][-1] 

540 

541 # Temporary arrays until we crop/sort the source catalog 

542 source_index = np.zeros(n_tot_source, dtype=np.int32) 

543 obj_index = np.zeros(n_tot_source, dtype=np.int32) 

544 

545 ctr = 0 

546 for i in range(len(primary_star_cat)): 

547 obj_index[ctr: ctr + n_source_per_obj[i]] = i 

548 for b in range(len(bands)): 

549 source_index[ctr: ctr + n_source_per_band_per_obj[b, i]] = band_uses[b][idxs[b][i]] 

550 ctr += n_source_per_band_per_obj[b, i] 

551 

552 source_cat_index_band_offset = np.cumsum(n_source_per_band_per_obj, axis=0) 

553 

554 for b, band in enumerate(bands): 

555 primary_star_cat[f'nsource_{band}'] = n_source_per_band_per_obj[b, :] 

556 if b == 0: 

557 # The first band listed is the same as the overall star 

558 primary_star_cat[f'source_cat_index_{band}'] = primary_star_cat['source_cat_index'] 

559 else: 

560 # Other band indices are offset from the previous band 

561 primary_star_cat[f'source_cat_index_{band}'] = (primary_star_cat['source_cat_index'] 

562 + source_cat_index_band_offset[b - 1, :]) 

563 

564 star_source_cat = star_source_cat[source_index] 

565 star_source_cat['obj_index'] = obj_index 

566 

567 return star_source_cat, primary_star_cat 

568 

569 def _compute_unique_ids(self, skymap, tract, nstar): 

570 """Compute unique star ids. 

571 

572 This is a simple hash of the tract and star to provide an 

573 id that is unique for a given processing. 

574 

575 Parameters 

576 ---------- 

577 skymap : `lsst.skymap.Skymap` 

578 Skymap object. 

579 tract : `int` 

580 Tract id number. 

581 nstar : `int` 

582 Number of stars. 

583 

584 Returns 

585 ------- 

586 ids : `np.ndarray` 

587 Array of unique star ids. 

588 """ 

589 # The end of the id will be big enough to hold the tract number 

590 mult = 10**(int(np.log10(len(skymap))) + 1) 

591 

592 return (np.arange(nstar) + 1)*mult + tract 

593 

594 def _get_primary_dtype(self, primary_bands): 

595 """Get the numpy datatype for the primary star catalog. 

596 

597 Parameters 

598 ---------- 

599 primary_bands : `list` [`str`] 

600 List of primary bands. 

601 

602 Returns 

603 ------- 

604 dtype : `numpy.dtype` 

605 Datatype of the primary catalog. 

606 """ 

607 max_len = max([len(primary_band) for primary_band in primary_bands]) 

608 

609 dtype = [('isolated_star_id', 'i8'), 

610 (self.config.ra_column, 'f8'), 

611 (self.config.dec_column, 'f8'), 

612 ('primary_band', f'U{max_len}'), 

613 ('source_cat_index', 'i4'), 

614 ('nsource', 'i4')] 

615 

616 for band in primary_bands: 

617 dtype.append((f'source_cat_index_{band}', 'i4')) 

618 dtype.append((f'nsource_{band}', 'i4')) 

619 

620 return dtype