Coverage for python / lsst / source / injection / utils / _consolidate_injected_coadd_catalogs.py: 11%

289 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-28 09:23 +0000

1# This file is part of source_injection. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = [ 

25 "ConsolidateInjectedCatalogsConnections", 

26 "ConsolidateInjectedCatalogsConfig", 

27 "ConsolidateInjectedCatalogsTask", 

28] 

29 

30from collections import defaultdict 

31 

32import astropy.table 

33import astropy.units as u 

34import numpy as np 

35from astropy.table import Table, join, vstack 

36from astropy.table.column import MaskedColumn 

37from smatch.matcher import Matcher # type: ignore [import-not-found] 

38 

39from lsst.daf.butler import DatasetProvenance 

40from lsst.geom import Box2D, SpherePoint, degrees 

41from lsst.pex.config import Field, ListField 

42from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct 

43from lsst.pipe.base.connections import InputQuantizedConnection 

44from lsst.pipe.base.connectionTypes import Input, Output 

45from lsst.skymap import BaseSkyMap 

46 

47 

48class ConsolidateInjectedCatalogsConnections( 

49 PipelineTaskConnections, 

50 dimensions=("instrument", "skymap", "tract"), 

51 defaultTemplates={ 

52 "injected_prefix": "injected_", 

53 }, 

54): 

55 """Base connections for source injection tasks.""" 

56 

57 input_catalogs = Input( 

58 doc="Per-patch and per-band injected catalogs to draw inputs from.", 

59 name="{injected_prefix}deep_coadd_predetection_catalog", 

60 dimensions=("skymap", "tract", "patch", "band"), 

61 storageClass="ArrowAstropy", 

62 minimum=1, 

63 multiple=True, 

64 ) 

65 skyMap = Input( 

66 doc="Input definition of geometry/bbox and projection/wcs for warped exposures", 

67 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

68 storageClass="SkyMap", 

69 dimensions=("skymap",), 

70 ) 

71 output_catalog = Output( 

72 doc="Per-tract multiband catalog of injected sources.", 

73 name="{injected_prefix}deep_coadd_predetection_catalog_tract", 

74 storageClass="ArrowAstropy", 

75 dimensions=("skymap", "tract"), 

76 ) 

77 

78 

79def _get_catalogs( 

80 inputs: dict, 

81 input_refs: InputQuantizedConnection, 

82 skymap: BaseSkyMap, 

83 col_ra: str = "ra", 

84 col_dec: str = "dec", 

85 include_outer: bool = True, 

86) -> tuple[dict, int]: 

87 """Organize input catalogs into a dictionary with photometry band 

88 keys. 

89 

90 Parameters 

91 ---------- 

92 inputs: `dict` 

93 A dictionary containing the input datasets. 

94 input_refs: `lsst.pipe.base.connections.InputQuantizedConnection` 

95 The input dataset references used by the butler. 

96 skymap : `lsst.skymap.BaseSkyMap` 

97 Sky tessellation object 

98 col_ra: `str` 

99 Column name for right ascension (in degrees). 

100 col_dec: `str` 

101 Column name for declination (in degrees). 

102 include_outer: `bool` 

103 Whether to include objects injected into the outer (not inner) region 

104 of a patch. 

105 

106 Returns 

107 ------- 

108 `tuple[dict, int]` 

109 contains : 

110 catalog_dict: `dict` 

111 A dictionary with photometric bands for keys and astropy 

112 tables for items. 

113 tract: `int` 

114 The tract covering the catalogs in catalog_dict. 

115 """ 

116 catalog_dict: dict[str, dict[str, astropy.table.Table]] = {} 

117 tracts = set() 

118 for ref, catalog in zip(input_refs.input_catalogs, inputs["input_catalogs"]): 

119 band = ref.dataId.band.name 

120 if band not in catalog_dict: 

121 catalog_dict[band] = {} 

122 tract = ref.dataId.tract.id 

123 tractInfo = skymap[tract] 

124 # Load the patch number to check for patch overlap duplicates. 

125 patch = ref.dataId.patch.id 

126 if not include_outer: 

