Coverage for python/lsst/source/injection/utils/consolidate_injected_deepCoadd_catalogs.py: 21%

143 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-23 02:56 -0700

1# This file is part of source_injection. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = [ 

25 "ConsolidateInjectedCatalogsConnections", 

26 "ConsolidateInjectedCatalogsConfig", 

27 "ConsolidateInjectedCatalogsTask", 

28 "consolidate_injected_deepCoadd_catalogs", 

29] 

30 

31import numpy as np 

32from astropy.table import Table, vstack 

33from astropy.table.column import MaskedColumn 

34from lsst.geom import Box2D, SpherePoint, degrees 

35from lsst.pex.config import Field 

36from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct 

37from lsst.pipe.base.connections import InputQuantizedConnection 

38from lsst.pipe.base.connectionTypes import Input, Output, PrerequisiteInput 

39from lsst.skymap import BaseSkyMap 

40from smatch.matcher import Matcher # type: ignore [import-not-found] 

41 

42 

43class ConsolidateInjectedCatalogsConnections( 

44 PipelineTaskConnections, 

45 dimensions=("instrument", "skymap", "tract"), 

46 defaultTemplates={ 

47 "injected_prefix": "injected_", 

48 }, 

49): 

50 """Base connections for source injection tasks.""" 

51 

52 input_catalogs = PrerequisiteInput( 

53 doc="Per-patch and per-band injected catalogs to draw inputs from.", 

54 name="{injected_prefix}deepCoadd_catalog", 

55 dimensions=("skymap", "tract", "patch", "band"), 

56 storageClass="ArrowAstropy", 

57 minimum=1, 

58 multiple=True, 

59 ) 

60 skyMap = Input( 

61 doc="Input definition of geometry/bbox and projection/wcs for warped exposures", 

62 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

63 storageClass="SkyMap", 

64 dimensions=("skymap",), 

65 ) 

66 output_catalog = Output( 

67 doc="Per-tract multiband catalog of injected sources.", 

68 name="{injected_prefix}deepCoadd_catalog_tract", 

69 storageClass="ArrowAstropy", 

70 dimensions=("skymap", "tract"), 

71 ) 

72 

73 

74class ConsolidateInjectedCatalogsConfig( # type: ignore [call-arg] 

75 PipelineTaskConfig, pipelineConnections=ConsolidateInjectedCatalogsConnections 

76): 

77 """Base configuration for source injection tasks.""" 

78 

79 col_ra = Field[str]( 

80 doc="Column name for right ascension (in degrees).", 

81 default="ra", 

82 ) 

83 col_dec = Field[str]( 

84 doc="Column name for declination (in degrees).", 

85 default="dec", 

86 ) 

87 col_mag = Field[str]( 

88 doc="Column name for magnitude.", 

89 default="mag", 

90 ) 

91 col_source_type = Field[str]( 

92 doc="Column name for the source type used in the input catalog. Must match one of the surface " 

93 "brightness profiles defined by GalSim. For more information see the Galsim docs at " 

94 "https://galsim-developers.github.io/GalSim/_build/html/sb.html", 

95 default="source_type", 

96 ) 

97 injectionKey = Field[str]( 

98 doc="True if the source was successfully injected.", 

99 default="injection_flag", 

100 ) 

101 isPatchInnerKey = Field[str]( 

102 doc="True if source is in the inner region of a coadd patch.", 

103 default="injected_isPatchInner", 

104 ) 

105 isTractInnerKey = Field[str]( 

106 doc="True if source is in the inner region of a coadd tract.", 

107 default="injected_isTractInner", 

108 ) 

109 isPrimaryKey = Field[str]( 

110 doc="True if the source was successfully injected and is in both the inner region of a coadd patch " 

111 "and tract.", 

112 default="injected_isPrimary", 

113 ) 

114 remove_patch_overlap_duplicates = Field[bool]( 

115 doc="Optional parameter to remove patch overlap duplicate sources.", 

116 default=False, 

117 ) 

