Coverage for python/lsst/pipe/tasks/isolatedStarAssociation.py: 13%

219 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-02-28 12:33 +0000

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ['IsolatedStarAssociationConnections', 

23 'IsolatedStarAssociationConfig', 

24 'IsolatedStarAssociationTask'] 

25 

26import numpy as np 

27import pandas as pd 

28from smatch.matcher import Matcher 

29 

30import lsst.pex.config as pexConfig 

31import lsst.pipe.base as pipeBase 

32from lsst.skymap import BaseSkyMap 

33from lsst.meas.algorithms.sourceSelector import sourceSelectorRegistry 

34 

35 

36class IsolatedStarAssociationConnections(pipeBase.PipelineTaskConnections, 

37 dimensions=('instrument', 'tract', 'skymap',), 

38 defaultTemplates={}): 

39 source_table_visit = pipeBase.connectionTypes.Input( 

40 doc='Source table in parquet format, per visit', 

41 name='sourceTable_visit', 

42 storageClass='DataFrame', 

43 dimensions=('instrument', 'visit'), 

44 deferLoad=True, 

45 multiple=True, 

46 ) 

47 skymap = pipeBase.connectionTypes.Input( 

48 doc="Input definition of geometry/bbox and projection/wcs for warped exposures", 

49 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

50 storageClass='SkyMap', 

51 dimensions=('skymap',), 

52 ) 

53 isolated_star_sources = pipeBase.connectionTypes.Output( 

54 doc='Catalog of individual sources for the isolated stars', 

55 name='isolated_star_sources', 

56 storageClass='DataFrame', 

57 dimensions=('instrument', 'tract', 'skymap'), 

58 ) 

59 isolated_star_cat = pipeBase.connectionTypes.Output( 

60 doc='Catalog of isolated star positions', 

61 name='isolated_star_cat', 

62 storageClass='DataFrame', 

63 dimensions=('instrument', 'tract', 'skymap'), 

64 ) 

65 

66 

67class IsolatedStarAssociationConfig(pipeBase.PipelineTaskConfig, 

68 pipelineConnections=IsolatedStarAssociationConnections): 

69 """Configuration for IsolatedStarAssociationTask.""" 

70 

71 inst_flux_field = pexConfig.Field( 

72 doc=('Full name of instFlux field to use for s/n selection and persistence. ' 

73 'The associated flag will be implicity included in bad_flags. ' 

74 'Note that this is expected to end in ``instFlux``.'), 

75 dtype=str, 

76 default='apFlux_12_0_instFlux', 

77 ) 

78 match_radius = pexConfig.Field( 

79 doc='Match radius (arcseconds)', 

80 dtype=float, 

81 default=1.0, 

82 ) 

83 isolation_radius = pexConfig.Field( 

84 doc=('Isolation radius (arcseconds). Any stars with average centroids ' 

85 'within this radius of another star will be rejected from the final ' 

86 'catalog. This radius should be at least 2x match_radius.'), 

87 dtype=float, 

88 default=2.0, 

89 ) 

90 band_order = pexConfig.ListField( 

91 doc=(('Ordered list of bands to use for matching/storage. ' 

92 'Any bands not listed will not be matched.')), 

93 dtype=str, 

94 default=['i', 'z', 'r', 'g', 'y', 'u'], 

95 ) 

96 id_column = pexConfig.Field( 

97 doc='Name of column with source id.', 

98 dtype=str, 

99 default='sourceId', 

100 ) 

101 ra_column = pexConfig.Field( 

102 doc='Name of column with right ascension.', 

103 dtype=str, 

104 default='ra', 

105 ) 

106 dec_column = pexConfig.Field( 

107 doc='Name of column with declination.', 

108 dtype=str, 

109 default='dec', 

110 ) 

111 physical_filter_column = pexConfig.Field( 

112 doc='Name of column with physical filter name', 

113 dtype=str, 

114 default='physical_filter', 

115 ) 

116 band_column = pexConfig.Field( 

117 doc='Name of column with band name', 

118 dtype=str, 

119 default='band', 

120 ) 