127 is_inner = getPatchInner( 

128 tractInfo[patch], 

129 catalog[col_ra], 

130 catalog[col_dec], 

131 ) 

132 catalog = catalog[is_inner] 

133 catalog["patch"] = patch 

134 # Strip provenance from catalogs before merging to avoid the 

135 # provenance headers triggering warnings in the astropy naive 

136 # metadata merge tool. 

137 DatasetProvenance.strip_provenance_from_flat_dict(catalog.meta) 

138 catalog_dict[band][patch] = catalog 

139 tracts.add(tract) 

140 # Check that only catalogs covered by a single tract are loaded. 

141 if len(tracts) > 1: 

142 raise RuntimeError(f"Got tract={tract} with {tracts=}; there should only be one tract") 

143 # Stack the per-band catalogs. 

144 for band, catalog_patches in catalog_dict.items(): 

145 catalog_dict[band] = vstack([item[1] for item in sorted(catalog_patches.items())]) 

146 return catalog_dict, list(tracts)[0] 

147 

148 

149def _get_patches( 

150 catalog_dict: dict[str, astropy.table.Table], 

151 tractInfo, 

152 col_ra: str = "ra", 

153 col_dec: str = "dec", 

154 index=None, 

155): 

156 """Create a patch column and assign each row a patch number. 

157 

158 Parameters 

159 ---------- 

160 catalog_dict: `dict` 

161 A dictionary with photometric bands for keys and astropy tables for 

162 items. 

163 tractInfo: `lsst.skymap.tractInfo.ExplicitTractInfo` 

164 Information for a tract specified explicitly. 

165 col_ra: `str` 

166 Column name for right ascension (in degrees). 

167 col_dec: `str` 

168 Column name for declination (in degrees). 

169 """ 

170 for catalog in catalog_dict.values(): 

171 if "patch" not in catalog.colnames: 

172 patches = np.empty(len(catalog), dtype=int) 

173 catalog.add_column(patches, name="patch", index=index) 

174 

175 patches = catalog["patch"] 

176 for idx, row in enumerate(catalog): 

177 coord = SpherePoint(row[col_ra], row[col_dec], degrees) 

178 patchInfo = tractInfo.findPatch(coord) 

179 patches[idx] = int(patchInfo.getSequentialIndex()) 

180 

181 

182def getPatchInner( 

183 patchInfo, 

184 ra: np.ndarray, 

185 dec: np.ndarray, 

186): 

187 """Set a flag for each source if it is in the innerBBox of a patch. 

188 

189 Parameters 

190 ---------- 

191 patchInfo : `lsst.skymap.PatchInfo` 

192 Information about a `SkyMap` `Patch`. 

193 ra: `np.ndarray` 

194 Right ascension values in degrees. 

195 dec: `np.ndarray` 

196 Declination values in degrees. 

197 

198 Returns 

199 ------- 

200 isPatchInner : array-like of `bool` 

201 `True` for each source that has a centroid 

202 in the inner region of a patch. 

203 """ 

204 # convert the coordinates to pixel positions 

205 wcs = patchInfo.getWcs() 

206 x, y = wcs.skyToPixelArray(ra, dec, degrees=True) 

207 

208 # set inner flags for each source 

209 innerFloatBBox = Box2D(patchInfo.getInnerBBox()) 

210 isPatchInner = innerFloatBBox.contains(x, y) 

211 

212 return isPatchInner 

213 

214 

215def getTractInner( 

216 tractInfo, 

217 skyMap, 

218 ra: np.ndarray, 

219 dec: np.ndarray, 

220): 

221 """Set a flag for each source that the skyMap includes in tractInfo. 

222 

223 Parameters 

224 ---------- 

225 tractInfo : `lsst.skymap.TractInfo` 

226 Tract object 

227 skyMap : `lsst.skymap.BaseSkyMap` 

228 Sky tessellation object 

229 ra: `np.ndarray` 

230 Right ascension values in degrees. 

231 dec: `np.ndarray` 

232 Declination values in degrees. 

233 

234 Returns 

235 ------- 

236 isTractInner : array-like of `bool` 

237 True if the skyMap.findTract method returns 

238 the same tract as tractInfo. 

239 """ 