118 get_catalogs_from_butler = Field[bool]( 

119 doc="Optional parameter to specify whether or not the input catalogs are loaded with a Butler.", 

120 default=True, 

121 ) 

122 pixel_match_radius = Field[float]( 

123 doc="Radius for matching catalogs across different bands.", 

124 default=0.1, 

125 ) 

126 

127 

128def _get_catalogs( 

129 inputs: dict, 

130 input_refs: InputQuantizedConnection, 

131) -> tuple[dict, int]: 

132 """Organize input catalogs into a dictionary with photometry band 

133 keys. 

134 

135 Parameters 

136 ---------- 

137 inputs: `dict` 

138 A dictionary containing the input datasets. 

139 input_refs: `lsst.pipe.base.connections.InputQuantizedConnection` 

140 The input dataset references used by the butler. 

141 

142 Returns 

143 ------- 

144 `tuple[dict, int]` 

145 contains : 

146 catalog_dict: `dict` 

147 A dictionary with photometric bands for keys and astropy 

148 tables for items. 

149 tract: `int` 

150 The tract covering the catalogs in catalog_dict. 

151 """ 

152 catalog_dict: dict[str, list] = {} 

153 tracts = set() 

154 for ref, catalog in zip(input_refs.input_catalogs, inputs["input_catalogs"]): 

155 band = ref.dataId.band.name 

156 if band not in catalog_dict: 

157 catalog_dict[band] = [] 

158 # Load the patch number to check for patch overlap duplicates. 

159 catalog["patch"] = ref.dataId.patch.id 

160 catalog_dict[band].append(catalog) 

161 tracts.add(ref.dataId.tract.id) 

162 # Stack the per-band catalogs. 

163 for band, catalog_list in catalog_dict.items(): 

164 catalog_dict[band] = vstack(catalog_list) 

165 # Check that only catalogs covered by a single tract are loaded. 

166 if len(tracts) != 1: 

167 raise RuntimeError(f"len({tracts=}) != 1") 

168 return (catalog_dict, list(tracts)[0]) 

169 

170 

171def _get_patches( 

172 catalog_dict: dict, 

173 tractInfo, 

174 col_ra: str = "ra", 

175 col_dec: str = "dec", 

176): 

177 """Create a patch column and assign each row a patch number. 

178 

179 Parameters 

180 ---------- 

181 catalog_dict: `dict` 

182 A dictionary with photometric bands for keys and astropy tables for 

183 items. 

184 tractInfo: `lsst.skymap.tractInfo.ExplicitTractInfo` 

185 Information for a tract specified explicitly. 

186 col_ra: `str` 

187 Column name for right ascension (in degrees). 

188 col_dec: `str` 

189 Column name for declination (in degrees). 

190 """ 

191 for catalog in catalog_dict.values(): 

192 if "patch" not in list(set(catalog.columns)): 

193 catalog.add_column(col=-9999, name="patch") 

194 for row in catalog: 

195 coord = SpherePoint(row[col_ra], row[col_dec], degrees) 

196 patchInfo = tractInfo.findPatch(coord) 

197 row["patch"] = int(patchInfo.getSequentialIndex()) 

198 

199 

200def getPatchInner( 

201 catalog: Table, 

202 patchInfo, 

203 col_ra: str = "ra", 

204 col_dec: str = "dec", 

205): 

206 """Set a flag for each source if it is in the innerBBox of a patch. 

207 

208 Parameters 

209 ---------- 

210 catalog: `astropy.table.Table` 

211 A catalog of sources. 

212 patchInfo : `lsst.skymap.PatchInfo` 

213 Information about a `SkyMap` `Patch`. 

214 col_ra: `str` 

215 Column name for right ascension (in degrees). 

216 col_dec: `str` 

217 Column name for declination (in degrees). 

218 

219 Returns 

220 ------- 

221 isPatchInner : array-like of `bool` 

222 `True` for each source that has a centroid 

223 in the inner region of a patch. 

224 """ 

225 # Extract positions for all the sources. 

226 ra = catalog[col_ra] 

227 dec = catalog[col_dec] 

228 

229 # convert the coordinates to pixel positions 

