Coverage for python/lsst/source/injection/utils/consolidate_injected_deepCoadd_catalogs.py: 21%
143 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-20 12:11 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-20 12:11 +0000
1# This file is part of source_injection.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = [
25 "ConsolidateInjectedCatalogsConnections",
26 "ConsolidateInjectedCatalogsConfig",
27 "ConsolidateInjectedCatalogsTask",
28 "consolidate_injected_deepCoadd_catalogs",
29]
31import numpy as np
32from astropy.table import Table, vstack
33from astropy.table.column import MaskedColumn
34from lsst.geom import Box2D, SpherePoint, degrees
35from lsst.pex.config import Field
36from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct
37from lsst.pipe.base.connections import InputQuantizedConnection
38from lsst.pipe.base.connectionTypes import Input, Output, PrerequisiteInput
39from lsst.skymap import BaseSkyMap
40from smatch.matcher import Matcher # type: ignore [import-not-found]
43class ConsolidateInjectedCatalogsConnections(
44 PipelineTaskConnections,
45 dimensions=("instrument", "skymap", "tract"),
46 defaultTemplates={
47 "injected_prefix": "injected_",
48 },
49):
50 """Base connections for source injection tasks."""
52 input_catalogs = PrerequisiteInput(
53 doc="Per-patch and per-band injected catalogs to draw inputs from.",
54 name="{injected_prefix}deepCoadd_catalog",
55 dimensions=("skymap", "tract", "patch", "band"),
56 storageClass="ArrowAstropy",
57 minimum=1,
58 multiple=True,
59 )
60 skyMap = Input(
61 doc="Input definition of geometry/bbox and projection/wcs for warped exposures",
62 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
63 storageClass="SkyMap",
64 dimensions=("skymap",),
65 )
66 output_catalog = Output(
67 doc="Per-tract multiband catalog of injected sources.",
68 name="{injected_prefix}deepCoadd_catalog_tract",
69 storageClass="ArrowAstropy",
70 dimensions=("skymap", "tract"),
71 )
74class ConsolidateInjectedCatalogsConfig( # type: ignore [call-arg]
75 PipelineTaskConfig, pipelineConnections=ConsolidateInjectedCatalogsConnections
76):
77 """Base configuration for source injection tasks."""
79 col_ra = Field[str](
80 doc="Column name for right ascension (in degrees).",
81 default="ra",
82 )
83 col_dec = Field[str](
84 doc="Column name for declination (in degrees).",
85 default="dec",
86 )
87 col_mag = Field[str](
88 doc="Column name for magnitude.",
89 default="mag",
90 )
91 col_source_type = Field[str](
92 doc="Column name for the source type used in the input catalog. Must match one of the surface "
93 "brightness profiles defined by GalSim. For more information see the Galsim docs at "
94 "https://galsim-developers.github.io/GalSim/_build/html/sb.html",
95 default="source_type",
96 )
97 injectionKey = Field[str](
98 doc="True if the source was successfully injected.",
99 default="injection_flag",
100 )
101 isPatchInnerKey = Field[str](
102 doc="True if source is in the inner region of a coadd patch.",
103 default="injected_isPatchInner",
104 )
105 isTractInnerKey = Field[str](
106 doc="True if source is in the inner region of a coadd tract.",
107 default="injected_isTractInner",
108 )
109 isPrimaryKey = Field[str](
110 doc="True if the source was successfully injected and is in both the inner region of a coadd patch "
111 "and tract.",
112 default="injected_isPrimary",
113 )
114 remove_patch_overlap_duplicates = Field[bool](
115 doc="Optional parameter to remove patch overlap duplicate sources.",
116 default=False,
117 )
118 get_catalogs_from_butler = Field[bool](
119 doc="Optional parameter to specify whether or not the input catalogs are loaded with a Butler.",
120 default=True,
121 )
122 pixel_match_radius = Field[float](
123 doc="Radius for matching catalogs across different bands.",
124 default=0.1,
125 )
128def _get_catalogs(
129 inputs: dict,
130 input_refs: InputQuantizedConnection,
131) -> tuple[dict, int]:
132 """Organize input catalogs into a dictionary with photometry band
133 keys.
135 Parameters
136 ----------
137 inputs: `dict`
138 A dictionary containing the input datasets.
139 input_refs: `lsst.pipe.base.connections.InputQuantizedConnection`
140 The input dataset references used by the butler.
142 Returns
143 -------
144 `tuple[dict, int]`
145 contains :
146 catalog_dict: `dict`
147 A dictionary with photometric bands for keys and astropy
148 tables for items.
149 tract: `int`
150 The tract covering the catalogs in catalog_dict.
151 """
152 catalog_dict: dict[str, list] = {}
153 tracts = set()
154 for ref, catalog in zip(input_refs.input_catalogs, inputs["input_catalogs"]):
155 band = ref.dataId.band.name
156 if band not in catalog_dict:
157 catalog_dict[band] = []
158 # Load the patch number to check for patch overlap duplicates.
159 catalog["patch"] = ref.dataId.patch.id
160 catalog_dict[band].append(catalog)
161 tracts.add(ref.dataId.tract.id)
162 # Stack the per-band catalogs.
163 for band, catalog_list in catalog_dict.items():
164 catalog_dict[band] = vstack(catalog_list)
165 # Check that only catalogs covered by a single tract are loaded.
166 if len(tracts) != 1:
167 raise RuntimeError(f"len({tracts=}) != 1")
168 return (catalog_dict, list(tracts)[0])
171def _get_patches(
172 catalog_dict: dict,
173 tractInfo,
174 col_ra: str = "ra",
175 col_dec: str = "dec",
176):
177 """Create a patch column and assign each row a patch number.
179 Parameters
180 ----------
181 catalog_dict: `dict`
182 A dictionary with photometric bands for keys and astropy tables for
183 items.
184 tractInfo: `lsst.skymap.tractInfo.ExplicitTractInfo`
185 Information for a tract specified explicitly.
186 col_ra: `str`
187 Column name for right ascension (in degrees).
188 col_dec: `str`
189 Column name for declination (in degrees).
190 """
191 for catalog in catalog_dict.values():
192 if "patch" not in list(set(catalog.columns)):
193 catalog.add_column(col=-9999, name="patch")
194 for row in catalog:
195 coord = SpherePoint(row[col_ra], row[col_dec], degrees)
196 patchInfo = tractInfo.findPatch(coord)
197 row["patch"] = int(patchInfo.getSequentialIndex())
200def getPatchInner(
201 catalog: Table,
202 patchInfo,
203 col_ra: str = "ra",
204 col_dec: str = "dec",
205):
206 """Set a flag for each source if it is in the innerBBox of a patch.
208 Parameters
209 ----------
210 catalog: `astropy.table.Table`
211 A catalog of sources.
212 patchInfo : `lsst.skymap.PatchInfo`
213 Information about a `SkyMap` `Patch`.
214 col_ra: `str`
215 Column name for right ascension (in degrees).
216 col_dec: `str`
217 Column name for declination (in degrees).
219 Returns
220 -------
221 isPatchInner : array-like of `bool`
222 `True` for each source that has a centroid
223 in the inner region of a patch.
224 """
225 # Extract positions for all the sources.
226 ra = catalog[col_ra]
227 dec = catalog[col_dec]
229 # convert the coordinates to pixel positions
230 wcs = patchInfo.getWcs()
231 x, y = wcs.skyToPixelArray(ra, dec, degrees=True)
233 # set inner flags for each source
234 innerFloatBBox = Box2D(patchInfo.getInnerBBox())
235 isPatchInner = innerFloatBBox.contains(x, y)
237 return isPatchInner
240def getTractInner(
241 catalog,
242 tractInfo,
243 skyMap,
244 col_ra: str = "ra",
245 col_dec: str = "dec",
246):
247 """Set a flag for each source that the skyMap includes in tractInfo.
249 Parameters
250 ----------
251 catalog: `astropy.table.Table`
252 A catalog of sources.
253 tractInfo : `lsst.skymap.TractInfo`
254 Tract object
255 skyMap : `lsst.skymap.BaseSkyMap`
256 Sky tessellation object
257 col_ra: `str`
258 Column name for right ascension (in degrees).
259 col_dec: `str`
260 Column name for declination (in degrees).
262 Returns
263 -------
264 isTractInner : array-like of `bool`
265 True if the skyMap.findTract method returns
266 the same tract as tractInfo.
267 """
268 ras, decs = catalog[col_ra].value, catalog[col_dec].value
269 skyCoords = [SpherePoint(ra, dec, degrees) for (ra, dec) in list(zip(ras, decs))]
270 tractId = tractInfo.getId()
271 isTractInner = np.array([skyMap.findTract(coord).getId() == tractId for coord in skyCoords])
272 return isTractInner
275def setPrimaryFlags(
276 catalog,
277 skyMap,
278 tractInfo,
279 patches: list,
280 col_ra: str = "ra",
281 col_dec: str = "dec",
282 isPatchInnerKey: str = "injected_isPatchInner",
283 isTractInnerKey: str = "injected_isTractInner",
284 isPrimaryKey: str = "injected_isPrimary",
285 injectionKey: str = "injection_flag",
286):
287 """Set isPrimary and related flags on sources.
289 For coadded imaging, the `isPrimary` flag returns True when an object
290 has no children, is in the inner region of a coadd patch, is in the
291 inner region of a coadd trach, and is not detected in a pseudo-filter
292 (e.g., a sky_object).
293 For single frame imaging, the isPrimary flag returns True when a
294 source has no children and is not a sky source.
296 Parameters
297 ----------
298 catalog: `astropy.table.Table`
299 A catalog of sources.
300 Writes is-patch-inner, is-tract-inner, and is-primary flags.
301 skyMap : `lsst.skymap.BaseSkyMap`
302 Sky tessellation object
303 tractInfo : `lsst.skymap.TractInfo`
304 Tract object
305 patches : `list`
306 List of coadd patches
307 col_ra: `str`
308 Column name for right ascension (in degrees).
309 col_dec: `str`
310 Column name for declination (in degrees).
311 isPatchInnerKey: `str`
312 Column name for the isPatchInner flag.
313 isTractInnerKey: `str`
314 Column name for the isTractInner flag.
315 isPrimaryKey: `str`
316 Column name for the isPrimary flag.
317 injectionKey: `str`
318 Column name for the injection flag.
319 """
320 # Mark whether sources are contained within the inner regions of the
321 # given tract/patch.
322 isPatchInner = np.array([False] * len(catalog))
323 for patch in patches:
324 patchMask = catalog["patch"] == patch
325 patchInfo = tractInfo.getPatchInfo(patch)
326 isPatchInner[patchMask] = getPatchInner(catalog[patchMask], patchInfo, col_ra, col_dec)
327 isTractInner = getTractInner(catalog, tractInfo, skyMap)
328 isPrimary = isTractInner & isPatchInner & (catalog[injectionKey] == 0)
330 catalog[isPatchInnerKey] = isPatchInner
331 catalog[isTractInnerKey] = isTractInner
332 catalog[isPrimaryKey] = isPrimary
335def _make_multiband_catalog(
336 bands: list,
337 catalog_dict: dict,
338 match_radius: float,
339 col_ra: str = "ra",
340 col_dec: str = "dec",
341 col_mag: str = "mag",
342) -> Table:
343 """Combine multiple band-specific catalogs into one multiband
344 catalog.
346 Parameters
347 ----------
348 bands: `list`
349 A list of string photometry bands.
350 catalog_dict: `dict`
351 A dictionary with photometric bands for keys and astropy
352 tables for items.
353 match_radius: `float`
354 The radius for matching catalogs across bands.
355 col_ra: `str`
356 Column name for right ascension (in degrees).
357 col_dec: `str`
358 Column name for declination (in degrees).
359 col_mag: `str`
360 Column name for magnitude.
362 Returns
363 -------
364 multiband_catalog: `astropy.table.Table`
365 A catalog with sources that have magnitude information across all
366 bands.
367 """
368 # Load the first catalog then loop to add info for the other bands.
369 multiband_catalog = catalog_dict[bands[0]].copy()
370 multiband_catalog.add_column(col=multiband_catalog[col_mag], name=f"{bands[0]}_{col_mag}")
371 multiband_catalog.remove_column(col_mag)
372 for band in bands[1:]:
373 # Make a column for the new band.
374 multiband_catalog.add_column([np.nan] * len(multiband_catalog), name=f"{band}_{col_mag}")
375 # Match the input catalog for this band to the existing
376 # multiband catalog.
377 catalog_next_band = catalog_dict[band].copy()
378 catalog_next_band.rename_column(col_mag, f"{band}_{col_mag}")
379 with Matcher(multiband_catalog[col_ra], multiband_catalog[col_dec]) as m:
380 idx, multiband_match_inds, next_band_match_inds, dists = m.query_radius(
381 catalog_next_band[col_ra],
382 catalog_next_band[col_dec],
383 match_radius,
384 return_indices=True,
385 )
386 # If there are matches...
387 if len(multiband_match_inds) > 0 and len(next_band_match_inds) > 0:
388 # ...choose the coordinates in the brightest band.
389 for i, j in zip(multiband_match_inds, next_band_match_inds):
390 mags = []
391 for col in multiband_catalog.colnames:
392 if f"_{col_mag}" in col:
393 mags.append((col, multiband_catalog[i][col]))
394 bright_mag = min([x[1] for x in mags])
395 if catalog_next_band[f"{band}_{col_mag}"][j] < bright_mag:
396 multiband_catalog[col_ra][i] = catalog_next_band[col_ra][j]
397 multiband_catalog[col_dec][i] = catalog_next_band[col_dec][j]
398 # TODO: Once multicomponent object support is added, make some
399 # logic to pick the correct source_type.
400 # Fill the new mag value.
401 multiband_catalog[f"{band}_{col_mag}"][multiband_match_inds] = catalog_next_band[
402 f"{band}_{col_mag}"
403 ][next_band_match_inds]
404 # Add rows for all the sources without matches.
405 not_next_band_match_inds = np.full(len(catalog_next_band), True, dtype=bool)
406 not_next_band_match_inds[next_band_match_inds] = False
407 multiband_catalog = vstack([multiband_catalog, catalog_next_band[not_next_band_match_inds]])
408 # Otherwise just stack the tables.
409 else:
410 multiband_catalog = vstack([multiband_catalog, catalog_next_band])
411 # Fill any automatically masked values with NaNs.
412 if multiband_catalog.has_masked_columns:
413 for col in multiband_catalog.columns:
414 if isinstance(multiband_catalog[col], MaskedColumn):
415 multiband_catalog[col] = multiband_catalog[col].filled(np.nan)
416 return multiband_catalog
419def consolidate_injected_deepCoadd_catalogs(
420 catalog_dict: dict,
421 skymap: BaseSkyMap,
422 tract: int,
423 pixel_match_radius: float = 0.1,
424 get_catalogs_from_butler: bool = True,
425 col_ra: str = "ra",
426 col_dec: str = "dec",
427 col_mag: str = "mag",
428 isPatchInnerKey="injected_isPatchInner",
429 isTractInnerKey="injected_isTractInner",
430 isPrimaryKey="injected_isPrimary",
431 injectionKey="injection_flag",
432) -> Table:
433 """Consolidate all tables in catalog_dict into one table.
435 Parameters
436 ----------
437 catalog_dict: `dict`
438 A dictionary with photometric bands for keys and astropy tables for
439 items.
440 skymap: `lsst.skymap.BaseSkyMap`
441 A base skymap.
442 tract: `int`
443 The tract where sources have been injected.
444 pixel_match_radius: `float`
445 Match radius in pixels to use for self-matching catalogs across
446 different bands.
447 get_catalogs_from_butler: `bool`
448 Optional parameter to specify whether or not the input catalogs are
449 loaded with a Butler.
450 col_ra: `str`
451 Column name for right ascension (in degrees).
452 col_dec: `str`
453 Column name for declination (in degrees).
454 col_mag: `str`
455 Column name for magnitude.
456 isPatchInnerKey: `str`
457 Column name for the isPatchInner flag.
458 isTractInnerKey: `str`
459 Column name for the isTractInner flag.
460 isPrimaryKey: `str`
461 Column name for the isPrimary flag.
462 injectionKey: `str`
463 Column name for the injection flag.
465 Returns
466 -------
467 multiband_catalog: `astropy.table.Table`
468 A single table containing all information of the separate
469 tables in catalog_dict
470 """
471 tractInfo = skymap.generateTract(tract)
472 # If patch numbers are not loaded via dataIds from the butler, manually
473 # load patch numbers from source positions.
474 if not get_catalogs_from_butler:
475 _get_patches(catalog_dict, tractInfo)
477 # Convert the pixel match radius to degrees.
478 tractWcs = tractInfo.getWcs()
479 pixel_scale = tractWcs.getPixelScale()
480 match_radius = pixel_match_radius * pixel_scale.asDegrees()
481 bands = list(catalog_dict.keys())
482 if len(bands) > 1:
483 # Match the catalogs across bands.
484 output_catalog = _make_multiband_catalog(bands, catalog_dict, match_radius)
485 else:
486 output_catalog = catalog_dict[bands[0]]
487 output_catalog.rename_column(col_mag, f"{bands[0]}_{col_mag}")
488 # Remove sources outside tract boundaries.
489 out_of_tract_bounds = []
490 for index, (ra, dec) in enumerate(list(zip(output_catalog[col_ra], output_catalog[col_dec]))):
491 point = SpherePoint(ra * degrees, dec * degrees)
492 if not tractInfo.contains(point):
493 out_of_tract_bounds.append(index)
494 output_catalog.remove_rows(out_of_tract_bounds)
495 # Assign flags.
496 patches = list(set(output_catalog["patch"]))
497 setPrimaryFlags(
498 catalog=output_catalog,
499 skyMap=skymap,
500 tractInfo=tractInfo,
501 patches=patches,
502 col_ra=col_ra,
503 col_dec=col_dec,
504 isPatchInnerKey=isPatchInnerKey,
505 isTractInnerKey=isTractInnerKey,
506 isPrimaryKey=isPrimaryKey,
507 injectionKey=injectionKey,
508 )
509 # Add a new injected_id column.
510 output_catalog.add_column(col=list(range(len(output_catalog))), name="injected_id")
511 # Remove unneccesary output columns.
512 output_catalog.remove_column("patch")
513 # Reorder columns
514 mag_cols = [col for col in output_catalog.columns if f"_{col_mag}" in col]
515 new_order = [
516 "injected_id",
517 "ra",
518 "dec",
519 "source_type",
520 *mag_cols,
521 "injection_id",
522 "injection_draw_size",
523 "injection_flag",
524 "injected_isPatchInner",
525 "injected_isTractInner",
526 "injected_isPrimary",
527 ]
529 return output_catalog[new_order]
532class ConsolidateInjectedCatalogsTask(PipelineTask):
533 """Class for combining all tables in a collection of input catalogs
534 into one table.
535 """
537 _DefaultName = "consolidateInjectedCatalogsTask"
538 ConfigClass = ConsolidateInjectedCatalogsConfig
540 def runQuantum(self, butlerQC, input_refs, output_refs):
541 inputs = butlerQC.get(input_refs)
542 catalog_dict, tract = _get_catalogs(inputs, input_refs)
543 outputs = self.run(
544 catalog_dict,
545 inputs["skyMap"],
546 tract,
547 self.config.pixel_match_radius,
548 self.config.get_catalogs_from_butler,
549 self.config.col_ra,
550 self.config.col_dec,
551 self.config.col_mag,
552 self.config.isPatchInnerKey,
553 self.config.isTractInnerKey,
554 self.config.isPrimaryKey,
555 self.config.injectionKey,
556 )
557 butlerQC.put(outputs, output_refs)
559 def run(
560 self,
561 catalog_dict: dict,
562 skymap: BaseSkyMap,
563 tract: int,
564 pixel_match_radius: float = 0.1,
565 get_catalogs_from_butler: bool = True,
566 col_ra: str = "ra",
567 col_dec: str = "dec",
568 col_mag: str = "mag",
569 isPatchInnerKey="injected_isPatchInner",
570 isTractInnerKey="injected_isTractInner",
571 isPrimaryKey="injected_isPrimary",
572 injectionKey="injection_flag",
573 ) -> Table:
574 """Consolidate all tables in catalog_dict into one table.
576 catalog_dict: `dict`
577 A dictionary with photometric bands for keys and astropy tables for
578 items.
579 skymap: `lsst.skymap.BaseSkyMap`
580 A base skymap.
581 tract: `int`
582 The tract where sources have been injected.
583 pixel_match_radius: `float`
584 Match radius in pixels to use for self-matching catalogs across
585 different bands.
586 col_ra: `str`
587 Column name for right ascension (in degrees).
588 col_dec: `str`
589 Column name for declination (in degrees).
590 col_mag: `str`
591 Column name for magnitude.
592 isPatchInnerKey: `str`
593 Column name for the isPatchInner flag.
594 isTractInnerKey: `str`
595 Column name for the isTractInner flag.
596 isPrimaryKey: `str`
597 Column name for the isPrimary flag.
598 injectionKey: `str`
599 Column name for the injection flag.
601 Returns
602 -------
603 output_struct : `lsst.pipe.base.Struct`
604 contains :
605 multiband_catalog: `astropy.table.Table`
606 A single table containing all information of the separate
607 tables in catalog_dict
608 """
609 output_catalog = consolidate_injected_deepCoadd_catalogs(
610 catalog_dict,
611 skymap,
612 tract,
613 pixel_match_radius,
614 get_catalogs_from_butler,
615 col_ra,
616 col_dec,
617 col_mag,
618 isPatchInnerKey,
619 isTractInnerKey,
620 isPrimaryKey,
621 injectionKey,
622 )
623 output_struct = Struct(output_catalog=output_catalog)
624 return output_struct