240 tractId = tractInfo.getId() 

241 isTractInner = np.array( 

242 [skyMap.findTract(SpherePoint(_ra, _dec, degrees)).getId() == tractId for _ra, _dec in zip(ra, dec)], 

243 dtype=bool, 

244 ) 

245 

246 return isTractInner 

247 

248 

249class ConsolidateInjectedCatalogsConfig( # type: ignore [call-arg] 

250 PipelineTaskConfig, pipelineConnections=ConsolidateInjectedCatalogsConnections 

251): 

252 """Base configuration for source injection tasks.""" 

253 

254 col_ra = Field[str]( 

255 doc="Column name for right ascension (in degrees).", 

256 default="ra", 

257 ) 

258 col_dec = Field[str]( 

259 doc="Column name for declination (in degrees).", 

260 default="dec", 

261 ) 

262 col_mag = Field[str]( 

263 doc="Column name for magnitude.", 

264 default="mag", 

265 ) 

266 col_source_type = Field[str]( 

267 doc="Column name for the source type used in the input catalog. Must match one of the surface " 

268 "brightness profiles defined by GalSim. For more information see the Galsim docs at " 

269 "https://galsim-developers.github.io/GalSim/_build/html/sb.html", 

270 default="source_type", 

271 ) 

272 columns_extra = ListField[str]( 

273 doc="Extra columns to be copied from the injection catalog (e.g. for shapes)", 

274 default=[], 

275 ) 

276 groupIdKey = Field[str]( 

277 doc="Key for the group id column to merge sources on, if any", 

278 default=None, 

279 optional=True, 

280 ) 

281 injectionKey = Field[str]( 

282 doc="True if the source was successfully injected.", 

283 default="injection_flag", 

284 ) 

285 injectionSizeKey = Field[str]( 

286 doc="The size of the drawn injection box", 

287 default="injection_draw_size", 

288 ) 

289 isPatchInnerKey = Field[str]( 

290 doc="True if source is in the inner region of a coadd patch.", 

291 default="injected_isPatchInner", 

292 ) 

293 isTractInnerKey = Field[str]( 

294 doc="True if source is in the inner region of a coadd tract.", 

295 default="injected_isTractInner", 

296 ) 

297 isPrimaryKey = Field[str]( 

298 doc="True if the source was successfully injected and is in both the inner region of a coadd patch " 

299 "and tract.", 

300 default="injected_isPrimary", 

301 ) 

302 remove_patch_overlap_duplicates = Field[bool]( 

303 doc="Optional parameter to remove patch overlap duplicate sources.", 

304 default=False, 

305 ) 

306 get_catalogs_from_butler = Field[bool]( 

307 doc="Optional parameter to specify whether or not the input catalogs are loaded with a Butler.", 

308 default=True, 

309 ) 

310 pixel_match_radius = Field[float]( 

311 doc="Radius for matching catalogs across different bands.", 

312 default=0.1, 

313 ) 

314 

315 def consolidate_catalogs( 

316 self, 

317 catalog_dict: dict[str, astropy.table.Table], 

318 skymap: BaseSkyMap, 

319 tract: int, 

320 copy_catalogs: bool = False, 

321 ) -> Table: 

322 """Consolidate all tables in catalog_dict into one table. 

323 

324 Parameters 

325 ---------- 

326 catalog_dict: `dict` 

327 A dictionary with photometric bands for keys and astropy tables for 

328 items. 

329 skymap: `lsst.skymap.BaseSkyMap` 

330 A base skymap. 

331 tract: `int` 

332 The tract where sources have been injected. 

333 copy_catalogs: `bool` 

334 Whether to copy the input catalogs; if False, they will be modified 

335 in-place. 

336 

337 Returns 

338 ------- 

339 multiband_catalog: `astropy.table.Table` 

340 A single table containing all information of the separate 

341 tables in catalog_dict 

342 """ 

343 tractInfo = skymap.generateTract(tract) 

344 # If patch numbers are not loaded via dataIds from the butler, manually 