230 wcs = patchInfo.getWcs() 

231 x, y = wcs.skyToPixelArray(ra, dec, degrees=True) 

232 

233 # set inner flags for each source 

234 innerFloatBBox = Box2D(patchInfo.getInnerBBox()) 

235 isPatchInner = innerFloatBBox.contains(x, y) 

236 

237 return isPatchInner 

238 

239 

240def getTractInner( 

241 catalog, 

242 tractInfo, 

243 skyMap, 

244 col_ra: str = "ra", 

245 col_dec: str = "dec", 

246): 

247 """Set a flag for each source that the skyMap includes in tractInfo. 

248 

249 Parameters 

250 ---------- 

251 catalog: `astropy.table.Table` 

252 A catalog of sources. 

253 tractInfo : `lsst.skymap.TractInfo` 

254 Tract object 

255 skyMap : `lsst.skymap.BaseSkyMap` 

256 Sky tessellation object 

257 col_ra: `str` 

258 Column name for right ascension (in degrees). 

259 col_dec: `str` 

260 Column name for declination (in degrees). 

261 

262 Returns 

263 ------- 

264 isTractInner : array-like of `bool` 

265 True if the skyMap.findTract method returns 

266 the same tract as tractInfo. 

267 """ 

268 ras, decs = catalog[col_ra].value, catalog[col_dec].value 

269 skyCoords = [SpherePoint(ra, dec, degrees) for (ra, dec) in list(zip(ras, decs))] 

270 tractId = tractInfo.getId() 

271 isTractInner = np.array([skyMap.findTract(coord).getId() == tractId for coord in skyCoords]) 

272 return isTractInner 

273 

274 

275def setPrimaryFlags( 

276 catalog, 

277 skyMap, 

278 tractInfo, 

279 patches: list, 

280 col_ra: str = "ra", 

281 col_dec: str = "dec", 

282 isPatchInnerKey: str = "injected_isPatchInner", 

283 isTractInnerKey: str = "injected_isTractInner", 

284 isPrimaryKey: str = "injected_isPrimary", 

285 injectionKey: str = "injection_flag", 

286): 

287 """Set isPrimary and related flags on sources. 

288 

289 For coadded imaging, the `isPrimary` flag returns True when an object 

290 has no children, is in the inner region of a coadd patch, is in the 

291 inner region of a coadd trach, and is not detected in a pseudo-filter 

292 (e.g., a sky_object). 

293 For single frame imaging, the isPrimary flag returns True when a 

294 source has no children and is not a sky source. 

295 

296 Parameters 

297 ---------- 

298 catalog: `astropy.table.Table` 

299 A catalog of sources. 

300 Writes is-patch-inner, is-tract-inner, and is-primary flags. 

301 skyMap : `lsst.skymap.BaseSkyMap` 

302 Sky tessellation object 

303 tractInfo : `lsst.skymap.TractInfo` 

304 Tract object 

305 patches : `list` 

306 List of coadd patches 

307 col_ra: `str` 

308 Column name for right ascension (in degrees). 

309 col_dec: `str` 

310 Column name for declination (in degrees). 

311 isPatchInnerKey: `str` 

312 Column name for the isPatchInner flag. 

313 isTractInnerKey: `str` 

314 Column name for the isTractInner flag. 

315 isPrimaryKey: `str` 

316 Column name for the isPrimary flag. 

317 injectionKey: `str` 

318 Column name for the injection flag. 

319 """ 

320 # Mark whether sources are contained within the inner regions of the 

321 # given tract/patch. 

322 isPatchInner = np.array([False] * len(catalog)) 

323 for patch in patches: 

324 patchMask = catalog["patch"] == patch 

325 patchInfo = tractInfo.getPatchInfo(patch) 

326 isPatchInner[patchMask] = getPatchInner(catalog[patchMask], patchInfo, col_ra, col_dec) 

327 isTractInner = getTractInner(catalog, tractInfo, skyMap) 

328 isPrimary = isTractInner & isPatchInner & (catalog[injectionKey] == 0) 

329 