121 extra_columns = pexConfig.ListField( 

122 doc='Extra names of columns to read and persist (beyond instFlux and error).', 

123 dtype=str, 

124 default=['x', 

125 'y', 

126 'xErr', 

127 'yErr', 

128 'apFlux_17_0_instFlux', 

129 'apFlux_17_0_instFluxErr', 

130 'apFlux_17_0_flag', 

131 'localBackground_instFlux', 

132 'localBackground_flag', 

133 'ixx', 

134 'iyy', 

135 'ixy',] 

136 ) 

137 source_selector = sourceSelectorRegistry.makeField( 

138 doc='How to select sources. Under normal usage this should not be changed.', 

139 default='science' 

140 ) 

141 

142 def setDefaults(self): 

143 super().setDefaults() 

144 

145 source_selector = self.source_selector['science'] 

146 source_selector.setDefaults() 

147 

148 source_selector.doFlags = True 

149 source_selector.doUnresolved = True 

150 source_selector.doSignalToNoise = True 

151 source_selector.doIsolated = True 

152 source_selector.doRequireFiniteRaDec = True 

153 source_selector.doRequirePrimary = True 

154 

155 source_selector.signalToNoise.minimum = 10.0 

156 source_selector.signalToNoise.maximum = 1000.0 

157 

158 flux_flag_name = self.inst_flux_field.replace("instFlux", "flag") 

159 

160 source_selector.flags.bad = ['pixelFlags_edge', 

161 'pixelFlags_interpolatedCenter', 

162 'pixelFlags_saturatedCenter', 

163 'pixelFlags_crCenter', 

164 'pixelFlags_bad', 

165 'pixelFlags_interpolated', 

166 'pixelFlags_saturated', 

167 'centroid_flag', 

168 flux_flag_name] 

169 

170 source_selector.signalToNoise.fluxField = self.inst_flux_field 

171 source_selector.signalToNoise.errField = self.inst_flux_field + 'Err' 

172 

173 source_selector.isolated.parentName = 'parentSourceId' 

174 source_selector.isolated.nChildName = 'deblend_nChild' 

175 

176 source_selector.unresolved.maximum = 0.5 

177 source_selector.unresolved.name = 'extendedness' 

178 

179 source_selector.requireFiniteRaDec.raColName = self.ra_column 

180 source_selector.requireFiniteRaDec.decColName = self.dec_column 

181 

182 

183class IsolatedStarAssociationTask(pipeBase.PipelineTask): 

184 """Associate sources into isolated star catalogs. 

185 """ 

186 ConfigClass = IsolatedStarAssociationConfig 

187 _DefaultName = 'isolatedStarAssociation' 

188 

189 def __init__(self, **kwargs): 

190 super().__init__(**kwargs) 

191 

192 self.makeSubtask('source_selector') 

193 # Only log warning and fatal errors from the source_selector 

194 self.source_selector.log.setLevel(self.source_selector.log.WARN) 

195 

196 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

197 input_ref_dict = butlerQC.get(inputRefs) 

198 

199 tract = butlerQC.quantum.dataId['tract'] 

200 

201 source_table_refs = input_ref_dict['source_table_visit'] 

202 

203 self.log.info('Running with %d source_table_visit dataRefs', 

204 len(source_table_refs)) 

205 

206 source_table_ref_dict_temp = {source_table_ref.dataId['visit']: source_table_ref for 

207 source_table_ref in source_table_refs} 

208 

209 bands = {source_table_ref.dataId['band'] for source_table_ref in source_table_refs} 

210 for band in bands: 

211 if band not in self.config.band_order: 

212 self.log.warning('Input data has data from band %s but that band is not ' 

213 'configured for matching', band) 

214 

215 # TODO: Sort by visit until DM-31701 is done and we have deterministic 

216 # dataset ordering. 

217 source_table_ref_dict = {visit: source_table_ref_dict_temp[visit] for 

218 visit in sorted(source_table_ref_dict_temp.keys())} 

219 

220 struct = self.run(input_ref_dict['skymap'], tract, source_table_ref_dict) 