345 # load patch numbers from source positions. 

346 if not self.get_catalogs_from_butler: 

347 _get_patches(catalog_dict, tractInfo) 

348 

349 # Convert the pixel match radius to degrees. 

350 tractWcs = tractInfo.getWcs() 

351 pixel_scale = tractWcs.getPixelScale() 

352 match_radius = self.pixel_match_radius * pixel_scale.asDegrees() 

353 has_groups = bool(self.groupIdKey) 

354 bands = list(catalog_dict.keys()) 

355 if has_groups or (len(bands) > 1): 

356 # Match the catalogs across bands. 

357 output_catalog = self.make_multiband_catalog( 

358 bands, 

359 catalog_dict, 

360 match_radius, 

361 copy_catalogs=copy_catalogs, 

362 ) 

363 else: 

364 output_catalog = catalog_dict[bands[0]] 

365 output_catalog.rename_column(self.col_mag, f"{bands[0]}_{self.col_mag}") 

366 # Remove sources outside tract boundaries. 

367 out_of_tract_bounds = [] 

368 for index, (ra, dec) in enumerate(zip(output_catalog[self.col_ra], output_catalog[self.col_dec])): 

369 point = SpherePoint(ra * degrees, dec * degrees) 

370 if not tractInfo.contains(point): 

371 out_of_tract_bounds.append(index) 

372 if out_of_tract_bounds: 

373 output_catalog.remove_rows(out_of_tract_bounds) 

374 # Assign flags. 

375 # There may be a pre-existing patch column; however, patches can shift 

376 # if centroids vary per band, etc. 

377 # TODO: Need to check if this can cause duplicate or dropped entries 

378 # for objects near patch boundaries 

379 _get_patches( 

380 {"x": output_catalog}, 

381 tractInfo=tractInfo, 

382 col_ra=self.col_ra, 

383 col_dec=self.col_dec, 

384 ) 

385 patches = list(set(output_catalog["patch"])) 

386 self.setPrimaryFlags( 

387 catalog=output_catalog, 

388 skyMap=skymap, 

389 tractInfo=tractInfo, 

390 patches=patches, 

391 ) 

392 # Add a new injected_id column. 

393 output_catalog.add_column(col=np.arange(len(output_catalog)), name="injected_id") 

394 # If using a group ID, the draw_size column is 

395 # not preserved, and the rest are already ordered 

396 if self.groupIdKey: 

397 return output_catalog 

398 # Reorder columns 

399 mag_cols = [col for col in output_catalog.columns if f"_{self.col_mag}" in col] 

400 

401 new_order = [ 

402 "injected_id", 

403 self.col_ra, 

404 self.col_dec, 

405 self.col_source_type, 

406 *mag_cols, 

407 "patch", 

408 "injection_id", 

409 self.injectionSizeKey, 

410 self.injectionKey, 

411 "injected_isPatchInner", 

412 "injected_isTractInner", 

413 "injected_isPrimary", 

414 ] 

415 for column in output_catalog.columns: 

416 if column not in new_order: 

417 new_order.append(column) 

418 

419 return output_catalog[new_order] 

420 

421 def make_multiband_catalog( 

422 self, 

423 bands: list, 

424 catalog_dict: dict[str, astropy.table.Table], 

425 match_radius: float, 

426 copy_catalogs: bool = False, 

427 ) -> Table: 

428 """Combine multiple band-specific catalogs into one multiband 

429 catalog. 

430 

431 Parameters 

432 ---------- 

433 bands: `list` 

434 A list of string photometry bands. 

435 catalog_dict: `dict` 

436 A dictionary with photometric bands for keys and astropy 

437 tables for items. 

438 match_radius: `float` 

439 The radius for matching catalogs across bands in arcsec. 

440 copy_catalogs: `bool` 

441 Whether to copy the input catalogs; if False, they will be modified 

442 in-place. 

443 

444 Returns 

445 ------- 

446 multiband_catalog: `astropy.table.Table` 

447 A catalog with sources that have magnitude information across all 

448 bands. 

449 """ 

450 col_mag = self.col_mag 

