Coverage for python / lsst / source / injection / utils / _consolidate_injected_coadd_catalogs.py: 11%
289 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 09:38 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 09:38 +0000
1# This file is part of source_injection.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = [
25 "ConsolidateInjectedCatalogsConnections",
26 "ConsolidateInjectedCatalogsConfig",
27 "ConsolidateInjectedCatalogsTask",
28]
30from collections import defaultdict
32import astropy.table
33import astropy.units as u
34import numpy as np
35from astropy.table import Table, join, vstack
36from astropy.table.column import MaskedColumn
37from smatch.matcher import Matcher # type: ignore [import-not-found]
39from lsst.daf.butler import DatasetProvenance
40from lsst.geom import Box2D, SpherePoint, degrees
41from lsst.pex.config import Field, ListField
42from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct
43from lsst.pipe.base.connections import InputQuantizedConnection
44from lsst.pipe.base.connectionTypes import Input, Output
45from lsst.skymap import BaseSkyMap
48class ConsolidateInjectedCatalogsConnections(
49 PipelineTaskConnections,
50 dimensions=("instrument", "skymap", "tract"),
51 defaultTemplates={
52 "injected_prefix": "injected_",
53 },
54):
55 """Base connections for source injection tasks."""
57 input_catalogs = Input(
58 doc="Per-patch and per-band injected catalogs to draw inputs from.",
59 name="{injected_prefix}deep_coadd_predetection_catalog",
60 dimensions=("skymap", "tract", "patch", "band"),
61 storageClass="ArrowAstropy",
62 minimum=1,
63 multiple=True,
64 )
65 skyMap = Input(
66 doc="Input definition of geometry/bbox and projection/wcs for warped exposures",
67 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
68 storageClass="SkyMap",
69 dimensions=("skymap",),
70 )
71 output_catalog = Output(
72 doc="Per-tract multiband catalog of injected sources.",
73 name="{injected_prefix}deep_coadd_predetection_catalog_tract",
74 storageClass="ArrowAstropy",
75 dimensions=("skymap", "tract"),
76 )
79def _get_catalogs(
80 inputs: dict,
81 input_refs: InputQuantizedConnection,
82 skymap: BaseSkyMap,
83 col_ra: str = "ra",
84 col_dec: str = "dec",
85 include_outer: bool = True,
86) -> tuple[dict, int]:
87 """Organize input catalogs into a dictionary with photometry band
88 keys.
90 Parameters
91 ----------
92 inputs: `dict`
93 A dictionary containing the input datasets.
94 input_refs: `lsst.pipe.base.connections.InputQuantizedConnection`
95 The input dataset references used by the butler.
96 skymap : `lsst.skymap.BaseSkyMap`
97 Sky tessellation object
98 col_ra: `str`
99 Column name for right ascension (in degrees).
100 col_dec: `str`
101 Column name for declination (in degrees).
102 include_outer: `bool`
103 Whether to include objects injected into the outer (not inner) region
104 of a patch.
106 Returns
107 -------
108 `tuple[dict, int]`
109 contains :
110 catalog_dict: `dict`
111 A dictionary with photometric bands for keys and astropy
112 tables for items.
113 tract: `int`
114 The tract covering the catalogs in catalog_dict.
115 """
116 catalog_dict: dict[str, dict[str, astropy.table.Table]] = {}
117 tracts = set()
118 for ref, catalog in zip(input_refs.input_catalogs, inputs["input_catalogs"]):
119 band = ref.dataId.band.name
120 if band not in catalog_dict:
121 catalog_dict[band] = {}
122 tract = ref.dataId.tract.id
123 tractInfo = skymap[tract]
124 # Load the patch number to check for patch overlap duplicates.
125 patch = ref.dataId.patch.id
126 if not include_outer:
127 is_inner = getPatchInner(
128 tractInfo[patch],
129 catalog[col_ra],
130 catalog[col_dec],
131 )
132 catalog = catalog[is_inner]
133 catalog["patch"] = patch
134 # Strip provenance from catalogs before merging to avoid the
135 # provenance headers triggering warnings in the astropy naive
136 # metadata merge tool.
137 DatasetProvenance.strip_provenance_from_flat_dict(catalog.meta)
138 catalog_dict[band][patch] = catalog
139 tracts.add(tract)
140 # Check that only catalogs covered by a single tract are loaded.
141 if len(tracts) > 1:
142 raise RuntimeError(f"Got tract={tract} with {tracts=}; there should only be one tract")
143 # Stack the per-band catalogs.
144 for band, catalog_patches in catalog_dict.items():
145 catalog_dict[band] = vstack([item[1] for item in sorted(catalog_patches.items())])
146 return catalog_dict, list(tracts)[0]
149def _get_patches(
150 catalog_dict: dict[str, astropy.table.Table],
151 tractInfo,
152 col_ra: str = "ra",
153 col_dec: str = "dec",
154 index=None,
155):
156 """Create a patch column and assign each row a patch number.
158 Parameters
159 ----------
160 catalog_dict: `dict`
161 A dictionary with photometric bands for keys and astropy tables for
162 items.
163 tractInfo: `lsst.skymap.tractInfo.ExplicitTractInfo`
164 Information for a tract specified explicitly.
165 col_ra: `str`
166 Column name for right ascension (in degrees).
167 col_dec: `str`
168 Column name for declination (in degrees).
169 """
170 for catalog in catalog_dict.values():
171 if "patch" not in catalog.colnames:
172 patches = np.empty(len(catalog), dtype=int)
173 catalog.add_column(patches, name="patch", index=index)
175 patches = catalog["patch"]
176 for idx, row in enumerate(catalog):
177 coord = SpherePoint(row[col_ra], row[col_dec], degrees)
178 patchInfo = tractInfo.findPatch(coord)
179 patches[idx] = int(patchInfo.getSequentialIndex())
182def getPatchInner(
183 patchInfo,
184 ra: np.ndarray,
185 dec: np.ndarray,
186):
187 """Set a flag for each source if it is in the innerBBox of a patch.
189 Parameters
190 ----------
191 patchInfo : `lsst.skymap.PatchInfo`
192 Information about a `SkyMap` `Patch`.
193 ra: `np.ndarray`
194 Right ascension values in degrees.
195 dec: `np.ndarray`
196 Declination values in degrees.
198 Returns
199 -------
200 isPatchInner : array-like of `bool`
201 `True` for each source that has a centroid
202 in the inner region of a patch.
203 """
204 # convert the coordinates to pixel positions
205 wcs = patchInfo.getWcs()
206 x, y = wcs.skyToPixelArray(ra, dec, degrees=True)
208 # set inner flags for each source
209 innerFloatBBox = Box2D(patchInfo.getInnerBBox())
210 isPatchInner = innerFloatBBox.contains(x, y)
212 return isPatchInner
215def getTractInner(
216 tractInfo,
217 skyMap,
218 ra: np.ndarray,
219 dec: np.ndarray,
220):
221 """Set a flag for each source that the skyMap includes in tractInfo.
223 Parameters
224 ----------
225 tractInfo : `lsst.skymap.TractInfo`
226 Tract object
227 skyMap : `lsst.skymap.BaseSkyMap`
228 Sky tessellation object
229 ra: `np.ndarray`
230 Right ascension values in degrees.
231 dec: `np.ndarray`
232 Declination values in degrees.
234 Returns
235 -------
236 isTractInner : array-like of `bool`
237 True if the skyMap.findTract method returns
238 the same tract as tractInfo.
239 """
240 tractId = tractInfo.getId()
241 isTractInner = np.array(
242 [skyMap.findTract(SpherePoint(_ra, _dec, degrees)).getId() == tractId for _ra, _dec in zip(ra, dec)],
243 dtype=bool,
244 )
246 return isTractInner
249class ConsolidateInjectedCatalogsConfig( # type: ignore [call-arg]
250 PipelineTaskConfig, pipelineConnections=ConsolidateInjectedCatalogsConnections
251):
252 """Base configuration for source injection tasks."""
254 col_ra = Field[str](
255 doc="Column name for right ascension (in degrees).",
256 default="ra",
257 )
258 col_dec = Field[str](
259 doc="Column name for declination (in degrees).",
260 default="dec",
261 )
262 col_mag = Field[str](
263 doc="Column name for magnitude.",
264 default="mag",
265 )
266 col_source_type = Field[str](
267 doc="Column name for the source type used in the input catalog. Must match one of the surface "
268 "brightness profiles defined by GalSim. For more information see the Galsim docs at "
269 "https://galsim-developers.github.io/GalSim/_build/html/sb.html",
270 default="source_type",
271 )
272 columns_extra = ListField[str](
273 doc="Extra columns to be copied from the injection catalog (e.g. for shapes)",
274 default=[],
275 )
276 groupIdKey = Field[str](
277 doc="Key for the group id column to merge sources on, if any",
278 default=None,
279 optional=True,
280 )
281 injectionKey = Field[str](
282 doc="True if the source was successfully injected.",
283 default="injection_flag",
284 )
285 injectionSizeKey = Field[str](
286 doc="The size of the drawn injection box",
287 default="injection_draw_size",
288 )
289 isPatchInnerKey = Field[str](
290 doc="True if source is in the inner region of a coadd patch.",
291 default="injected_isPatchInner",
292 )
293 isTractInnerKey = Field[str](
294 doc="True if source is in the inner region of a coadd tract.",
295 default="injected_isTractInner",
296 )
297 isPrimaryKey = Field[str](
298 doc="True if the source was successfully injected and is in both the inner region of a coadd patch "
299 "and tract.",
300 default="injected_isPrimary",
301 )
302 remove_patch_overlap_duplicates = Field[bool](
303 doc="Optional parameter to remove patch overlap duplicate sources.",
304 default=False,
305 )
306 get_catalogs_from_butler = Field[bool](
307 doc="Optional parameter to specify whether or not the input catalogs are loaded with a Butler.",
308 default=True,
309 )
310 pixel_match_radius = Field[float](
311 doc="Radius for matching catalogs across different bands.",
312 default=0.1,
313 )
315 def consolidate_catalogs(
316 self,
317 catalog_dict: dict[str, astropy.table.Table],
318 skymap: BaseSkyMap,
319 tract: int,
320 copy_catalogs: bool = False,
321 ) -> Table:
322 """Consolidate all tables in catalog_dict into one table.
324 Parameters
325 ----------
326 catalog_dict: `dict`
327 A dictionary with photometric bands for keys and astropy tables for
328 items.
329 skymap: `lsst.skymap.BaseSkyMap`
330 A base skymap.
331 tract: `int`
332 The tract where sources have been injected.
333 copy_catalogs: `bool`
334 Whether to copy the input catalogs; if False, they will be modified
335 in-place.
337 Returns
338 -------
339 multiband_catalog: `astropy.table.Table`
340 A single table containing all information of the separate
341 tables in catalog_dict
342 """
343 tractInfo = skymap.generateTract(tract)
344 # If patch numbers are not loaded via dataIds from the butler, manually
345 # load patch numbers from source positions.
346 if not self.get_catalogs_from_butler:
347 _get_patches(catalog_dict, tractInfo)
349 # Convert the pixel match radius to degrees.
350 tractWcs = tractInfo.getWcs()
351 pixel_scale = tractWcs.getPixelScale()
352 match_radius = self.pixel_match_radius * pixel_scale.asDegrees()
353 has_groups = bool(self.groupIdKey)
354 bands = list(catalog_dict.keys())
355 if has_groups or (len(bands) > 1):
356 # Match the catalogs across bands.
357 output_catalog = self.make_multiband_catalog(
358 bands,
359 catalog_dict,
360 match_radius,
361 copy_catalogs=copy_catalogs,
362 )
363 else:
364 output_catalog = catalog_dict[bands[0]]
365 output_catalog.rename_column(self.col_mag, f"{bands[0]}_{self.col_mag}")
366 # Remove sources outside tract boundaries.
367 out_of_tract_bounds = []
368 for index, (ra, dec) in enumerate(zip(output_catalog[self.col_ra], output_catalog[self.col_dec])):
369 point = SpherePoint(ra * degrees, dec * degrees)
370 if not tractInfo.contains(point):
371 out_of_tract_bounds.append(index)
372 if out_of_tract_bounds:
373 output_catalog.remove_rows(out_of_tract_bounds)
374 # Assign flags.
375 # There may be a pre-existing patch column; however, patches can shift
376 # if centroids vary per band, etc.
377 # TODO: Need to check if this can cause duplicate or dropped entries
378 # for objects near patch boundaries
379 _get_patches(
380 {"x": output_catalog},
381 tractInfo=tractInfo,
382 col_ra=self.col_ra,
383 col_dec=self.col_dec,
384 )
385 patches = list(set(output_catalog["patch"]))
386 self.setPrimaryFlags(
387 catalog=output_catalog,
388 skyMap=skymap,
389 tractInfo=tractInfo,
390 patches=patches,
391 )
392 # Add a new injected_id column.
393 output_catalog.add_column(col=np.arange(len(output_catalog)), name="injected_id")
394 # If using a group ID, the draw_size column is
395 # not preserved, and the rest are already ordered
396 if self.groupIdKey:
397 return output_catalog
398 # Reorder columns
399 mag_cols = [col for col in output_catalog.columns if f"_{self.col_mag}" in col]
401 new_order = [
402 "injected_id",
403 self.col_ra,
404 self.col_dec,
405 self.col_source_type,
406 *mag_cols,
407 "patch",
408 "injection_id",
409 self.injectionSizeKey,
410 self.injectionKey,
411 "injected_isPatchInner",
412 "injected_isTractInner",
413 "injected_isPrimary",
414 ]
415 for column in output_catalog.columns:
416 if column not in new_order:
417 new_order.append(column)
419 return output_catalog[new_order]
421 def make_multiband_catalog(
422 self,
423 bands: list,
424 catalog_dict: dict[str, astropy.table.Table],
425 match_radius: float,
426 copy_catalogs: bool = False,
427 ) -> Table:
428 """Combine multiple band-specific catalogs into one multiband
429 catalog.
431 Parameters
432 ----------
433 bands: `list`
434 A list of string photometry bands.
435 catalog_dict: `dict`
436 A dictionary with photometric bands for keys and astropy
437 tables for items.
438 match_radius: `float`
439 The radius for matching catalogs across bands in arcsec.
440 copy_catalogs: `bool`
441 Whether to copy the input catalogs; if False, they will be modified
442 in-place.
444 Returns
445 -------
446 multiband_catalog: `astropy.table.Table`
447 A catalog with sources that have magnitude information across all
448 bands.
449 """
450 col_mag = self.col_mag
451 col_ra = self.col_ra
452 col_dec = self.col_dec
454 if self.groupIdKey:
455 n_comps = 0
456 n_rows = {}
457 for band, catalog in catalog_dict.items():
458 groupIds, counts = np.unique(
459 catalog[self.groupIdKey],
460 return_counts=True,
461 )
462 n_rows[band] = len(groupIds)
463 n_comps = max(np.max(counts), n_comps)
465 # Maybe validate that this is true? Why set groupId otherwise?
466 is_multicomp = n_comps > 1
467 prefix_comp = "comp{idx_comp}_" if is_multicomp else ""
468 columns_comp = {
469 column_in: f"{{band}}_{prefix_comp}{column_in}"
470 for column_in in [self.col_source_type, self.injectionKey] + list(self.columns_extra)
471 }
473 for band, catalog in catalog_dict.items():
474 n_rows_b = n_rows[band]
475 unit_mag = catalog["mag"].unit or u.ABmag
476 counts = defaultdict(int)
477 idxs_new = {}
479 groupIds = np.full(n_rows_b, 0, dtype=catalog[self.groupIdKey].dtype)
480 ra = np.full(n_rows_b, np.nan, dtype=catalog[col_ra].dtype)
481 dec = np.full(n_rows_b, np.nan, dtype=catalog[col_dec].dtype)
482 injection_flag = np.full(n_rows_b, True, dtype=bool)
483 if is_multicomp:
484 flux = np.full(n_rows_b, np.nan, dtype=catalog[col_mag].dtype)
485 else:
486 mag = np.full(n_rows_b, np.nan, dtype=catalog[col_mag].dtype)
488 values_comp = {}
489 if is_multicomp:
490 for idx_comp in range(1, n_comps + 1):
491 prefix_comp = f"comp{idx_comp}_"
492 values_comp[f"{band}_{prefix_comp}flux"] = np.ma.masked_array(
493 np.full(n_rows_b, np.nan, dtype=catalog[col_mag].dtype),
494 mask=np.ones(n_rows_b, dtype=bool),
495 )
496 for idx_comp in range(1, n_comps + 1):
497 for column_comp_in, column_comp_out in columns_comp.items():
498 values_col = np.full(n_rows_b, np.nan, dtype=catalog[column_comp_in].dtype)
499 column_out = column_comp_out.format(band=band, idx_comp=idx_comp)
500 values_comp[column_out] = (
501 np.ma.masked_array(values_col, mask=np.ones(n_rows_b, dtype=bool))
502 if is_multicomp
503 else values_col
504 )
506 idx_new = 0
507 for row in catalog:
508 groupId = row[self.groupIdKey]
509 idx_comp = counts[groupId] + 1
510 counts[groupId] = idx_comp
511 prefix_comp = f"comp{idx_comp}_" if is_multicomp else ""
513 injected = row[self.injectionKey] == False # noqa: E712
514 if idx_comp == 1:
515 ra[idx_new] = row[col_ra]
516 dec[idx_new] = row[col_dec]
517 groupIds[idx_new] = groupId
518 idxs_new[groupId] = idx_new
519 if is_multicomp:
520 flux_comp = (row[col_mag] * unit_mag).to(u.nJy).value
521 flux[idx_new] = flux_comp * injected
522 else:
523 mag[idx_new] = row[col_mag]
524 idx_old = idx_new
525 idx_new += 1
526 else:
527 idx_old = idxs_new[groupId]
528 flux_cumul = flux[idx_old]
529 flux_comp = (row[col_mag] * unit_mag).to(u.nJy).value
530 # TODO: Excluding not-injected components means
531 # objects with no injected components will have the
532 # ra,dec of their first component, not the mean of
533 # all excluded components.
534 if flux_comp * injected > 0:
535 flux_new = flux_cumul + flux_comp
536 flux[idx_old] = flux_new
537 if ra[idx_old] != row[col_ra]:
538 # Take a weighted mean
539 # TODO: deal with periodicity
540 ra[idx_old] = (flux_cumul * ra[idx_old] + flux_comp * row[col_ra]) / flux_new
541 if dec[idx_old] != row[col_dec]:
542 dec[idx_old] = (
543 flux_cumul * dec[idx_old] + flux_comp * row[col_dec]
544 ) / flux_new
546 # One component is sufficient to say it was injected
547 injection_flag[idx_old] &= not injected
549 if is_multicomp:
550 column = values_comp[f"{band}_{prefix_comp}flux"]
551 column.mask[idx_old] = False
552 column[idx_old] = flux_comp
554 for column_comp in columns_comp:
555 column = f"{band}_{prefix_comp}{column_comp}"
556 values_comp[column].mask[idx_old] = False
557 values_comp[column][idx_old] = row[column_comp]
559 if is_multicomp:
560 mag = (flux * u.nJy).to(u.ABmag).value
562 columns = {
563 self.groupIdKey: groupIds,
564 f"{band}_{self.injectionKey}": injection_flag,
565 f"{band}_{col_ra}": ra,
566 f"{band}_{col_dec}": dec,
567 f"{band}_{col_mag}": mag,
568 }
569 units = {
570 self.groupIdKey: None,
571 f"{band}_{self.injectionKey}": None,
572 f"{band}_{col_ra}": catalog[col_ra].unit,
573 f"{band}_{col_dec}": catalog[col_dec].unit,
574 f"{band}_{col_mag}": catalog[col_mag].unit,
575 }
577 for idx_comp in range(1, n_comps + 1):
578 for column_comp_in, column_comp_out in columns_comp.items():
579 name_column = column_comp_out.format(band=band, idx_comp=idx_comp)
580 columns[name_column] = values_comp[name_column]
581 units[name_column] = catalog[column_comp_in].unit
583 catalog_new = astropy.table.Table(columns, units=units)
584 catalog_dict[band] = catalog_new
586 copy_catalogs = False
588 # Load the first catalog then loop to add info for the other bands.
589 multiband_catalog = catalog_dict[bands[0]]
590 if copy_catalogs:
591 multiband_catalog = multiband_catalog.copy()
592 if not self.groupIdKey:
593 multiband_catalog.add_column(col=multiband_catalog[col_mag], name=f"{bands[0]}_{col_mag}")
594 multiband_catalog.remove_column(col_mag)
595 else:
596 coords_same = True
598 for band in bands[1:]:
599 if not self.groupIdKey:
600 # Make a column for the new band.
601 multiband_catalog.add_column([np.nan] * len(multiband_catalog), name=f"{band}_{col_mag}")
602 # Match the input catalog for this band to the existing
603 # multiband catalog.
604 catalog_next_band = catalog_dict[band]
605 if copy_catalogs:
606 catalog_next_band = catalog_next_band.copy()
607 if not self.groupIdKey:
608 catalog_next_band.rename_column(col_mag, f"{band}_{col_mag}")
609 if match_radius >= 0:
610 with Matcher(multiband_catalog[col_ra], multiband_catalog[col_dec]) as m:
611 idx, multiband_match_inds, next_band_match_inds, dists = m.query_radius(
612 catalog_next_band[col_ra],
613 catalog_next_band[col_dec],
614 match_radius,
615 return_indices=True,
616 )
617 else:
618 if self.groupIdKey:
619 multiband_catalog = join(
620 multiband_catalog,
621 catalog_next_band,
622 keys=self.groupIdKey,
623 join_type="outer",
624 )
625 if coords_same:
626 for coord in (col_ra, col_dec):
627 coords = multiband_catalog[f"{band}_{coord}"]
628 coords_good = coords[np.isfinite(coords)]
629 coords_good_first = multiband_catalog[f"{bands[0]}_{coord}"]
630 coords_good_first = coords_good_first[np.isfinite(coords_good_first)]
631 if (len(coords_good) != len(coords_good_first)) or (
632 np.any(coords_good != coords_good_first)
633 ):
634 coords_same = False
636 multiband_match_inds, next_band_match_inds = [], []
637 # If there are matches...
638 if len(multiband_match_inds) > 0 and len(next_band_match_inds) > 0:
639 # ...choose the coordinates in the brightest band.
640 for i, j in zip(multiband_match_inds, next_band_match_inds):
641 mags = []
642 for col in multiband_catalog.colnames:
643 if f"_{col_mag}" in col:
644 mags.append((col, multiband_catalog[i][col]))
645 bright_mag = min([x[1] for x in mags])
646 if catalog_next_band[f"{band}_{col_mag}"][j] < bright_mag:
647 multiband_catalog[col_ra][i] = catalog_next_band[col_ra][j]
648 multiband_catalog[col_dec][i] = catalog_next_band[col_dec][j]
649 # TODO: Once multicomponent object support is added, make some
650 # logic to pick the correct source_type.
651 # Fill the new mag value.
652 multiband_catalog[f"{band}_{col_mag}"][multiband_match_inds] = catalog_next_band[
653 f"{band}_{col_mag}"
654 ][next_band_match_inds]
655 # Add rows for all the sources without matches.
656 not_next_band_match_inds = np.full(len(catalog_next_band), True, dtype=bool)
657 not_next_band_match_inds[next_band_match_inds] = False
658 multiband_catalog = vstack([multiband_catalog, catalog_next_band[not_next_band_match_inds]])
659 # Otherwise just stack the tables.
660 elif not self.groupIdKey:
661 multiband_catalog = vstack([multiband_catalog, catalog_next_band])
663 if self.groupIdKey:
664 if coords_same:
665 for coord in (col_ra, col_dec):
666 for band in bands[1:]:
667 del multiband_catalog[f"{band}_{coord}"]
668 multiband_catalog.rename_column(f"{bands[0]}_{coord}", coord)
669 else:
670 fluxes = np.array(
671 [(multiband_catalog[f"{band}_{col_mag}"] * u.ABmag).to(u.nJy).value for band in bands]
672 )
673 # TODO: Test this better and deal with periodicity in RA
674 for coord in (col_dec, col_ra):
675 coords = np.array([multiband_catalog[f"{band}_{coord}"] for band in bands])
676 coords = np.nanmean(fluxes * coords, axis=0) / np.nansum(fluxes, axis=0)
677 multiband_catalog.add_column(coords, index=1, name=coord)
678 multiband_catalog.add_column(
679 np.all(
680 np.array([multiband_catalog[f"{band}_{self.injectionKey}"] for band in bands]) == 1,
681 axis=0,
682 ),
683 index=1,
684 name=self.injectionKey,
685 )
686 else:
687 # Fill in per-band injection flag columns
688 if not copy_catalogs:
689 multiband_catalog[self.injectionKey][:] = catalog_dict[bands[0]][self.injectionKey][:]
690 multiband_catalog[f"{bands[0]}_{self.injectionKey}"] = multiband_catalog[self.injectionKey]
691 for band in bands[1:]:
692 injected_band = catalog_dict[band][self.injectionKey]
693 multiband_catalog[self.injectionKey] &= injected_band
694 multiband_catalog[f"{band}_{self.injectionKey}"] = injected_band
696 # Fill any automatically masked values with NaNs if possible (float)
697 # Otherwise, use the dtype's minimum value (for int, bool, etc.)
698 if multiband_catalog.has_masked_columns:
699 for colname in multiband_catalog.columns:
700 column = multiband_catalog[colname]
701 if isinstance(column, MaskedColumn):
702 # Set the underlying values in-place
703 column._data[column.mask] = (
704 np.nan if np.issubdtype(column.dtype, float) else np.ma.maximum_fill_value(column)
705 )
706 return multiband_catalog
708 def setPrimaryFlags(
709 self,
710 catalog,
711 skyMap,
712 tractInfo,
713 patches: list,
714 ):
715 """Set isPrimary and related flags on sources.
717 For co-added imaging, the `isPrimary` flag returns True when an object
718 has no children, is in the inner region of a coadd patch, is in the
719 inner region of a coadd tract, and is not detected in a pseudo-filter
720 (e.g., a sky_object).
721 For single frame imaging, the isPrimary flag returns True when a
722 source has no children and is not a sky source.
724 Parameters
725 ----------
726 catalog: `astropy.table.Table`
727 A catalog of sources.
728 Writes is-patch-inner, is-tract-inner, and is-primary flags.
729 skyMap : `lsst.skymap.BaseSkyMap`
730 Sky tessellation object
731 tractInfo : `lsst.skymap.TractInfo`
732 Tract object
733 patches : `list`
734 List of coadd patches
735 """
736 # Mark whether sources are contained within the inner regions of the
737 # given tract/patch.
738 isPatchInner = np.full(len(catalog), 0, dtype=bool)
739 ra, dec = catalog[self.col_ra].data, catalog[self.col_dec].data
740 for patch in patches:
741 patchMask = catalog["patch"] == patch
742 patchInfo = tractInfo.getPatchInfo(patch)
743 isPatchInner[patchMask] = getPatchInner(patchInfo, ra[patchMask], dec[patchMask])
744 isTractInner = getTractInner(tractInfo, skyMap, ra, dec)
745 isPrimary = isTractInner & isPatchInner & (catalog[self.injectionKey] == 0)
747 catalog[self.isPatchInnerKey] = isPatchInner
748 catalog[self.isTractInnerKey] = isTractInner
749 catalog[self.isPrimaryKey] = isPrimary
752class ConsolidateInjectedCatalogsTask(PipelineTask):
753 """Class for combining all tables in a collection of input catalogs
754 into one table.
755 """
757 _DefaultName = "consolidateInjectedCatalogsTask"
758 ConfigClass = ConsolidateInjectedCatalogsConfig
760 def runQuantum(self, butlerQC, input_refs, output_refs):
761 inputs = butlerQC.get(input_refs)
762 skymap = inputs["skyMap"]
763 catalog_dict, tract = _get_catalogs(
764 inputs,
765 input_refs,
766 skymap,
767 col_ra=self.config.col_ra,
768 col_dec=self.config.col_dec,
769 include_outer=not bool(self.config.groupIdKey),
770 )
771 outputs = self.run(catalog_dict, skymap, tract)
772 butlerQC.put(outputs, output_refs)
774 def run(
775 self,
776 catalog_dict: dict,
777 skymap: BaseSkyMap,
778 tract: int,
779 ) -> Table:
780 """Consolidate all tables in catalog_dict into one table.
782 catalog_dict: `dict`
783 A dictionary with photometric bands for keys and astropy tables for
784 items.
785 skymap: `lsst.skymap.BaseSkyMap`
786 A base skymap.
787 tract: `int`
788 The tract where sources have been injected.
790 Returns
791 -------
792 output_struct : `lsst.pipe.base.Struct`
793 contains :
794 multiband_catalog: `astropy.table.Table`
795 A single table containing all information of the separate
796 tables in catalog_dict
797 """
798 output_catalog = self.config.consolidate_catalogs(
799 catalog_dict,
800 skymap,
801 tract,
802 )
803 output_struct = Struct(output_catalog=output_catalog)
804 return output_struct