221 

222 butlerQC.put(pd.DataFrame(struct.star_source_cat), 

223 outputRefs.isolated_star_sources) 

224 butlerQC.put(pd.DataFrame(struct.star_cat), 

225 outputRefs.isolated_star_cat) 

226 

227 def run(self, skymap, tract, source_table_ref_dict): 

228 """Run the isolated star association task. 

229 

230 Parameters 

231 ---------- 

232 skymap : `lsst.skymap.SkyMap` 

233 Skymap object. 

234 tract : `int` 

235 Tract number. 

236 source_table_ref_dict : `dict` 

237 Dictionary of source_table refs. Key is visit, value is dataref. 

238 

239 Returns 

240 ------- 

241 struct : `lsst.pipe.base.struct` 

242 Struct with outputs for persistence. 

243 """ 

244 star_source_cat = self._make_all_star_sources(skymap[tract], source_table_ref_dict) 

245 

246 primary_bands = self.config.band_order 

247 

248 # Do primary matching 

249 primary_star_cat = self._match_primary_stars(primary_bands, star_source_cat) 

250 

251 if len(primary_star_cat) == 0: 

252 return pipeBase.Struct(star_source_cat=np.zeros(0, star_source_cat.dtype), 

253 star_cat=np.zeros(0, primary_star_cat.dtype)) 

254 

255 # Remove neighbors 

256 primary_star_cat = self._remove_neighbors(primary_star_cat) 

257 

258 if len(primary_star_cat) == 0: 

259 return pipeBase.Struct(star_source_cat=np.zeros(0, star_source_cat.dtype), 

260 star_cat=np.zeros(0, primary_star_cat.dtype)) 

261 

262 # Crop to inner tract region 

263 inner_tract_ids = skymap.findTractIdArray(primary_star_cat[self.config.ra_column], 

264 primary_star_cat[self.config.dec_column], 

265 degrees=True) 

266 use = (inner_tract_ids == tract) 

267 self.log.info('Total of %d isolated stars in inner tract.', use.sum()) 

268 

269 primary_star_cat = primary_star_cat[use] 

270 

271 if len(primary_star_cat) == 0: 

272 return pipeBase.Struct(star_source_cat=np.zeros(0, star_source_cat.dtype), 

273 star_cat=np.zeros(0, primary_star_cat.dtype)) 

274 

275 # Set the unique ids. 

276 primary_star_cat['isolated_star_id'] = self._compute_unique_ids(skymap, 

277 tract, 

278 len(primary_star_cat)) 

279 

280 # Match to sources. 

281 star_source_cat, primary_star_cat = self._match_sources(primary_bands, 

282 star_source_cat, 

283 primary_star_cat) 

284 

285 return pipeBase.Struct(star_source_cat=star_source_cat, 

286 star_cat=primary_star_cat) 

287 

288 def _make_all_star_sources(self, tract_info, source_table_ref_dict): 

289 """Make a catalog of all the star sources. 

290 

291 Parameters 

292 ---------- 

293 tract_info : `lsst.skymap.TractInfo` 

294 Information about the tract. 

295 source_table_ref_dict : `dict` 

296 Dictionary of source_table refs. Key is visit, value is dataref. 

297 

298 Returns 

299 ------- 

300 star_source_cat : `np.ndarray` 

301 Catalog of star sources. 

302 """ 

303 # Internally, we use a numpy recarray, they are by far the fastest 

304 # option in testing for relatively narrow tables. 

305 # (have not tested wide tables) 

306 all_columns, persist_columns = self._get_source_table_visit_column_names() 

307 poly = tract_info.outer_sky_polygon 

308 

309 tables = [] 

310 for visit in source_table_ref_dict: 

311 source_table_ref = source_table_ref_dict[visit] 

312 df = source_table_ref.get(parameters={'columns': all_columns}) 

313 df.reset_index(inplace=True) 

314 

315 goodSrc = self.source_selector.selectSources(df) 

316 

317 table = df[persist_columns][goodSrc.selected].to_records() 

318 

319 # Append columns that include the row in the source table 