451 col_ra = self.col_ra 

452 col_dec = self.col_dec 

453 

454 if self.groupIdKey: 

455 n_comps = 0 

456 n_rows = {} 

457 for band, catalog in catalog_dict.items(): 

458 groupIds, counts = np.unique( 

459 catalog[self.groupIdKey], 

460 return_counts=True, 

461 ) 

462 n_rows[band] = len(groupIds) 

463 n_comps = max(np.max(counts), n_comps) 

464 

465 # Maybe validate that this is true? Why set groupId otherwise? 

466 is_multicomp = n_comps > 1 

467 prefix_comp = "comp{idx_comp}_" if is_multicomp else "" 

468 columns_comp = { 

469 column_in: f"{{band}}_{prefix_comp}{column_in}" 

470 for column_in in [self.col_source_type, self.injectionKey] + list(self.columns_extra) 

471 } 

472 

473 for band, catalog in catalog_dict.items(): 

474 n_rows_b = n_rows[band] 

475 unit_mag = catalog["mag"].unit or u.ABmag 

476 counts = defaultdict(int) 

477 idxs_new = {} 

478 

479 groupIds = np.full(n_rows_b, 0, dtype=catalog[self.groupIdKey].dtype) 

480 ra = np.full(n_rows_b, np.nan, dtype=catalog[col_ra].dtype) 

481 dec = np.full(n_rows_b, np.nan, dtype=catalog[col_dec].dtype) 

482 injection_flag = np.full(n_rows_b, True, dtype=bool) 

483 if is_multicomp: 

484 flux = np.full(n_rows_b, np.nan, dtype=catalog[col_mag].dtype) 

485 else: 

486 mag = np.full(n_rows_b, np.nan, dtype=catalog[col_mag].dtype) 

487 

488 values_comp = {} 

489 if is_multicomp: 

490 for idx_comp in range(1, n_comps + 1): 

491 prefix_comp = f"comp{idx_comp}_" 

492 values_comp[f"{band}_{prefix_comp}flux"] = np.ma.masked_array( 

493 np.full(n_rows_b, np.nan, dtype=catalog[col_mag].dtype), 

494 mask=np.ones(n_rows_b, dtype=bool), 

495 ) 

496 for idx_comp in range(1, n_comps + 1): 

497 for column_comp_in, column_comp_out in columns_comp.items(): 

498 values_col = np.full(n_rows_b, np.nan, dtype=catalog[column_comp_in].dtype) 

499 column_out = column_comp_out.format(band=band, idx_comp=idx_comp) 

500 values_comp[column_out] = ( 

501 np.ma.masked_array(values_col, mask=np.ones(n_rows_b, dtype=bool)) 

502 if is_multicomp 

503 else values_col 

504 ) 

505 

506 idx_new = 0 

507 for row in catalog: 

508 groupId = row[self.groupIdKey] 

509 idx_comp = counts[groupId] + 1 

510 counts[groupId] = idx_comp 

511 prefix_comp = f"comp{idx_comp}_" if is_multicomp else "" 

512 

513 injected = row[self.injectionKey] == False # noqa: E712 

514 if idx_comp == 1: 

515 ra[idx_new] = row[col_ra] 

516 dec[idx_new] = row[col_dec] 

517 groupIds[idx_new] = groupId 

518 idxs_new[groupId] = idx_new 

519 if is_multicomp: 

520 flux_comp = (row[col_mag] * unit_mag).to(u.nJy).value 

521 flux[idx_new] = flux_comp * injected 

522 else: 

523 mag[idx_new] = row[col_mag] 

524 idx_old = idx_new 

525 idx_new += 1 

526 else: 

527 idx_old = idxs_new[groupId] 

528 flux_cumul = flux[idx_old] 

529 flux_comp = (row[col_mag] * unit_mag).to(u.nJy).value 

530 # TODO: Excluding not-injected components means 

531 # objects with no injected components will have the 

532 # ra,dec of their first component, not the mean of 

533 # all excluded components. 

534 if flux_comp * injected > 0: 

535 flux_new = flux_cumul + flux_comp 