330 catalog[isPatchInnerKey] = isPatchInner 

331 catalog[isTractInnerKey] = isTractInner 

332 catalog[isPrimaryKey] = isPrimary 

333 

334 

335def _make_multiband_catalog( 

336 bands: list, 

337 catalog_dict: dict, 

338 match_radius: float, 

339 col_ra: str = "ra", 

340 col_dec: str = "dec", 

341 col_mag: str = "mag", 

342) -> Table: 

343 """Combine multiple band-specific catalogs into one multiband 

344 catalog. 

345 

346 Parameters 

347 ---------- 

348 bands: `list` 

349 A list of string photometry bands. 

350 catalog_dict: `dict` 

351 A dictionary with photometric bands for keys and astropy 

352 tables for items. 

353 match_radius: `float` 

354 The radius for matching catalogs across bands. 

355 col_ra: `str` 

356 Column name for right ascension (in degrees). 

357 col_dec: `str` 

358 Column name for declination (in degrees). 

359 col_mag: `str` 

360 Column name for magnitude. 

361 

362 Returns 

363 ------- 

364 multiband_catalog: `astropy.table.Table` 

365 A catalog with sources that have magnitude information across all 

366 bands. 

367 """ 

368 # Load the first catalog then loop to add info for the other bands. 

369 multiband_catalog = catalog_dict[bands[0]].copy() 

370 multiband_catalog.add_column(col=multiband_catalog[col_mag], name=f"{bands[0]}_{col_mag}") 

371 multiband_catalog.remove_column(col_mag) 

372 for band in bands[1:]: 

373 # Make a column for the new band. 

374 multiband_catalog.add_column([np.nan] * len(multiband_catalog), name=f"{band}_{col_mag}") 

375 # Match the input catalog for this band to the existing 

376 # multiband catalog. 

377 catalog_next_band = catalog_dict[band].copy() 

378 catalog_next_band.rename_column(col_mag, f"{band}_{col_mag}") 

379 with Matcher(multiband_catalog[col_ra], multiband_catalog[col_dec]) as m: 

380 idx, multiband_match_inds, next_band_match_inds, dists = m.query_radius( 

381 catalog_next_band[col_ra], 

382 catalog_next_band[col_dec], 

383 match_radius, 

384 return_indices=True, 

385 ) 

386 # If there are matches... 

387 if len(multiband_match_inds) > 0 and len(next_band_match_inds) > 0: 

388 # ...choose the coordinates in the brightest band. 

389 for i, j in zip(multiband_match_inds, next_band_match_inds): 

390 mags = [] 

391 for col in multiband_catalog.colnames: 

392 if f"_{col_mag}" in col: 

393 mags.append((col, multiband_catalog[i][col])) 

394 bright_mag = min([x[1] for x in mags]) 

395 if catalog_next_band[f"{band}_{col_mag}"][j] < bright_mag: 

396 multiband_catalog[col_ra][i] = catalog_next_band[col_ra][j] 

397 multiband_catalog[col_dec][i] = catalog_next_band[col_dec][j] 

398 # TODO: Once multicomponent object support is added, make some 

399 # logic to pick the correct source_type. 

400 # Fill the new mag value. 

401 multiband_catalog[f"{band}_{col_mag}"][multiband_match_inds] = catalog_next_band[ 

402 f"{band}_{col_mag}" 

403 ][next_band_match_inds] 

404 # Add rows for all the sources without matches. 

405 not_next_band_match_inds = np.full(len(catalog_next_band), True, dtype=bool) 

406 not_next_band_match_inds[next_band_match_inds] = False 

407 multiband_catalog = vstack([multiband_catalog, catalog_next_band[not_next_band_match_inds]]) 

408 # Otherwise just stack the tables. 

409 else: 

410 multiband_catalog = vstack([multiband_catalog, catalog_next_band]) 

411 # Fill any automatically masked values with NaNs. 

412 if multiband_catalog.has_masked_columns: 

413 for col in multiband_catalog.columns: 

414 if isinstance(multiband_catalog[col], MaskedColumn): 