320 # and the matched object index (to be filled later). 

321 table = np.lib.recfunctions.append_fields(table, 

322 ['source_row', 

323 'obj_index'], 

324 [np.where(goodSrc.selected)[0], 

325 np.zeros(goodSrc.selected.sum(), dtype=np.int32)], 

326 dtypes=['i4', 'i4'], 

327 usemask=False) 

328 

329 # We cut to the outer tract polygon to ensure consistent matching 

330 # from tract to tract. 

331 tract_use = poly.contains(np.deg2rad(table[self.config.ra_column]), 

332 np.deg2rad(table[self.config.dec_column])) 

333 

334 tables.append(table[tract_use]) 

335 

336 # Combine tables 

337 star_source_cat = np.concatenate(tables) 

338 

339 return star_source_cat 

340 

341 def _get_source_table_visit_column_names(self): 

342 """Get the list of sourceTable_visit columns from the config. 

343 

344 Returns 

345 ------- 

346 all_columns : `list` [`str`] 

347 All columns to read 

348 persist_columns : `list` [`str`] 

349 Columns to persist (excluding selection columns) 

350 """ 

351 columns = [self.config.id_column, 

352 'visit', 'detector', 

353 self.config.ra_column, self.config.dec_column, 

354 self.config.physical_filter_column, self.config.band_column, 

355 self.config.inst_flux_field, self.config.inst_flux_field + 'Err'] 

356 columns.extend(self.config.extra_columns) 

357 

358 all_columns = columns.copy() 

359 if self.source_selector.config.doFlags: 

360 all_columns.extend(self.source_selector.config.flags.bad) 

361 if self.source_selector.config.doUnresolved: 

362 all_columns.append(self.source_selector.config.unresolved.name) 

363 if self.source_selector.config.doIsolated: 

364 all_columns.append(self.source_selector.config.isolated.parentName) 

365 all_columns.append(self.source_selector.config.isolated.nChildName) 

366 if self.source_selector.config.doRequirePrimary: 

367 all_columns.append(self.source_selector.config.requirePrimary.primaryColName) 

368 

369 return all_columns, columns 

370 

371 def _match_primary_stars(self, primary_bands, star_source_cat): 

372 """Match primary stars. 

373 

374 Parameters 

375 ---------- 

376 primary_bands : `list` [`str`] 

377 Ordered list of primary bands. 

378 star_source_cat : `np.ndarray` 

379 Catalog of star sources. 

380 

381 Returns 

382 ------- 

383 primary_star_cat : `np.ndarray` 

384 Catalog of primary star positions 

385 """ 

386 ra_col = self.config.ra_column 

387 dec_col = self.config.dec_column 

388 

389 dtype = self._get_primary_dtype(primary_bands) 

390 

391 primary_star_cat = None 

392 for primary_band in primary_bands: 

393 use = (star_source_cat['band'] == primary_band) 

394 

395 ra = star_source_cat[ra_col][use] 

396 dec = star_source_cat[dec_col][use] 

397 

398 with Matcher(ra, dec) as matcher: 

399 try: 

400 # New smatch API 

401 idx = matcher.query_groups(self.config.match_radius/3600., min_match=1) 

402 except AttributeError: 

403 # Old smatch API 

404 idx = matcher.query_self(self.config.match_radius/3600., min_match=1) 

405 

406 count = len(idx) 

407 

408 if count == 0: 

409 self.log.info('Found 0 primary stars in %s band.', primary_band) 

410 continue 

411 

412 band_cat = np.zeros(count, dtype=dtype) 

413 band_cat['primary_band'] = primary_band 

414 

415 # If the tract cross ra=0 (that is, it has both low ra and high ra) 

416 # then we need to remap all ra values from [0, 360) to [-180, 180) 

417 # before doing any position averaging. 

418 remapped = False 

419 if ra.min() < 60.0 and ra.max() > 300.0: 

420 ra_temp = (ra + 180.0) % 360. - 180. 

421 remapped = True 

422 else: 

423 ra_temp = ra 

424 

425 # Compute mean position for each primary star 