536 flux[idx_old] = flux_new 

537 if ra[idx_old] != row[col_ra]: 

538 # Take a weighted mean 

539 # TODO: deal with periodicity 

540 ra[idx_old] = (flux_cumul * ra[idx_old] + flux_comp * row[col_ra]) / flux_new 

541 if dec[idx_old] != row[col_dec]: 

542 dec[idx_old] = ( 

543 flux_cumul * dec[idx_old] + flux_comp * row[col_dec] 

544 ) / flux_new 

545 

546 # One component is sufficient to say it was injected 

547 injection_flag[idx_old] &= not injected 

548 

549 if is_multicomp: 

550 column = values_comp[f"{band}_{prefix_comp}flux"] 

551 column.mask[idx_old] = False 

552 column[idx_old] = flux_comp 

553 

554 for column_comp in columns_comp: 

555 column = f"{band}_{prefix_comp}{column_comp}" 

556 values_comp[column].mask[idx_old] = False 

557 values_comp[column][idx_old] = row[column_comp] 

558 

559 if is_multicomp: 

560 mag = (flux * u.nJy).to(u.ABmag).value 

561 

562 columns = { 

563 self.groupIdKey: groupIds, 

564 f"{band}_{self.injectionKey}": injection_flag, 

565 f"{band}_{col_ra}": ra, 

566 f"{band}_{col_dec}": dec, 

567 f"{band}_{col_mag}": mag, 

568 } 

569 units = { 

570 self.groupIdKey: None, 

571 f"{band}_{self.injectionKey}": None, 

572 f"{band}_{col_ra}": catalog[col_ra].unit, 

573 f"{band}_{col_dec}": catalog[col_dec].unit, 

574 f"{band}_{col_mag}": catalog[col_mag].unit, 

575 } 

576 

577 for idx_comp in range(1, n_comps + 1): 

578 for column_comp_in, column_comp_out in columns_comp.items(): 

579 name_column = column_comp_out.format(band=band, idx_comp=idx_comp) 

580 columns[name_column] = values_comp[name_column] 

581 units[name_column] = catalog[column_comp_in].unit 

582 

583 catalog_new = astropy.table.Table(columns, units=units) 

584 catalog_dict[band] = catalog_new 

585 

586 copy_catalogs = False 

587 

588 # Load the first catalog then loop to add info for the other bands. 

589 multiband_catalog = catalog_dict[bands[0]] 

590 if copy_catalogs: 

591 multiband_catalog = multiband_catalog.copy() 

592 if not self.groupIdKey: 

593 multiband_catalog.add_column(col=multiband_catalog[col_mag], name=f"{bands[0]}_{col_mag}") 

594 multiband_catalog.remove_column(col_mag) 

595 else: 

596 coords_same = True 

597 

598 for band in bands[1:]: 

599 if not self.groupIdKey: 

600 # Make a column for the new band. 

601 multiband_catalog.add_column([np.nan] * len(multiband_catalog), name=f"{band}_{col_mag}") 

602 # Match the input catalog for this band to the existing 

603 # multiband catalog. 

604 catalog_next_band = catalog_dict[band] 

605 if copy_catalogs: 

606 catalog_next_band = catalog_next_band.copy() 

607 if not self.groupIdKey: 

608 catalog_next_band.rename_column(col_mag, f"{band}_{col_mag}") 

609 if match_radius >= 0: 

610 with Matcher(multiband_catalog[col_ra], multiband_catalog[col_dec]) as m: 

611 idx, multiband_match_inds, next_band_match_inds, dists = m.query_radius( 

612 catalog_next_band[col_ra], 

613 catalog_next_band[col_dec], 

614 match_radius, 

615 return_indices=True, 

616 ) 

617 else: 

618 if self.groupIdKey: 

619 multiband_catalog = join( 

620 multiband_catalog, 

621 catalog_next_band, 

622 keys=self.groupIdKey, 

623 join_type="outer", 

624 ) 

625 if coords_same: 

626 for coord in (col_ra, col_dec): 

627 coords = multiband_catalog[f"{band}_{coord}"] 