415 multiband_catalog[col] = multiband_catalog[col].filled(np.nan) 

416 return multiband_catalog 

417 

418 

419def consolidate_injected_deepCoadd_catalogs( 

420 catalog_dict: dict, 

421 skymap: BaseSkyMap, 

422 tract: int, 

423 pixel_match_radius: float = 0.1, 

424 get_catalogs_from_butler: bool = True, 

425 col_ra: str = "ra", 

426 col_dec: str = "dec", 

427 col_mag: str = "mag", 

428 isPatchInnerKey="injected_isPatchInner", 

429 isTractInnerKey="injected_isTractInner", 

430 isPrimaryKey="injected_isPrimary", 

431 injectionKey="injection_flag", 

432) -> Table: 

433 """Consolidate all tables in catalog_dict into one table. 

434 

435 Parameters 

436 ---------- 

437 catalog_dict: `dict` 

438 A dictionary with photometric bands for keys and astropy tables for 

439 items. 

440 skymap: `lsst.skymap.BaseSkyMap` 

441 A base skymap. 

442 tract: `int` 

443 The tract where sources have been injected. 

444 pixel_match_radius: `float` 

445 Match radius in pixels to use for self-matching catalogs across 

446 different bands. 

447 get_catalogs_from_butler: `bool` 

448 Optional parameter to specify whether or not the input catalogs are 

449 loaded with a Butler. 

450 col_ra: `str` 

451 Column name for right ascension (in degrees). 

452 col_dec: `str` 

453 Column name for declination (in degrees). 

454 col_mag: `str` 

455 Column name for magnitude. 

456 isPatchInnerKey: `str` 

457 Column name for the isPatchInner flag. 

458 isTractInnerKey: `str` 

459 Column name for the isTractInner flag. 

460 isPrimaryKey: `str` 

461 Column name for the isPrimary flag. 

462 injectionKey: `str` 

463 Column name for the injection flag. 

464 

465 Returns 

466 ------- 

467 multiband_catalog: `astropy.table.Table` 

468 A single table containing all information of the separate 

469 tables in catalog_dict 

470 """ 

471 tractInfo = skymap.generateTract(tract) 

472 # If patch numbers are not loaded via dataIds from the butler, manually 

473 # load patch numbers from source positions. 

474 if not get_catalogs_from_butler: 

475 _get_patches(catalog_dict, tractInfo) 

476 

477 # Convert the pixel match radius to degrees. 

478 tractWcs = tractInfo.getWcs() 

479 pixel_scale = tractWcs.getPixelScale() 

480 match_radius = pixel_match_radius * pixel_scale.asDegrees() 

481 bands = list(catalog_dict.keys()) 

482 if len(bands) > 1: 

483 # Match the catalogs across bands. 

484 output_catalog = _make_multiband_catalog(bands, catalog_dict, match_radius) 

485 else: 

486 output_catalog = catalog_dict[bands[0]] 

487 output_catalog.rename_column(col_mag, f"{bands[0]}_{col_mag}") 

488 # Remove sources outside tract boundaries. 

489 out_of_tract_bounds = [] 

490 for index, (ra, dec) in enumerate(list(zip(output_catalog[col_ra], output_catalog[col_dec]))): 

491 point = SpherePoint(ra * degrees, dec * degrees) 

492 if not tractInfo.contains(point): 

493 out_of_tract_bounds.append(index) 

494 output_catalog.remove_rows(out_of_tract_bounds) 

495 # Assign flags. 

496 patches = list(set(output_catalog["patch"])) 

497 setPrimaryFlags( 

498 catalog=output_catalog, 

499 skyMap=skymap, 

500 tractInfo=tractInfo, 

501 patches=patches, 

502 col_ra=col_ra, 

503 col_dec=col_dec, 

504 isPatchInnerKey=isPatchInnerKey, 

505 isTractInnerKey=isTractInnerKey, 

506 isPrimaryKey=isPrimaryKey, 

507 injectionKey=injectionKey, 

508 ) 

509 # Add a new injected_id column. 

510 output_catalog.add_column(col=list(range(len(output_catalog))), name="injected_id") 