426 for i, row in enumerate(idx): 

427 row = np.array(row) 

428 band_cat[ra_col][i] = np.mean(ra_temp[row]) 

429 band_cat[dec_col][i] = np.mean(dec[row]) 

430 

431 if remapped: 

432 # Remap ra back to [0, 360) 

433 band_cat[ra_col] %= 360.0 

434 

435 # Match to previous band catalog(s), and remove duplicates. 

436 if primary_star_cat is None or len(primary_star_cat) == 0: 

437 primary_star_cat = band_cat 

438 else: 

439 with Matcher(band_cat[ra_col], band_cat[dec_col]) as matcher: 

440 idx = matcher.query_radius(primary_star_cat[ra_col], 

441 primary_star_cat[dec_col], 

442 self.config.match_radius/3600.) 

443 # Any object with a match should be removed. 

444 match_indices = np.array([i for i in range(len(idx)) if len(idx[i]) > 0]) 

445 if len(match_indices) > 0: 

446 band_cat = np.delete(band_cat, match_indices) 

447 

448 primary_star_cat = np.append(primary_star_cat, band_cat) 

449 self.log.info('Found %d primary stars in %s band.', len(band_cat), primary_band) 

450 

451 # If everything was cut, we still want the correct datatype. 

452 if primary_star_cat is None: 

453 primary_star_cat = np.zeros(0, dtype=dtype) 

454 

455 return primary_star_cat 

456 

457 def _remove_neighbors(self, primary_star_cat): 

458 """Remove neighbors from the primary star catalog. 

459 

460 Parameters 

461 ---------- 

462 primary_star_cat : `np.ndarray` 

463 Primary star catalog. 

464 

465 Returns 

466 ------- 

467 primary_star_cat_cut : `np.ndarray` 

468 Primary star cat with neighbors removed. 

469 """ 

470 ra_col = self.config.ra_column 

471 dec_col = self.config.dec_column 

472 

473 with Matcher(primary_star_cat[ra_col], primary_star_cat[dec_col]) as matcher: 

474 # By setting min_match=2 objects that only match to themselves 

475 # will not be recorded. 

476 try: 

477 # New smatch API 

478 idx = matcher.query_groups(self.config.isolation_radius/3600., min_match=2) 

479 except AttributeError: 

480 # Old smatch API 

481 idx = matcher.query_self(self.config.isolation_radius/3600., min_match=2) 

482 

483 try: 

484 neighbor_indices = np.concatenate(idx) 

485 except ValueError: 

486 neighbor_indices = np.zeros(0, dtype=int) 

487 

488 if len(neighbor_indices) > 0: 

489 neighbored = np.unique(neighbor_indices) 

490 self.log.info('Cutting %d objects with close neighbors.', len(neighbored)) 

491 primary_star_cat = np.delete(primary_star_cat, neighbored) 

492 

493 return primary_star_cat 

494 

495 def _match_sources(self, bands, star_source_cat, primary_star_cat): 

496 """Match individual sources to primary stars. 

497 

498 Parameters 

499 ---------- 

500 bands : `list` [`str`] 

501 List of bands. 

502 star_source_cat : `np.ndarray` 

503 Array of star sources. 

504 primary_star_cat : `np.ndarray` 

505 Array of primary stars. 

506 

507 Returns 

508 ------- 

509 star_source_cat_sorted : `np.ndarray` 

510 Sorted and cropped array of star sources. 

511 primary_star_cat : `np.ndarray` 

512 Catalog of isolated stars, with indexes to star_source_cat_cut. 

513 """ 

514 ra_col = self.config.ra_column 

515 dec_col = self.config.dec_column 

516 

517 # We match sources per-band because it allows us to have sorted 

518 # sources for easy retrieval of per-band matches. 

519 n_source_per_band_per_obj = np.zeros((len(bands), 

520 len(primary_star_cat)), 

521 dtype=np.int32) 

522 band_uses = [] 

523 idxs = [] 

524 with Matcher(primary_star_cat[ra_col], primary_star_cat[dec_col]) as matcher: 