628 coords_good = coords[np.isfinite(coords)] 

629 coords_good_first = multiband_catalog[f"{bands[0]}_{coord}"] 

630 coords_good_first = coords_good_first[np.isfinite(coords_good_first)] 

631 if (len(coords_good) != len(coords_good_first)) or ( 

632 np.any(coords_good != coords_good_first) 

633 ): 

634 coords_same = False 

635 

636 multiband_match_inds, next_band_match_inds = [], [] 

637 # If there are matches... 

638 if len(multiband_match_inds) > 0 and len(next_band_match_inds) > 0: 

639 # ...choose the coordinates in the brightest band. 

640 for i, j in zip(multiband_match_inds, next_band_match_inds): 

641 mags = [] 

642 for col in multiband_catalog.colnames: 

643 if f"_{col_mag}" in col: 

644 mags.append((col, multiband_catalog[i][col])) 

645 bright_mag = min([x[1] for x in mags]) 

646 if catalog_next_band[f"{band}_{col_mag}"][j] < bright_mag: 

647 multiband_catalog[col_ra][i] = catalog_next_band[col_ra][j] 

648 multiband_catalog[col_dec][i] = catalog_next_band[col_dec][j] 

649 # TODO: Once multicomponent object support is added, make some 

650 # logic to pick the correct source_type. 

651 # Fill the new mag value. 

652 multiband_catalog[f"{band}_{col_mag}"][multiband_match_inds] = catalog_next_band[ 

653 f"{band}_{col_mag}" 

654 ][next_band_match_inds] 

655 # Add rows for all the sources without matches. 

656 not_next_band_match_inds = np.full(len(catalog_next_band), True, dtype=bool) 

657 not_next_band_match_inds[next_band_match_inds] = False 

658 multiband_catalog = vstack([multiband_catalog, catalog_next_band[not_next_band_match_inds]]) 

659 # Otherwise just stack the tables. 

660 elif not self.groupIdKey: 

661 multiband_catalog = vstack([multiband_catalog, catalog_next_band]) 

662 

663 if self.groupIdKey: 

664 if coords_same: 

665 for coord in (col_ra, col_dec): 

666 for band in bands[1:]: 

667 del multiband_catalog[f"{band}_{coord}"] 

668 multiband_catalog.rename_column(f"{bands[0]}_{coord}", coord) 

669 else: 

670 fluxes = np.array( 

671 [(multiband_catalog[f"{band}_{col_mag}"] * u.ABmag).to(u.nJy).value for band in bands] 

672 ) 

673 # TODO: Test this better and deal with periodicity in RA 

674 for coord in (col_dec, col_ra): 

675 coords = np.array([multiband_catalog[f"{band}_{coord}"] for band in bands]) 

676 coords = np.nanmean(fluxes * coords, axis=0) / np.nansum(fluxes, axis=0) 

677 multiband_catalog.add_column(coords, index=1, name=coord) 

678 multiband_catalog.add_column( 

679 np.all( 

680 np.array([multiband_catalog[f"{band}_{self.injectionKey}"] for band in bands]) == 1, 

681 axis=0, 

682 ), 

683 index=1, 

684 name=self.injectionKey, 

685 ) 

686 else: 

687 # Fill in per-band injection flag columns 

688 if not copy_catalogs: 

689 multiband_catalog[self.injectionKey][:] = catalog_dict[bands[0]][self.injectionKey][:] 

690 multiband_catalog[f"{bands[0]}_{self.injectionKey}"] = multiband_catalog[self.injectionKey] 

691 for band in bands[1:]: 

692 injected_band = catalog_dict[band][self.injectionKey] 

693 multiband_catalog[self.injectionKey] &= injected_band 

694 multiband_catalog[f"{band}_{self.injectionKey}"] = injected_band 

695 

696 # Fill any automatically masked values with NaNs if possible (float) 

697 # Otherwise, use the dtype's minimum value (for int, bool, etc.) 

698 if multiband_catalog.has_masked_columns: 

699 for colname in multiband_catalog.columns: 

700 column = multiband_catalog[colname] 