511 # Remove unneccesary output columns. 

512 output_catalog.remove_column("patch") 

513 # Reorder columns 

514 mag_cols = [col for col in output_catalog.columns if f"_{col_mag}" in col] 

515 new_order = [ 

516 "injected_id", 

517 "ra", 

518 "dec", 

519 "source_type", 

520 *mag_cols, 

521 "injection_id", 

522 "injection_draw_size", 

523 "injection_flag", 

524 "injected_isPatchInner", 

525 "injected_isTractInner", 

526 "injected_isPrimary", 

527 ] 

528 

529 return output_catalog[new_order] 

530 

531 

532class ConsolidateInjectedCatalogsTask(PipelineTask): 

533 """Class for combining all tables in a collection of input catalogs 

534 into one table. 

535 """ 

536 

537 _DefaultName = "consolidateInjectedCatalogsTask" 

538 ConfigClass = ConsolidateInjectedCatalogsConfig 

539 

540 def runQuantum(self, butlerQC, input_refs, output_refs): 

541 inputs = butlerQC.get(input_refs) 

542 catalog_dict, tract = _get_catalogs(inputs, input_refs) 

543 outputs = self.run( 

544 catalog_dict, 

545 inputs["skyMap"], 

546 tract, 

547 self.config.pixel_match_radius, 

548 self.config.get_catalogs_from_butler, 

549 self.config.col_ra, 

550 self.config.col_dec, 

551 self.config.col_mag, 

552 self.config.isPatchInnerKey, 

553 self.config.isTractInnerKey, 

554 self.config.isPrimaryKey, 

555 self.config.injectionKey, 

556 ) 

557 butlerQC.put(outputs, output_refs) 

558 

559 def run( 

560 self, 

561 catalog_dict: dict, 

562 skymap: BaseSkyMap, 

563 tract: int, 

564 pixel_match_radius: float = 0.1, 

565 get_catalogs_from_butler: bool = True, 

566 col_ra: str = "ra", 

567 col_dec: str = "dec", 

568 col_mag: str = "mag", 

569 isPatchInnerKey="injected_isPatchInner", 

570 isTractInnerKey="injected_isTractInner", 

571 isPrimaryKey="injected_isPrimary", 

572 injectionKey="injection_flag", 

573 ) -> Table: 

574 """Consolidate all tables in catalog_dict into one table. 

575 

576 catalog_dict: `dict` 

577 A dictionary with photometric bands for keys and astropy tables for 

578 items. 

579 skymap: `lsst.skymap.BaseSkyMap` 

580 A base skymap. 

581 tract: `int` 

582 The tract where sources have been injected. 

583 pixel_match_radius: `float` 

584 Match radius in pixels to use for self-matching catalogs across 

585 different bands. 

586 col_ra: `str` 

587 Column name for right ascension (in degrees). 

588 col_dec: `str` 

589 Column name for declination (in degrees). 

590 col_mag: `str` 

591 Column name for magnitude. 

592 isPatchInnerKey: `str` 

593 Column name for the isPatchInner flag. 

594 isTractInnerKey: `str` 

595 Column name for the isTractInner flag. 

596 isPrimaryKey: `str` 

597 Column name for the isPrimary flag. 

598 injectionKey: `str` 

599 Column name for the injection flag. 

600 

601 Returns 

602 ------- 

603 output_struct : `lsst.pipe.base.Struct` 

604 contains : 

605 multiband_catalog: `astropy.table.Table` 

606 A single table containing all information of the separate 

607 tables in catalog_dict 

608 """ 

609 output_catalog = consolidate_injected_deepCoadd_catalogs( 

610 catalog_dict, 

611 skymap, 

612 tract, 

613 pixel_match_radius, 

614 get_catalogs_from_butler, 

615 col_ra, 

616 col_dec, 

617 col_mag, 

618 isPatchInnerKey, 

619 isTractInnerKey, 

620 isPrimaryKey, 

621 injectionKey, 

622 ) 

623 output_struct = Struct(output_catalog=output_catalog) 

624 return output_struct