525 for b, band in enumerate(bands): 

526 band_use, = np.where(star_source_cat['band'] == band) 

527 

528 idx = matcher.query_radius(star_source_cat[ra_col][band_use], 

529 star_source_cat[dec_col][band_use], 

530 self.config.match_radius/3600.) 

531 n_source_per_band_per_obj[b, :] = np.array([len(row) for row in idx]) 

532 idxs.append(idx) 

533 band_uses.append(band_use) 

534 

535 n_source_per_obj = np.sum(n_source_per_band_per_obj, axis=0) 

536 

537 primary_star_cat['nsource'] = n_source_per_obj 

538 primary_star_cat['source_cat_index'][1:] = np.cumsum(n_source_per_obj)[:-1] 

539 

540 n_tot_source = primary_star_cat['source_cat_index'][-1] + primary_star_cat['nsource'][-1] 

541 

542 # Temporary arrays until we crop/sort the source catalog 

543 source_index = np.zeros(n_tot_source, dtype=np.int32) 

544 obj_index = np.zeros(n_tot_source, dtype=np.int32) 

545 

546 ctr = 0 

547 for i in range(len(primary_star_cat)): 

548 obj_index[ctr: ctr + n_source_per_obj[i]] = i 

549 for b in range(len(bands)): 

550 source_index[ctr: ctr + n_source_per_band_per_obj[b, i]] = band_uses[b][idxs[b][i]] 

551 ctr += n_source_per_band_per_obj[b, i] 

552 

553 source_cat_index_band_offset = np.cumsum(n_source_per_band_per_obj, axis=0) 

554 

555 for b, band in enumerate(bands): 

556 primary_star_cat[f'nsource_{band}'] = n_source_per_band_per_obj[b, :] 

557 if b == 0: 

558 # The first band listed is the same as the overall star 

559 primary_star_cat[f'source_cat_index_{band}'] = primary_star_cat['source_cat_index'] 

560 else: 

561 # Other band indices are offset from the previous band 

562 primary_star_cat[f'source_cat_index_{band}'] = (primary_star_cat['source_cat_index'] 

563 + source_cat_index_band_offset[b - 1, :]) 

564 

565 star_source_cat = star_source_cat[source_index] 

566 star_source_cat['obj_index'] = obj_index 

567 

568 return star_source_cat, primary_star_cat 

569 

570 def _compute_unique_ids(self, skymap, tract, nstar): 

571 """Compute unique star ids. 

572 

573 This is a simple hash of the tract and star to provide an 

574 id that is unique for a given processing. 

575 

576 Parameters 

577 ---------- 

578 skymap : `lsst.skymap.Skymap` 

579 Skymap object. 

580 tract : `int` 

581 Tract id number. 

582 nstar : `int` 

583 Number of stars. 

584 

585 Returns 

586 ------- 

587 ids : `np.ndarray` 

588 Array of unique star ids. 

589 """ 

590 # The end of the id will be big enough to hold the tract number 

591 mult = 10**(int(np.log10(len(skymap))) + 1) 

592 

593 return (np.arange(nstar) + 1)*mult + tract 

594 

595 def _get_primary_dtype(self, primary_bands): 

596 """Get the numpy datatype for the primary star catalog. 

597 

598 Parameters 

599 ---------- 

600 primary_bands : `list` [`str`] 

601 List of primary bands. 

602 

603 Returns 

604 ------- 

605 dtype : `numpy.dtype` 

606 Datatype of the primary catalog. 

607 """ 

608 max_len = max([len(primary_band) for primary_band in primary_bands]) 

609 

610 dtype = [('isolated_star_id', 'i8'), 

611 (self.config.ra_column, 'f8'), 

612 (self.config.dec_column, 'f8'), 

613 ('primary_band', f'U{max_len}'), 

614 ('source_cat_index', 'i4'), 

615 ('nsource', 'i4')] 

616 

617 for band in primary_bands: 

618 dtype.append((f'source_cat_index_{band}', 'i4')) 

619 dtype.append((f'nsource_{band}', 'i4')) 

620 

621 return dtype