701 if isinstance(column, MaskedColumn): 

702 # Set the underlying values in-place 

703 column._data[column.mask] = ( 

704 np.nan if np.issubdtype(column.dtype, float) else np.ma.maximum_fill_value(column) 

705 ) 

706 return multiband_catalog 

707 

708 def setPrimaryFlags( 

709 self, 

710 catalog, 

711 skyMap, 

712 tractInfo, 

713 patches: list, 

714 ): 

715 """Set isPrimary and related flags on sources. 

716 

717 For co-added imaging, the `isPrimary` flag returns True when an object 

718 has no children, is in the inner region of a coadd patch, is in the 

719 inner region of a coadd tract, and is not detected in a pseudo-filter 

720 (e.g., a sky_object). 

721 For single frame imaging, the isPrimary flag returns True when a 

722 source has no children and is not a sky source. 

723 

724 Parameters 

725 ---------- 

726 catalog: `astropy.table.Table` 

727 A catalog of sources. 

728 Writes is-patch-inner, is-tract-inner, and is-primary flags. 

729 skyMap : `lsst.skymap.BaseSkyMap` 

730 Sky tessellation object 

731 tractInfo : `lsst.skymap.TractInfo` 

732 Tract object 

733 patches : `list` 

734 List of coadd patches 

735 """ 

736 # Mark whether sources are contained within the inner regions of the 

737 # given tract/patch. 

738 isPatchInner = np.full(len(catalog), 0, dtype=bool) 

739 ra, dec = catalog[self.col_ra].data, catalog[self.col_dec].data 

740 for patch in patches: 

741 patchMask = catalog["patch"] == patch 

742 patchInfo = tractInfo.getPatchInfo(patch) 

743 isPatchInner[patchMask] = getPatchInner(patchInfo, ra[patchMask], dec[patchMask]) 

744 isTractInner = getTractInner(tractInfo, skyMap, ra, dec) 

745 isPrimary = isTractInner & isPatchInner & (catalog[self.injectionKey] == 0) 

746 

747 catalog[self.isPatchInnerKey] = isPatchInner 

748 catalog[self.isTractInnerKey] = isTractInner 

749 catalog[self.isPrimaryKey] = isPrimary 

750 

751 

752class ConsolidateInjectedCatalogsTask(PipelineTask): 

753 """Class for combining all tables in a collection of input catalogs 

754 into one table. 

755 """ 

756 

757 _DefaultName = "consolidateInjectedCatalogsTask" 

758 ConfigClass = ConsolidateInjectedCatalogsConfig 

759 

760 def runQuantum(self, butlerQC, input_refs, output_refs): 

761 inputs = butlerQC.get(input_refs) 

762 skymap = inputs["skyMap"] 

763 catalog_dict, tract = _get_catalogs( 

764 inputs, 

765 input_refs, 

766 skymap, 

767 col_ra=self.config.col_ra, 

768 col_dec=self.config.col_dec, 

769 include_outer=not bool(self.config.groupIdKey), 

770 ) 

771 outputs = self.run(catalog_dict, skymap, tract) 

772 butlerQC.put(outputs, output_refs) 

773 

774 def run( 

775 self, 

776 catalog_dict: dict, 

777 skymap: BaseSkyMap, 

778 tract: int, 

779 ) -> Table: 

780 """Consolidate all tables in catalog_dict into one table. 

781 

782 catalog_dict: `dict` 

783 A dictionary with photometric bands for keys and astropy tables for 

784 items. 

785 skymap: `lsst.skymap.BaseSkyMap` 

786 A base skymap. 

787 tract: `int` 

788 The tract where sources have been injected. 

789 

790 Returns 

791 ------- 

792 output_struct : `lsst.pipe.base.Struct` 

793 contains : 

794 multiband_catalog: `astropy.table.Table` 

795 A single table containing all information of the separate 

796 tables in catalog_dict 

797 """ 

798 output_catalog = self.config.consolidate_catalogs( 

799 catalog_dict, 

800 skymap, 

801 tract, 

802 ) 

803 output_struct = Struct(output_catalog=output_catalog) 

804 return output_struct