Coverage for python/lsst/source/injection/inject_engine.py: 7%
226 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 05:05 -0700
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 05:05 -0700
1# This file is part of source_injection.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["generate_galsim_objects", "inject_galsim_objects_into_exposure"]
26import os
27from collections import Counter
28from collections.abc import Generator
29from typing import Any
31import galsim
32import numpy as np
33import numpy.ma as ma
34from astropy.io import fits
35from astropy.table import Table
36from galsim import GalSimFFTSizeError
37from lsst.afw.geom import SkyWcs
38from lsst.afw.image import ExposureF, PhotoCalib
39from lsst.geom import Box2I, Point2D, Point2I, SpherePoint, arcseconds, degrees
40from lsst.pex.exceptions import InvalidParameterError, LogicError
43def get_object_data(source_data: dict[str, Any], object_class: galsim.GSObject) -> dict[str, Any]:
44 """Assemble a dictionary of allowed keyword arguments and their
45 corresponding values to use when constructing a GSObject.
47 Parameters
48 ----------
49 source_data : `dict` [`str`, `Any`]
50 Dictionary of source data.
51 object_class : `galsim.gsobject.GSObject`
52 Class of GSObject to match against.
54 Returns
55 -------
56 object_data : `dict` [`str`, `Any`]
57 Dictionary of source data to pass to the GSObject constructor.
58 """
59 req = getattr(object_class, "_req_params", {})
60 opt = getattr(object_class, "_opt_params", {})
61 single = getattr(object_class, "_single_params", {})
63 object_data = {}
65 # Check required args.
66 for key in req:
67 if key not in source_data:
68 raise ValueError(f"Required parameter {key} not found in input catalog.")
69 object_data[key] = source_data[key]
71 # Optional args.
72 for key in opt:
73 if key in source_data:
74 object_data[key] = source_data[key]
76 # Single args.
77 for s in single:
78 count = 0
79 for key in s:
80 if key in source_data:
81 count += 1
82 if count > 1:
83 raise ValueError(f"Only one of {s.keys()} allowed for type {object_class}.")
84 object_data[key] = source_data[key]
85 if count == 0:
86 raise ValueError(f"One of the args {s.keys()} is required for type {object_class}.")
88 return object_data
91def get_shear_data(
92 source_data: dict[str, Any],
93 shear_attributes: list[str] = [
94 "g1",
95 "g2",
96 "g",
97 "e1",
98 "e2",
99 "e",
100 "eta1",
101 "eta2",
102 "eta",
103 "q",
104 "beta",
105 "shear",
106 ],
107) -> dict[str, Any]:
108 """Assemble a dictionary of allowed keyword arguments and their
109 corresponding values to use when constructing a Shear.
111 Parameters
112 ----------
113 source_data : `dict` [`str`, `Any`]
114 Dictionary of source data.
115 shear_attributes : `list` [`str`], optional
116 List of allowed shear attributes.
118 Returns
119 -------
120 shear_data : `dict` [`str`, `Any`]
121 Dictionary of source data to pass to the Shear constructor.
122 """
123 shear_params = set(shear_attributes) & set(source_data.keys())
124 shear_data = {}
125 for shear_param in shear_params:
126 if shear_param == "beta":
127 shear_data.update({shear_param: source_data[shear_param] * galsim.degrees})
128 else:
129 shear_data.update({shear_param: source_data[shear_param]})
130 return shear_data
133def generate_galsim_objects(
134 injection_catalog: Table,
135 photo_calib: PhotoCalib,
136 wcs: SkyWcs,
137 fits_alignment: str,
138 stamp_prefix: str = "",
139 logger: Any | None = None,
140) -> Generator[tuple[SpherePoint, Point2D, int, galsim.gsobject.GSObject], None, None]:
141 """Generate GalSim objects from an injection catalog.
143 Parameters
144 ----------
145 injection_catalog : `astropy.table.Table`
146 Table of sources to be injected.
147 photo_calib : `lsst.afw.image.PhotoCalib`
148 Photometric calibration used to calibrate injected sources.
149 wcs : `lsst.afw.geom.SkyWcs`
150 WCS used to calibrate injected sources.
151 fits_alignment : `str`
152 Alignment of the FITS image to the WCS. Allowed values: "wcs", "pixel".
153 stamp_prefix : `str`
154 Prefix to add to the stamp name.
155 logger : `lsst.utils.logging.LsstLogAdapter`, optional
156 Logger to use for logging messages.
158 Yields
159 ------
160 sky_coords : `lsst.geom.SpherePoint`
161 RA/Dec coordinates of the source.
162 pixel_coords : `lsst.geom.Point2D`
163 Pixel coordinates of the source.
164 draw_size : `int`
165 Size of the stamp to draw.
166 object : `galsim.gsobject.GSObject`
167 A fully specified and transformed GalSim object.
168 """
169 if logger:
170 source_types = Counter(injection_catalog["source_type"]) # type: ignore
171 grammar0 = "source" if len(injection_catalog) == 1 else "sources"
172 grammar1 = "type" if len(source_types) == 1 else "types"
173 logger.info(
174 "Generating %d injection %s consisting of %d unique %s: %s.",
175 len(injection_catalog),
176 grammar0,
177 len(source_types),
178 grammar1,
179 ", ".join(f"{k}({v})" for k, v in source_types.items()),
180 )
181 for source_data_full in injection_catalog:
182 items = dict(source_data_full).items() # type: ignore
183 source_data = {k: v for (k, v) in items if v is not ma.masked}
184 try:
185 sky_coords = SpherePoint(float(source_data["ra"]), float(source_data["dec"]), degrees)
186 except KeyError:
187 sky_coords = wcs.pixelToSky(float(source_data["x"]), float(source_data["y"]))
188 try:
189 pixel_coords = Point2D(source_data["x"], source_data["y"])
190 except KeyError:
191 pixel_coords = wcs.skyToPixel(sky_coords)
192 try:
193 inst_flux = photo_calib.magnitudeToInstFlux(source_data["mag"], pixel_coords)
194 except LogicError:
195 continue
196 try:
197 draw_size = int(source_data["draw_size"])
198 except KeyError:
199 draw_size = 0
201 if source_data["source_type"] == "Stamp":
202 stamp_file = stamp_prefix + source_data["stamp"]
203 object = make_galsim_stamp(stamp_file, fits_alignment, wcs, sky_coords, inst_flux)
204 elif source_data["source_type"] == "Trail":
205 object = make_galsim_trail(source_data, wcs, sky_coords, inst_flux)
206 else:
207 object = make_galsim_object(source_data, source_data["source_type"], inst_flux, logger)
209 yield sky_coords, pixel_coords, draw_size, object
212def make_galsim_object(
213 source_data: dict[str, Any],
214 source_type: str,
215 inst_flux: float,
216 logger: Any | None = None,
217) -> galsim.gsobject.GSObject:
218 """Make a generic GalSim object from a collection of source data.
220 Parameters
221 ----------
222 source_data : `dict` [`str`, `Any`]
223 Dictionary of source data.
224 source_type : `str`
225 Type of the source, corresponding to a GalSim class.
226 inst_flux : `float`
227 Instrumental flux of the source.
228 logger : `~lsst.utils.logging.LsstLogAdapter`, optional
230 Returns
231 -------
232 object : `galsim.gsobject.GSObject`
233 A fully specified and transformed GalSim object.
234 """
235 # Populate the non-shear and non-flux parameters.
236 object_class = getattr(galsim, source_type)
237 object_data = get_object_data(source_data, object_class)
238 object = object_class(**object_data)
239 # Create a version of the object with an area-preserving shear applied.
240 shear_data = get_shear_data(source_data)
241 if shear_data:
242 try:
243 object = object.shear(**shear_data)
244 except TypeError as err:
245 if logger:
246 logger.warning("Cannot apply shear to GalSim object: %s", err)
247 pass
248 # Apply the instrumental flux and return.
249 object = object.withFlux(inst_flux)
250 return object
253def make_galsim_trail(
254 source_data: dict[str, Any],
255 wcs: SkyWcs,
256 sky_coords: SpherePoint,
257 inst_flux: float,
258 trail_thickness: float = 1e-6,
259) -> galsim.gsobject.GSObject:
260 """Make a trail with GalSim from a collection of source data.
262 Parameters
263 ----------
264 source_data : `dict` [`str`, `Any`]
265 Dictionary of source data.
266 wcs : `lsst.afw.geom.SkyWcs`
267 World coordinate system.
268 sky_coords : `lsst.geom.SpherePoint`
269 Sky coordinates of the source.
270 inst_flux : `float`
271 Instrumental flux of the source.
272 trail_thickness : `float`
273 Thickness of the trail in pixels.
275 Returns
276 -------
277 object : `galsim.gsobject.GSObject`
278 A fully specified and transformed GalSim object.
279 """
280 # Make a 'thin' box to mimic a line surface brightness profile of default
281 # thickness = 1e-6 (i.e., much thinner than a pixel)
282 object = galsim.Box(source_data["trail_length"], trail_thickness)
283 try:
284 object = object.rotate(source_data["beta"] * galsim.degrees)
285 except KeyError:
286 pass
287 object = object.withFlux(inst_flux * source_data["trail_length"]) # type: ignore
288 # GalSim objects are assumed to be in sky-coords. As we want the trail to
289 # appear as defined above in image-coords, we must transform it here.
290 linear_wcs = wcs.linearizePixelToSky(sky_coords, arcseconds)
291 mat = linear_wcs.getMatrix()
292 object = object.transform(mat[0, 0], mat[0, 1], mat[1, 0], mat[1, 1])
293 return object # type: ignore
296def make_galsim_stamp(
297 stamp_file: str,
298 fits_alignment: str,
299 wcs: SkyWcs,
300 sky_coords: SpherePoint,
301 inst_flux: float,
302) -> galsim.gsobject.GSObject:
303 """Make a postage stamp with GalSim from a FITS file.
305 Parameters
306 ----------
307 stamp_file : `str`
308 Path to the FITS file containing the postage stamp.
309 fits_alignment : `str`
310 Alignment of the FITS image to the WCS. Allowed values: "wcs", "pixel".
311 wcs : `lsst.afw.geom.SkyWcs`
312 World coordinate system.
313 sky_coords : `lsst.geom.SpherePoint`
314 Sky coordinates of the source.
315 inst_flux : `float`
316 Instrumental flux of the source.
318 Returns
319 -------
320 object: `galsim.gsobject.GSObject`
321 A fully specified and transformed GalSim object.
322 """
323 stamp_file = stamp_file.strip()
324 if os.path.exists(stamp_file):
325 with fits.open(stamp_file) as hdul:
326 hdu_images = [hdu.is_image and hdu.size > 0 for hdu in hdul] # type: ignore
327 if any(hdu_images):
328 stamp_data = galsim.fits.read(stamp_file, read_header=True, hdu=np.where(hdu_images)[0][0])
329 else:
330 raise RuntimeError(f"Cannot find image in input FITS file {stamp_file}.")
331 match fits_alignment:
332 case "wcs":
333 # galsim.fits.read will always attach a WCS to its output.
334 # If it can't find a WCS in the FITS header, then it
335 # defaults to scale = 1.0 arcsec / pix. If that's the scale
336 # then we need to check if it was explicitly set or if it's
337 # just the default. If it's just the default then we should
338 # raise an exception.
339 if is_wcs_galsim_default(stamp_data.wcs, stamp_data.header): # type: ignore
340 raise RuntimeError(f"Cannot find WCS in input FITS file {stamp_file}")
341 case "pixel":
342 # We need to set stamp_data.wcs to the local target WCS.
343 linear_wcs = wcs.linearizePixelToSky(sky_coords, arcseconds)
344 mat = linear_wcs.getMatrix()
345 stamp_data.wcs = galsim.JacobianWCS( # type: ignore
346 mat[0, 0], mat[0, 1], mat[1, 0], mat[1, 1]
347 )
348 object = galsim.InterpolatedImage(stamp_data, calculate_stepk=False)
349 object = object.withFlux(inst_flux)
350 return object # type: ignore
351 else:
352 raise RuntimeError(f"Cannot locate input FITS postage stamp {stamp_file}.")
355def is_wcs_galsim_default(
356 wcs: galsim.fitswcs.GSFitsWCS,
357 hdr: galsim.fits.FitsHeader,
358) -> bool:
359 """Decide if wcs = galsim.PixelScale(1.0) is explicitly present in header,
360 or if it's just the GalSim default.
362 Parameters
363 ----------
364 wcs : galsim.fitswcs.GSFitsWCS
365 Potentially default WCS.
366 hdr : galsim.fits.FitsHeader
367 Header as read in by GalSim.
369 Returns
370 -------
371 is_default : bool
372 True if default, False if explicitly set in header.
373 """
374 if wcs != galsim.PixelScale(1.0):
375 return False
376 if hdr.get("GS_WCS") is not None:
377 return False
378 if hdr.get("CTYPE1", "LINEAR") == "LINEAR":
379 return not any(k in hdr for k in ["CD1_1", "CDELT1"])
380 for wcs_type in galsim.fitswcs.fits_wcs_types:
381 # If one of these succeeds, then assume result is explicit.
382 try:
383 wcs_type._readHeader(hdr)
384 return False
385 except Exception:
386 pass
387 else:
388 return not any(k in hdr for k in ["CD1_1", "CDELT1"])
391def inject_galsim_objects_into_exposure(
392 exposure: ExposureF,
393 objects: Generator[tuple[SpherePoint, Point2D, int, galsim.gsobject.GSObject], None, None],
394 mask_plane_name: str = "INJECTED",
395 calib_flux_radius: float = 12.0,
396 draw_size_max: int = 1000,
397 logger: Any | None = None,
398) -> tuple[list[int], list[galsim.BoundsI], list[bool], list[bool]]:
399 """Inject sources into given exposure using GalSim.
401 Parameters
402 ----------
403 exposure : `lsst.afw.image.ExposureF`
404 The exposure to inject synthetic sources into.
405 objects : `Generator` [`tuple`, None, None]
406 An iterator of tuples that contains (or generates) locations and object
407 surface brightness profiles to inject. The tuples should contain the
408 following elements: `lsst.geom.SpherePoint`, `lsst.geom.Point2D`,
409 `int`, `galsim.gsobject.GSObject`.
410 mask_plane_name : `str`
411 Name of the mask plane to use for the injected sources.
412 calib_flux_radius : `float`
413 Radius in pixels to use for the flux calibration. This is used to
414 produce the correct instrumental fluxes within the radius. The value
415 should match that of the field defined in slot_CalibFlux_instFlux.
416 draw_size_max : `int`
417 Maximum allowed size of the drawn object. If the object is larger than
418 this, the draw size will be clipped to this size.
419 logger : `lsst.utils.logging.LsstLogAdapter`, optional
420 Logger to use for logging messages.
422 Returns
423 -------
424 draw_sizes : `list` [`int`]
425 Draw sizes of the injected sources.
426 common_bounds : `list` [`galsim.BoundsI`]
427 Common bounds of the drawn objects.
428 fft_size_errors : `list` [`bool`]
429 Boolean flags indicating whether a GalSimFFTSizeError was raised.
430 psf_compute_errors : `list` [`bool`]
431 Boolean flags indicating whether a PSF computation error was raised.
432 """
433 exposure.mask.addMaskPlane(mask_plane_name)
434 mask_plane_core_name = mask_plane_name + "_CORE"
435 exposure.mask.addMaskPlane(mask_plane_core_name)
436 if logger:
437 logger.info(
438 "Adding %s and %s mask planes to the exposure.",
439 mask_plane_name,
440 mask_plane_core_name,
441 )
442 psf = exposure.getPsf()
443 wcs = exposure.getWcs()
444 bbox = exposure.getBBox()
445 full_bounds = galsim.BoundsI(bbox.minX, bbox.maxX, bbox.minY, bbox.maxY)
446 galsim_image = galsim.Image(exposure.image.array, bounds=full_bounds)
447 pixel_scale = wcs.getPixelScale(bbox.getCenter()).asArcseconds()
449 draw_sizes: list[int] = []
450 common_bounds: list[galsim.BoundsI] = []
451 fft_size_errors: list[bool] = []
452 psf_compute_errors: list[bool] = []
453 for i, (sky_coords, pixel_coords, draw_size, object) in enumerate(objects):
454 # Instantiate default returns in case of early exit from this loop.
455 draw_sizes.append(0)
456 common_bounds.append(galsim.BoundsI())
457 fft_size_errors.append(False)
458 psf_compute_errors.append(False)
460 # Get spatial coordinates and WCS.
461 posd = galsim.PositionD(pixel_coords.x, pixel_coords.y)
462 posi = galsim.PositionI(pixel_coords.x // 1, pixel_coords.y // 1)
463 if logger:
464 logger.debug(f"Injecting synthetic source at {pixel_coords}.")
465 mat = wcs.linearizePixelToSky(sky_coords, arcseconds).getMatrix()
466 galsim_wcs = galsim.JacobianWCS(mat[0, 0], mat[0, 1], mat[1, 0], mat[1, 1])
468 # This check is here because sometimes the WCS is multivalued and
469 # objects that should not be included were being included.
470 galsim_pixel_scale = np.sqrt(galsim_wcs.pixelArea())
471 if galsim_pixel_scale < pixel_scale / 2 or galsim_pixel_scale > pixel_scale * 2:
472 continue
474 # Compute the PSF at the object location.
475 try:
476 psf_array = psf.computeKernelImage(pixel_coords).array
477 except InvalidParameterError:
478 # Try mapping to nearest point contained in bbox.
479 contained_point = Point2D(
480 np.clip(pixel_coords.x, bbox.minX, bbox.maxX), np.clip(pixel_coords.y, bbox.minY, bbox.maxY)
481 )
482 if pixel_coords == contained_point: # no difference, so skip immediately
483 psf_compute_errors[i] = True
484 if logger:
485 logger.debug("Cannot compute PSF for object at %s; flagging and skipping.", sky_coords)
486 continue
487 # Otherwise, try again with new point.
488 try:
489 psf_array = psf.computeKernelImage(contained_point).array
490 except InvalidParameterError:
491 psf_compute_errors[i] = True
492 if logger:
493 logger.debug("Cannot compute PSF for object at %s; flagging and skipping.", sky_coords)
494 continue
496 # Compute the aperture corrected PSF interpolated image.
497 aperture_correction = psf.computeApertureFlux(calib_flux_radius, psf.getAveragePosition())
498 psf_array /= aperture_correction
499 galsim_psf = galsim.InterpolatedImage(galsim.Image(psf_array), wcs=galsim_wcs)
501 # Convolve the object with the PSF and generate draw size.
502 conv = galsim.Convolve(object, galsim_psf)
503 if draw_size == 0:
504 draw_size = conv.getGoodImageSize(galsim_wcs.minLinearScale()) # type: ignore
505 injection_draw_size = int(draw_size)
506 injection_core_size = 3
507 if draw_size_max > 0 and injection_draw_size > draw_size_max:
508 if logger:
509 logger.warning(
510 "Clipping draw size for object at %s from %d to %d pixels.",
511 sky_coords,
512 injection_draw_size,
513 draw_size_max,
514 )
515 injection_draw_size = draw_size_max
516 draw_sizes[i] = injection_draw_size
517 if injection_core_size > injection_draw_size:
518 if logger:
519 logger.debug(
520 "Clipping core size for object at %s from %d to %d pixels.",
521 sky_coords,
522 injection_core_size,
523 injection_draw_size,
524 )
525 injection_core_size = injection_draw_size
526 sub_bounds = galsim.BoundsI(posi).withBorder(injection_draw_size // 2)
527 object_common_bounds = full_bounds & sub_bounds
528 common_bounds[i] = object_common_bounds # type: ignore
530 # Inject the source if there is any overlap.
531 if object_common_bounds.area() > 0:
532 common_image = galsim_image[object_common_bounds]
533 offset = posd - object_common_bounds.true_center
534 # Note, for calexp injection, pixel is already part of the PSF and
535 # for coadd injection, it's incorrect to include the output pixel.
536 # So for both cases, we draw using method='no_pixel'.
537 try:
538 conv.drawImage(
539 common_image, add_to_image=True, offset=offset, wcs=galsim_wcs, method="no_pixel"
540 )
541 except GalSimFFTSizeError as err:
542 fft_size_errors[i] = True
543 if logger:
544 logger.debug(
545 "GalSimFFTSizeError raised for object at %s; flagging and skipping.\n%s",
546 sky_coords,
547 err,
548 )
549 continue
550 common_box = Box2I(
551 Point2I(object_common_bounds.xmin, object_common_bounds.ymin),
552 Point2I(object_common_bounds.xmax, object_common_bounds.ymax),
553 )
554 bitvalue = exposure.mask.getPlaneBitMask(mask_plane_name)
555 exposure[common_box].mask.array |= bitvalue
556 # Add a 3 x 3 pixel mask centered on the object. The mask must be
557 # large enough to always identify the core/peak of the injected
558 # source, but small enough that it rarely overlaps real sources.
559 sub_bounds_core = galsim.BoundsI(posi).withBorder(injection_core_size // 2)
560 object_common_bounds_core = full_bounds & sub_bounds_core
561 if object_common_bounds_core.area() > 0:
562 common_box_core = Box2I(
563 Point2I(object_common_bounds_core.xmin, object_common_bounds_core.ymin),
564 Point2I(object_common_bounds_core.xmax, object_common_bounds_core.ymax),
565 )
566 bitvalue_core = exposure.mask.getPlaneBitMask(mask_plane_core_name)
567 exposure[common_box_core].mask.array |= bitvalue_core
568 else:
569 if logger:
570 logger.debug("No area overlap for object at %s; flagging and skipping.", sky_coords)
572 return draw_sizes, common_bounds, fft_size_errors, psf_compute_errors